diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..ed389d0 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,148 @@ +cmake_minimum_required(VERSION 3.6) + +if(ANDROID) + include_directories(${OPENSSL_INCLUDE_DIR}) + link_directories(${OPENSSL_LINK_DIR}) +else() + find_package(OpenSSL REQUIRED) +endif () + +include_directories(${OPENSSL_INCLUDE_DIR} ${INC_DIR}/workflow) + +if (KAFKA STREQUAL "y") + find_path(SNAPPY_INCLUDE_PATH NAMES snappy.h) + find_library(SNAPPY_LIB NAMES snappy) + if ((NOT SNAPPY_INCLUDE_PATH) OR (NOT SNAPPY_LIB)) + message(FATAL_ERROR "Fail to find snappy with KAFKA=y") + endif () + include_directories(${SNAPPY_INCLUDE_PATH}) +endif () + +add_subdirectory(kernel) +add_subdirectory(util) +add_subdirectory(manager) +add_subdirectory(protocol) +add_subdirectory(factory) +add_subdirectory(nameservice) +add_subdirectory(server) +add_subdirectory(client) + +add_dependencies(kernel LINK_HEADERS) +add_dependencies(util LINK_HEADERS) +add_dependencies(manager LINK_HEADERS) +add_dependencies(protocol LINK_HEADERS) +add_dependencies(factory LINK_HEADERS) +add_dependencies(nameservice LINK_HEADERS) +add_dependencies(server LINK_HEADERS) +add_dependencies(client LINK_HEADERS) + +set(STATIC_LIB_NAME ${PROJECT_NAME}-static) +set(SHARED_LIB_NAME ${PROJECT_NAME}-shared) + +add_library( + ${STATIC_LIB_NAME} STATIC + $ + $ + $ + $ + $ + $ + $ + $ +) + +add_library( + ${SHARED_LIB_NAME} SHARED + $ + $ + $ + $ + $ + $ + $ + $ +) + +if(ANDROID) + target_link_libraries(${SHARED_LIB_NAME} ssl crypto c) + target_link_libraries(${STATIC_LIB_NAME} ssl crypto c) +else() + target_link_libraries(${SHARED_LIB_NAME} OpenSSL::SSL OpenSSL::Crypto pthread) + target_link_libraries(${STATIC_LIB_NAME} OpenSSL::SSL OpenSSL::Crypto pthread) +endif () + +set_target_properties(${STATIC_LIB_NAME} PROPERTIES OUTPUT_NAME ${PROJECT_NAME}) +set_target_properties(${SHARED_LIB_NAME} PROPERTIES OUTPUT_NAME ${PROJECT_NAME} VERSION ${PROJECT_VERSION} SOVERSION ${PROJECT_VERSION_MAJOR}) + +if (KAFKA STREQUAL "y") + add_dependencies(client_kafka LINK_HEADERS) + add_dependencies(util_kafka LINK_HEADERS) + add_dependencies(protocol_kafka LINK_HEADERS) + add_dependencies(factory_kafka LINK_HEADERS) + + set(KAFKA_STATIC_LIB_NAME "wfkafka-static") + add_library( + ${KAFKA_STATIC_LIB_NAME} STATIC + $ + $ + $ + $ + ) + set_target_properties(${KAFKA_STATIC_LIB_NAME} PROPERTIES OUTPUT_NAME "wfkafka") + + set(KAFKA_SHARED_LIB_NAME "wfkafka-shared") + add_library( + ${KAFKA_SHARED_LIB_NAME} SHARED + $ + $ + $ + $ + ) + if (APPLE) + target_link_libraries(${KAFKA_SHARED_LIB_NAME} + ${SHARED_LIB_NAME} z lz4 zstd ${SNAPPY_LIB}) + else () + target_link_libraries(${KAFKA_SHARED_LIB_NAME} ${SHARED_LIB_NAME}) + endif () + set_target_properties(${KAFKA_SHARED_LIB_NAME} PROPERTIES OUTPUT_NAME "wfkafka" VERSION ${PROJECT_VERSION} SOVERSION ${PROJECT_VERSION_MAJOR}) +endif () + +install( + TARGETS ${STATIC_LIB_NAME} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + COMPONENT devel +) + +install( + TARGETS ${SHARED_LIB_NAME} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + COMPONENT devel +) + +if (KAFKA STREQUAL "y") + install( + TARGETS ${KAFKA_STATIC_LIB_NAME} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + COMPONENT devel + ) + install( + TARGETS ${KAFKA_SHARED_LIB_NAME} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + COMPONENT devel + ) +endif () + +target_include_directories(${STATIC_LIB_NAME} BEFORE PUBLIC + "$" + "$") +target_include_directories(${SHARED_LIB_NAME} BEFORE PUBLIC + "$" + "$") +if (KAFKA STREQUAL "y") + target_include_directories(${KAFKA_STATIC_LIB_NAME} BEFORE PUBLIC + "$" + "$") + target_include_directories(${KAFKA_SHARED_LIB_NAME} BEFORE PUBLIC + "$" + "$") +endif () diff --git a/src/client/CMakeLists.txt b/src/client/CMakeLists.txt new file mode 100644 index 0000000..5f074f1 --- /dev/null +++ b/src/client/CMakeLists.txt @@ -0,0 +1,37 @@ +cmake_minimum_required(VERSION 3.6) +project(client) + +set(SRC + WFDnsClient.cc +) + +if (NOT REDIS STREQUAL "n") + set(SRC + ${SRC} + WFRedisSubscriber.cc + ) +endif () + +if (NOT MYSQL STREQUAL "n") + set(SRC + ${SRC} + WFMySQLConnection.cc + ) +endif () + +if (NOT CONSUL STREQUAL "n") + set(SRC + ${SRC} + WFConsulClient.cc + ) +endif () + +add_library(${PROJECT_NAME} OBJECT ${SRC}) + +if (KAFKA STREQUAL "y") + set(SRC + WFKafkaClient.cc + ) + add_library("client_kafka" OBJECT ${SRC}) +endif () + diff --git a/src/client/WFConsulClient.cc b/src/client/WFConsulClient.cc new file mode 100644 index 0000000..7bcf946 --- /dev/null +++ b/src/client/WFConsulClient.cc @@ -0,0 +1,1252 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhenpeng (wangzhenpeng@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include "json_parser.h" +#include "StringUtil.h" +#include "URIParser.h" +#include "HttpUtil.h" +#include "WFConsulClient.h" + +using namespace protocol; + +WFConsulTask::WFConsulTask(const std::string& proxy_url, + const std::string& service_namespace, + const std::string& service_name, + const std::string& service_id, + int retry_max, + consul_callback_t&& cb) : + proxy_url(proxy_url), + callback(std::move(cb)) +{ + this->service.service_name = service_name; + this->service.service_namespace = service_namespace; + this->service.service_id = service_id; + this->api_type = CONSUL_API_TYPE_UNKNOWN; + this->retry_max = retry_max; + this->finish = false; + this->consul_index = 0; +} + +void WFConsulTask::set_service(const struct protocol::ConsulService *service) +{ + this->service.tags = service->tags; + this->service.meta = service->meta; + this->service.tag_override = service->tag_override; + this->service.service_address = service->service_address; + this->service.lan = service->lan; + this->service.lan_ipv4 = service->lan_ipv4; + this->service.lan_ipv6 = service->lan_ipv6; + this->service.virtual_address = service->virtual_address; + this->service.wan = service->wan; + this->service.wan_ipv4 = service->wan_ipv4; + this->service.wan_ipv6 = service->wan_ipv6; +} + +static bool parse_discover_result(const json_value_t *root, + std::vector& result); +static bool parse_list_service_result(const json_value_t *root, + std::vector& result); + +bool WFConsulTask::get_discover_result( + std::vector& result) +{ + json_value_t *root; + int errno_bak; + bool ret; + + if (this->api_type != CONSUL_API_TYPE_DISCOVER) + { + errno = EPERM; + return false; + } + + errno_bak = errno; + errno = EBADMSG; + std::string body = HttpUtil::decode_chunked_body(&this->http_resp); + root = json_value_parse(body.c_str()); + if (!root) + return false; + + ret = parse_discover_result(root, result); + json_value_destroy(root); + if (ret) + errno = errno_bak; + + return ret; +} + +bool WFConsulTask::get_list_service_result( + std::vector& result) +{ + json_value_t *root; + int errno_bak; + bool ret; + + if (this->api_type != CONSUL_API_TYPE_LIST_SERVICE) + { + errno = EPERM; + return false; + } + + errno_bak = errno; + errno = EBADMSG; + std::string body = HttpUtil::decode_chunked_body(&this->http_resp); + root = json_value_parse(body.c_str()); + if (!root) + return false; + + ret = parse_list_service_result(root, result); + json_value_destroy(root); + if (ret) + errno = errno_bak; + + return ret; +} + +void WFConsulTask::dispatch() +{ + WFHttpTask *task; + + if (this->finish) + { + this->subtask_done(); + return; + } + + switch(this->api_type) + { + case CONSUL_API_TYPE_DISCOVER: + task = create_discover_task(); + break; + + case CONSUL_API_TYPE_LIST_SERVICE: + task = create_list_service_task(); + break; + + case CONSUL_API_TYPE_DEREGISTER: + task = create_deregister_task(); + break; + + case CONSUL_API_TYPE_REGISTER: + task = create_register_task(); + if (task) + break; + + if (1) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + } + else + { + default: + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_CONSUL_API_UNKNOWN; + } + + this->finish = true; + this->subtask_done(); + return; + } + + series_of(this)->push_front(this); + series_of(this)->push_front(task); + this->subtask_done(); +} + +SubTask *WFConsulTask::done() +{ + SeriesWork *series = series_of(this); + + if (finish) + { + if (this->callback) + this->callback(this); + + delete this; + } + + return series->pop(); +} + +static std::string convert_time_to_str(int milliseconds) +{ + std::string str_time; + int seconds = milliseconds / 1000; + + if (seconds >= 180) + str_time = std::to_string(seconds / 60) + "m"; + else + str_time = std::to_string(seconds) + "s"; + + return str_time; +} + +std::string WFConsulTask::generate_discover_request() +{ + std::string url = this->proxy_url; + + url += "/v1/health/service/" + this->service.service_name; + url += "?dc=" + this->config.get_datacenter(); + url += "&ns=" + this->service.service_namespace; + std::string passing = this->config.get_passing() ? "true" : "false"; + url += "&passing=" + passing; + url += "&token=" + this->config.get_token(); + url += "&filter=" + this->config.get_filter_expr(); + + //consul blocking query + if (this->config.blocking_query()) + { + url += "&index=" + std::to_string(this->get_consul_index()); + url += "&wait=" + convert_time_to_str(this->config.get_wait_ttl()); + } + + return url; +} + +WFHttpTask *WFConsulTask::create_discover_task() +{ + std::string url = generate_discover_request(); + WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, this->retry_max, + discover_callback); + HttpRequest *req = task->get_req(); + + req->add_header_pair("Content-Type", "application/json"); + task->user_data = this; + return task; +} + +WFHttpTask *WFConsulTask::create_list_service_task() +{ + std::string url = this->proxy_url; + + url += "/v1/catalog/services?token=" + this->config.get_token(); + url += "&dc=" + this->config.get_datacenter(); + url += "&ns=" + this->service.service_namespace; + + WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, this->retry_max, + list_service_callback); + HttpRequest *req = task->get_req(); + + req->add_header_pair("Content-Type", "application/json"); + task->user_data = this; + return task; +} + +static void print_json_value(const json_value_t *val, int depth, + std::string& json_str); + +static bool create_register_request(const json_value_t *root, + const struct ConsulService *service, + const ConsulConfig& config); + +WFHttpTask *WFConsulTask::create_register_task() +{ + std::string payload; + + std::string url = this->proxy_url; + url += "/v1/agent/service/register?replace-existing-checks="; + url += this->config.get_replace_checks() ? "true" : "false"; + + WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, this->retry_max, + register_callback); + HttpRequest *req = task->get_req(); + + req->set_method(HttpMethodPut); + req->add_header_pair("Content-Type", "application/json"); + + if (!this->config.get_token().empty()) + req->add_header_pair("X-Consul-Token", this->config.get_token()); + + json_value_t *root = json_value_create(JSON_VALUE_OBJECT); + if (root) + { + if (create_register_request(root, &this->service, this->config)) + print_json_value(root, 0, payload); + + json_value_destroy(root); + if (!payload.empty() && req->append_output_body(payload)) + { + task->user_data = this; + return task; + } + } + + task->dismiss(); + return NULL; +} + +WFHttpTask *WFConsulTask::create_deregister_task() +{ + std::string url = this->proxy_url; + + url += "/v1/agent/service/deregister/" + this->service.service_id; + url += "?ns=" + this->service.service_namespace; + + WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, this->retry_max, + register_callback); + HttpRequest *req = task->get_req(); + + req->set_method(HttpMethodPut); + req->add_header_pair("Content-Type", "application/json"); + + std::string token = this->config.get_token(); + if (!token.empty()) + req->add_header_pair("X-Consul-Token", token); + + task->user_data = this; + return task; +} + +bool WFConsulTask::check_task_result(WFHttpTask *task, WFConsulTask *consul_task) +{ + if (task->get_state() != WFT_STATE_SUCCESS) + { + consul_task->state = task->get_state(); + consul_task->error = task->get_error(); + return false; + } + + protocol::HttpResponse *resp = task->get_resp(); + if (strcmp(resp->get_status_code(), "200") != 0) + { + consul_task->state = WFT_STATE_TASK_ERROR; + consul_task->error = WFT_ERR_CONSUL_CHECK_RESPONSE_FAILED; + return false; + } + + return true; +} + +long long WFConsulTask::get_consul_index(HttpResponse *resp) +{ + long long consul_index = 0; + + // get consul-index from http header + protocol::HttpHeaderCursor cursor(resp); + std::string consul_index_str; + if (cursor.find("X-Consul-Index", consul_index_str)) + { + consul_index = strtoll(consul_index_str.c_str(), NULL, 10); + if (consul_index < 0) + consul_index = 0; + } + + return consul_index; +} + +void WFConsulTask::discover_callback(WFHttpTask *task) +{ + WFConsulTask *t = (WFConsulTask*)task->user_data; + + if (WFConsulTask::check_task_result(task, t)) + { + protocol::HttpResponse *resp = task->get_resp(); + long long consul_index = t->get_consul_index(resp); + long long last_consul_index = t->get_consul_index(); + t->set_consul_index(consul_index < last_consul_index ? 0 : consul_index); + t->state = task->get_state(); + } + + t->http_resp = std::move(*task->get_resp()); + t->finish = true; +} + +void WFConsulTask::list_service_callback(WFHttpTask *task) +{ + WFConsulTask *t = (WFConsulTask*)task->user_data; + + if (WFConsulTask::check_task_result(task, t)) + t->state = task->get_state(); + + t->http_resp = std::move(*task->get_resp()); + t->finish = true; +} + +void WFConsulTask::register_callback(WFHttpTask *task) +{ + WFConsulTask *t = (WFConsulTask *)task->user_data; + + if (WFConsulTask::check_task_result(task, t)) + t->state = task->get_state(); + + t->http_resp = std::move(*task->get_resp()); + t->finish = true; +} + +int WFConsulClient::init(const std::string& proxy_url, ConsulConfig config) +{ + ParsedURI uri; + + if (URIParser::parse(proxy_url, uri) >= 0) + { + this->proxy_url = uri.scheme; + this->proxy_url += "://"; + this->proxy_url += uri.host; + if (uri.port) + { + this->proxy_url += ":"; + this->proxy_url += uri.port; + } + + this->config = std::move(config); + return 0; + } + else if (uri.state == URI_STATE_INVALID) + errno = EINVAL; + + return -1; +} + +int WFConsulClient::init(const std::string& proxy_url) +{ + return this->init(proxy_url, ConsulConfig()); +} + +WFConsulTask *WFConsulClient::create_discover_task( + const std::string& service_namespace, + const std::string& service_name, + int retry_max, + consul_callback_t cb) +{ + WFConsulTask *task = new WFConsulTask(this->proxy_url, service_namespace, + service_name, "", retry_max, + std::move(cb)); + task->set_api_type(CONSUL_API_TYPE_DISCOVER); + task->set_config(this->config); + return task; +} + +WFConsulTask *WFConsulClient::create_list_service_task( + const std::string& service_namespace, + int retry_max, + consul_callback_t cb) +{ + WFConsulTask *task = new WFConsulTask(this->proxy_url, service_namespace, + "", "", retry_max, + std::move(cb)); + task->set_api_type(CONSUL_API_TYPE_LIST_SERVICE); + task->set_config(this->config); + return task; +} + +WFConsulTask *WFConsulClient::create_register_task( + const std::string& service_namespace, + const std::string& service_name, + const std::string& service_id, + int retry_max, + consul_callback_t cb) +{ + WFConsulTask *task = new WFConsulTask(this->proxy_url, service_namespace, + service_name, service_id, retry_max, + std::move(cb)); + task->set_api_type(CONSUL_API_TYPE_REGISTER); + task->set_config(this->config); + return task; +} + +WFConsulTask *WFConsulClient::create_deregister_task( + const std::string& service_namespace, + const std::string& service_id, + int retry_max, + consul_callback_t cb) +{ + WFConsulTask *task = new WFConsulTask(this->proxy_url, service_namespace, + "", service_id, retry_max, + std::move(cb)); + task->set_api_type(CONSUL_API_TYPE_DEREGISTER); + task->set_config(this->config); + return task; +} + +static bool create_tagged_address(const ConsulAddress& consul_address, + const std::string& name, + json_object_t *tagged_obj) +{ + if (consul_address.first.empty()) + return true; + + const json_value_t *val = json_object_append(tagged_obj, name.c_str(), + JSON_VALUE_OBJECT); + if (!val) + return false; + + json_object_t *obj = json_value_object(val); + if (!obj) + return false; + + if (!json_object_append(obj, "Address", JSON_VALUE_STRING, + consul_address.first.c_str())) + return false; + + if (!json_object_append(obj, "Port", JSON_VALUE_NUMBER, + (double)consul_address.second)) + return false; + + return true; +} + +static bool create_health_check(const ConsulConfig& config, json_object_t *obj) +{ + const json_value_t *val; + std::string str; + + if (!config.get_health_check()) + return true; + + val = json_object_append(obj, "Check", JSON_VALUE_OBJECT); + if (!val) + return false; + + obj = json_value_object(val); + if (!obj) + return false; + + str = config.get_check_name(); + if (!json_object_append(obj, "Name", JSON_VALUE_STRING, str.c_str())) + return false; + + str = config.get_check_notes(); + if (!json_object_append(obj, "Notes", JSON_VALUE_STRING, str.c_str())) + return false; + + str = config.get_check_http_url(); + if (!str.empty()) + { + if (!json_object_append(obj, "HTTP", JSON_VALUE_STRING, str.c_str())) + return false; + + str = config.get_check_http_method(); + if (!json_object_append(obj, "Method", JSON_VALUE_STRING, str.c_str())) + return false; + + str = config.get_http_body(); + if (!json_object_append(obj, "Body", JSON_VALUE_STRING, str.c_str())) + return false; + + val = json_object_append(obj, "Header", JSON_VALUE_OBJECT); + if (!val) + return false; + + json_object_t *header_obj = json_value_object(val); + if (!header_obj) + return false; + + for (const auto& header : *config.get_http_headers()) + { + val = json_object_append(header_obj, header.first.c_str(), + JSON_VALUE_ARRAY); + if (!val) + return false; + + json_array_t *arr = json_value_array(val); + if (!arr) + return false; + + for (const auto& value : header.second) + { + if (!json_array_append(arr, JSON_VALUE_STRING, value.c_str())) + return false; + } + } + } + + str = config.get_check_tcp(); + if (!str.empty()) + { + if (!json_object_append(obj, "TCP", JSON_VALUE_STRING, str.c_str())) + return false; + } + + str = config.get_initial_status(); + if (!json_object_append(obj, "Status", JSON_VALUE_STRING, str.c_str())) + return false; + + str = convert_time_to_str(config.get_auto_deregister_time()); + if (!json_object_append(obj, "DeregisterCriticalServiceAfter", + JSON_VALUE_STRING, str.c_str())) + return false; + + str = convert_time_to_str(config.get_check_interval()); + if (!json_object_append(obj, "Interval", JSON_VALUE_STRING, str.c_str())) + return false; + + str = convert_time_to_str(config.get_check_timeout()); + if (!json_object_append(obj, "Timeout", JSON_VALUE_STRING, str.c_str())) + return false; + + if (!json_object_append(obj, "SuccessBeforePassing", JSON_VALUE_NUMBER, + (double)config.get_success_times())) + return false; + + if (!json_object_append(obj, "FailuresBeforeCritical", JSON_VALUE_NUMBER, + (double)config.get_failure_times())) + return false; + + return true; +} + +static bool create_register_request(const json_value_t *root, + const struct ConsulService *service, + const ConsulConfig& config) +{ + const json_value_t *val; + json_object_t *obj; + + obj = json_value_object(root); + if (!obj) + return false; + + if (!json_object_append(obj, "ID", JSON_VALUE_STRING, + service->service_id.c_str())) + return false; + + if (!json_object_append(obj, "Name", JSON_VALUE_STRING, + service->service_name.c_str())) + return false; + + if (!service->service_namespace.empty()) + { + if (!json_object_append(obj, "ns", JSON_VALUE_STRING, + service->service_namespace.c_str())) + return false; + } + + val = json_object_append(obj, "Tags", JSON_VALUE_ARRAY); + if (!val) + return false; + + json_array_t *arr = json_value_array(val); + if (!arr) + return false; + + for (const auto& tag : service->tags) + { + if (!json_array_append(arr, JSON_VALUE_STRING, tag.c_str())) + return false; + } + + if (!json_object_append(obj, "Address", JSON_VALUE_STRING, + service->service_address.first.c_str())) + return false; + + if (!json_object_append(obj, "Port", JSON_VALUE_NUMBER, + (double)service->service_address.second)) + return false; + + val = json_object_append(obj, "Meta", JSON_VALUE_OBJECT); + if (!val) + return false; + + json_object_t *meta_obj = json_value_object(val); + if (!meta_obj) + return false; + + for (const auto& meta_kv : service->meta) + { + if (!json_object_append(meta_obj, meta_kv.first.c_str(), + JSON_VALUE_STRING, meta_kv.second.c_str())) + return false; + } + + int type = service->tag_override ? JSON_VALUE_TRUE : JSON_VALUE_FALSE; + if (!json_object_append(obj, "EnableTagOverride", type)) + return false; + + val = json_object_append(obj, "TaggedAddresses", JSON_VALUE_OBJECT); + if (!val) + return false; + + json_object_t *tagged_obj = json_value_object(val); + if (!tagged_obj) + return false; + + if (!create_tagged_address(service->lan, "lan", tagged_obj)) + return false; + + if (!create_tagged_address(service->lan_ipv4, "lan_ipv4", tagged_obj)) + return false; + + if (!create_tagged_address(service->lan_ipv6, "lan_ipv6", tagged_obj)) + return false; + + if (!create_tagged_address(service->virtual_address, "virtual", tagged_obj)) + return false; + + if (!create_tagged_address(service->wan, "wan", tagged_obj)) + return false; + + if (!create_tagged_address(service->wan_ipv4, "wan_ipv4", tagged_obj)) + return false; + + if (!create_tagged_address(service->wan_ipv6, "wan_ipv6", tagged_obj)) + return false; + + // create health check + if (!create_health_check(config, obj)) + return false; + + return true; +} + +static bool parse_list_service_result(const json_value_t *root, + std::vector& result) +{ + const json_object_t *obj; + const json_value_t *val; + const json_array_t *arr; + const char *key; + const char *str; + + obj = json_value_object(root); + if (!obj) + return false; + + json_object_for_each(key, val, obj) + { + struct ConsulServiceTags instance; + + instance.service_name = key; + arr = json_value_array(val); + if (!arr) + return false; + + const json_value_t *tag_val; + json_array_for_each(tag_val, arr) + { + str = json_value_string(tag_val); + if (!str) + return false; + + instance.tags.emplace_back(str); + } + + result.emplace_back(std::move(instance)); + } + + return true; +} + +static bool parse_discover_node(const json_object_t *obj, + struct ConsulServiceInstance *instance) +{ + const json_value_t *val; + const char *str; + + val = json_object_find("Node", obj); + if (!val) + return false; + + obj = json_value_object(val); + if (!obj) + return false; + + val = json_object_find("ID", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->node_id = str; + val = json_object_find("Node", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->node_name = str; + val = json_object_find("Address", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->node_address = str; + val = json_object_find("Datacenter", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->dc = str; + + val = json_object_find("Meta", obj); + if (!val) + return false; + + const json_object_t *meta_obj = json_value_object(val); + if (!meta_obj) + return false; + + const char *meta_k; + const json_value_t *meta_v; + + json_object_for_each(meta_k, meta_v, meta_obj) + { + str = json_value_string(meta_v); + if (!str) + return false; + + instance->node_meta[meta_k] = str; + } + + val = json_object_find("CreateIndex", obj); + if (val) + instance->create_index = json_value_number(val); + + val = json_object_find("ModifyIndex", obj); + if (val) + instance->modify_index = json_value_number(val); + + return true; +} + +static bool parse_tagged_address(const char *name, + const json_value_t *tagged_val, + ConsulAddress& tagged_address) +{ + const json_value_t *val; + const json_object_t *obj; + const char *str; + + obj = json_value_object(tagged_val); + if (!obj) + return false; + + val = json_object_find(name, obj); + if (!val) + return false; + + obj = json_value_object(val); + if (!obj) + return false; + + val = json_object_find("Address", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + tagged_address.first = str; + val = json_object_find("Port", obj); + if (!val) + return false; + + tagged_address.second = json_value_number(val); + return true; +} + +static bool parse_service(const json_object_t *obj, + struct ConsulService *service) +{ + const json_value_t *val; + const char *str; + + val = json_object_find("Service", obj); + if (!val) + return false; + + obj = json_value_object(val); + if (!obj) + return false; + + val = json_object_find("ID", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + service->service_id = str; + val = json_object_find("Service", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + service->service_name = str; + val = json_object_find("Namespace", obj); + if (val) + { + str = json_value_string(val); + if (!str) + return false; + + service->service_namespace = str; + } + + val = json_object_find("Address", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + service->service_address.first = str; + + val = json_object_find("Port", obj); + if (!val) + return false; + + service->service_address.second = json_value_number(val); + val = json_object_find("TaggedAddresses", obj); + if (!val) + return false; + + parse_tagged_address("lan", val, service->lan); + parse_tagged_address("lan_ipv4", val, service->lan_ipv4); + parse_tagged_address("lan_ipv6", val, service->lan_ipv6); + parse_tagged_address("virtual", val, service->virtual_address); + parse_tagged_address("wan", val, service->wan); + parse_tagged_address("wan_ipv4", val, service->wan_ipv4); + parse_tagged_address("wan_ipv6", val, service->wan_ipv6); + + val = json_object_find("Tags", obj); + if (!val) + return false; + + const json_array_t *tags_arr = json_value_array(val); + if (tags_arr) + { + const json_value_t *tags_value; + json_array_for_each(tags_value, tags_arr) + { + str = json_value_string(tags_value); + if (!str) + return false; + + service->tags.emplace_back(str); + } + } + + val = json_object_find("Meta", obj); + if (!val) + return false; + + const json_object_t *meta_obj = json_value_object(val); + if (!meta_obj) + return false; + + const char *meta_k; + const json_value_t *meta_v; + json_object_for_each(meta_k, meta_v, meta_obj) + { + str = json_value_string(meta_v); + if (!str) + return false; + + service->meta[meta_k] = str; + } + + val = json_object_find("EnableTagOverride", obj); + if (val) + service->tag_override = (json_value_type(val) == JSON_VALUE_TRUE); + + return true; +} + +static bool parse_health_check(const json_object_t *obj, + struct ConsulServiceInstance *instance) +{ + const json_value_t *val; + const char *str; + + val = json_object_find("Checks", obj); + if (!val) + return false; + + const json_array_t *check_arr = json_value_array(val); + if (!check_arr) + return false; + + const json_value_t *arr_val; + json_array_for_each(arr_val, check_arr) + { + obj = json_value_object(arr_val); + if (!obj) + return false; + + val = json_object_find("ServiceName", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + std::string check_service_name = str; + val = json_object_find("ServiceID", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + std::string check_service_id = str; + if (check_service_id.empty() || check_service_name.empty()) + continue; + + val = json_object_find("CheckID", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->check_id = str; + + val = json_object_find("Name", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->check_name = str; + val = json_object_find("Status", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->check_status = str; + val = json_object_find("Notes", obj); + if (val) + { + str = json_value_string(val); + if (!str) + return false; + + instance->check_notes = str; + } + + val = json_object_find("Output", obj); + if (val) + { + str = json_value_string(val); + if (!str) + return false; + + instance->check_output = str; + } + + val = json_object_find("Type", obj); + if (val) + { + str = json_value_string(val); + if (!str) + return false; + + instance->check_type = str; + } + + break; //only one effective service health check + } + + return true; +} + +static bool parse_discover_result(const json_value_t *root, + std::vector& result) +{ + const json_array_t *arr = json_value_array(root); + const json_value_t *val; + const json_object_t *obj; + + if (!arr) + return false; + + json_array_for_each(val, arr) + { + struct ConsulServiceInstance instance; + + obj = json_value_object(val); + if (!obj) + return false; + + if (!parse_discover_node(obj, &instance)) + return false; + + if (!parse_service(obj, &instance.service)) + return false; + + parse_health_check(obj, &instance); + result.emplace_back(std::move(instance)); + } + + return true; +} + +static void print_json_object(const json_object_t *obj, int depth, + std::string& json_str) +{ + const char *name; + const json_value_t *val; + int n = 0; + int i; + + json_str += "{\n"; + json_object_for_each(name, val, obj) + { + if (n != 0) + json_str += ",\n"; + + n++; + for (i = 0; i < depth + 1; i++) + json_str += " "; + + json_str += "\""; + json_str += name; + json_str += "\": "; + print_json_value(val, depth + 1, json_str); + } + + json_str += "\n"; + for (i = 0; i < depth; i++) + json_str += " "; + + json_str += "}"; +} + +static void print_json_array(const json_array_t *arr, int depth, + std::string& json_str) +{ + const json_value_t *val; + int n = 0; + int i; + + json_str += "[\n"; + json_array_for_each(val, arr) + { + if (n != 0) + json_str += ",\n"; + + n++; + for (i = 0; i < depth + 1; i++) + json_str += " "; + + print_json_value(val, depth + 1, json_str); + } + + json_str += "\n"; + for (i = 0; i < depth; i++) + json_str += " "; + + json_str += "]"; +} + +static void print_json_string(const char *str, std::string& json_str) +{ + json_str += "\""; + while (*str) + { + switch (*str) + { + case '\r': + json_str += "\\r"; + break; + case '\n': + json_str += "\\n"; + break; + case '\f': + json_str += "\\f"; + break; + case '\b': + json_str += "\\b"; + break; + case '\"': + json_str += "\\\""; + break; + case '\t': + json_str += "\\t"; + break; + case '\\': + json_str += "\\\\"; + break; + default: + json_str += *str; + break; + } + str++; + } + json_str += "\""; +} + +static void print_json_number(double number, std::string& json_str) +{ + long long integer = number; + + if (integer == number) + json_str += std::to_string(integer); + else + json_str += std::to_string(number); +} + +static void print_json_value(const json_value_t *val, int depth, + std::string& json_str) +{ + switch (json_value_type(val)) + { + case JSON_VALUE_STRING: + print_json_string(json_value_string(val), json_str); + break; + case JSON_VALUE_NUMBER: + print_json_number(json_value_number(val), json_str); + break; + case JSON_VALUE_OBJECT: + print_json_object(json_value_object(val), depth, json_str); + break; + case JSON_VALUE_ARRAY: + print_json_array(json_value_array(val), depth, json_str); + break; + case JSON_VALUE_TRUE: + json_str += "true"; + break; + case JSON_VALUE_FALSE: + json_str += "false"; + break; + case JSON_VALUE_NULL: + json_str += "null"; + break; + } +} + diff --git a/src/client/WFConsulClient.h b/src/client/WFConsulClient.h new file mode 100644 index 0000000..729eaff --- /dev/null +++ b/src/client/WFConsulClient.h @@ -0,0 +1,161 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhenpeng (wangzhenpeng@sogou-inc.com) +*/ + +#ifndef _WFCONSULCLIENT_H_ +#define _WFCONSULCLIENT_H_ + +#include +#include +#include +#include +#include "HttpMessage.h" +#include "WFTaskFactory.h" +#include "ConsulDataTypes.h" + +class WFConsulTask; +using consul_callback_t = std::function; + +enum +{ + CONSUL_API_TYPE_UNKNOWN = 0, + CONSUL_API_TYPE_DISCOVER, + CONSUL_API_TYPE_LIST_SERVICE, + CONSUL_API_TYPE_REGISTER, + CONSUL_API_TYPE_DEREGISTER, +}; + +class WFConsulTask : public WFGenericTask +{ +public: + bool get_discover_result( + std::vector& result); + + bool get_list_service_result( + std::vector& result); + +public: + void set_service(const struct protocol::ConsulService *service); + + void set_api_type(int api_type) + { + this->api_type = api_type; + } + + int get_api_type() const + { + return this->api_type; + } + + void set_callback(consul_callback_t cb) + { + this->callback = std::move(cb); + } + + void set_consul_index(long long consul_index) + { + this->consul_index = consul_index; + } + long long get_consul_index() const { return this->consul_index; } + + const protocol::HttpResponse *get_http_resp() const + { + return &this->http_resp; + } + +protected: + void set_config(protocol::ConsulConfig conf) + { + this->config = std::move(conf); + } + +protected: + virtual void dispatch(); + virtual SubTask *done(); + + WFHttpTask *create_discover_task(); + WFHttpTask *create_list_service_task(); + WFHttpTask *create_register_task(); + WFHttpTask *create_deregister_task(); + + std::string generate_discover_request(); + long long get_consul_index(protocol::HttpResponse *resp); + + static bool check_task_result(WFHttpTask *task, WFConsulTask *consul_task); + static void discover_callback(WFHttpTask *task); + static void list_service_callback(WFHttpTask *task); + static void register_callback(WFHttpTask *task); + +protected: + protocol::ConsulConfig config; + struct protocol::ConsulService service; + std::string proxy_url; + int retry_max; + int api_type; + bool finish; + long long consul_index; + protocol::HttpResponse http_resp; + consul_callback_t callback; + +protected: + WFConsulTask(const std::string& proxy_url, + const std::string& service_namespace, + const std::string& service_name, + const std::string& service_id, + int retry_max, consul_callback_t&& cb); + virtual ~WFConsulTask() { } + friend class WFConsulClient; +}; + +class WFConsulClient +{ +public: + // example: http://127.0.0.1:8500 + int init(const std::string& proxy_url); + int init(const std::string& proxy_url, protocol::ConsulConfig config); + void deinit() { } + + WFConsulTask *create_discover_task(const std::string& service_namespace, + const std::string& service_name, + int retry_max, + consul_callback_t cb); + + WFConsulTask *create_list_service_task(const std::string& service_namespace, + int retry_max, + consul_callback_t cb); + + WFConsulTask *create_register_task(const std::string& service_namespace, + const std::string& service_name, + const std::string& service_id, + int retry_max, + consul_callback_t cb); + + WFConsulTask *create_deregister_task(const std::string& service_namespace, + const std::string& service_id, + int retry_max, + consul_callback_t cb); + +private: + std::string proxy_url; + protocol::ConsulConfig config; + +public: + virtual ~WFConsulClient() { } +}; + +#endif + diff --git a/src/client/WFDnsClient.cc b/src/client/WFDnsClient.cc new file mode 100644 index 0000000..8ccf6c5 --- /dev/null +++ b/src/client/WFDnsClient.cc @@ -0,0 +1,277 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#include +#include +#include +#include "URIParser.h" +#include "StringUtil.h" +#include "WFDnsClient.h" + +using namespace protocol; + +using DnsCtx = std::function; +using ComplexTask = WFComplexClientTask; + +class DnsParams +{ +public: + struct dns_params + { + std::vector uris; + std::vector search_list; + int ndots; + int attempts; + bool rotate; + }; + +public: + DnsParams() + { + this->ref = new std::atomic(1); + this->params = new dns_params(); + } + + DnsParams(const DnsParams& p) + { + this->ref = p.ref; + this->params = p.params; + this->incref(); + } + + DnsParams& operator=(const DnsParams& p) + { + if (this != &p) + { + this->decref(); + this->ref = p.ref; + this->params = p.params; + this->incref(); + } + return *this; + } + + ~DnsParams() { this->decref(); } + + const dns_params *get_params() const { return this->params; } + dns_params *get_params() { return this->params; } + +private: + void incref() { (*this->ref)++; } + void decref() + { + if (--*this->ref == 0) + { + delete this->params; + delete this->ref; + } + } + +private: + dns_params *params; + std::atomic *ref; +}; + +enum +{ + DNS_STATUS_TRY_ORIGIN_DONE = 0, + DNS_STATUS_TRY_ORIGIN_FIRST = 1, + DNS_STATUS_TRY_ORIGIN_LAST = 2 +}; + +struct DnsStatus +{ + std::string origin_name; + std::string current_name; + size_t next_server; // next server to try + size_t last_server; // last server to try + size_t next_domain; // next search domain to try + int attempts_left; + int try_origin_state; +}; + +static int __get_ndots(const std::string& s) +{ + int ndots = 0; + for (size_t i = 0; i < s.size(); i++) + ndots += s[i] == '.'; + return ndots; +} + +static bool __has_next_name(const DnsParams::dns_params *p, + struct DnsStatus *s) +{ + if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_FIRST) + { + s->current_name = s->origin_name; + s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE; + return true; + } + + if (s->next_domain < p->search_list.size()) + { + s->current_name = s->origin_name; + s->current_name.push_back('.'); + s->current_name.append(p->search_list[s->next_domain]); + + s->next_domain++; + return true; + } + + if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_LAST) + { + s->current_name = s->origin_name; + s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE; + return true; + } + + return false; +} + +static void __callback_internal(WFDnsTask *task, const DnsParams& params, + struct DnsStatus& s) +{ + ComplexTask *ctask = static_cast(task); + int state = task->get_state(); + DnsRequest *req = task->get_req(); + DnsResponse *resp = task->get_resp(); + const auto *p = params.get_params(); + int rcode = resp->get_rcode(); + + bool try_next_server = state != WFT_STATE_SUCCESS || + rcode == DNS_RCODE_SERVER_FAILURE || + rcode == DNS_RCODE_NOT_IMPLEMENTED || + rcode == DNS_RCODE_REFUSED; + bool try_next_name = rcode == DNS_RCODE_FORMAT_ERROR || + rcode == DNS_RCODE_NAME_ERROR || + resp->get_ancount() == 0; + + if (try_next_server) + { + if (s.last_server == s.next_server) + s.attempts_left--; + if (s.attempts_left <= 0) + return; + + s.next_server = (s.next_server + 1) % p->uris.size(); + ctask->set_redirect(p->uris[s.next_server]); + return; + } + + if (try_next_name && __has_next_name(p, &s)) + { + req->set_question_name(s.current_name.c_str()); + ctask->set_redirect(p->uris[s.next_server]); + return; + } +} + +int WFDnsClient::init(const std::string& url) +{ + return this->init(url, "", 1, 2, false); +} + +int WFDnsClient::init(const std::string& url, const std::string& search_list, + int ndots, int attempts, bool rotate) +{ + std::vector hosts; + std::vector uris; + std::string host; + ParsedURI uri; + + this->id = 0; + hosts = StringUtil::split_filter_empty(url, ','); + + for (size_t i = 0; i < hosts.size(); i++) + { + host = hosts[i]; + if (strncasecmp(host.c_str(), "dns://", 6) != 0 && + strncasecmp(host.c_str(), "dnss://", 7) != 0) + { + host = "dns://" + host; + } + + if (URIParser::parse(host, uri) != 0) + return -1; + + uris.emplace_back(std::move(uri)); + } + + if (uris.empty() || ndots < 0 || attempts < 1) + { + errno = EINVAL; + return -1; + } + + this->params = new DnsParams; + DnsParams::dns_params *q = ((DnsParams *)this->params)->get_params(); + q->uris = std::move(uris); + q->search_list = StringUtil::split_filter_empty(search_list, ','); + q->ndots = ndots > 15 ? 15 : ndots; + q->attempts = attempts > 5 ? 5 : attempts; + q->rotate = rotate; + + return 0; +} + +void WFDnsClient::deinit() +{ + delete (DnsParams *)this->params; + this->params = NULL; +} + +WFDnsTask *WFDnsClient::create_dns_task(const std::string& name, + dns_callback_t callback) +{ + DnsParams::dns_params *p = ((DnsParams *)this->params)->get_params(); + struct DnsStatus status; + size_t next_server; + WFDnsTask *task; + DnsRequest *req; + + next_server = p->rotate ? this->id++ % p->uris.size() : 0; + + status.origin_name = name; + status.next_domain = 0; + status.attempts_left = p->attempts; + status.try_origin_state = DNS_STATUS_TRY_ORIGIN_FIRST; + + if (!name.empty() && name.back() == '.') + status.next_domain = p->search_list.size(); + else if (__get_ndots(name) < p->ndots) + status.try_origin_state = DNS_STATUS_TRY_ORIGIN_LAST; + + __has_next_name(p, &status); + + task = WFTaskFactory::create_dns_task(p->uris[next_server], 0, + std::move(callback)); + status.next_server = next_server; + status.last_server = (next_server + p->uris.size() - 1) % p->uris.size(); + + req = task->get_req(); + req->set_question(status.current_name.c_str(), DNS_TYPE_A, DNS_CLASS_IN); + req->set_rd(1); + + ComplexTask *ctask = static_cast(task); + *ctask->get_mutable_ctx() = std::bind(__callback_internal, + std::placeholders::_1, + *(DnsParams *)params, status); + + return task; +} + diff --git a/src/client/WFDnsClient.h b/src/client/WFDnsClient.h new file mode 100644 index 0000000..489c171 --- /dev/null +++ b/src/client/WFDnsClient.h @@ -0,0 +1,47 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#ifndef _WFDNSCLIENT_H_ +#define _WFDNSCLIENT_H_ + +#include +#include +#include "WFTaskFactory.h" +#include "DnsMessage.h" + +class WFDnsClient +{ +public: + int init(const std::string& url); + int init(const std::string& url, const std::string& search_list, + int ndots, int attempts, bool rotate); + void deinit(); + + WFDnsTask *create_dns_task(const std::string& name, + dns_callback_t callback); + +private: + void *params; + std::atomic id; + +public: + virtual ~WFDnsClient() { } +}; + +#endif + diff --git a/src/client/WFKafkaClient.cc b/src/client/WFKafkaClient.cc new file mode 100644 index 0000000..8339aed --- /dev/null +++ b/src/client/WFKafkaClient.cc @@ -0,0 +1,1694 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhulei (wangzhulei@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) + Liu Kai (liukaidx@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "WFTaskError.h" +#include "StringUtil.h" +#include "KafkaTaskImpl.inl" +#include "WFKafkaClient.h" + +#define KAFKA_HEARTBEAT_INTERVAL (3 * 1000 * 1000) + +#define KAFKA_CGROUP_UNINIT 0 +#define KAFKA_CGROUP_DOING 1 +#define KAFKA_CGROUP_DONE 2 +#define KAFKA_CGROUP_NONE 3 + +#define KAFKA_HEARTBEAT_UNINIT 0 +#define KAFKA_HEARTBEAT_DOING 1 +#define KAFKA_HEARTBEAT_DONE 2 + +#define KAFKA_DEINIT (1<<30) + +using namespace protocol; + +using ComplexKafkaTask = WFComplexClientTask; + +class KafkaMember +{ +public: + KafkaMember() : scheme("kafka://"), ref(1) + { + this->transport_type = TT_TCP; + this->cgroup_status = KAFKA_CGROUP_NONE; + this->heartbeat_status = KAFKA_HEARTBEAT_UNINIT; + this->meta_doing = false; + this->cgroup_outdated = false; + this->client_deinit = false; + this->heartbeat_series = NULL; + } + + void incref() + { + ++this->ref; + } + + void decref() + { + if (--this->ref == 0) + delete this; + } + + enum TransportType transport_type; + std::string scheme; + std::vector broker_hosts; + SSL_CTX *ssl_ctx; + KafkaCgroup cgroup; + KafkaMetaList meta_list; + KafkaBrokerMap broker_map; + KafkaConfig config; + std::map meta_status; + std::mutex mutex; + char cgroup_status; + char heartbeat_status; + bool meta_doing; + bool cgroup_outdated; + bool client_deinit; + void *heartbeat_series; + size_t cgroup_wait_cnt; + size_t meta_wait_cnt; + std::atomic ref; +}; + +class KafkaClientTask : public WFKafkaTask +{ +public: + KafkaClientTask(const std::string& query, int retry_max, + kafka_callback_t&& callback, + WFKafkaClient *client) : + WFKafkaTask(retry_max, std::move(callback)) + { + this->api_type = Kafka_Unknown; + this->kafka_error = 0; + this->member = client->member; + this->query = query; + + this->member->incref(); + this->member->mutex.lock(); + this->config = client->member->config; + if (!this->member->broker_hosts.empty()) + { + int rpos = rand() % this->member->broker_hosts.size(); + this->url = this->member->broker_hosts.at(rpos); + } + this->member->mutex.unlock(); + + this->info_generated = false; + this->msg = NULL; + } + + virtual ~KafkaClientTask() + { + this->member->decref(); + } + + std::string *get_url() { return &this->url; } + +protected: + virtual bool add_topic(const std::string& topic); + + virtual bool add_toppar(const KafkaToppar& toppar); + + virtual bool add_produce_record(const std::string& topic, int partition, + KafkaRecord record); + + virtual bool add_offset_toppar(const KafkaToppar& toppar); + + virtual void dispatch(); + + virtual void parse_query(); + virtual void generate_info(); + +private: + static void kafka_meta_callback(__WFKafkaTask *task); + + static void kafka_merge_meta_list(KafkaMetaList *dst, + KafkaMetaList *src); + + static void kafka_merge_broker_list(const std::string& scheme, + std::vector *hosts, + KafkaBrokerMap *dst, + KafkaBrokerList *src); + + static void kafka_cgroup_callback(__WFKafkaTask *task); + + static void kafka_offsetcommit_callback(__WFKafkaTask *task); + + static void kafka_parallel_callback(const ParallelWork *pwork); + + static void kafka_timer_callback(WFTimerTask *task); + + static void kafka_heartbeat_callback(__WFKafkaTask *task); + + static void kafka_leavegroup_callback(__WFKafkaTask *task); + + static void kafka_rebalance_proc(KafkaMember *member, SeriesWork *series); + + static void kafka_rebalance_callback(__WFKafkaTask *task); + + void kafka_move_task_callback(__WFKafkaTask *task); + + void kafka_process_toppar_offset(KafkaToppar *task_toppar); + + bool compare_topics(KafkaClientTask *task); + + bool check_cgroup(); + + bool check_meta(); + + int arrange_toppar(int api_type); + + int arrange_produce(); + + int arrange_fetch(); + + int arrange_commit(); + + int arrange_offset(); + + int dispatch_locked(); + + KafkaBroker *get_broker(int node_id) + { + return this->member->broker_map.find_item(node_id); + } + + int get_node_id(const KafkaToppar *toppar); + + bool get_meta_status(KafkaMetaList **uninit_meta_list); + void set_meta_status(bool status); + + std::string get_userinfo() { return this->userinfo; } + +private: + KafkaMember *member; + KafkaBroker broker; + std::map toppar_list_map; + std::string url; + std::string query; + std::set topic_set; + std::string userinfo; + bool info_generated; + bool wait_cgroup; + void *msg; + + friend class WFKafkaClient; +}; + +int KafkaClientTask::get_node_id(const KafkaToppar *toppar) +{ + int preferred_read_replica = toppar->get_preferred_read_replica(); + if (preferred_read_replica >= 0) + return preferred_read_replica; + + bool flag = false; + this->member->meta_list.rewind(); + KafkaMeta *meta; + while ((meta = this->member->meta_list.get_next()) != NULL) + { + if (strcmp(meta->get_topic(), toppar->get_topic()) == 0) + { + flag = true; + break; + } + } + + const kafka_broker_t *broker = NULL; + if (flag) + broker = meta->get_broker(toppar->get_partition()); + + if (!broker) + return -1; + + return broker->node_id; +} + +void KafkaClientTask::kafka_offsetcommit_callback(__WFKafkaTask *task) +{ + KafkaClientTask *t = (KafkaClientTask *)task->user_data; + if (task->get_state() == WFT_STATE_SUCCESS) + t->result.set_resp(std::move(*task->get_resp()), 0); + + t->finish = true; + t->state = task->get_state(); + t->error = task->get_error(); + t->kafka_error = *static_cast(task)->get_mutable_ctx(); +} + +void KafkaClientTask::kafka_leavegroup_callback(__WFKafkaTask *task) +{ + KafkaClientTask *t = (KafkaClientTask *)task->user_data; + t->finish = true; + t->state = task->get_state(); + t->error = task->get_error(); + t->kafka_error = *static_cast(task)->get_mutable_ctx(); +} + +void KafkaClientTask::kafka_rebalance_callback(__WFKafkaTask *task) +{ + KafkaMember *member = (KafkaMember *)task->user_data; + SeriesWork *series = series_of(task); + size_t max; + + member->mutex.lock(); + if (member->client_deinit) + { + member->mutex.unlock(); + member->decref(); + return; + } + + if (task->get_state() == WFT_STATE_SUCCESS) + { + member->cgroup_status = KAFKA_CGROUP_DONE; + member->cgroup = std::move(*(task->get_resp()->get_cgroup())); + + if (member->heartbeat_status == KAFKA_HEARTBEAT_UNINIT) + { + __WFKafkaTask *kafka_task; + KafkaBroker *coordinator = member->cgroup.get_coordinator(); + kafka_task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, + coordinator->get_host(), + coordinator->get_port(), + member->ssl_ctx, "", 0, + kafka_heartbeat_callback); + kafka_task->user_data = member; + kafka_task->get_req()->set_api_type(Kafka_Heartbeat); + kafka_task->get_req()->set_cgroup(member->cgroup); + kafka_task->get_req()->set_broker(*coordinator); + series->push_back(kafka_task); + + member->heartbeat_status = KAFKA_HEARTBEAT_DOING; + member->heartbeat_series = series; + } + + max = member->cgroup_wait_cnt; + char name[64]; + snprintf(name, 64, "%p.cgroup", member); + member->mutex.unlock(); + + WFTaskFactory::signal_by_name(name, NULL, max); + } + else + { + kafka_rebalance_proc(member, series); + member->mutex.unlock(); + } +} + +void KafkaClientTask::kafka_rebalance_proc(KafkaMember *member, SeriesWork *series) +{ + KafkaBroker *coordinator = member->cgroup.get_coordinator(); + __WFKafkaTask *task; + task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, + coordinator->get_host(), + coordinator->get_port(), + member->ssl_ctx, "", 0, + kafka_rebalance_callback); + task->user_data = member; + task->get_req()->set_config(member->config); + task->get_req()->set_api_type(Kafka_FindCoordinator); + task->get_req()->set_cgroup(member->cgroup); + task->get_req()->set_meta_list(member->meta_list); + + member->cgroup_status = KAFKA_CGROUP_DOING; + member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT; + member->cgroup_outdated = false; + + series->push_back(task); +} + +void KafkaClientTask::kafka_heartbeat_callback(__WFKafkaTask *task) +{ + KafkaMember *member = (KafkaMember *)task->user_data; + SeriesWork *series = series_of(task); + KafkaResponse *resp = task->get_resp(); + + member->mutex.lock(); + + if (member->client_deinit || member->heartbeat_series != series) + { + member->mutex.unlock(); + member->decref(); + return; + } + + if (resp->get_cgroup()->get_error() == 0) + { + member->heartbeat_status = KAFKA_HEARTBEAT_DONE; + WFTimerTask *timer_task; + timer_task = WFTaskFactory::create_timer_task(KAFKA_HEARTBEAT_INTERVAL, + kafka_timer_callback); + timer_task->user_data = member; + series->push_back(timer_task); + } + else + kafka_rebalance_proc(member, series); + + member->mutex.unlock(); +} + +void KafkaClientTask::kafka_timer_callback(WFTimerTask *task) +{ + KafkaMember *member = (KafkaMember *)task->user_data; + SeriesWork *series = series_of(task); + + member->mutex.lock(); + if (member->client_deinit || member->heartbeat_series != series) + { + member->mutex.unlock(); + member->decref(); + return; + } + + member->heartbeat_status = KAFKA_HEARTBEAT_DOING; + + __WFKafkaTask *kafka_task; + KafkaBroker *coordinator = member->cgroup.get_coordinator(); + kafka_task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, + coordinator->get_host(), + coordinator->get_port(), + member->ssl_ctx, "", 0, + kafka_heartbeat_callback); + + kafka_task->user_data = member; + kafka_task->get_req()->set_config(member->config); + kafka_task->get_req()->set_api_type(Kafka_Heartbeat); + kafka_task->get_req()->set_cgroup(member->cgroup); + kafka_task->get_req()->set_broker(*coordinator); + series->push_back(kafka_task); + + member->mutex.unlock(); +} + +void KafkaClientTask::kafka_merge_meta_list(KafkaMetaList *dst, + KafkaMetaList *src) +{ + src->rewind(); + KafkaMeta *src_meta; + while ((src_meta = src->get_next()) != NULL) + { + dst->rewind(); + + KafkaMeta *dst_meta; + while ((dst_meta = dst->get_next()) != NULL) + { + if (strcmp(dst_meta->get_topic(), src_meta->get_topic()) == 0) + { + dst->del_cur(); + delete dst_meta; + break; + } + } + + dst->add_item(*src_meta); + } +} + +void KafkaClientTask::kafka_merge_broker_list(const std::string& scheme, + std::vector *hosts, + KafkaBrokerMap *dst, + KafkaBrokerList *src) +{ + hosts->clear(); + src->rewind(); + KafkaBroker *src_broker; + while ((src_broker = src->get_next()) != NULL) + { + std::string host = scheme + src_broker->get_host() + ":" + + std::to_string(src_broker->get_port()); + hosts->emplace_back(std::move(host)); + + if (!dst->find_item(src_broker->get_node_id())) + dst->add_item(*src_broker, src_broker->get_node_id()); + } +} + +void KafkaClientTask::kafka_meta_callback(__WFKafkaTask *task) +{ + KafkaClientTask *t = (KafkaClientTask *)task->user_data; + void *msg = NULL; + size_t max; + + t->member->mutex.lock(); + t->state = task->get_state(); + t->error = task->get_error(); + t->kafka_error = *static_cast(task)->get_mutable_ctx(); + if (t->state == WFT_STATE_SUCCESS) + { + kafka_merge_meta_list(&t->member->meta_list, + task->get_resp()->get_meta_list()); + + t->meta_list.rewind(); + KafkaMeta *meta; + while ((meta = t->meta_list.get_next()) != NULL) + (t->member->meta_status)[meta->get_topic()] = true; + + kafka_merge_broker_list(t->member->scheme, + &t->member->broker_hosts, + &t->member->broker_map, + task->get_resp()->get_broker_list()); + } + else + { + t->meta_list.rewind(); + KafkaMeta *meta; + while ((meta = t->meta_list.get_next()) != NULL) + (t->member->meta_status)[meta->get_topic()] = false; + + t->finish = true; + msg = t; + } + + t->member->meta_doing = false; + max = t->member->meta_wait_cnt; + char name[64]; + snprintf(name, 64, "%p.meta", t->member); + t->member->mutex.unlock(); + + WFTaskFactory::signal_by_name(name, msg, max); +} + +void KafkaClientTask::kafka_cgroup_callback(__WFKafkaTask *task) +{ + KafkaClientTask *t = (KafkaClientTask *)task->user_data; + KafkaMember *member = t->member; + SeriesWork *heartbeat_series = NULL; + void *msg = NULL; + size_t max; + + member->mutex.lock(); + t->state = task->get_state(); + t->error = task->get_error(); + t->kafka_error = *static_cast(task)->get_mutable_ctx(); + + if (t->state == WFT_STATE_SUCCESS) + { + member->cgroup = std::move(*(task->get_resp()->get_cgroup())); + + kafka_merge_meta_list(&member->meta_list, + task->get_resp()->get_meta_list()); + + t->meta_list.rewind(); + KafkaMeta *meta; + while ((meta = t->meta_list.get_next()) != NULL) + (member->meta_status)[meta->get_topic()] = true; + + kafka_merge_broker_list(member->scheme, + &member->broker_hosts, + &member->broker_map, + task->get_resp()->get_broker_list()); + + member->cgroup_status = KAFKA_CGROUP_DONE; + + if (member->heartbeat_status == KAFKA_HEARTBEAT_UNINIT) + { + __WFKafkaTask *kafka_task; + KafkaBroker *coordinator = member->cgroup.get_coordinator(); + kafka_task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, + coordinator->get_host(), + coordinator->get_port(), + member->ssl_ctx, "", 0, + kafka_heartbeat_callback); + kafka_task->user_data = member; + member->incref(); + + kafka_task->get_req()->set_config(member->config); + kafka_task->get_req()->set_api_type(Kafka_Heartbeat); + kafka_task->get_req()->set_cgroup(member->cgroup); + kafka_task->get_req()->set_broker(*coordinator); + + heartbeat_series = Workflow::create_series_work(kafka_task, nullptr); + member->heartbeat_status = KAFKA_HEARTBEAT_DOING; + member->heartbeat_series = heartbeat_series; + } + } + else + { + member->cgroup_status = KAFKA_CGROUP_UNINIT; + member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT; + member->heartbeat_series = NULL; + t->finish = true; + msg = t; + } + + max = member->cgroup_wait_cnt; + char name[64]; + snprintf(name, 64, "%p.cgroup", member); + member->mutex.unlock(); + + WFTaskFactory::signal_by_name(name, msg, max); + + if (heartbeat_series) + heartbeat_series->start(); +} + +void KafkaClientTask::kafka_parallel_callback(const ParallelWork *pwork) +{ + KafkaClientTask *t = (KafkaClientTask *)pwork->get_context(); + t->finish = true; + t->state = WFT_STATE_TASK_ERROR; + t->error = 0; + + std::pair *state_error; + bool flag = false; + int16_t state = WFT_STATE_SUCCESS; + int16_t error = 0; + int kafka_error = 0; + for (size_t i = 0; i < pwork->size(); i++) + { + state_error = (std::pair *)pwork->series_at(i)->get_context(); + if ((state_error->first >> 16) != WFT_STATE_SUCCESS) + { + if (!flag) + { + flag = true; + t->member->mutex.lock(); + t->set_meta_status(false); + t->member->mutex.unlock(); + } + state = state_error->first >> 16; + error = state_error->first & 0xffff; + kafka_error = state_error->second; + } + else + { + t->state = WFT_STATE_SUCCESS; + } + + delete state_error; + } + + if (t->state != WFT_STATE_SUCCESS) + { + t->state = state; + t->error = error; + t->kafka_error = kafka_error; + } +} + +void KafkaClientTask::kafka_process_toppar_offset(KafkaToppar *task_toppar) +{ + KafkaToppar *toppar; + + struct list_head *pos; + list_for_each(pos, this->member->cgroup.get_assigned_toppar_list()) + { + toppar = this->member->cgroup.get_assigned_toppar_by_pos(pos); + if (strcmp(toppar->get_topic(), task_toppar->get_topic()) == 0 && + toppar->get_partition() == task_toppar->get_partition()) + { + long long offset = task_toppar->get_offset() - 1; + KafkaRecord *last_record = task_toppar->get_tail_record(); + if (last_record) + offset = last_record->get_offset(); + toppar->set_offset(offset + 1); + toppar->set_low_watermark(task_toppar->get_low_watermark()); + toppar->set_high_watermark(task_toppar->get_high_watermark()); + } + } +} + +void KafkaClientTask::kafka_move_task_callback(__WFKafkaTask *task) +{ + auto *state_error = new std::pair; + int16_t state = task->get_state(); + int16_t error = task->get_error(); + + /* 'state' is always positive. */ + state_error->first = (state << 16) | error; + state_error->second = *static_cast(task)->get_mutable_ctx(); + series_of(task)->set_context(state_error); + + KafkaTopparList *toppar_list = task->get_resp()->get_toppar_list(); + + if (task->get_state() == WFT_STATE_SUCCESS && + task->get_resp()->get_api_type() == Kafka_Fetch) + { + toppar_list->rewind(); + KafkaToppar *task_toppar; + + while ((task_toppar = toppar_list->get_next()) != NULL) + kafka_process_toppar_offset(task_toppar); + } + + if (task->get_state() == WFT_STATE_SUCCESS) + { + long idx = (long)(task->user_data); + this->result.set_resp(std::move(*task->get_resp()), idx); + } +} + +void KafkaClientTask::generate_info() +{ + if (this->info_generated) + return; + + if (this->config.get_sasl_mech()) + { + const char *username = this->config.get_sasl_username(); + const char *password = this->config.get_sasl_password(); + + this->userinfo.clear(); + if (username) + this->userinfo += StringUtil::url_encode_component(username); + this->userinfo += ":"; + if (password) + this->userinfo += StringUtil::url_encode_component(password); + this->userinfo += ":"; + this->userinfo += this->config.get_sasl_mech(); + this->userinfo += ":"; + this->userinfo += std::to_string((intptr_t)this->member); + } + else + { + char buf[64]; + snprintf(buf, 64, "user:pass:sasl:%p", this->member); + this->userinfo = buf; + } + + const char *hostport = this->url.c_str() + this->member->scheme.size(); + this->url = this->member->scheme + this->userinfo + "@" + hostport; + this->info_generated = true; +} + +void KafkaClientTask::parse_query() +{ + auto query_kv = URIParser::split_query_strict(this->query); + int api_type = this->api_type; + for (const auto &kv : query_kv) + { + if (strcasecmp(kv.first.c_str(), "api") == 0 && + api_type == Kafka_Unknown) + { + for (auto& v : kv.second) + { + if (strcasecmp(v.c_str(), "fetch") == 0) + this->api_type = Kafka_Fetch; + else if (strcasecmp(v.c_str(), "produce") == 0) + this->api_type = Kafka_Produce; + else if (strcasecmp(v.c_str(), "commit") == 0) + this->api_type = Kafka_OffsetCommit; + else if (strcasecmp(v.c_str(), "meta") == 0) + this->api_type = Kafka_Metadata; + else if (strcasecmp(v.c_str(), "leavegroup") == 0) + this->api_type = Kafka_LeaveGroup; + else if (strcasecmp(v.c_str(), "listoffsets") == 0) + this->api_type = Kafka_ListOffsets; + } + } + else if (strcasecmp(kv.first.c_str(), "topic") == 0) + { + for (auto& v : kv.second) + this->add_topic(v); + } + } +} + +bool KafkaClientTask::get_meta_status(KafkaMetaList **uninit_meta_list) +{ + this->meta_list.rewind(); + KafkaMeta *meta; + std::set unique; + bool status = true; + + while ((meta = this->meta_list.get_next()) != NULL) + { + if (!unique.insert(meta->get_topic()).second) + continue; + + if (!this->member->meta_status[meta->get_topic()]) + { + if (status) + { + *uninit_meta_list = new KafkaMetaList; + status = false; + } + + (*uninit_meta_list)->add_item(*meta); + } + } + + return status; +} + +void KafkaClientTask::set_meta_status(bool status) +{ + this->member->meta_list.rewind(); + KafkaMeta *meta; + while ((meta = this->member->meta_list.get_next()) != NULL) + this->member->meta_status[meta->get_topic()] = false; +} + +bool KafkaClientTask::compare_topics(KafkaClientTask *task) +{ + auto first1 = topic_set.cbegin(), last1 = topic_set.cend(); + auto first2 = task->topic_set.cbegin(), last2 = task->topic_set.cend(); + int cmp; + + // check whether task->topic_set is a subset of topic_set + while (first1 != last1 && first2 != last2) + { + cmp = first1->compare(*first2); + if (cmp == 0) + { + ++first1; + ++first2; + } + else if (cmp < 0) + ++first1; + else + return false; + } + + return first2 == last2; +} + +bool KafkaClientTask::check_cgroup() +{ + KafkaMember *member = this->member; + + if (member->cgroup_outdated && member->cgroup_status != KAFKA_CGROUP_DOING) + { + member->cgroup_outdated = false; + member->cgroup_status = KAFKA_CGROUP_UNINIT; + member->heartbeat_series = NULL; + member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT; + } + + if (member->cgroup_status == KAFKA_CGROUP_DOING) + { + WFConditional *cond; + char name[64]; + snprintf(name, 64, "%p.cgroup", this->member); + this->wait_cgroup = true; + cond = WFTaskFactory::create_conditional(name, this, &this->msg); + series_of(this)->push_front(cond); + member->cgroup_wait_cnt++; + return false; + } + + if ((this->api_type == Kafka_Fetch || this->api_type == Kafka_OffsetCommit) && + (member->cgroup_status == KAFKA_CGROUP_UNINIT)) + { + __WFKafkaTask *task; + + task = __WFKafkaTaskFactory::create_kafka_task(this->url, member->ssl_ctx, + this->retry_max, + kafka_cgroup_callback); + task->user_data = this; + task->get_req()->set_config(this->config); + task->get_req()->set_api_type(Kafka_FindCoordinator); + task->get_req()->set_cgroup(member->cgroup); + task->get_req()->set_meta_list(member->meta_list); + series_of(this)->push_front(this); + series_of(this)->push_front(task); + member->cgroup_status = KAFKA_CGROUP_DOING; + member->cgroup_wait_cnt = 0; + return false; + } + + return true; +} + +bool KafkaClientTask::check_meta() +{ + KafkaMember *member = this->member; + KafkaMetaList *uninit_meta_list; + + if (this->get_meta_status(&uninit_meta_list)) + return true; + + if (member->meta_doing) + { + WFConditional *cond; + char name[64]; + snprintf(name, 64, "%p.meta", this->member); + this->wait_cgroup = false; + cond = WFTaskFactory::create_conditional(name, this, &this->msg); + series_of(this)->push_front(cond); + member->meta_wait_cnt++; + } + else + { + __WFKafkaTask *task; + + task = __WFKafkaTaskFactory::create_kafka_task(this->url, member->ssl_ctx, + this->retry_max, + kafka_meta_callback); + task->user_data = this; + task->get_req()->set_config(this->config); + task->get_req()->set_api_type(Kafka_Metadata); + task->get_req()->set_meta_list(*uninit_meta_list); + series_of(this)->push_front(this); + series_of(this)->push_front(task); + member->meta_wait_cnt = 0; + member->meta_doing = true; + } + + delete uninit_meta_list; + return false; +} + +int KafkaClientTask::dispatch_locked() +{ + KafkaMember *member = this->member; + KafkaBroker *coordinator; + __WFKafkaTask *task; + ParallelWork *parallel; + SeriesWork *series; + + if (this->check_cgroup() == false) + return member->cgroup_wait_cnt > 0; + + if (this->check_meta() == false) + return member->meta_wait_cnt > 0; + + if (arrange_toppar(this->api_type) < 0) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_KAFKA_ARRANGE_FAILED; + this->finish = true; + return 0; + } + + if (this->member->cgroup_outdated) + { + series_of(this)->push_front(this); + return 0; + } + + switch(this->api_type) + { + case Kafka_Produce: + if (this->toppar_list_map.size() == 0) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_KAFKA_PRODUCE_FAILED; + this->finish = true; + break; + } + + parallel = Workflow::create_parallel_work(kafka_parallel_callback); + this->result.create(this->toppar_list_map.size()); + parallel->set_context(this); + for (auto &v : this->toppar_list_map) + { + auto cb = std::bind(&KafkaClientTask::kafka_move_task_callback, this, + std::placeholders::_1); + KafkaBroker *broker = get_broker(v.first); + task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, + broker->get_host(), + broker->get_port(), + member->ssl_ctx, + this->get_userinfo(), + this->retry_max, + std::move(cb)); + task->get_req()->set_config(this->config); + task->get_req()->set_toppar_list(v.second); + task->get_req()->set_broker(*broker); + task->get_req()->set_api_type(Kafka_Produce); + task->user_data = (void *)parallel->size(); + series = Workflow::create_series_work(task, nullptr); + parallel->add_series(series); + } + series_of(this)->push_front(this); + series_of(this)->push_front(parallel); + break; + + case Kafka_Fetch: + if (this->toppar_list_map.size() == 0) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_KAFKA_FETCH_FAILED; + this->finish = true; + break; + } + + parallel = Workflow::create_parallel_work(kafka_parallel_callback); + this->result.create(this->toppar_list_map.size()); + parallel->set_context(this); + for (auto &v : this->toppar_list_map) + { + auto cb = std::bind(&KafkaClientTask::kafka_move_task_callback, this, + std::placeholders::_1); + KafkaBroker *broker = get_broker(v.first); + task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, + broker->get_host(), + broker->get_port(), + member->ssl_ctx, + this->get_userinfo(), + this->retry_max, + std::move(cb)); + + task->get_req()->set_config(this->config); + task->get_req()->set_toppar_list(v.second); + task->get_req()->set_broker(*broker); + task->get_req()->set_api_type(Kafka_Fetch); + task->user_data = (void *)parallel->size(); + series = Workflow::create_series_work(task, nullptr); + parallel->add_series(series); + } + + series_of(this)->push_front(this); + series_of(this)->push_front(parallel); + break; + + case Kafka_Metadata: + this->finish = true; + break; + + case Kafka_OffsetCommit: + if (!member->cgroup.get_group()) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_KAFKA_COMMIT_FAILED; + this->finish = true; + break; + } + + this->result.create(1); + coordinator = member->cgroup.get_coordinator(); + task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, + coordinator->get_host(), + coordinator->get_port(), + member->ssl_ctx, + this->get_userinfo(), + this->retry_max, + kafka_offsetcommit_callback); + task->user_data = this; + task->get_req()->set_config(this->config); + task->get_req()->set_cgroup(member->cgroup); + task->get_req()->set_broker(*coordinator); + task->get_req()->set_toppar_list(this->toppar_list); + task->get_req()->set_api_type(this->api_type); + series_of(this)->push_front(this); + series_of(this)->push_front(task); + break; + + case Kafka_LeaveGroup: + if (!member->cgroup.get_group()) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_KAFKA_LEAVEGROUP_FAILED; + this->finish = true; + break; + } + + coordinator = member->cgroup.get_coordinator(); + if (!coordinator->get_host()) + break; + + task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, + coordinator->get_host(), + coordinator->get_port(), + member->ssl_ctx, + this->get_userinfo(), 0, + kafka_leavegroup_callback); + task->user_data = this; + task->get_req()->set_config(this->config); + task->get_req()->set_api_type(Kafka_LeaveGroup); + task->get_req()->set_broker(*coordinator); + task->get_req()->set_cgroup(member->cgroup); + series_of(this)->push_front(this); + series_of(this)->push_front(task); + break; + + case Kafka_ListOffsets: + if (this->toppar_list_map.size() == 0) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_KAFKA_LIST_OFFSETS_FAILED; + this->finish = true; + break; + } + + parallel = Workflow::create_parallel_work(kafka_parallel_callback); + this->result.create(this->toppar_list_map.size()); + parallel->set_context(this); + for (auto &v : this->toppar_list_map) + { + auto cb = std::bind(&KafkaClientTask::kafka_move_task_callback, this, + std::placeholders::_1); + KafkaBroker *broker = get_broker(v.first); + task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, + broker->get_host(), + broker->get_port(), + member->ssl_ctx, + this->get_userinfo(), + this->retry_max, + std::move(cb)); + task->get_req()->set_config(this->config); + task->get_req()->set_toppar_list(v.second); + task->get_req()->set_broker(*broker); + task->get_req()->set_api_type(Kafka_ListOffsets); + task->user_data = (void *)parallel->size(); + series = Workflow::create_series_work(task, nullptr); + parallel->add_series(series); + } + series_of(this)->push_front(this); + series_of(this)->push_front(parallel); + break; + + default: + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_KAFKA_API_UNKNOWN; + this->finish = true; + break; + } + + return 0; +} + +void KafkaClientTask::dispatch() +{ + if (this->finish) + { + this->subtask_done(); + return; + } + + if (this->msg) + { + KafkaClientTask *task = static_cast(this->msg); + if (this->wait_cgroup || this->compare_topics(task) == true) + { + this->state = task->get_state(); + this->error = task->get_error(); + this->kafka_error = get_kafka_error(); + this->finish = true; + this->subtask_done(); + return; + } + + this->msg = NULL; + } + + if (!this->query.empty()) + this->parse_query(); + + this->generate_info(); + + int flag; + this->member->mutex.lock(); + flag = this->dispatch_locked(); + if (flag) + this->subtask_done(); + this->member->mutex.unlock(); + + if (!flag) + this->subtask_done(); +} + +bool KafkaClientTask::add_topic(const std::string& topic) +{ + bool flag = false; + this->member->mutex.lock(); + + this->topic_set.insert(topic); + this->member->meta_list.rewind(); + KafkaMeta *meta; + while ((meta = this->member->meta_list.get_next()) != NULL) + { + if (meta->get_topic() == topic) + { + flag = true; + break; + } + } + + if (!flag) + { + this->member->meta_status[topic] = false; + + KafkaMeta tmp; + if (!tmp.set_topic(topic)) + { + this->member->mutex.unlock(); + return false; + } + + this->meta_list.add_item(tmp); + this->member->meta_list.add_item(tmp); + + if (this->member->cgroup.get_group()) + this->member->cgroup_outdated = true; + } + else + { + this->meta_list.rewind(); + KafkaMeta *exist; + while ((exist = this->meta_list.get_next()) != NULL) + { + if (strcmp(exist->get_topic(), meta->get_topic()) == 0) + { + this->member->mutex.unlock(); + return true; + } + } + + this->meta_list.add_item(*meta); + } + + this->member->mutex.unlock(); + + return true; +} + +bool KafkaClientTask::add_toppar(const KafkaToppar& toppar) +{ + if (this->member->cgroup.get_group()) + return false; + + bool flag = false; + this->member->mutex.lock(); + + this->member->meta_list.rewind(); + KafkaMeta *meta; + while ((meta = this->member->meta_list.get_next()) != NULL) + { + if (strcmp(meta->get_topic(), toppar.get_topic()) == 0) + { + flag = true; + break; + } + } + + this->topic_set.insert(toppar.get_topic()); + if (!flag) + { + KafkaMeta tmp; + if (!tmp.set_topic(toppar.get_topic())) + { + this->member->mutex.unlock(); + return false; + } + + KafkaToppar new_toppar; + if (!new_toppar.set_topic_partition(toppar.get_topic(), toppar.get_partition())) + { + this->member->mutex.unlock(); + return false; + } + + new_toppar.set_offset(toppar.get_offset()); + new_toppar.set_offset_timestamp(toppar.get_offset_timestamp()); + new_toppar.set_low_watermark(toppar.get_low_watermark()); + new_toppar.set_high_watermark(toppar.get_high_watermark()); + this->toppar_list.add_item(new_toppar); + + this->meta_list.add_item(tmp); + this->member->meta_list.add_item(tmp); + + if (this->member->cgroup.get_group()) + this->member->cgroup_outdated = true; + } + else + { + this->toppar_list.rewind(); + KafkaToppar *exist; + while ((exist = this->toppar_list.get_next()) != NULL) + { + if (strcmp(exist->get_topic(), toppar.get_topic()) == 0 && + exist->get_partition() == toppar.get_partition()) + { + this->member->mutex.unlock(); + return true; + } + } + + KafkaToppar new_toppar; + if (!new_toppar.set_topic_partition(toppar.get_topic(), toppar.get_partition())) + { + this->member->mutex.unlock(); + return true; + } + + new_toppar.set_offset(toppar.get_offset()); + new_toppar.set_offset_timestamp(toppar.get_offset_timestamp()); + new_toppar.set_low_watermark(toppar.get_low_watermark()); + new_toppar.set_high_watermark(toppar.get_high_watermark()); + this->toppar_list.add_item(new_toppar); + + this->meta_list.add_item(*meta); + } + + this->member->mutex.unlock(); + + return true; +} + +bool KafkaClientTask::add_produce_record(const std::string& topic, + int partition, + KafkaRecord record) +{ + if (!add_topic(topic)) + return false; + + bool flag = false; + this->toppar_list.rewind(); + KafkaToppar *toppar; + while ((toppar = this->toppar_list.get_next()) != NULL) + { + if (toppar->get_topic() == topic && + toppar->get_partition() == partition) + { + flag = true; + break; + } + } + + if (!flag) + { + KafkaToppar new_toppar; + if (!new_toppar.set_topic_partition(topic, partition)) + return false; + + new_toppar.add_record(std::move(record)); + this->toppar_list.add_item(std::move(new_toppar)); + } + else + toppar->add_record(std::move(record)); + + return true; +} + +static bool check_replace_toppar(KafkaTopparList *toppar_list, KafkaToppar *toppar) +{ + bool flag = false; + toppar_list->rewind(); + KafkaToppar *exist; + while ((exist = toppar_list->get_next()) != NULL) + { + if (strcmp(exist->get_topic(), toppar->get_topic()) == 0 && + exist->get_partition() == toppar->get_partition()) + { + flag = true; + if (toppar->get_offset() > exist->get_offset()) + { + toppar_list->add_item(std::move(*toppar)); + toppar_list->del_cur(); + delete exist; + return true; + } + } + } + + if (!flag) + { + toppar_list->add_item(std::move(*toppar)); + return true; + } + + return false; +} + +int KafkaClientTask::arrange_toppar(int api_type) +{ + switch(api_type) + { + case Kafka_Produce: + return this->arrange_produce(); + + case Kafka_Fetch: + return this->arrange_fetch(); + + case Kafka_ListOffsets: + return this->arrange_offset(); + + case Kafka_OffsetCommit: + return this->arrange_commit(); + + default: + return 0; + } +} + +bool KafkaClientTask::add_offset_toppar(const protocol::KafkaToppar& toppar) +{ + if (!add_topic(toppar.get_topic())) + return false; + + KafkaToppar *exist; + bool found = false; + while ((exist = this->toppar_list.get_next()) != NULL) + { + if (strcmp(exist->get_topic(), toppar.get_topic()) == 0 && + exist->get_partition() == toppar.get_partition()) + { + found = true; + break; + } + } + + if (!found) + { + KafkaToppar toppar_t; + toppar_t.set_topic_partition(toppar.get_topic(), toppar.get_partition()); + this->toppar_list.add_item(std::move(toppar_t)); + } + + return true; +} + +int KafkaClientTask::arrange_offset() +{ + this->toppar_list.rewind(); + KafkaToppar *toppar; + while ((toppar = this->toppar_list.get_next()) != NULL) + { + int node_id = get_node_id(toppar); + if (node_id < 0) + return -1; + + if (this->toppar_list_map.find(node_id) == this->toppar_list_map.end()) + this->toppar_list_map[node_id] = (KafkaTopparList()); + + KafkaToppar new_toppar; + if (!new_toppar.set_topic_partition(toppar->get_topic(), toppar->get_partition())) + return -1; + + this->toppar_list_map[node_id].add_item(std::move(new_toppar)); + } + + return 0; +} + +int KafkaClientTask::arrange_commit() +{ + this->toppar_list.rewind(); + KafkaTopparList new_toppar_list; + KafkaToppar *toppar; + while ((toppar = this->toppar_list.get_next()) != NULL) + { + check_replace_toppar(&new_toppar_list, toppar); + } + + this->toppar_list = std::move(new_toppar_list); + return 0; +} + +int KafkaClientTask::arrange_fetch() +{ + this->meta_list.rewind(); + for (auto& topic : topic_set) + { + if (this->member->cgroup.get_group()) + { + this->member->cgroup.assigned_toppar_rewind(); + KafkaToppar *toppar; + while ((toppar = this->member->cgroup.get_assigned_toppar_next()) != NULL) + { + if (topic.compare(toppar->get_topic()) == 0) + { + int node_id = get_node_id(toppar); + if (node_id < 0) + return -1; + + if (this->toppar_list_map.find(node_id) == this->toppar_list_map.end()) + this->toppar_list_map[node_id] = (KafkaTopparList()); + + KafkaToppar new_toppar; + if (!new_toppar.set_topic_partition(toppar->get_topic(), toppar->get_partition())) + return -1; + + new_toppar.set_offset(toppar->get_offset()); + new_toppar.set_low_watermark(toppar->get_low_watermark()); + new_toppar.set_high_watermark(toppar->get_high_watermark()); + this->toppar_list_map[node_id].add_item(std::move(new_toppar)); + } + } + } + else + { + this->toppar_list.rewind(); + KafkaToppar *toppar; + while ((toppar = this->toppar_list.get_next()) != NULL) + { + if (topic.compare(toppar->get_topic()) == 0) + { + int node_id = get_node_id(toppar); + if (node_id < 0) + return -1; + + if (this->toppar_list_map.find(node_id) == this->toppar_list_map.end()) + this->toppar_list_map[node_id] = KafkaTopparList(); + + KafkaToppar new_toppar; + if (!new_toppar.set_topic_partition(toppar->get_topic(), toppar->get_partition())) + return -1; + + new_toppar.set_offset(toppar->get_offset()); + new_toppar.set_offset_timestamp(toppar->get_offset_timestamp()); + new_toppar.set_low_watermark(toppar->get_low_watermark()); + new_toppar.set_high_watermark(toppar->get_high_watermark()); + this->toppar_list_map[node_id].add_item(std::move(new_toppar)); + } + } + } + } + + return 0; +} + +int KafkaClientTask::arrange_produce() +{ + this->toppar_list.rewind(); + KafkaToppar *toppar; + while ((toppar = this->toppar_list.get_next()) != NULL) + { + if (toppar->get_partition() < 0) + { + toppar->record_rewind(); + KafkaRecord *record; + while ((record = toppar->get_record_next()) != NULL) + { + int partition_num; + const KafkaMeta *meta; + meta = get_meta(toppar->get_topic(), &this->member->meta_list); + if (!meta) + return -1; + + partition_num = meta->get_partition_elements(); + if (partition_num <= 0) + return -1; + + int partition = -1; + if (this->partitioner) + { + const void *key; + size_t key_len; + record->get_key(&key, &key_len); + partition = this->partitioner(toppar->get_topic(), key, + key_len, partition_num); + } + else + partition = rand() % partition_num; + + KafkaToppar *new_toppar = get_toppar(toppar->get_topic(), + partition, + &this->toppar_list); + if (!new_toppar) + { + KafkaToppar tmp; + if (!tmp.set_topic_partition(toppar->get_topic(), partition)) + return -1; + + new_toppar = this->toppar_list.add_item(std::move(tmp)); + } + + record->get_raw_ptr()->toppar = new_toppar->get_raw_ptr(); + new_toppar->add_record(std::move(*record)); + toppar->del_record_cur(); + delete record; + } + this->toppar_list.del_cur(); + delete toppar; + } + else + { + KafkaRecord *record; + while ((record = toppar->get_record_next()) != NULL) + record->get_raw_ptr()->toppar = toppar->get_raw_ptr(); + } + } + + this->toppar_list.rewind(); + KafkaTopparList toppar_list; + while ((toppar = this->toppar_list.get_next()) != NULL) + { + int node_id = get_node_id(toppar); + if (node_id < 0) + return -1; + + if (this->toppar_list_map.find(node_id) == this->toppar_list_map.end()) + this->toppar_list_map[node_id] = KafkaTopparList(); + + this->toppar_list_map[node_id].add_item(std::move(*toppar)); + } + + return 0; +} + +SubTask *WFKafkaTask::done() +{ + SeriesWork *series = series_of(this); + + auto cb = [] (WFTimerTask *task) + { + WFKafkaTask *kafka_task = (WFKafkaTask *)task->user_data; + if (kafka_task->callback) + kafka_task->callback(kafka_task); + + delete kafka_task; + }; + + if (finish) + { + if (this->state == WFT_STATE_TASK_ERROR) + { + WFTimerTask *timer; + timer = WFTaskFactory::create_timer_task(0, 0, std::move(cb)); + timer->user_data = this; + series->push_front(timer); + } + else + { + if (this->callback) + this->callback(this); + + delete this; + } + } + + return series->pop(); +} + +int WFKafkaClient::init(const std::string& broker, SSL_CTX *ssl_ctx) +{ + std::vector broker_hosts; + std::string::size_type ppos = 0; + std::string::size_type pos; + bool use_ssl; + + use_ssl = (strncasecmp(broker.c_str(), "kafkas://", 9) == 0); + while (1) + { + pos = broker.find(',', ppos); + std::string host = broker.substr(ppos, pos - ppos); + if (use_ssl) + { + if (strncasecmp(host.c_str(), "kafkas://", 9) != 0) + { + errno = EINVAL; + return -1; + } + } + else if (strncasecmp(host.c_str(), "kafka://", 8) != 0) + { + if (strncasecmp(host.c_str(), "kafkas://", 9) == 0) + { + errno = EINVAL; + return -1; + } + + host = "kafka://" + host; + } + + broker_hosts.emplace_back(host); + if (pos == std::string::npos) + break; + + ppos = pos + 1; + } + + this->member = new KafkaMember; + this->member->broker_hosts = std::move(broker_hosts); + this->member->ssl_ctx = ssl_ctx; + if (use_ssl) + { + this->member->transport_type = TT_TCP_SSL; + this->member->scheme = "kafkas://"; + } + + return 0; +} + +int WFKafkaClient::init(const std::string& broker, const std::string& group, + SSL_CTX *ssl_ctx) +{ + if (this->init(broker, ssl_ctx) < 0) + return -1; + + this->member->cgroup.set_group(group); + this->member->cgroup_status = KAFKA_CGROUP_UNINIT; + return 0; +} + +int WFKafkaClient::deinit() +{ + this->member->mutex.lock(); + this->member->client_deinit = true; + this->member->mutex.unlock(); + this->member->decref(); + return 0; +} + +WFKafkaTask *WFKafkaClient::create_kafka_task(const std::string& query, + int retry_max, + kafka_callback_t cb) +{ + WFKafkaTask *task = new KafkaClientTask(query, retry_max, std::move(cb), this); + return task; +} + +WFKafkaTask *WFKafkaClient::create_kafka_task(int retry_max, + kafka_callback_t cb) +{ + WFKafkaTask *task = new KafkaClientTask("", retry_max, std::move(cb), this); + return task; +} + +WFKafkaTask *WFKafkaClient::create_leavegroup_task(int retry_max, + kafka_callback_t cb) +{ + WFKafkaTask *task = new KafkaClientTask("api=leavegroup", retry_max, + std::move(cb), this); + return task; +} + +void WFKafkaClient::set_config(protocol::KafkaConfig conf) +{ + this->member->config = std::move(conf); +} + +KafkaMetaList *WFKafkaClient::get_meta_list() +{ + return &this->member->meta_list; +} + diff --git a/src/client/WFKafkaClient.h b/src/client/WFKafkaClient.h new file mode 100644 index 0000000..61ec25e --- /dev/null +++ b/src/client/WFKafkaClient.h @@ -0,0 +1,194 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhulei (wangzhulei@sogou-inc.com) +*/ + +#ifndef _WFKAFKACLIENT_H_ +#define _WFKAFKACLIENT_H_ + +#include +#include +#include +#include +#include "WFTask.h" +#include "KafkaMessage.h" +#include "KafkaResult.h" + +class WFKafkaTask; +class WFKafkaClient; + +using kafka_callback_t = std::function; +using kafka_partitioner_t = std::function; + +class WFKafkaTask : public WFGenericTask +{ +public: + virtual bool add_topic(const std::string& topic) = 0; + + virtual bool add_toppar(const protocol::KafkaToppar& toppar) = 0; + + virtual bool add_produce_record(const std::string& topic, int partition, + protocol::KafkaRecord record) = 0; + + virtual bool add_offset_toppar(const protocol::KafkaToppar& toppar) = 0; + + void add_commit_record(const protocol::KafkaRecord& record) + { + protocol::KafkaToppar toppar; + toppar.set_topic_partition(record.get_topic(), record.get_partition()); + toppar.set_offset(record.get_offset()); + toppar.set_error(0); + this->toppar_list.add_item(std::move(toppar)); + } + + void add_commit_toppar(const protocol::KafkaToppar& toppar) + { + protocol::KafkaToppar toppar_t; + toppar_t.set_topic_partition(toppar.get_topic(), toppar.get_partition()); + toppar_t.set_offset(toppar.get_offset()); + toppar_t.set_error(0); + this->toppar_list.add_item(std::move(toppar_t)); + } + + void add_commit_item(const std::string& topic, int partition, + long long offset) + { + protocol::KafkaToppar toppar; + toppar.set_topic_partition(topic, partition); + toppar.set_offset(offset); + toppar.set_error(0); + this->toppar_list.add_item(std::move(toppar)); + } + + void set_api_type(int api_type) + { + this->api_type = api_type; + } + + int get_api_type() const + { + return this->api_type; + } + + void set_config(protocol::KafkaConfig conf) + { + this->config = std::move(conf); + } + + void set_partitioner(kafka_partitioner_t partitioner) + { + this->partitioner = std::move(partitioner); + } + + protocol::KafkaResult *get_result() + { + return &this->result; + } + + int get_kafka_error() const + { + return this->kafka_error; + } + + void set_callback(kafka_callback_t cb) + { + this->callback = std::move(cb); + } + +protected: + WFKafkaTask(int retry_max, kafka_callback_t&& cb) + { + this->callback = std::move(cb); + this->retry_max = retry_max; + this->finish = false; + } + + virtual ~WFKafkaTask() { } + + virtual SubTask *done(); + +protected: + protocol::KafkaConfig config; + protocol::KafkaTopparList toppar_list; + protocol::KafkaMetaList meta_list; + protocol::KafkaResult result; + kafka_callback_t callback; + kafka_partitioner_t partitioner; + int api_type; + int kafka_error; + int retry_max; + bool finish; + +private: + friend class WFKafkaClient; +}; + +class WFKafkaClient +{ +public: + // example: kafka://10.160.23.23:9000 + // example: kafka://kafka.sogou + // example: kafka.sogou:9090 + // example: kafka://10.160.23.23:9000,10.123.23.23,kafka://kafka.sogou + // example: kafkas://kafka.sogou -> kafka over TLS + int init(const std::string& broker_url) + { + return this->init(broker_url, NULL); + } + + int init(const std::string& broker_url, const std::string& group) + { + return this->init(broker_url, group, NULL); + } + + // With a specific SSL_CTX. Effective only on brokers over TLS. + int init(const std::string& broker_url, SSL_CTX *ssl_ctx); + + int init(const std::string& broker_url, const std::string& group, + SSL_CTX *ssl_ctx); + + int deinit(); + + // example: topic=xxx&topic=yyy&api=fetch + // example: api=commit + WFKafkaTask *create_kafka_task(const std::string& query, + int retry_max, + kafka_callback_t cb); + + WFKafkaTask *create_kafka_task(int retry_max, kafka_callback_t cb); + + void set_config(protocol::KafkaConfig conf); + +public: + /* If you don't leavegroup manually, rebalance would be triggered */ + WFKafkaTask *create_leavegroup_task(int retry_max, + kafka_callback_t callback); + +public: + protocol::KafkaMetaList *get_meta_list(); + + protocol::KafkaBrokerList *get_broker_list(); + +private: + class KafkaMember *member; + friend class KafkaClientTask; +}; + +#endif + diff --git a/src/client/WFMySQLConnection.cc b/src/client/WFMySQLConnection.cc new file mode 100644 index 0000000..ff66a57 --- /dev/null +++ b/src/client/WFMySQLConnection.cc @@ -0,0 +1,55 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include "URIParser.h" +#include "WFMySQLConnection.h" + +int WFMySQLConnection::init(const std::string& url, SSL_CTX *ssl_ctx) +{ + std::string query; + ParsedURI uri; + + if (URIParser::parse(url, uri) >= 0) + { + if (uri.query) + { + query = uri.query; + query += '&'; + } + + query += "transaction=INTERNAL_CONN_ID_" + std::to_string(this->id); + free(uri.query); + uri.query = strdup(query.c_str()); + if (uri.query) + { + this->uri = std::move(uri); + this->ssl_ctx = ssl_ctx; + return 0; + } + } + else if (uri.state == URI_STATE_INVALID) + errno = EINVAL; + + return -1; +} + diff --git a/src/client/WFMySQLConnection.h b/src/client/WFMySQLConnection.h new file mode 100644 index 0000000..3f45228 --- /dev/null +++ b/src/client/WFMySQLConnection.h @@ -0,0 +1,89 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFMYSQLCONNECTION_H_ +#define _WFMYSQLCONNECTION_H_ + +#include +#include +#include +#include +#include "URIParser.h" +#include "WFTaskFactory.h" + +class WFMySQLConnection +{ +public: + /* example: mysql://username:passwd@127.0.0.1/dbname?character_set=utf8 + * IP string is recommmended in url. When using a domain name, the first + * address resovled will be used. Don't use upstream name as a host. */ + int init(const std::string& url) + { + return this->init(url, NULL); + } + + int init(const std::string& url, SSL_CTX *ssl_ctx); + + void deinit() { } + +public: + WFMySQLTask *create_query_task(const std::string& query, + mysql_callback_t callback) + { + WFMySQLTask *task = WFTaskFactory::create_mysql_task(this->uri, 0, + std::move(callback)); + this->set_ssl_ctx(task); + task->get_req()->set_query(query); + return task; + } + + /* If you don't disconnect manually, the TCP connection will be + * kept alive after this object is deleted, and maybe reused by + * another WFMySQLConnection object with same id and url. */ + WFMySQLTask *create_disconnect_task(mysql_callback_t callback) + { + WFMySQLTask *task = this->create_query_task("", std::move(callback)); + this->set_ssl_ctx(task); + task->set_keep_alive(0); + return task; + } + +protected: + void set_ssl_ctx(WFMySQLTask *task) const + { + using MySQLRequest = protocol::MySQLRequest; + using MySQLResponse = protocol::MySQLResponse; + auto *t = (WFComplexClientTask *)task; + /* 'ssl_ctx' can be NULL and will use default. */ + t->set_ssl_ctx(this->ssl_ctx); + } + +protected: + ParsedURI uri; + SSL_CTX *ssl_ctx; + int id; + +public: + /* Make sure that concurrent connections have different id. + * When a connection object is deleted, id can be reused. */ + WFMySQLConnection(int id) { this->id = id; } + virtual ~WFMySQLConnection() { } +}; + +#endif + diff --git a/src/client/WFRedisSubscriber.cc b/src/client/WFRedisSubscriber.cc new file mode 100644 index 0000000..33828ec --- /dev/null +++ b/src/client/WFRedisSubscriber.cc @@ -0,0 +1,129 @@ +/* + Copyright (c) 2024 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include "URIParser.h" +#include "RedisTaskImpl.inl" +#include "WFRedisSubscriber.h" + +int WFRedisSubscribeTask::sync_send(const std::string& command, + const std::vector& params) +{ + std::string str("*" + std::to_string(1 + params.size()) + "\r\n"); + int ret; + + str += "$" + std::to_string(command.size()) + "\r\n" + command + "\r\n"; + for (const std::string& p : params) + str += "$" + std::to_string(p.size()) + "\r\n" + p + "\r\n"; + + this->mutex.lock(); + if (this->task) + { + ret = this->task->push(str.c_str(), str.size()); + if (ret == (int)str.size()) + ret = 0; + else + { + if (ret >= 0) + errno = ENOBUFS; + ret = -1; + } + } + else + { + errno = ENOENT; + ret = -1; + } + + this->mutex.unlock(); + return ret; +} + +void WFRedisSubscribeTask::task_extract(WFRedisTask *task) +{ + auto *t = (WFRedisSubscribeTask *)task->user_data; + + if (t->extract) + t->extract(t); +} + +void WFRedisSubscribeTask::task_callback(WFRedisTask *task) +{ + auto *t = (WFRedisSubscribeTask *)task->user_data; + + t->mutex.lock(); + t->task = NULL; + t->mutex.unlock(); + + t->state = task->get_state(); + t->error = task->get_error(); + if (t->callback) + t->callback(t); + + t->release(); +} + +int WFRedisSubscriber::init(const std::string& url, SSL_CTX *ssl_ctx) +{ + if (URIParser::parse(url, this->uri) >= 0) + { + this->ssl_ctx = ssl_ctx; + return 0; + } + + if (this->uri.state == URI_STATE_INVALID) + errno = EINVAL; + + return -1; +} + +WFRedisTask * +WFRedisSubscriber::create_redis_task(const std::string& command, + const std::vector& params) +{ + WFRedisTask *task = __WFRedisTaskFactory::create_subscribe_task(this->uri, + WFRedisSubscribeTask::task_extract, + WFRedisSubscribeTask::task_callback); + this->set_ssl_ctx(task); + task->get_req()->set_request(command, params); + return task; +} + +WFRedisSubscribeTask * +WFRedisSubscriber::create_subscribe_task( + const std::vector& channels, + extract_t extract, callback_t callback) +{ + WFRedisTask *task = this->create_redis_task("SUBSCRIBE", channels); + return new WFRedisSubscribeTask(task, std::move(extract), + std::move(callback)); +} + +WFRedisSubscribeTask * +WFRedisSubscriber::create_psubscribe_task( + const std::vector& patterns, + extract_t extract, callback_t callback) +{ + WFRedisTask *task = this->create_redis_task("PSUBSCRIBE", patterns); + return new WFRedisSubscribeTask(task, std::move(extract), + std::move(callback)); +} + diff --git a/src/client/WFRedisSubscriber.h b/src/client/WFRedisSubscriber.h new file mode 100644 index 0000000..c16ca60 --- /dev/null +++ b/src/client/WFRedisSubscriber.h @@ -0,0 +1,240 @@ +/* + Copyright (c) 2024 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFREDISSUBSCRIBER_H_ +#define _WFREDISSUBSCRIBER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "WFTask.h" +#include "WFTaskFactory.h" + +class WFRedisSubscribeTask : public WFGenericTask +{ +public: + /* Note: Call 'get_resp()' only in the 'extract' function or + before the task is started to set response size limit. */ + protocol::RedisResponse *get_resp() + { + return this->task->get_resp(); + } + +public: + /* User needs to call 'release()' exactly once, anywhere. */ + void release() + { + if (this->flag.exchange(true)) + delete this; + } + +public: + /* Note: After 'release()' is called, all the requesting functions + should not be called except in 'extract', because the task + point may have been deleted because 'callback' finished. */ + + int subscribe(const std::vector& channels) + { + return this->sync_send("SUBSCRIBE", channels); + } + + int unsubscribe(const std::vector& channels) + { + return this->sync_send("UNSUBSCRIBE", channels); + } + + int unsubscribe() + { + return this->sync_send("UNSUBSCRIBE", { }); + } + + int psubscribe(const std::vector& patterns) + { + return this->sync_send("PSUBSCRIBE", patterns); + } + + int punsubscribe(const std::vector& patterns) + { + return this->sync_send("PUNSUBSCRIBE", patterns); + } + + int punsubscribe() + { + return this->sync_send("PUNSUBSCRIBE", { }); + } + + int ping(const std::string& message) + { + return this->sync_send("PING", { message }); + } + + int ping() + { + return this->sync_send("PING", { }); + } + + int quit() + { + return this->sync_send("QUIT", { }); + } + +public: + /* All 'timeout' proxy functions can only be called only before + the task is started or in 'extract'. */ + + /* Timeout of waiting for each message. Very useful. If not set, + the max waiting time will be the global 'response_timeout'*/ + void set_watch_timeout(int timeout) + { + this->task->set_watch_timeout(timeout); + } + + /* Timeout of receiving a complete message. */ + void set_recv_timeout(int timeout) + { + this->task->set_receive_timeout(timeout); + } + + /* Timeout of sending the first subscribe request. */ + void set_send_timeout(int timeout) + { + this->task->set_send_timeout(timeout); + } + + /* The default keep alive timeout is 0. If you want to keep + the connection alive, make sure not to send any request + after all channels/patterns were unsubscribed. */ + void set_keep_alive(int timeout) + { + this->task->set_keep_alive(timeout); + } + +public: + /* Call 'set_extract' or 'set_callback' only before the task + is started, or in 'extract'. */ + + void set_extract(std::function ex) + { + this->extract = std::move(ex); + } + + void set_callback(std::function cb) + { + this->callback = std::move(cb); + } + +protected: + virtual void dispatch() + { + series_of(this)->push_front(this->task); + this->subtask_done(); + } + + virtual SubTask *done() + { + return series_of(this)->pop(); + } + +protected: + int sync_send(const std::string& command, + const std::vector& params); + static void task_extract(WFRedisTask *task); + static void task_callback(WFRedisTask *task); + +protected: + WFRedisTask *task; + std::mutex mutex; + std::atomic flag; + std::function extract; + std::function callback; + +protected: + WFRedisSubscribeTask(WFRedisTask *task, + std::function&& ex, + std::function&& cb) : + flag(false), + extract(std::move(ex)), + callback(std::move(cb)) + { + task->user_data = this; + this->task = task; + } + + virtual ~WFRedisSubscribeTask() + { + if (this->task) + this->task->dismiss(); + } + + friend class WFRedisSubscriber; +}; + +class WFRedisSubscriber +{ +public: + int init(const std::string& url) + { + return this->init(url, NULL); + } + + int init(const std::string& url, SSL_CTX *ssl_ctx); + + void deinit() { } + +public: + using extract_t = std::function; + using callback_t = std::function; + +public: + WFRedisSubscribeTask * + create_subscribe_task(const std::vector& channels, + extract_t extract, callback_t callback); + + WFRedisSubscribeTask * + create_psubscribe_task(const std::vector& patterns, + extract_t extract, callback_t callback); + +protected: + void set_ssl_ctx(WFRedisTask *task) const + { + using RedisRequest = protocol::RedisRequest; + using RedisResponse = protocol::RedisResponse; + auto *t = (WFComplexClientTask *)task; + /* 'ssl_ctx' can be NULL and will use default. */ + t->set_ssl_ctx(this->ssl_ctx); + } + +protected: + WFRedisTask *create_redis_task(const std::string& command, + const std::vector& params); + +protected: + ParsedURI uri; + SSL_CTX *ssl_ctx; + +public: + virtual ~WFRedisSubscriber() { } +}; + +#endif + diff --git a/src/client/xmake.lua b/src/client/xmake.lua new file mode 100644 index 0000000..8be59cf --- /dev/null +++ b/src/client/xmake.lua @@ -0,0 +1,23 @@ +target("client") + set_kind("object") + add_files("*.cc") + remove_files("WFKafkaClient.cc") + if not has_config("redis") then + remove_files("WFRedisSubscriber.cc") + end + if not has_config("mysql") then + remove_files("WFMySQLConnection.cc") + end + if not has_config("consul") then + remove_files("WFConsulClient.cc") + end + +target("kafka_client") + if has_config("kafka") then + add_files("WFKafkaClient.cc") + set_kind("object") + add_deps("client") + add_packages("zlib", "snappy", "zstd", "lz4") + else + set_kind("phony") + end diff --git a/src/factory/CMakeLists.txt b/src/factory/CMakeLists.txt new file mode 100644 index 0000000..c8efcaf --- /dev/null +++ b/src/factory/CMakeLists.txt @@ -0,0 +1,37 @@ +cmake_minimum_required(VERSION 3.6) +project(factory) + +set(SRC + WFGraphTask.cc + DnsTaskImpl.cc + WFTaskFactory.cc + Workflow.cc + HttpTaskImpl.cc + WFResourcePool.cc + WFMessageQueue.cc + FileTaskImpl.cc +) + +if (NOT MYSQL STREQUAL "n") + set(SRC + ${SRC} + MySQLTaskImpl.cc + ) +endif () + +if (NOT REDIS STREQUAL "n") + set(SRC + ${SRC} + RedisTaskImpl.cc + ) +endif () + +add_library(${PROJECT_NAME} OBJECT ${SRC}) + +if (KAFKA STREQUAL "y") + set(SRC + KafkaTaskImpl.cc + ) + add_library("factory_kafka" OBJECT ${SRC}) + set_property(SOURCE KafkaTaskImpl.cc APPEND PROPERTY COMPILE_OPTIONS "-fno-rtti") +endif () diff --git a/src/factory/DnsTaskImpl.cc b/src/factory/DnsTaskImpl.cc new file mode 100644 index 0000000..0ca7ea7 --- /dev/null +++ b/src/factory/DnsTaskImpl.cc @@ -0,0 +1,226 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#include +#include +#include "DnsMessage.h" +#include "WFTaskError.h" +#include "WFTaskFactory.h" +#include "WFServer.h" + +using namespace protocol; + +#define DNS_KEEPALIVE_DEFAULT (60 * 1000) + +/**********Client**********/ + +class ComplexDnsTask : public WFComplexClientTask> +{ + static struct addrinfo hints; + static std::atomic seq; + +public: + ComplexDnsTask(int retry_max, dns_callback_t&& cb): + WFComplexClientTask(retry_max, std::move(cb)) + { + this->set_transport_type(TT_UDP); + } + +protected: + virtual CommMessageOut *message_out(); + virtual bool init_success(); + virtual bool finish_once(); + +private: + bool need_redirect(); +}; + +struct addrinfo ComplexDnsTask::hints = +{ + .ai_flags = AI_NUMERICSERV | AI_NUMERICHOST, + .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM +}; + +std::atomic ComplexDnsTask::seq(0); + +CommMessageOut *ComplexDnsTask::message_out() +{ + DnsRequest *req = this->get_req(); + DnsResponse *resp = this->get_resp(); + enum TransportType type = this->get_transport_type(); + + if (req->get_id() == 0) + req->set_id(++ComplexDnsTask::seq * 99991 % 65535 + 1); + resp->set_request_id(req->get_id()); + resp->set_request_name(req->get_question_name()); + req->set_single_packet(type == TT_UDP); + resp->set_single_packet(type == TT_UDP); + + return this->WFClientTask::message_out(); +} + +bool ComplexDnsTask::init_success() +{ + if (uri_.scheme && strcasecmp(uri_.scheme, "dnss") == 0) + this->WFComplexClientTask::set_transport_type(TT_TCP_SSL); + else if (!uri_.scheme || strcasecmp(uri_.scheme, "dns") != 0) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_SCHEME_INVALID; + return false; + } + + if (!this->route_result_.request_object) + { + enum TransportType type = this->get_transport_type(); + struct addrinfo *addr; + int ret; + + ret = getaddrinfo(uri_.host, uri_.port, &hints, &addr); + if (ret != 0) + { + this->state = WFT_STATE_DNS_ERROR; + this->error = ret; + return false; + } + + auto *ep = &WFGlobal::get_global_settings()->dns_server_params; + ret = WFGlobal::get_route_manager()->get(type, addr, info_, ep, + uri_.host, ssl_ctx_, + route_result_); + freeaddrinfo(addr); + if (ret < 0) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + return false; + } + } + + return true; +} + +bool ComplexDnsTask::finish_once() +{ + if (this->state == WFT_STATE_SUCCESS) + { + if (need_redirect()) + this->set_redirect(uri_); + else if (this->state != WFT_STATE_SUCCESS) + this->disable_retry(); + } + + /* If retry times meet retry max and there is no redirect, + * we ask the client for a retry or redirect. + */ + if (retry_times_ == retry_max_ && !redirect_ && *this->get_mutable_ctx()) + { + /* Reset type to UDP before a client redirect. */ + this->set_transport_type(TT_UDP); + (*this->get_mutable_ctx())(this); + } + + return true; +} + +bool ComplexDnsTask::need_redirect() +{ + DnsResponse *client_resp = this->get_resp(); + enum TransportType type = this->get_transport_type(); + + if (type == TT_UDP && client_resp->get_tc() == 1) + { + this->set_transport_type(TT_TCP); + return true; + } + + return false; +} + +/**********Client Factory**********/ + +WFDnsTask *WFTaskFactory::create_dns_task(const std::string& url, + int retry_max, + dns_callback_t callback) +{ + ParsedURI uri; + + URIParser::parse(url, uri); + return WFTaskFactory::create_dns_task(uri, retry_max, std::move(callback)); +} + +WFDnsTask *WFTaskFactory::create_dns_task(const ParsedURI& uri, + int retry_max, + dns_callback_t callback) +{ + ComplexDnsTask *task = new ComplexDnsTask(retry_max, std::move(callback)); + const char *name; + + if (uri.path && uri.path[0] && uri.path[1]) + name = uri.path + 1; + else + name = "."; + + DnsRequest *req = task->get_req(); + req->set_question(name, DNS_TYPE_A, DNS_CLASS_IN); + + task->init(uri); + task->set_keep_alive(DNS_KEEPALIVE_DEFAULT); + return task; +} + + +/**********Server**********/ + +class WFDnsServerTask : public WFServerTask +{ +public: + WFDnsServerTask(CommService *service, + std::function& proc) : + WFServerTask(service, WFGlobal::get_scheduler(), proc) + { + this->type = ((WFServerBase *)service)->get_params()->transport_type; + } + +protected: + virtual CommMessageIn *message_in() + { + this->get_req()->set_single_packet(this->type == TT_UDP); + return this->WFServerTask::message_in(); + } + + virtual CommMessageOut *message_out() + { + this->get_resp()->set_single_packet(this->type == TT_UDP); + return this->WFServerTask::message_out(); + } + +protected: + enum TransportType type; +}; + +/**********Server Factory**********/ + +WFDnsTask *WFServerTaskFactory::create_dns_task(CommService *service, + std::function& proc) +{ + return new WFDnsServerTask(service, proc); +} + diff --git a/src/factory/FileTaskImpl.cc b/src/factory/FileTaskImpl.cc new file mode 100644 index 0000000..5d9de0a --- /dev/null +++ b/src/factory/FileTaskImpl.cc @@ -0,0 +1,400 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) + Li Jinghao (lijinghao@sogou-inc.com) +*/ + +#include +#include +#include +#include "WFGlobal.h" +#include "WFTaskFactory.h" + +class WFFilepreadTask : public WFFileIOTask +{ +public: + WFFilepreadTask(int fd, void *buf, size_t count, off_t offset, + IOService *service, fio_callback_t&& cb) : + WFFileIOTask(service, std::move(cb)) + { + this->args.fd = fd; + this->args.buf = buf; + this->args.count = count; + this->args.offset = offset; + } + +protected: + virtual int prepare() + { + this->prep_pread(this->args.fd, this->args.buf, this->args.count, + this->args.offset); + return 0; + } +}; + +class WFFilepwriteTask : public WFFileIOTask +{ +public: + WFFilepwriteTask(int fd, const void *buf, size_t count, off_t offset, + IOService *service, fio_callback_t&& cb) : + WFFileIOTask(service, std::move(cb)) + { + this->args.fd = fd; + this->args.buf = (void *)buf; + this->args.count = count; + this->args.offset = offset; + } + +protected: + virtual int prepare() + { + this->prep_pwrite(this->args.fd, this->args.buf, this->args.count, + this->args.offset); + return 0; + } +}; + +class WFFilepreadvTask : public WFFileVIOTask +{ +public: + WFFilepreadvTask(int fd, const struct iovec *iov, int iovcnt, off_t offset, + IOService *service, fvio_callback_t&& cb) : + WFFileVIOTask(service, std::move(cb)) + { + this->args.fd = fd; + this->args.iov = iov; + this->args.iovcnt = iovcnt; + this->args.offset = offset; + } + +protected: + virtual int prepare() + { + this->prep_preadv(this->args.fd, this->args.iov, this->args.iovcnt, + this->args.offset); + return 0; + } +}; + +class WFFilepwritevTask : public WFFileVIOTask +{ +public: + WFFilepwritevTask(int fd, const struct iovec *iov, int iovcnt, off_t offset, + IOService *service, fvio_callback_t&& cb) : + WFFileVIOTask(service, std::move(cb)) + { + this->args.fd = fd; + this->args.iov = iov; + this->args.iovcnt = iovcnt; + this->args.offset = offset; + } + +protected: + virtual int prepare() + { + this->prep_pwritev(this->args.fd, this->args.iov, this->args.iovcnt, + this->args.offset); + return 0; + } +}; + +class WFFilefsyncTask : public WFFileSyncTask +{ +public: + WFFilefsyncTask(int fd, IOService *service, fsync_callback_t&& cb) : + WFFileSyncTask(service, std::move(cb)) + { + this->args.fd = fd; + } + +protected: + virtual int prepare() + { + this->prep_fsync(this->args.fd); + return 0; + } +}; + +class WFFilefdsyncTask : public WFFileSyncTask +{ +public: + WFFilefdsyncTask(int fd, IOService *service, fsync_callback_t&& cb) : + WFFileSyncTask(service, std::move(cb)) + { + this->args.fd = fd; + } + +protected: + virtual int prepare() + { + this->prep_fdsync(this->args.fd); + return 0; + } +}; + +/* File tasks created with path name. */ + +class __WFFilepreadTask : public WFFilepreadTask +{ +public: + __WFFilepreadTask(const std::string& path, void *buf, size_t count, + off_t offset, IOService *service, fio_callback_t&& cb): + WFFilepreadTask(-1, buf, count, offset, service, std::move(cb)), + pathname(path) + { + } + +protected: + virtual int prepare() + { + this->args.fd = open(this->pathname.c_str(), O_RDONLY); + if (this->args.fd < 0) + return -1; + + return WFFilepreadTask::prepare(); + } + + virtual SubTask *done() + { + if (this->args.fd >= 0) + { + close(this->args.fd); + this->args.fd = -1; + } + + return WFFilepreadTask::done(); + } + +protected: + std::string pathname; +}; + +class __WFFilepwriteTask : public WFFilepwriteTask +{ +public: + __WFFilepwriteTask(const std::string& path, const void *buf, size_t count, + off_t offset, IOService *service, fio_callback_t&& cb): + WFFilepwriteTask(-1, buf, count, offset, service, std::move(cb)), + pathname(path) + { + } + +protected: + virtual int prepare() + { + this->args.fd = open(this->pathname.c_str(), O_WRONLY | O_CREAT, 0644); + if (this->args.fd < 0) + return -1; + + return WFFilepwriteTask::prepare(); + } + + virtual SubTask *done() + { + if (this->args.fd >= 0) + { + close(this->args.fd); + this->args.fd = -1; + } + + return WFFilepwriteTask::done(); + } + +protected: + std::string pathname; +}; + +class __WFFilepreadvTask : public WFFilepreadvTask +{ +public: + __WFFilepreadvTask(const std::string& path, const struct iovec *iov, + int iovcnt, off_t offset, IOService *service, + fvio_callback_t&& cb) : + WFFilepreadvTask(-1, iov, iovcnt, offset, service, std::move(cb)), + pathname(path) + { + } + +protected: + virtual int prepare() + { + this->args.fd = open(this->pathname.c_str(), O_RDONLY); + if (this->args.fd < 0) + return -1; + + return WFFilepreadvTask::prepare(); + } + + virtual SubTask *done() + { + if (this->args.fd >= 0) + { + close(this->args.fd); + this->args.fd = -1; + } + + return WFFilepreadvTask::done(); + } + +protected: + std::string pathname; +}; + +class __WFFilepwritevTask : public WFFilepwritevTask +{ +public: + __WFFilepwritevTask(const std::string& path, const struct iovec *iov, + int iovcnt, off_t offset, IOService *service, + fvio_callback_t&& cb) : + WFFilepwritevTask(-1, iov, iovcnt, offset, service, std::move(cb)), + pathname(path) + { + } + +protected: + virtual int prepare() + { + this->args.fd = open(this->pathname.c_str(), O_WRONLY | O_CREAT, 0644); + if (this->args.fd < 0) + return -1; + + return WFFilepwritevTask::prepare(); + } + +protected: + virtual SubTask *done() + { + if (this->args.fd >= 0) + { + close(this->args.fd); + this->args.fd = -1; + } + + return WFFilepwritevTask::done(); + } + +protected: + std::string pathname; +}; + +/* Factory functions with fd. */ + +WFFileIOTask *WFTaskFactory::create_pread_task(int fd, + void *buf, + size_t count, + off_t offset, + fio_callback_t callback) +{ + return new WFFilepreadTask(fd, buf, count, offset, + WFGlobal::get_io_service(), + std::move(callback)); +} + +WFFileIOTask *WFTaskFactory::create_pwrite_task(int fd, + const void *buf, + size_t count, + off_t offset, + fio_callback_t callback) +{ + return new WFFilepwriteTask(fd, buf, count, offset, + WFGlobal::get_io_service(), + std::move(callback)); +} + +WFFileVIOTask *WFTaskFactory::create_preadv_task(int fd, + const struct iovec *iovec, + int iovcnt, + off_t offset, + fvio_callback_t callback) +{ + return new WFFilepreadvTask(fd, iovec, iovcnt, offset, + WFGlobal::get_io_service(), + std::move(callback)); +} + +WFFileVIOTask *WFTaskFactory::create_pwritev_task(int fd, + const struct iovec *iovec, + int iovcnt, + off_t offset, + fvio_callback_t callback) +{ + return new WFFilepwritevTask(fd, iovec, iovcnt, offset, + WFGlobal::get_io_service(), + std::move(callback)); +} + +WFFileSyncTask *WFTaskFactory::create_fsync_task(int fd, + fsync_callback_t callback) +{ + return new WFFilefsyncTask(fd, + WFGlobal::get_io_service(), + std::move(callback)); +} + +WFFileSyncTask *WFTaskFactory::create_fdsync_task(int fd, + fsync_callback_t callback) +{ + return new WFFilefdsyncTask(fd, + WFGlobal::get_io_service(), + std::move(callback)); +} + +/* Factory functions with path name. */ + +WFFileIOTask *WFTaskFactory::create_pread_task(const std::string& path, + void *buf, + size_t count, + off_t offset, + fio_callback_t callback) +{ + return new __WFFilepreadTask(path, buf, count, offset, + WFGlobal::get_io_service(), + std::move(callback)); +} + +WFFileIOTask *WFTaskFactory::create_pwrite_task(const std::string& path, + const void *buf, + size_t count, + off_t offset, + fio_callback_t callback) +{ + return new __WFFilepwriteTask(path, buf, count, offset, + WFGlobal::get_io_service(), + std::move(callback)); +} + +WFFileVIOTask *WFTaskFactory::create_preadv_task(const std::string& path, + const struct iovec *iovec, + int iovcnt, + off_t offset, + fvio_callback_t callback) +{ + return new __WFFilepreadvTask(path, iovec, iovcnt, offset, + WFGlobal::get_io_service(), + std::move(callback)); +} + +WFFileVIOTask *WFTaskFactory::create_pwritev_task(const std::string& path, + const struct iovec *iovec, + int iovcnt, + off_t offset, + fvio_callback_t callback) +{ + return new __WFFilepwritevTask(path, iovec, iovcnt, offset, + WFGlobal::get_io_service(), + std::move(callback)); +} + diff --git a/src/factory/HttpTaskImpl.cc b/src/factory/HttpTaskImpl.cc new file mode 100644 index 0000000..5820876 --- /dev/null +++ b/src/factory/HttpTaskImpl.cc @@ -0,0 +1,1077 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) + Liu Kai (liukaidx@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) + Li Yingxin (liyingxin@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include "WFTaskError.h" +#include "WFTaskFactory.h" +#include "StringUtil.h" +#include "WFGlobal.h" +#include "HttpUtil.h" +#include "SSLWrapper.h" + +using namespace protocol; + +#define HTTP_KEEPALIVE_DEFAULT (60 * 1000) +#define HTTP_KEEPALIVE_MAX (300 * 1000) + +/**********Client**********/ + +static int __encode_auth(const char *p, std::string& auth) +{ + size_t len = strlen(p); + size_t base64_len = (len + 2) / 3 * 4; + char *base64 = (char *)malloc(base64_len + 1); + + if (!base64) + return -1; + + EVP_EncodeBlock((unsigned char *)base64, (const unsigned char *)p, len); + auth.append("Basic "); + auth.append(base64, base64_len); + + free(base64); + return 0; +} + +class ComplexHttpTask : public WFComplexClientTask +{ +public: + ComplexHttpTask(int redirect_max, + int retry_max, + http_callback_t&& callback): + WFComplexClientTask(retry_max, std::move(callback)), + redirect_max_(redirect_max), + redirect_count_(0) + { + HttpRequest *client_req = this->get_req(); + + client_req->set_method(HttpMethodGet); + client_req->set_http_version("HTTP/1.1"); + } + +protected: + virtual CommMessageOut *message_out(); + virtual CommMessageIn *message_in(); + virtual int keep_alive_timeout(); + virtual bool init_success(); + virtual void init_failed(); + virtual bool finish_once(); + +protected: + bool need_redirect(const ParsedURI& uri, ParsedURI& new_uri); + bool redirect_url(HttpResponse *client_resp, + const ParsedURI& uri, ParsedURI& new_uri); + void set_empty_request(); + void check_response(); + +private: + int redirect_max_; + int redirect_count_; +}; + +CommMessageOut *ComplexHttpTask::message_out() +{ + HttpRequest *req = this->get_req(); + struct HttpMessageHeader header; + bool is_alive; + + if (!req->is_chunked() && !req->has_content_length_header()) + { + size_t body_size = req->get_output_body_size(); + const char *method = req->get_method(); + + if (body_size != 0 || strcmp(method, "POST") == 0 || + strcmp(method, "PUT") == 0) + { + char buf[32]; + header.name = "Content-Length"; + header.name_len = strlen("Content-Length"); + header.value = buf; + header.value_len = sprintf(buf, "%zu", body_size); + req->add_header(&header); + } + } + + if (req->has_connection_header()) + is_alive = req->is_keep_alive(); + else + { + header.name = "Connection"; + header.name_len = strlen("Connection"); + is_alive = (this->keep_alive_timeo != 0); + if (is_alive) + { + header.value = "Keep-Alive"; + header.value_len = strlen("Keep-Alive"); + } + else + { + header.value = "close"; + header.value_len = strlen("close"); + } + + req->add_header(&header); + } + + if (!is_alive) + this->keep_alive_timeo = 0; + else if (req->has_keep_alive_header()) + { + HttpHeaderCursor cursor(req); + + //req---Connection: Keep-Alive + //req---Keep-Alive: timeout=0,max=100 + header.name = "Keep-Alive"; + header.name_len = strlen("Keep-Alive"); + header.value = NULL; + header.value_len = 0; + if (cursor.find(&header)) + { + std::string keep_alive((const char *)header.value, header.value_len); + std::vector params = StringUtil::split(keep_alive, ','); + + for (const auto& kv : params) + { + std::vector arr = StringUtil::split(kv, '='); + if (arr.size() < 2) + arr.emplace_back("0"); + + std::string key = StringUtil::strip(arr[0]); + std::string val = StringUtil::strip(arr[1]); + if (strcasecmp(key.c_str(), "timeout") == 0) + { + this->keep_alive_timeo = 1000 * atoi(val.c_str()); + break; + } + } + } + + if ((unsigned int)this->keep_alive_timeo > HTTP_KEEPALIVE_MAX) + this->keep_alive_timeo = HTTP_KEEPALIVE_MAX; + } + + return this->WFComplexClientTask::message_out(); +} + +CommMessageIn *ComplexHttpTask::message_in() +{ + HttpResponse *resp = this->get_resp(); + + if (strcmp(this->get_req()->get_method(), HttpMethodHead) == 0) + resp->parse_zero_body(); + + return this->WFComplexClientTask::message_in(); +} + +int ComplexHttpTask::keep_alive_timeout() +{ + return this->resp.is_keep_alive() ? this->keep_alive_timeo : 0; +} + +void ComplexHttpTask::set_empty_request() +{ + HttpRequest *client_req = this->get_req(); + HttpHeaderCursor cursor(client_req); + struct HttpMessageHeader header = { + .name = "Host", + .name_len = strlen("Host"), + }; + + client_req->set_request_uri("/"); + cursor.find_and_erase(&header); + + header.name = "Authorization"; + header.name_len = strlen("Authorization"); + cursor.find_and_erase(&header); +} + +void ComplexHttpTask::init_failed() +{ + this->set_empty_request(); +} + +bool ComplexHttpTask::init_success() +{ + HttpRequest *client_req = this->get_req(); + std::string request_uri; + std::string header_host; + bool is_ssl; + + if (uri_.scheme && strcasecmp(uri_.scheme, "http") == 0) + is_ssl = false; + else if (uri_.scheme && strcasecmp(uri_.scheme, "https") == 0) + is_ssl = true; + else + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_SCHEME_INVALID; + return false; + } + + //todo http+unix + //https://stackoverflow.com/questions/26964595/whats-the-correct-way-to-use-a-unix-domain-socket-in-requests-framework + //https://stackoverflow.com/questions/27037990/connecting-to-postgres-via-database-url-and-unix-socket-in-rails + + if (uri_.path && uri_.path[0]) + request_uri = uri_.path; + else + request_uri = "/"; + + if (uri_.query && uri_.query[0]) + { + request_uri += "?"; + request_uri += uri_.query; + } + + if (uri_.host && uri_.host[0]) + header_host = uri_.host; + + if (uri_.port && uri_.port[0]) + { + int port = atoi(uri_.port); + + if (is_ssl) + { + if (port != 443) + { + header_host += ":"; + header_host += uri_.port; + } + } + else + { + if (port != 80) + { + header_host += ":"; + header_host += uri_.port; + } + } + } + + this->WFComplexClientTask::set_transport_type(is_ssl ? TT_TCP_SSL : TT_TCP); + client_req->set_request_uri(request_uri.c_str()); + client_req->set_header_pair("Host", header_host.c_str()); + + if (uri_.userinfo && uri_.userinfo[0]) + { + std::string userinfo(uri_.userinfo); + std::string http_auth; + + StringUtil::url_decode(userinfo); + + if (__encode_auth(userinfo.c_str(), http_auth) < 0) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + return false; + } + + client_req->set_header_pair("Authorization", http_auth.c_str()); + } + + return true; +} + +bool ComplexHttpTask::redirect_url(HttpResponse *client_resp, + const ParsedURI& uri, ParsedURI& new_uri) +{ + if (redirect_count_ < redirect_max_) + { + redirect_count_++; + std::string url; + HttpHeaderCursor cursor(client_resp); + + if (!cursor.find("Location", url) || url.empty()) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_HTTP_BAD_REDIRECT_HEADER; + return false; + } + + if (url[0] == '/') + { + if (url[1] != '/') + { + if (uri.port) + url = ':' + (uri.port + url); + + url = "//" + (uri.host + url); + } + + url = uri.scheme + (':' + url); + } + + URIParser::parse(url, new_uri); + return true; + } + + return false; +} + +bool ComplexHttpTask::need_redirect(const ParsedURI& uri, ParsedURI& new_uri) +{ + HttpRequest *client_req = this->get_req(); + HttpResponse *client_resp = this->get_resp(); + const char *status_code_str = client_resp->get_status_code(); + const char *method = client_req->get_method(); + + if (!status_code_str || !method) + return false; + + int status_code = atoi(status_code_str); + + switch (status_code) + { + case 301: + case 302: + case 303: + if (redirect_url(client_resp, uri, new_uri)) + { + if (strcasecmp(method, HttpMethodGet) != 0 && + strcasecmp(method, HttpMethodHead) != 0) + { + client_req->set_method(HttpMethodGet); + } + + return true; + } + else + break; + + case 307: + case 308: + if (redirect_url(client_resp, uri, new_uri)) + return true; + else + break; + + default: + break; + } + + return false; +} + +void ComplexHttpTask::check_response() +{ + HttpResponse *resp = this->get_resp(); + + resp->end_parsing(); + if (this->state == WFT_STATE_SYS_ERROR && this->error == ECONNRESET) + { + /* Servers can end the message by closing the connection. */ + if (resp->is_header_complete() && !resp->is_chunked() && + !resp->has_content_length_header()) + { + this->state = WFT_STATE_SUCCESS; + this->error = 0; + } + } +} + +bool ComplexHttpTask::finish_once() +{ + if (this->state != WFT_STATE_SUCCESS) + this->check_response(); + + if (this->state == WFT_STATE_SUCCESS) + { + ParsedURI new_uri; + if (this->need_redirect(uri_, new_uri)) + { + if (uri_.userinfo && strcasecmp(uri_.host, new_uri.host) == 0) + { + if (!new_uri.userinfo) + { + new_uri.userinfo = uri_.userinfo; + uri_.userinfo = NULL; + } + } + else if (uri_.userinfo) + { + HttpRequest *client_req = this->get_req(); + HttpHeaderCursor cursor(client_req); + struct HttpMessageHeader header = { + .name = "Authorization", + .name_len = strlen("Authorization") + }; + + cursor.find_and_erase(&header); + } + + this->set_redirect(new_uri); + } + else if (this->state != WFT_STATE_SUCCESS) + this->disable_retry(); + } + + return true; +} + +/*******Proxy Client*******/ + +static SSL *__create_ssl(SSL_CTX *ssl_ctx) +{ + BIO *wbio; + BIO *rbio; + SSL *ssl; + + rbio = BIO_new(BIO_s_mem()); + if (rbio) + { + wbio = BIO_new(BIO_s_mem()); + if (wbio) + { + ssl = SSL_new(ssl_ctx); + if (ssl) + { + SSL_set_bio(ssl, rbio, wbio); + return ssl; + } + + BIO_free(wbio); + } + + BIO_free(rbio); + } + + return NULL; +} + +class ComplexHttpProxyTask : public ComplexHttpTask +{ +public: + ComplexHttpProxyTask(int redirect_max, + int retry_max, + http_callback_t&& callback): + ComplexHttpTask(redirect_max, retry_max, std::move(callback)), + is_user_request_(true) + { } + + void set_user_uri(ParsedURI&& uri) { user_uri_ = std::move(uri); } + void set_user_uri(const ParsedURI& uri) { user_uri_ = uri; } + + virtual const ParsedURI *get_current_uri() const { return &user_uri_; } + +protected: + virtual CommMessageOut *message_out(); + virtual CommMessageIn *message_in(); + virtual int keep_alive_timeout(); + virtual int first_timeout(); + virtual bool init_success(); + virtual bool finish_once(); + +protected: + virtual WFConnection *get_connection() const + { + WFConnection *conn = this->ComplexHttpTask::get_connection(); + + if (conn && is_ssl_) + return (SSLConnection *)conn->get_context(); + + return conn; + } + +private: + struct SSLConnection : public WFConnection + { + SSL *ssl; + SSLHandshaker handshaker; + SSLWrapper wrapper; + SSLConnection(SSL *ssl) : handshaker(ssl), wrapper(&wrapper, ssl) + { + this->ssl = ssl; + } + }; + + SSLHandshaker *get_ssl_handshaker() const + { + return &((SSLConnection *)this->get_connection())->handshaker; + } + + SSLWrapper *get_ssl_wrapper(ProtocolMessage *msg) const + { + SSLConnection *conn = (SSLConnection *)this->get_connection(); + conn->wrapper = SSLWrapper(msg, conn->ssl); + return &conn->wrapper; + } + + int init_ssl_connection(); + + std::string proxy_auth_; + ParsedURI user_uri_; + bool is_ssl_; + bool is_user_request_; + short state_; + int error_; +}; + +int ComplexHttpProxyTask::init_ssl_connection() +{ + static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx(); + SSL *ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx); + WFConnection *conn; + + if (!ssl) + return -1; + + SSL_set_tlsext_host_name(ssl, user_uri_.host); + SSL_set_connect_state(ssl); + + conn = this->ComplexHttpTask::get_connection(); + SSLConnection *ssl_conn = new SSLConnection(ssl); + + auto&& deleter = [] (void *ctx) + { + SSLConnection *ssl_conn = (SSLConnection *)ctx; + SSL_free(ssl_conn->ssl); + delete ssl_conn; + }; + conn->set_context(ssl_conn, std::move(deleter)); + return 0; +} + +CommMessageOut *ComplexHttpProxyTask::message_out() +{ + long long seqid = this->get_seq(); + + if (seqid == 0) // CONNECT + { + HttpRequest *conn_req = new HttpRequest; + std::string request_uri(user_uri_.host); + + request_uri += ":"; + if (user_uri_.port) + request_uri += user_uri_.port; + else + request_uri += is_ssl_ ? "443" : "80"; + + conn_req->set_method("CONNECT"); + conn_req->set_request_uri(request_uri); + conn_req->set_http_version("HTTP/1.1"); + conn_req->add_header_pair("Host", request_uri.c_str()); + + if (!proxy_auth_.empty()) + conn_req->add_header_pair("Proxy-Authorization", proxy_auth_); + + is_user_request_ = false; + return conn_req; + } + else if (seqid == 1 && is_ssl_) // HANDSHAKE + { + is_user_request_ = false; + return get_ssl_handshaker(); + } + + auto *msg = (ProtocolMessage *)this->ComplexHttpTask::message_out(); + return is_ssl_ ? get_ssl_wrapper(msg) : msg; +} + +CommMessageIn *ComplexHttpProxyTask::message_in() +{ + long long seqid = this->get_seq(); + + if (seqid == 0) + { + HttpResponse *conn_resp = new HttpResponse; + conn_resp->parse_zero_body(); + return conn_resp; + } + else if (seqid == 1 && is_ssl_) + return get_ssl_handshaker(); + + auto *msg = (ProtocolMessage *)this->ComplexHttpTask::message_in(); + return is_ssl_ ? get_ssl_wrapper(msg) : msg; +} + +int ComplexHttpProxyTask::keep_alive_timeout() +{ + long long seqid = this->get_seq(); + + state_ = WFT_STATE_SUCCESS; + error_ = 0; + if (seqid == 0) + { + HttpResponse *resp = this->get_resp(); + const char *code_str; + int status_code; + + *resp = std::move(*(HttpResponse *)this->get_message_in()); + code_str = resp->get_status_code(); + status_code = code_str ? atoi(code_str) : 0; + + switch (status_code) + { + case 200: + break; + case 407: + this->disable_retry(); + default: + state_ = WFT_STATE_TASK_ERROR; + error_ = WFT_ERR_HTTP_PROXY_CONNECT_FAILED; + return 0; + } + + this->clear_resp(); + + if (is_ssl_ && init_ssl_connection() < 0) + { + state_ = WFT_STATE_SYS_ERROR; + error_ = errno; + return 0; + } + + return HTTP_KEEPALIVE_DEFAULT; + } + else if (seqid == 1 && is_ssl_) + return HTTP_KEEPALIVE_DEFAULT; + + return this->ComplexHttpTask::keep_alive_timeout(); +} + +int ComplexHttpProxyTask::first_timeout() +{ + return is_user_request_ ? this->watch_timeo : 0; +} + +bool ComplexHttpProxyTask::init_success() +{ + if (!uri_.scheme || strcasecmp(uri_.scheme, "http") != 0) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_SCHEME_INVALID; + return false; + } + + if (user_uri_.state == URI_STATE_ERROR) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = uri_.error; + return false; + } + else if (user_uri_.state != URI_STATE_SUCCESS) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_PARSE_FAILED; + return false; + } + + if (user_uri_.scheme && strcasecmp(user_uri_.scheme, "http") == 0) + is_ssl_ = false; + else if (user_uri_.scheme && strcasecmp(user_uri_.scheme, "https") == 0) + is_ssl_ = true; + else + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_SCHEME_INVALID; + return false; + } + + int user_port; + if (user_uri_.port) + { + user_port = atoi(user_uri_.port); + if (user_port <= 0 || user_port > 65535) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_PORT_INVALID; + return false; + } + } + else + user_port = is_ssl_ ? 443 : 80; + + std::string info("http-proxy|remote:"); + info += is_ssl_ ? "https://" : "http://"; + info += user_uri_.host; + info += ":"; + if (user_uri_.port) + info += user_uri_.port; + else + info += is_ssl_ ? "443" : "80"; + + if (uri_.userinfo && uri_.userinfo[0]) + { + std::string userinfo(uri_.userinfo); + + StringUtil::url_decode(userinfo); + proxy_auth_.clear(); + + if (__encode_auth(userinfo.c_str(), proxy_auth_) < 0) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + return false; + } + + info += "|auth:"; + info += proxy_auth_; + } + + this->WFComplexClientTask::set_info(info); + + std::string request_uri; + std::string header_host; + + if (user_uri_.path && user_uri_.path[0]) + request_uri = user_uri_.path; + else + request_uri = "/"; + + if (user_uri_.query && user_uri_.query[0]) + { + request_uri += "?"; + request_uri += user_uri_.query; + } + + if (user_uri_.host && user_uri_.host[0]) + header_host = user_uri_.host; + + if ((is_ssl_ && user_port != 443) || (!is_ssl_ && user_port != 80)) + { + header_host += ":"; + header_host += uri_.port; + } + + HttpRequest *client_req = this->get_req(); + client_req->set_request_uri(request_uri.c_str()); + client_req->set_header_pair("Host", header_host.c_str()); + this->WFComplexClientTask::set_transport_type(TT_TCP); + + if (user_uri_.userinfo && user_uri_.userinfo[0]) + { + std::string userinfo(user_uri_.userinfo); + std::string http_auth; + + StringUtil::url_decode(userinfo); + + if (__encode_auth(userinfo.c_str(), http_auth) < 0) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + return false; + } + + client_req->set_header_pair("Authorization", http_auth.c_str()); + } + + return true; +} + +bool ComplexHttpProxyTask::finish_once() +{ + if (!is_user_request_) + { + if (this->state == WFT_STATE_SUCCESS && state_ != WFT_STATE_SUCCESS) + { + this->state = state_; + this->error = error_; + } + + if (this->get_seq() == 0) + { + delete this->get_message_in(); + delete this->get_message_out(); + } + + is_user_request_ = true; + return false; + } + + if (this->state != WFT_STATE_SUCCESS) + this->check_response(); + + if (this->state == WFT_STATE_SUCCESS) + { + ParsedURI new_uri; + if (this->need_redirect(user_uri_, new_uri)) + { + if (user_uri_.userinfo && + strcasecmp(user_uri_.host, new_uri.host) == 0) + { + if (!new_uri.userinfo) + { + new_uri.userinfo = user_uri_.userinfo; + user_uri_.userinfo = NULL; + } + } + else if (user_uri_.userinfo) + { + HttpRequest *client_req = this->get_req(); + HttpHeaderCursor cursor(client_req); + struct HttpMessageHeader header = { + .name = "Authorization", + .name_len = strlen("Authorization") + }; + + cursor.find_and_erase(&header); + } + + user_uri_ = std::move(new_uri); + this->set_redirect(uri_); + } + else if (this->state != WFT_STATE_SUCCESS) + this->disable_retry(); + } + + return true; +} + +/**********Client Factory**********/ + +WFHttpTask *WFTaskFactory::create_http_task(const std::string& url, + int redirect_max, + int retry_max, + http_callback_t callback) +{ + auto *task = new ComplexHttpTask(redirect_max, + retry_max, + std::move(callback)); + ParsedURI uri; + + URIParser::parse(url, uri); + task->init(std::move(uri)); + task->set_keep_alive(HTTP_KEEPALIVE_DEFAULT); + return task; +} + +WFHttpTask *WFTaskFactory::create_http_task(const ParsedURI& uri, + int redirect_max, + int retry_max, + http_callback_t callback) +{ + auto *task = new ComplexHttpTask(redirect_max, + retry_max, + std::move(callback)); + + task->init(uri); + task->set_keep_alive(HTTP_KEEPALIVE_DEFAULT); + return task; +} + +WFHttpTask *WFTaskFactory::create_http_task(const std::string& url, + const std::string& proxy_url, + int redirect_max, + int retry_max, + http_callback_t callback) +{ + auto *task = new ComplexHttpProxyTask(redirect_max, + retry_max, + std::move(callback)); + + ParsedURI uri, user_uri; + URIParser::parse(url, user_uri); + URIParser::parse(proxy_url, uri); + + task->set_user_uri(std::move(user_uri)); + task->set_keep_alive(HTTP_KEEPALIVE_DEFAULT); + task->init(std::move(uri)); + return task; +} + +WFHttpTask *WFTaskFactory::create_http_task(const ParsedURI& uri, + const ParsedURI& proxy_uri, + int redirect_max, + int retry_max, + http_callback_t callback) +{ + auto *task = new ComplexHttpProxyTask(redirect_max, + retry_max, + std::move(callback)); + + task->set_user_uri(uri); + task->set_keep_alive(HTTP_KEEPALIVE_DEFAULT); + task->init(proxy_uri); + return task; +} + +/**********Server**********/ + +class WFHttpServerTask : public WFServerTask +{ +private: + using TASK = WFNetworkTask; + +public: + WFHttpServerTask(CommService *service, std::function& proc) : + WFServerTask(service, WFGlobal::get_scheduler(), proc), + req_is_alive_(false), + req_has_keep_alive_header_(false) + {} + +protected: + virtual void handle(int state, int error); + virtual CommMessageOut *message_out(); + +protected: + bool req_is_alive_; + bool req_has_keep_alive_header_; + std::string req_keep_alive_; +}; + +void WFHttpServerTask::handle(int state, int error) +{ + if (state == WFT_STATE_TOREPLY) + { + req_is_alive_ = this->req.is_keep_alive(); + if (req_is_alive_ && this->req.has_keep_alive_header()) + { + HttpHeaderCursor cursor(&this->req); + struct HttpMessageHeader header = { + .name = "Keep-Alive", + .name_len = strlen("Keep-Alive"), + }; + + req_has_keep_alive_header_ = cursor.find(&header); + if (req_has_keep_alive_header_) + { + req_keep_alive_.assign((const char *)header.value, + header.value_len); + } + } + } + + this->WFServerTask::handle(state, error); +} + +CommMessageOut *WFHttpServerTask::message_out() +{ + HttpResponse *resp = this->get_resp(); + struct HttpMessageHeader header; + + if (!resp->get_http_version()) + resp->set_http_version("HTTP/1.1"); + + const char *status_code_str = resp->get_status_code(); + if (!status_code_str || !resp->get_reason_phrase()) + { + int status_code; + + if (status_code_str) + status_code = atoi(status_code_str); + else + status_code = HttpStatusOK; + + HttpUtil::set_response_status(resp, status_code); + } + + if (!resp->is_chunked() && !resp->has_content_length_header()) + { + char buf[32]; + header.name = "Content-Length"; + header.name_len = strlen("Content-Length"); + header.value = buf; + header.value_len = sprintf(buf, "%zu", resp->get_output_body_size()); + resp->add_header(&header); + } + + bool is_alive; + + if (resp->has_connection_header()) + is_alive = resp->is_keep_alive(); + else + is_alive = req_is_alive_; + + if (!is_alive) + this->keep_alive_timeo = 0; + else + { + //req---Connection: Keep-Alive + //req---Keep-Alive: timeout=5,max=100 + + if (req_has_keep_alive_header_) + { + int flag = 0; + std::vector params = StringUtil::split(req_keep_alive_, ','); + + for (const auto& kv : params) + { + std::vector arr = StringUtil::split(kv, '='); + if (arr.size() < 2) + arr.emplace_back("0"); + + std::string key = StringUtil::strip(arr[0]); + std::string val = StringUtil::strip(arr[1]); + if (!(flag & 1) && strcasecmp(key.c_str(), "timeout") == 0) + { + flag |= 1; + // keep_alive_timeo = 5000ms when Keep-Alive: timeout=5 + this->keep_alive_timeo = 1000 * atoi(val.c_str()); + if (flag == 3) + break; + } + else if (!(flag & 2) && strcasecmp(key.c_str(), "max") == 0) + { + flag |= 2; + if (this->get_seq() >= atoi(val.c_str())) + { + this->keep_alive_timeo = 0; + break; + } + + if (flag == 3) + break; + } + } + } + + if ((unsigned int)this->keep_alive_timeo > HTTP_KEEPALIVE_MAX) + this->keep_alive_timeo = HTTP_KEEPALIVE_MAX; + //if (this->keep_alive_timeo < 0 || this->keep_alive_timeo > HTTP_KEEPALIVE_MAX) + + } + + if (!resp->has_connection_header()) + { + header.name = "Connection"; + header.name_len = 10; + if (this->keep_alive_timeo == 0) + { + header.value = "close"; + header.value_len = 5; + } + else + { + header.value = "Keep-Alive"; + header.value_len = 10; + } + + resp->add_header(&header); + } + + return this->WFServerTask::message_out(); +} + +/**********Server Factory**********/ + +WFHttpTask *WFServerTaskFactory::create_http_task(CommService *service, + std::function& process) +{ + return new WFHttpServerTask(service, process); +} + diff --git a/src/factory/KafkaTaskImpl.cc b/src/factory/KafkaTaskImpl.cc new file mode 100644 index 0000000..6b87e78 --- /dev/null +++ b/src/factory/KafkaTaskImpl.cc @@ -0,0 +1,784 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhulei (wangzhulei@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include "StringUtil.h" +#include "KafkaTaskImpl.inl" + +using namespace protocol; + +#define KAFKA_KEEPALIVE_DEFAULT (60 * 1000) +#define KAFKA_ROUNDTRIP_TIMEOUT (5 * 1000) + +static KafkaCgroup __create_cgroup(const KafkaCgroup *c) +{ + KafkaCgroup g; + const char *member_id = c->get_member_id(); + + if (member_id) + g.set_member_id(member_id); + + g.set_group(c->get_group()); + + return g; +} + +/**********Client**********/ + +class __ComplexKafkaTask : public WFComplexClientTask +{ +public: + __ComplexKafkaTask(int retry_max, __kafka_callback_t&& callback) : + WFComplexClientTask(retry_max, std::move(callback)) + { + is_user_request_ = true; + is_redirect_ = false; + ctx_ = 0; + } + +protected: + virtual CommMessageOut *message_out(); + virtual CommMessageIn *message_in(); + virtual bool init_success(); + virtual bool finish_once(); + +private: + struct KafkaConnectionInfo + { + kafka_api_t api; + kafka_sasl_t sasl; + std::string mechanisms; + + KafkaConnectionInfo() + { + kafka_api_init(&this->api); + kafka_sasl_init(&this->sasl); + } + + ~KafkaConnectionInfo() + { + kafka_api_deinit(&this->api); + kafka_sasl_deinit(&this->sasl); + } + + bool init(const char *mechanisms) + { + this->mechanisms = mechanisms; + + if (strncasecmp(mechanisms, "SCRAM", 5) == 0) + { + if (strcasecmp(mechanisms, "SCRAM-SHA-1") == 0) + { + this->sasl.scram.evp = EVP_sha1(); + this->sasl.scram.scram_h = SHA1; + this->sasl.scram.scram_h_size = SHA_DIGEST_LENGTH; + } + else if (strcasecmp(mechanisms, "SCRAM-SHA-256") == 0) + { + this->sasl.scram.evp = EVP_sha256(); + this->sasl.scram.scram_h = SHA256; + this->sasl.scram.scram_h_size = SHA256_DIGEST_LENGTH; + } + else if (strcasecmp(mechanisms, "SCRAM-SHA-512") == 0) + { + this->sasl.scram.evp = EVP_sha512(); + this->sasl.scram.scram_h = SHA512; + this->sasl.scram.scram_h_size = SHA512_DIGEST_LENGTH; + } + else + return false; + } + + return true; + } + }; + + virtual int keep_alive_timeout(); + virtual int first_timeout(); + bool has_next(); + bool process_produce(); + bool process_fetch(); + bool process_metadata(); + bool process_list_offsets(); + bool process_find_coordinator(); + bool process_join_group(); + bool process_sync_group(); + bool process_sasl_authenticate(); + bool process_sasl_handshake(); + + bool is_user_request_; + bool is_redirect_; + std::string user_info_; +}; + +CommMessageOut *__ComplexKafkaTask::message_out() +{ + long long seqid = this->get_seq(); + if (seqid == 0) + { + KafkaConnectionInfo *conn_info = new KafkaConnectionInfo; + + this->get_req()->set_api(&conn_info->api); + this->get_connection()->set_context(conn_info, [](void *ctx) { + delete (KafkaConnectionInfo *)ctx; + }); + + if (!this->get_req()->get_config()->get_broker_version()) + { + KafkaRequest *req = new KafkaRequest; + req->duplicate(*this->get_req()); + req->set_api_type(Kafka_ApiVersions); + is_user_request_ = false; + return req; + } + else + { + kafka_api_version_t *api; + size_t api_cnt; + const char *v = this->get_req()->get_config()->get_broker_version(); + int ret = kafka_api_version_is_queryable(v, &api, &api_cnt); + kafka_api_version_t *p = NULL; + + if (ret == 0) + { + p = (kafka_api_version_t *)malloc(api_cnt * sizeof(*p)); + if (p) + { + memcpy(p, api, api_cnt * sizeof(kafka_api_version_t)); + conn_info->api.api = p; + conn_info->api.elements = api_cnt; + conn_info->api.features = kafka_get_features(p, api_cnt); + } + } + + if (!p) + return NULL; + + seqid++; + } + } + + if (seqid == 1) + { + const char *sasl_mech = this->get_req()->get_config()->get_sasl_mech(); + KafkaConnectionInfo *conn_info = + (KafkaConnectionInfo *)this->get_connection()->get_context(); + if (sasl_mech && conn_info->sasl.status == 0) + { + if (!conn_info->init(sasl_mech)) + return NULL; + + this->get_req()->set_api(&conn_info->api); + this->get_req()->set_sasl(&conn_info->sasl); + + KafkaRequest *req = new KafkaRequest; + req->duplicate(*this->get_req()); + if (conn_info->api.features & KAFKA_FEATURE_SASL_HANDSHAKE) + req->set_api_type(Kafka_SaslHandshake); + else + req->set_api_type(Kafka_SaslAuthenticate); + req->set_correlation_id(1); + is_user_request_ = false; + return req; + } + } + + KafkaConnectionInfo *conn_info = + (KafkaConnectionInfo *)this->get_connection()->get_context(); + KafkaRequest *req = this->get_req(); + req->set_api(&conn_info->api); + + if (req->get_api_type() == Kafka_Fetch || + req->get_api_type() == Kafka_ListOffsets) + { + KafkaTopparList *req_toppar_lst = req->get_toppar_list(); + KafkaToppar *toppar; + KafkaTopparList toppar_list; + bool flag = false; + long long cfg_ts = req->get_config()->get_offset_timestamp(); + long long tp_ts; + + req_toppar_lst->rewind(); + while ((toppar = req_toppar_lst->get_next()) != NULL) + { + tp_ts = toppar->get_offset_timestamp(); + if (tp_ts == KAFKA_TIMESTAMP_UNINIT) + tp_ts = cfg_ts; + + if (toppar->get_offset() == KAFKA_OFFSET_UNINIT) + { + if (tp_ts == KAFKA_TIMESTAMP_EARLIEST) + toppar->set_offset(toppar->get_low_watermark()); + else if (tp_ts < 0) + { + toppar->set_offset(toppar->get_high_watermark()); + tp_ts = KAFKA_TIMESTAMP_LATEST; + } + } + else if (toppar->get_offset() == KAFKA_OFFSET_OVERFLOW) + { + if (tp_ts == KAFKA_TIMESTAMP_EARLIEST) + toppar->set_offset(toppar->get_low_watermark()); + else + { + toppar->set_offset(toppar->get_high_watermark()); + tp_ts = KAFKA_TIMESTAMP_LATEST; + } + } + + if (toppar->get_offset() < 0) + { + toppar->set_offset_timestamp(tp_ts); + toppar_list.add_item(*toppar); + flag = true; + } + } + + if (flag) + { + KafkaRequest *new_req = new KafkaRequest; + new_req->set_api(&conn_info->api); + new_req->set_broker(*req->get_broker()); + new_req->set_toppar_list(toppar_list); + new_req->set_config(*req->get_config()); + new_req->set_api_type(Kafka_ListOffsets); + new_req->set_correlation_id(seqid); + is_user_request_ = false; + return new_req; + } + } + + this->get_req()->set_correlation_id(seqid); + return this->WFComplexClientTask::message_out(); +} + +CommMessageIn *__ComplexKafkaTask::message_in() +{ + KafkaRequest *req = static_cast(this->get_message_out()); + KafkaResponse *resp = this->get_resp(); + KafkaCgroup *cgroup; + + resp->set_api_type(req->get_api_type()); + resp->set_api_version(req->get_api_version()); + resp->duplicate(*req); + + switch (req->get_api_type()) + { + case Kafka_FindCoordinator: + case Kafka_Heartbeat: + cgroup = req->get_cgroup(); + if (cgroup->get_group()) + resp->set_cgroup(__create_cgroup(cgroup)); + break; + default: + break; + } + + return this->WFComplexClientTask::message_in(); +} + +bool __ComplexKafkaTask::init_success() +{ + enum TransportType type; + + if (uri_.scheme && strcasecmp(uri_.scheme, "kafka") == 0) + type = TT_TCP; + else if (uri_.scheme && strcasecmp(uri_.scheme, "kafkas") == 0) + type = TT_TCP_SSL; + else + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_SCHEME_INVALID; + return false; + } + + std::string username, password, sasl, client; + if (uri_.userinfo) + { + const char *pos = strchr(uri_.userinfo, ':'); + if (pos) + { + username = std::string(uri_.userinfo, pos - uri_.userinfo); + StringUtil::url_decode(username); + const char *pos1 = strchr(pos + 1, ':'); + if (pos1) + { + password = std::string(pos + 1, pos1 - pos - 1); + StringUtil::url_decode(password); + const char *pos2 = strchr(pos1 + 1, ':'); + if (pos2) + { + sasl = std::string(pos1 + 1, pos2 - pos1 - 1); + client = std::string(pos1 + 1); + } + } + } + + if (username.empty() || password.empty() || sasl.empty() || client.empty()) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_SCHEME_INVALID; + return false; + } + + user_info_ = uri_.userinfo; + size_t info_len = username.size() + password.size() + sasl.size() + + client.size() + 50; + char *info = new char[info_len]; + + snprintf(info, info_len, "%s|user:%s|pass:%s|sasl:%s|client:%s|", "kafka", + username.c_str(), password.c_str(), sasl.c_str(), client.c_str()); + + this->WFComplexClientTask::set_info(info); + delete []info; + } + + this->WFComplexClientTask::set_transport_type(type); + return true; +} + +int __ComplexKafkaTask::keep_alive_timeout() +{ + if (this->get_resp()->get_broker()->get_error()) + return 0; + + return this->WFComplexClientTask::keep_alive_timeout(); +} + +int __ComplexKafkaTask::first_timeout() +{ + KafkaRequest *client_req = this->get_req(); + int ret = 0; + + switch(client_req->get_api_type()) + { + case Kafka_Fetch: + ret = client_req->get_config()->get_fetch_timeout(); + break; + + case Kafka_JoinGroup: + ret = client_req->get_config()->get_session_timeout(); + break; + + case Kafka_SyncGroup: + ret = client_req->get_config()->get_rebalance_timeout(); + break; + + case Kafka_Produce: + ret = client_req->get_config()->get_produce_timeout(); + break; + + default: + return 0; + } + + return ret + KAFKA_ROUNDTRIP_TIMEOUT; +} + +bool __ComplexKafkaTask::process_find_coordinator() +{ + KafkaCgroup *cgroup = this->get_resp()->get_cgroup(); + ctx_ = cgroup->get_error(); + if (ctx_) + { + this->error = WFT_ERR_KAFKA_CGROUP_FAILED; + this->state = WFT_STATE_TASK_ERROR; + return false; + } + else + { + this->get_req()->set_cgroup(*cgroup); + KafkaBroker *coordinator = cgroup->get_coordinator(); + std::string url(uri_.scheme); + url += "://"; + url += user_info_ + "@"; + url += coordinator->get_host(); + url += ":" + std::to_string(coordinator->get_port()); + + ParsedURI uri; + URIParser::parse(url, uri); + set_redirect(std::move(uri)); + this->get_req()->set_api_type(Kafka_JoinGroup); + is_redirect_ = true; + return true; + } +} + +bool __ComplexKafkaTask::process_join_group() +{ + KafkaResponse *msg = this->get_resp(); + switch(msg->get_cgroup()->get_error()) + { + case KAFKA_MEMBER_ID_REQUIRED: + this->get_req()->set_api_type(Kafka_JoinGroup); + break; + + case KAFKA_UNKNOWN_MEMBER_ID: + msg->get_cgroup()->set_member_id(""); + this->get_req()->set_api_type(Kafka_JoinGroup); + break; + + case 0: + this->get_req()->set_api_type(Kafka_Metadata); + break; + + default: + ctx_ = msg->get_cgroup()->get_error(); + this->error = WFT_ERR_KAFKA_CGROUP_FAILED; + this->state = WFT_STATE_TASK_ERROR; + return false; + } + + return true; +} + +bool __ComplexKafkaTask::process_sync_group() +{ + ctx_ = this->get_resp()->get_cgroup()->get_error(); + if (ctx_) + { + this->error = WFT_ERR_KAFKA_CGROUP_FAILED; + this->state = WFT_STATE_TASK_ERROR; + return false; + } + else + { + this->get_req()->set_api_type(Kafka_OffsetFetch); + return true; + } +} + +bool __ComplexKafkaTask::process_metadata() +{ + KafkaResponse *msg = this->get_resp(); + msg->get_meta_list()->rewind(); + KafkaMeta *meta; + while ((meta = msg->get_meta_list()->get_next()) != NULL) + { + switch (meta->get_error()) + { + case KAFKA_LEADER_NOT_AVAILABLE: + this->get_req()->set_api_type(Kafka_Metadata); + return true; + case 0: + break; + default: + ctx_ = meta->get_error(); + this->error = WFT_ERR_KAFKA_META_FAILED; + this->state = WFT_STATE_TASK_ERROR; + return false; + } + } + + this->get_req()->set_meta_list(*msg->get_meta_list()); + if (msg->get_cgroup()->get_group()) + { + if (msg->get_cgroup()->is_leader()) + { + KafkaCgroup *cgroup = msg->get_cgroup(); + if (cgroup->run_assignor(msg->get_meta_list(), + cgroup->get_protocol_name()) < 0) + { + this->error = WFT_ERR_KAFKA_CGROUP_ASSIGN_FAILED; + this->state = WFT_STATE_TASK_ERROR; + return false; + } + } + + this->get_req()->set_api_type(Kafka_SyncGroup); + return true; + } + + return false; +} + +bool __ComplexKafkaTask::process_fetch() +{ + bool ret = false; + KafkaToppar *toppar; + this->get_resp()->get_toppar_list()->rewind(); + while ((toppar = this->get_resp()->get_toppar_list()->get_next()) != NULL) + { + if (toppar->get_error() == KAFKA_OFFSET_OUT_OF_RANGE) + { + toppar->set_offset(KAFKA_OFFSET_OVERFLOW); + toppar->set_low_watermark(KAFKA_OFFSET_UNINIT); + toppar->set_high_watermark(KAFKA_OFFSET_UNINIT); + ret = true; + } + + switch (toppar->get_error()) + { + case KAFKA_UNKNOWN_TOPIC_OR_PARTITION: + case KAFKA_LEADER_NOT_AVAILABLE: + case KAFKA_NOT_LEADER_FOR_PARTITION: + case KAFKA_BROKER_NOT_AVAILABLE: + case KAFKA_REPLICA_NOT_AVAILABLE: + case KAFKA_KAFKA_STORAGE_ERROR: + case KAFKA_FENCED_LEADER_EPOCH: + this->get_req()->set_api_type(Kafka_Metadata); + return true; + case 0: + case KAFKA_OFFSET_OUT_OF_RANGE: + break; + default: + ctx_ = toppar->get_error(); + this->error = WFT_ERR_KAFKA_FETCH_FAILED; + this->state = WFT_STATE_TASK_ERROR; + return false; + } + } + return ret; +} + +bool __ComplexKafkaTask::process_list_offsets() +{ + KafkaToppar *toppar; + this->get_resp()->get_toppar_list()->rewind(); + while ((toppar = this->get_resp()->get_toppar_list()->get_next()) != NULL) + { + if (toppar->get_error()) + { + this->error = toppar->get_error(); + this->state = WFT_STATE_TASK_ERROR; + } + } + return false; +} + +bool __ComplexKafkaTask::process_produce() +{ + KafkaToppar *toppar; + this->get_resp()->get_toppar_list()->rewind(); + while ((toppar = this->get_resp()->get_toppar_list()->get_next()) != NULL) + { + if (!toppar->record_reach_end()) + { + this->get_req()->set_api_type(Kafka_Produce); + return true; + } + + switch (toppar->get_error()) + { + case KAFKA_UNKNOWN_TOPIC_OR_PARTITION: + case KAFKA_LEADER_NOT_AVAILABLE: + case KAFKA_NOT_LEADER_FOR_PARTITION: + case KAFKA_BROKER_NOT_AVAILABLE: + case KAFKA_REPLICA_NOT_AVAILABLE: + case KAFKA_KAFKA_STORAGE_ERROR: + case KAFKA_FENCED_LEADER_EPOCH: + this->get_req()->set_api_type(Kafka_Metadata); + return true; + case 0: + break; + default: + this->error = toppar->get_error(); + this->state = WFT_STATE_TASK_ERROR; + return false; + } + } + return false; +} + +bool __ComplexKafkaTask::process_sasl_handshake() +{ + ctx_ = this->get_resp()->get_broker()->get_error(); + if (ctx_) + { + this->error = WFT_ERR_KAFKA_SASL_DISALLOWED; + this->state = WFT_STATE_TASK_ERROR; + return false; + } + return true; +} + +bool __ComplexKafkaTask::process_sasl_authenticate() +{ + ctx_ = this->get_resp()->get_broker()->get_error(); + if (ctx_) + { + this->error = WFT_ERR_KAFKA_SASL_DISALLOWED; + this->state = WFT_STATE_TASK_ERROR; + } + return false; +} + +bool __ComplexKafkaTask::has_next() +{ + switch (this->get_resp()->get_api_type()) + { + case Kafka_Produce: + return this->process_produce(); + case Kafka_Fetch: + return this->process_fetch(); + case Kafka_Metadata: + return this->process_metadata(); + case Kafka_FindCoordinator: + return this->process_find_coordinator(); + case Kafka_JoinGroup: + return this->process_join_group(); + case Kafka_SyncGroup: + return this->process_sync_group(); + case Kafka_SaslHandshake: + return this->process_sasl_handshake(); + case Kafka_SaslAuthenticate: + return this->process_sasl_authenticate(); + case Kafka_ListOffsets: + return this->process_list_offsets(); + case Kafka_OffsetCommit: + case Kafka_OffsetFetch: + case Kafka_LeaveGroup: + case Kafka_DescribeGroups: + case Kafka_Heartbeat: + ctx_ = this->get_resp()->get_cgroup()->get_error(); + if (ctx_) + { + this->error = WFT_ERR_KAFKA_CGROUP_FAILED; + this->state = WFT_STATE_TASK_ERROR; + } + + break; + case Kafka_ApiVersions: + break; + default: + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_KAFKA_API_UNKNOWN; + break; + } + return false; +} + +bool __ComplexKafkaTask::finish_once() +{ + bool finish = true; + if (this->state == WFT_STATE_SUCCESS) + finish = !has_next(); + + if (!is_user_request_) + { + delete this->get_message_out(); + this->get_resp()->clear_buf(); + } + + if (is_redirect_ && this->state == WFT_STATE_UNDEFINED) + { + this->get_req()->clear_buf(); + is_redirect_ = false; + } + else if (this->state == WFT_STATE_SUCCESS) + { + if (!is_user_request_) + { + is_user_request_ = true; + return false; + } + + if (!finish) + { + this->get_req()->clear_buf(); + this->get_resp()->clear_buf(); + return false; + } + } + else + { + this->get_resp()->set_api_type(this->get_req()->get_api_type()); + this->get_resp()->set_api_version(this->get_req()->get_api_version()); + } + + is_user_request_ = true; + return true; +} + +/**********Factory**********/ +// kafka://user:password:sasl@host:port/api=type&topic=name +__WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const std::string& url, + SSL_CTX *ssl_ctx, + int retry_max, + __kafka_callback_t callback) +{ + auto *task = new __ComplexKafkaTask(retry_max, std::move(callback)); + task->set_ssl_ctx(ssl_ctx); + + ParsedURI uri; + URIParser::parse(url, uri); + task->init(std::move(uri)); + task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT); + return task; +} + +__WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const ParsedURI& uri, + SSL_CTX *ssl_ctx, + int retry_max, + __kafka_callback_t callback) +{ + auto *task = new __ComplexKafkaTask(retry_max, std::move(callback)); + task->set_ssl_ctx(ssl_ctx); + + task->init(uri); + task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT); + return task; +} + +__WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(enum TransportType type, + const char *host, + unsigned short port, + SSL_CTX *ssl_ctx, + const std::string& info, + int retry_max, + __kafka_callback_t callback) +{ + auto *task = new __ComplexKafkaTask(retry_max, std::move(callback)); + task->set_ssl_ctx(ssl_ctx); + + ParsedURI uri; + char buf[32]; + + if (type == TT_TCP_SSL) + uri.scheme = strdup("kafkas"); + else + uri.scheme = strdup("kafka"); + + if (!info.empty()) + uri.userinfo = strdup(info.c_str()); + + uri.host = strdup(host); + sprintf(buf, "%u", port); + uri.port = strdup(buf); + + if (!uri.scheme || !uri.host || !uri.port || + (!info.empty() && !uri.userinfo)) + { + uri.state = URI_STATE_ERROR; + uri.error = errno; + } + else + uri.state = URI_STATE_SUCCESS; + + task->init(std::move(uri)); + task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT); + return task; +} + diff --git a/src/factory/KafkaTaskImpl.inl b/src/factory/KafkaTaskImpl.inl new file mode 100644 index 0000000..83ec4c7 --- /dev/null +++ b/src/factory/KafkaTaskImpl.inl @@ -0,0 +1,53 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhulei (wangzhulei@sogou-inc.com) +*/ + +#include +#include "WFTaskFactory.h" +#include "KafkaMessage.h" + +// Kafka internal task. For __ComplexKafkaTask usage only +using __WFKafkaTask = WFNetworkTask; +using __kafka_callback_t = std::function; + +class __WFKafkaTaskFactory +{ +public: + /* __WFKafkaTask is create by __ComplexKafkaTask. This is an internal + * interface for create internal task. It should not be created directly by common + * user task. + */ + static __WFKafkaTask *create_kafka_task(const ParsedURI& uri, + SSL_CTX *ssl_ctx, + int retry_max, + __kafka_callback_t callback); + + static __WFKafkaTask *create_kafka_task(const std::string& url, + SSL_CTX *ssl_ctx, + int retry_max, + __kafka_callback_t callback); + + static __WFKafkaTask *create_kafka_task(enum TransportType type, + const char *host, + unsigned short port, + SSL_CTX *ssl_ctx, + const std::string& info, + int retry_max, + __kafka_callback_t callback); +}; + diff --git a/src/factory/MySQLTaskImpl.cc b/src/factory/MySQLTaskImpl.cc new file mode 100644 index 0000000..ad3dd23 --- /dev/null +++ b/src/factory/MySQLTaskImpl.cc @@ -0,0 +1,855 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) + Li Yingxin (liyingxin@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include "WFTaskError.h" +#include "WFTaskFactory.h" +#include "StringUtil.h" +#include "WFGlobal.h" +#include "mysql_types.h" + +using namespace protocol; + +#define MYSQL_KEEPALIVE_DEFAULT (60 * 1000) +#define MYSQL_KEEPALIVE_TRANSACTION (3600 * 1000) + +/**********Client**********/ + +class ComplexMySQLTask : public WFComplexClientTask +{ +protected: + virtual bool check_request(); + virtual CommMessageOut *message_out(); + virtual CommMessageIn *message_in(); + virtual int keep_alive_timeout(); + virtual int first_timeout(); + virtual bool init_success(); + virtual bool finish_once(); + +protected: + virtual WFConnection *get_connection() const + { + WFConnection *conn = this->WFComplexClientTask::get_connection(); + + if (conn) + { + void *ctx = conn->get_context(); + if (ctx) + conn = (WFConnection *)ctx; + } + + return conn; + } + +private: + enum ConnState + { + ST_SSL_REQUEST, + ST_AUTH_REQUEST, + ST_AUTH_SWITCH_REQUEST, + ST_CLEAR_PASSWORD_REQUEST, + ST_SHA256_PUBLIC_KEY_REQUEST, + ST_CSHA2_PUBLIC_KEY_REQUEST, + ST_RSA_AUTH_REQUEST, + ST_CHARSET_REQUEST, + ST_FIRST_USER_REQUEST, + ST_USER_REQUEST + }; + + struct MyConnection : public WFConnection + { + std::string str; // shared by auth, auth_swich and rsa_auth requests + unsigned char seed[20]; + enum ConnState state; + unsigned char mysql_seqid; + SSL *ssl; + SSLWrapper wrapper; + MyConnection(SSL *ssl) : wrapper(&wrapper, ssl) + { + this->ssl = ssl; + } + }; + + int check_handshake(MySQLHandshakeResponse *resp); + int auth_switch(MySQLAuthResponse *resp, MyConnection *conn); + + struct MySSLWrapper : public SSLWrapper + { + MySSLWrapper(ProtocolMessage *msg, SSL *ssl) : + SSLWrapper(msg, ssl) + { } + ProtocolMessage *get_msg() const { return this->message; } + virtual ~MySSLWrapper() { delete this->message; } + }; + +private: + std::string username_; + std::string password_; + std::string db_; + std::string res_charset_; + short character_set_; + short state_; + int error_; + bool is_ssl_; + bool is_user_request_; + +public: + ComplexMySQLTask(int retry_max, mysql_callback_t&& callback): + WFComplexClientTask(retry_max, std::move(callback)), + character_set_(33), + is_user_request_(true) + {} +}; + +bool ComplexMySQLTask::check_request() +{ + if (this->req.query_is_unset() == false) + { + if (this->req.get_command() == MYSQL_COM_QUERY) + { + std::string query = this->req.get_query(); + + if (strncasecmp(query.c_str(), "USE ", 4) && + strncasecmp(query.c_str(), "SET NAMES ", 10) && + strncasecmp(query.c_str(), "SET CHARSET ", 12) && + strncasecmp(query.c_str(), "SET CHARACTER SET ", 18)) + { + return true; + } + } + + this->error = WFT_ERR_MYSQL_COMMAND_DISALLOWED; + } + else + this->error = WFT_ERR_MYSQL_QUERY_NOT_SET; + + this->state = WFT_STATE_TASK_ERROR; + return false; +} + +static SSL *__create_ssl(SSL_CTX *ssl_ctx) +{ + BIO *wbio; + BIO *rbio; + SSL *ssl; + + rbio = BIO_new(BIO_s_mem()); + if (rbio) + { + wbio = BIO_new(BIO_s_mem()); + if (wbio) + { + ssl = SSL_new(ssl_ctx); + if (ssl) + { + SSL_set_bio(ssl, rbio, wbio); + return ssl; + } + + BIO_free(wbio); + } + + BIO_free(rbio); + } + + return NULL; +} + +CommMessageOut *ComplexMySQLTask::message_out() +{ + MySQLAuthSwitchRequest *auth_switch_req; + MySQLRSAAuthRequest *rsa_auth_req; + MySQLAuthRequest *auth_req; + MySQLRequest *req; + + is_user_request_ = false; + if (this->get_seq() == 0) + return new MySQLHandshakeRequest; + + auto *conn = (MyConnection *)this->get_connection(); + switch (conn->state) + { + case ST_SSL_REQUEST: + req = new MySQLSSLRequest(character_set_, conn->ssl); + req->set_seqid(conn->mysql_seqid); + return req; + + case ST_AUTH_REQUEST: + req = new MySQLAuthRequest; + auth_req = (MySQLAuthRequest *)req; + auth_req->set_auth(username_, password_, db_, character_set_); + auth_req->set_auth_plugin_name(std::move(conn->str)); + auth_req->set_seed(conn->seed); + break; + + case ST_CLEAR_PASSWORD_REQUEST: + conn->str = "mysql_clear_password"; + case ST_AUTH_SWITCH_REQUEST: + req = new MySQLAuthSwitchRequest; + auth_switch_req = (MySQLAuthSwitchRequest *)req; + auth_switch_req->set_password(password_); + auth_switch_req->set_auth_plugin_name(std::move(conn->str)); + auth_switch_req->set_seed(conn->seed); +#if OPENSSL_VERSION_NUMBER < 0x10100000L + WFGlobal::get_ssl_client_ctx(); +#endif + break; + + case ST_SHA256_PUBLIC_KEY_REQUEST: + req = new MySQLPublicKeyRequest; + ((MySQLPublicKeyRequest *)req)->set_sha256(); + break; + + case ST_CSHA2_PUBLIC_KEY_REQUEST: + req = new MySQLPublicKeyRequest; + ((MySQLPublicKeyRequest *)req)->set_caching_sha2(); + break; + + case ST_RSA_AUTH_REQUEST: + req = new MySQLRSAAuthRequest; + rsa_auth_req = (MySQLRSAAuthRequest *)req; + rsa_auth_req->set_password(password_); + rsa_auth_req->set_public_key(std::move(conn->str)); + rsa_auth_req->set_seed(conn->seed); + break; + + case ST_CHARSET_REQUEST: + req = new MySQLRequest; + req->set_query("SET NAMES " + res_charset_); + break; + + case ST_FIRST_USER_REQUEST: + if (this->is_fixed_conn()) + { + auto *target = (RouteManager::RouteTarget *)this->target; + + /* If it's a transaction task, generate a ECONNRESET error when + * the target was reconnected. */ + if (target->state) + { + is_user_request_ = true; + errno = ECONNRESET; + return NULL; + } + + target->state = 1; + } + + case ST_USER_REQUEST: + is_user_request_ = true; + req = (MySQLRequest *)this->WFComplexClientTask::message_out(); + break; + + default: + assert(0); + return NULL; + } + + if (!is_user_request_ && conn->state != ST_CHARSET_REQUEST) + req->set_seqid(conn->mysql_seqid); + + if (!is_ssl_) + return req; + + if (is_user_request_) + { + conn->wrapper = SSLWrapper(req, conn->ssl); + return &conn->wrapper; + } + else + return new MySSLWrapper(req, conn->ssl); +} + +CommMessageIn *ComplexMySQLTask::message_in() +{ + MySQLResponse *resp; + + if (this->get_seq() == 0) + return new MySQLHandshakeResponse; + + auto *conn = (MyConnection *)this->get_connection(); + switch (conn->state) + { + case ST_SSL_REQUEST: + return new SSLHandshaker(conn->ssl); + + case ST_AUTH_REQUEST: + case ST_AUTH_SWITCH_REQUEST: + resp = new MySQLAuthResponse; + break; + + case ST_CLEAR_PASSWORD_REQUEST: + case ST_RSA_AUTH_REQUEST: + resp = new MySQLResponse; + break; + + case ST_SHA256_PUBLIC_KEY_REQUEST: + case ST_CSHA2_PUBLIC_KEY_REQUEST: + resp = new MySQLPublicKeyResponse; + break; + + case ST_CHARSET_REQUEST: + resp = new MySQLResponse; + break; + + case ST_FIRST_USER_REQUEST: + case ST_USER_REQUEST: + resp = (MySQLResponse *)this->WFComplexClientTask::message_in(); + break; + + default: + assert(0); + return NULL; + } + + if (!is_ssl_) + return resp; + + if (is_user_request_) + { + conn->wrapper = SSLWrapper(resp, conn->ssl); + return &conn->wrapper; + } + else + return new MySSLWrapper(resp, conn->ssl); +} + +int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp) +{ + SSL *ssl = NULL; + + if (resp->host_disallowed()) + { + this->resp = std::move(*(MySQLResponse *)resp); + state_ = WFT_STATE_TASK_ERROR; + error_ = WFT_ERR_MYSQL_HOST_NOT_ALLOWED; + return 0; + } + + if (is_ssl_) + { + if (resp->get_capability_flags() & 0x800) + { + static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx(); + + ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx); + if (!ssl) + { + state_ = WFT_STATE_SYS_ERROR; + error_ = errno; + return 0; + } + + SSL_set_connect_state(ssl); + } + else + { + this->resp = std::move(*(MySQLResponse *)resp); + state_ = WFT_STATE_TASK_ERROR; + error_ = WFT_ERR_MYSQL_SSL_NOT_SUPPORTED; + return 0; + } + + } + + auto *conn = this->get_connection(); + auto *my_conn = new MyConnection(ssl); + + my_conn->str = resp->get_auth_plugin_name(); + if (!password_.empty() && my_conn->str == "sha256_password") + my_conn->str = "caching_sha2_password"; + + resp->get_seed(my_conn->seed); + my_conn->state = is_ssl_ ? ST_SSL_REQUEST : ST_AUTH_REQUEST; + my_conn->mysql_seqid = resp->get_seqid() + 1; + conn->set_context(my_conn, [](void *ctx) { + auto *my_conn = (MyConnection *)ctx; + if (my_conn->ssl) + SSL_free(my_conn->ssl); + delete my_conn; + }); + + return MYSQL_KEEPALIVE_DEFAULT; +} + +int ComplexMySQLTask::auth_switch(MySQLAuthResponse *resp, MyConnection *conn) +{ + std::string name = resp->get_auth_plugin_name(); + + if (conn->state != ST_AUTH_REQUEST || + (name == "mysql_clear_password" && !is_ssl_)) + { + state_ = WFT_STATE_SYS_ERROR; + error_ = EBADMSG; + return 0; + } + + if (password_.empty()) + { + conn->state = ST_CLEAR_PASSWORD_REQUEST; + } + else if (name == "sha256_password") + { + if (is_ssl_) + conn->state = ST_CLEAR_PASSWORD_REQUEST; + else + conn->state = ST_SHA256_PUBLIC_KEY_REQUEST; + } + else + { + conn->str = std::move(name); + conn->state = ST_AUTH_SWITCH_REQUEST; + } + + resp->get_seed(conn->seed); + conn->mysql_seqid = resp->get_seqid() + 1; + return MYSQL_KEEPALIVE_DEFAULT; +} + +int ComplexMySQLTask::keep_alive_timeout() +{ + auto *msg = (ProtocolMessage *)this->get_message_in(); + MySQLAuthResponse *auth_resp; + MySQLResponse *resp; + + state_ = WFT_STATE_SUCCESS; + error_ = 0; + if (this->get_seq() == 0) + return check_handshake((MySQLHandshakeResponse *)msg); + + auto *conn = (MyConnection *)this->get_connection(); + if (conn->state == ST_SSL_REQUEST) + { + conn->state = ST_AUTH_REQUEST; + conn->mysql_seqid++; + return MYSQL_KEEPALIVE_DEFAULT; + } + + if (is_ssl_) + resp = (MySQLResponse *)((MySSLWrapper *)msg)->get_msg(); + else + resp = (MySQLResponse *)msg; + + switch (conn->state) + { + case ST_AUTH_REQUEST: + case ST_AUTH_SWITCH_REQUEST: + case ST_CLEAR_PASSWORD_REQUEST: + case ST_RSA_AUTH_REQUEST: + if (resp->is_ok_packet()) + { + if (!res_charset_.empty()) + conn->state = ST_CHARSET_REQUEST; + else + conn->state = ST_FIRST_USER_REQUEST; + + break; + } + + if (resp->is_error_packet() || + conn->state == ST_CLEAR_PASSWORD_REQUEST || + conn->state == ST_RSA_AUTH_REQUEST) + { + this->resp = std::move(*resp); + state_ = WFT_STATE_TASK_ERROR; + error_ = WFT_ERR_MYSQL_ACCESS_DENIED; + return 0; + } + + auth_resp = (MySQLAuthResponse *)resp; + if (auth_resp->is_continue()) + { + if (is_ssl_) + conn->state = ST_CLEAR_PASSWORD_REQUEST; + else + conn->state = ST_CSHA2_PUBLIC_KEY_REQUEST; + + break; + } + + return auth_switch(auth_resp, conn); + + case ST_SHA256_PUBLIC_KEY_REQUEST: + case ST_CSHA2_PUBLIC_KEY_REQUEST: + conn->str = ((MySQLPublicKeyResponse *)resp)->get_public_key(); + conn->state = ST_RSA_AUTH_REQUEST; + break; + + case ST_CHARSET_REQUEST: + if (!resp->is_ok_packet()) + { + this->resp = std::move(*resp); + state_ = WFT_STATE_TASK_ERROR; + error_ = WFT_ERR_MYSQL_INVALID_CHARACTER_SET; + return 0; + } + + conn->state = ST_FIRST_USER_REQUEST; + return MYSQL_KEEPALIVE_DEFAULT; + + case ST_FIRST_USER_REQUEST: + conn->state = ST_USER_REQUEST; + case ST_USER_REQUEST: + return this->keep_alive_timeo; + + default: + assert(0); + return 0; + } + + conn->mysql_seqid = resp->get_seqid() + 1; + return MYSQL_KEEPALIVE_DEFAULT; +} + +int ComplexMySQLTask::first_timeout() +{ + return is_user_request_ ? this->watch_timeo : 0; +} + +/* ++--------------------+---------------------+-----+ +| CHARACTER_SET_NAME | COLLATION_NAME | ID | ++--------------------+---------------------+-----+ +| big5 | big5_chinese_ci | 1 | +| dec8 | dec8_swedish_ci | 3 | +| cp850 | cp850_general_ci | 4 | +| hp8 | hp8_english_ci | 6 | +| koi8r | koi8r_general_ci | 7 | +| latin1 | latin1_swedish_ci | 8 | +| latin2 | latin2_general_ci | 9 | +| swe7 | swe7_swedish_ci | 10 | +| ascii | ascii_general_ci | 11 | +| ujis | ujis_japanese_ci | 12 | +| sjis | sjis_japanese_ci | 13 | +| hebrew | hebrew_general_ci | 16 | +| tis620 | tis620_thai_ci | 18 | +| euckr | euckr_korean_ci | 19 | +| koi8u | koi8u_general_ci | 22 | +| gb2312 | gb2312_chinese_ci | 24 | +| greek | greek_general_ci | 25 | +| cp1250 | cp1250_general_ci | 26 | +| gbk | gbk_chinese_ci | 28 | +| latin5 | latin5_turkish_ci | 30 | +| armscii8 | armscii8_general_ci | 32 | +| utf8 | utf8_general_ci | 33 | +| ucs2 | ucs2_general_ci | 35 | +| cp866 | cp866_general_ci | 36 | +| keybcs2 | keybcs2_general_ci | 37 | +| macce | macce_general_ci | 38 | +| macroman | macroman_general_ci | 39 | +| cp852 | cp852_general_ci | 40 | +| latin7 | latin7_general_ci | 41 | +| cp1251 | cp1251_general_ci | 51 | +| utf16 | utf16_general_ci | 54 | +| utf16le | utf16le_general_ci | 56 | +| cp1256 | cp1256_general_ci | 57 | +| cp1257 | cp1257_general_ci | 59 | +| utf32 | utf32_general_ci | 60 | +| binary | binary | 63 | +| geostd8 | geostd8_general_ci | 92 | +| cp932 | cp932_japanese_ci | 95 | +| eucjpms | eucjpms_japanese_ci | 97 | +| gb18030 | gb18030_chinese_ci | 248 | +| utf8mb4 | utf8mb4_0900_ai_ci | 255 | ++--------------------+---------------------+-----+ +*/ + +static int __mysql_get_character_set(const std::string& charset) +{ + static std::unordered_map charset_map = { + {"big5", 1}, + {"dec8", 3}, + {"cp850", 4}, + {"hp8", 5}, + {"koi8r", 6}, + {"latin1", 7}, + {"latin2", 8}, + {"swe7", 10}, + {"ascii", 11}, + {"ujis", 12}, + {"sjis", 13}, + {"hebrew", 16}, + {"tis620", 18}, + {"euckr", 19}, + {"koi8u", 22}, + {"gb2312", 24}, + {"greek", 25}, + {"cp1250", 26}, + {"gbk", 28}, + {"latin5", 30}, + {"armscii8",32}, + {"utf8", 33}, + {"ucs2", 35}, + {"cp866", 36}, + {"keybcs2", 37}, + {"macce", 38}, + {"macroman",39}, + {"cp852", 40}, + {"latin7", 41}, + {"cp1251", 51}, + {"utf16", 54}, + {"utf16le", 56}, + {"cp1256", 57}, + {"cp1257", 59}, + {"utf32", 60}, + {"binary", 63}, + {"geostd8", 92}, + {"cp932", 95}, + {"eucjpms", 97}, + {"gb18030", 248}, + {"utf8mb4", 255}, + }; + + const auto it = charset_map.find(charset); + + if (it != charset_map.cend()) + return it->second; + + return -1; +} + +bool ComplexMySQLTask::init_success() +{ + if (uri_.scheme && strcasecmp(uri_.scheme, "mysql") == 0) + is_ssl_ = false; + else if (uri_.scheme && strcasecmp(uri_.scheme, "mysqls") == 0) + is_ssl_ = true; + else + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_SCHEME_INVALID; + return false; + } + + //todo mysql+unix + username_.clear(); + password_.clear(); + db_.clear(); + if (uri_.userinfo) + { + const char *colon = NULL; + const char *pos = uri_.userinfo; + + while (*pos && *pos != ':') + pos++; + + if (*pos == ':') + colon = pos++; + + if (colon) + { + if (colon > uri_.userinfo) + { + username_.assign(uri_.userinfo, colon - uri_.userinfo); + StringUtil::url_decode(username_); + } + + if (*pos) + { + password_.assign(pos); + StringUtil::url_decode(password_); + } + } + else + { + username_.assign(uri_.userinfo); + StringUtil::url_decode(username_); + } + } + + if (uri_.path && uri_.path[0] == '/' && uri_.path[1]) + { + db_.assign(uri_.path + 1); + StringUtil::url_decode(db_); + } + + std::string transaction; + + if (uri_.query) + { + auto query_kv = URIParser::split_query(uri_.query); + + for (auto& kv : query_kv) + { + if (strcasecmp(kv.first.c_str(), "transaction") == 0) + transaction = std::move(kv.second); + else if (strcasecmp(kv.first.c_str(), "character_set") == 0) + { + character_set_ = __mysql_get_character_set(kv.second); + if (character_set_ < 0) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_MYSQL_INVALID_CHARACTER_SET; + return false; + } + } + else if (strcasecmp(kv.first.c_str(), "character_set_results") == 0) + res_charset_ = std::move(kv.second); + } + } + + size_t info_len = username_.size() + password_.size() + db_.size() + + res_charset_.size() + 50; + char *info = new char[info_len]; + + snprintf(info, info_len, "%s|user:%s|pass:%s|db:%s|" + "charset:%d|rcharset:%s", + is_ssl_ ? "mysqls" : "mysql", username_.c_str(), password_.c_str(), + db_.c_str(), character_set_, res_charset_.c_str()); + this->WFComplexClientTask::set_transport_type(TT_TCP); + + if (!transaction.empty()) + { + this->set_fixed_addr(true); + this->set_fixed_conn(true); + this->WFComplexClientTask::set_info(info + ("|txn:" + transaction)); + } + else + this->WFComplexClientTask::set_info(info); + + delete []info; + return true; +} + +bool ComplexMySQLTask::finish_once() +{ + if (!is_user_request_) + { + delete this->get_message_out(); + delete this->get_message_in(); + + if (this->state == WFT_STATE_SUCCESS && state_ != WFT_STATE_SUCCESS) + { + this->state = state_; + this->error = error_; + this->disable_retry(); + } + + is_user_request_ = true; + return false; + } + + if (this->is_fixed_conn()) + { + if (this->state != WFT_STATE_SUCCESS || this->keep_alive_timeo == 0) + { + if (this->target) + ((RouteManager::RouteTarget *)this->target)->state = 0; + } + } + + return true; +} + +/**********Client Factory**********/ + +// mysql://user:password@host:port/db_name +// url = "mysql://admin:123456@192.168.1.101:3301/test" +// url = "mysql://127.0.0.1:3306" +WFMySQLTask *WFTaskFactory::create_mysql_task(const std::string& url, + int retry_max, + mysql_callback_t callback) +{ + auto *task = new ComplexMySQLTask(retry_max, std::move(callback)); + ParsedURI uri; + + URIParser::parse(url, uri); + task->init(std::move(uri)); + if (task->is_fixed_conn()) + task->set_keep_alive(MYSQL_KEEPALIVE_TRANSACTION); + else + task->set_keep_alive(MYSQL_KEEPALIVE_DEFAULT); + + return task; +} + +WFMySQLTask *WFTaskFactory::create_mysql_task(const ParsedURI& uri, + int retry_max, + mysql_callback_t callback) +{ + auto *task = new ComplexMySQLTask(retry_max, std::move(callback)); + + task->init(uri); + if (task->is_fixed_conn()) + task->set_keep_alive(MYSQL_KEEPALIVE_TRANSACTION); + else + task->set_keep_alive(MYSQL_KEEPALIVE_DEFAULT); + + return task; +} + +/**********Server**********/ + +class WFMySQLServerTask : public WFServerTask +{ +public: + WFMySQLServerTask(CommService *service, + std::function& proc): + WFServerTask(service, WFGlobal::get_scheduler(), proc) + {} + +protected: + virtual SubTask *done(); + virtual CommMessageOut *message_out(); + virtual CommMessageIn *message_in(); +}; + +SubTask *WFMySQLServerTask::done() +{ + if (this->get_seq() == 0) + delete this->get_message_in(); + + return this->WFServerTask::done(); +} + +CommMessageOut *WFMySQLServerTask::message_out() +{ + long long seqid = this->get_seq(); + + if (seqid == 0) + this->resp.set_ok_packet(); // always success + + return this->WFServerTask::message_out(); +} + +CommMessageIn *WFMySQLServerTask::message_in() +{ + long long seqid = this->get_seq(); + + if (seqid == 0) + return new MySQLAuthRequest; + + return this->WFServerTask::message_in(); +} + +/**********Server Factory**********/ + +WFMySQLTask *WFServerTaskFactory::create_mysql_task(CommService *service, + std::function& process) +{ + return new WFMySQLServerTask(service, process); +} + diff --git a/src/factory/RedisTaskImpl.cc b/src/factory/RedisTaskImpl.cc new file mode 100644 index 0000000..b32d339 --- /dev/null +++ b/src/factory/RedisTaskImpl.cc @@ -0,0 +1,446 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) + Li Yingxin (liyingxin@sogou-inc.com) + Liu Kai (liukaidx@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include "PackageWrapper.h" +#include "WFTaskError.h" +#include "WFTaskFactory.h" +#include "StringUtil.h" +#include "RedisTaskImpl.inl" + +using namespace protocol; + +#define REDIS_KEEPALIVE_DEFAULT (60 * 1000) +#define REDIS_REDIRECT_MAX 3 + +/**********Client**********/ + +class ComplexRedisTask : public WFComplexClientTask +{ +public: + ComplexRedisTask(int retry_max, redis_callback_t&& callback): + WFComplexClientTask(retry_max, std::move(callback)), + db_num_(0), + is_user_request_(true), + redirect_count_(0) + {} + +protected: + virtual bool check_request(); + virtual CommMessageOut *message_out(); + virtual CommMessageIn *message_in(); + virtual int keep_alive_timeout(); + virtual int first_timeout(); + virtual bool init_success(); + virtual bool finish_once(); + +protected: + bool need_redirect(); + + std::string username_; + std::string password_; + int db_num_; + bool succ_; + bool is_user_request_; + int redirect_count_; +}; + +bool ComplexRedisTask::check_request() +{ + std::string command; + + if (this->req.get_command(command) && + (strcasecmp(command.c_str(), "AUTH") == 0 || + strcasecmp(command.c_str(), "SELECT") == 0 || + strcasecmp(command.c_str(), "RESET") == 0 || + strcasecmp(command.c_str(), "ASKING") == 0)) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_REDIS_COMMAND_DISALLOWED; + return false; + } + + return true; +} + +CommMessageOut *ComplexRedisTask::message_out() +{ + long long seqid = this->get_seq(); + + if (seqid <= 1) + { + if (seqid == 0 && (!password_.empty() || !username_.empty())) + { + auto *auth_req = new RedisRequest; + + if (!username_.empty()) + auth_req->set_request("AUTH", {username_, password_}); + else + auth_req->set_request("AUTH", {password_}); + + succ_ = false; + is_user_request_ = false; + return auth_req; + } + + if (db_num_ > 0 && + (seqid == 0 || !password_.empty() || !username_.empty())) + { + auto *select_req = new RedisRequest; + char buf[32]; + + sprintf(buf, "%d", db_num_); + select_req->set_request("SELECT", {buf}); + + succ_ = false; + is_user_request_ = false; + return select_req; + } + } + + return this->WFComplexClientTask::message_out(); +} + +CommMessageIn *ComplexRedisTask::message_in() +{ + RedisRequest *req = this->get_req(); + RedisResponse *resp = this->get_resp(); + + if (is_user_request_) + resp->set_asking(req->is_asking()); + else + resp->set_asking(false); + + return this->WFComplexClientTask::message_in(); +} + +int ComplexRedisTask::keep_alive_timeout() +{ + if (this->is_user_request_) + return this->keep_alive_timeo; + + RedisResponse *resp = this->get_resp(); + + succ_ = (resp->parse_success() && + resp->result_ptr()->type != REDIS_REPLY_TYPE_ERROR); + + return succ_ ? REDIS_KEEPALIVE_DEFAULT : 0; +} + +int ComplexRedisTask::first_timeout() +{ + return is_user_request_ ? this->watch_timeo : 0; +} + +bool ComplexRedisTask::init_success() +{ + enum TransportType type; + + if (uri_.scheme && strcasecmp(uri_.scheme, "redis") == 0) + type = TT_TCP; + else if (uri_.scheme && strcasecmp(uri_.scheme, "rediss") == 0) + type = TT_TCP_SSL; + else + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_SCHEME_INVALID; + return false; + } + + //todo redis+unix + //https://stackoverflow.com/questions/26964595/whats-the-correct-way-to-use-a-unix-domain-socket-in-requests-framework + //https://stackoverflow.com/questions/27037990/connecting-to-postgres-via-database-url-and-unix-socket-in-rails + + if (uri_.userinfo) + { + char *p = strchr(uri_.userinfo, ':'); + if (p) + { + username_.assign(uri_.userinfo, p); + password_.assign(p + 1); + StringUtil::url_decode(username_); + StringUtil::url_decode(password_); + } + else + { + username_.assign(uri_.userinfo); + StringUtil::url_decode(username_); + } + } + + if (uri_.path && uri_.path[0] == '/' && uri_.path[1]) + db_num_ = atoi(uri_.path + 1); + + size_t info_len = username_.size() + password_.size() + 32 + 32; + char *info = new char[info_len]; + + sprintf(info, "redis|user:%s|pass:%s|db:%d", username_.c_str(), + password_.c_str(), db_num_); + this->WFComplexClientTask::set_transport_type(type); + this->WFComplexClientTask::set_info(info); + + delete []info; + return true; +} + +bool ComplexRedisTask::need_redirect() +{ + RedisRequest *client_req = this->get_req(); + RedisResponse *client_resp = this->get_resp(); + redis_reply_t *reply = client_resp->result_ptr(); + + if (reply->type == REDIS_REPLY_TYPE_ERROR) + { + if (reply->str == NULL) + return false; + + if (strncasecmp(reply->str, "NOAUTH ", 7) == 0) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_REDIS_ACCESS_DENIED; + return false; + } + + bool asking = false; + if (strncasecmp(reply->str, "ASK ", 4) == 0) + asking = true; + else if (strncasecmp(reply->str, "MOVED ", 6) != 0) + return false; + + if (redirect_count_ >= REDIS_REDIRECT_MAX) + return false; + + std::string err_str(reply->str, reply->len); + auto split_result = StringUtil::split_filter_empty(err_str, ' '); + if (split_result.size() == 3) + { + client_req->set_asking(asking); + + // format: COMMAND SLOT HOSTPORT + // example: MOVED/ASK 123 127.0.0.1:6379 + std::string& hostport = split_result[2]; + redirect_count_++; + + ParsedURI uri; + std::string url; + url.append(uri_.scheme); + url.append("://"); + url.append(hostport); + + URIParser::parse(url, uri); + std::swap(uri.host, uri_.host); + std::swap(uri.port, uri_.port); + std::swap(uri.state, uri_.state); + std::swap(uri.error, uri_.error); + + return true; + } + } + + return false; +} + +bool ComplexRedisTask::finish_once() +{ + if (!is_user_request_) + { + is_user_request_ = true; + delete this->get_message_out(); + + if (this->state == WFT_STATE_SUCCESS) + { + if (succ_) + this->clear_resp(); + else + { + this->disable_retry(); + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_REDIS_ACCESS_DENIED; + } + } + + return false; + } + + if (this->state == WFT_STATE_SUCCESS) + { + if (need_redirect()) + this->set_redirect(uri_); + else if (this->state != WFT_STATE_SUCCESS) + this->disable_retry(); + } + + return true; +} + +/****** Redis Subscribe ******/ + +class ComplexRedisSubscribeTask : public ComplexRedisTask +{ +public: + virtual int push(const void *buf, size_t size) + { + if (finished_) + { + errno = ENOENT; + return -1; + } + + if (!watching_) + { + errno = EAGAIN; + return -1; + } + + return this->scheduler->push(buf, size, this); + } + +protected: + virtual CommMessageIn *message_in() + { + if (!is_user_request_) + return this->ComplexRedisTask::message_in(); + + return &wrapper_; + } + + virtual int first_timeout() + { + return watching_ ? this->watch_timeo : 0; + } + +protected: + class SubscribeWrapper : public PackageWrapper + { + protected: + virtual ProtocolMessage *next_in(ProtocolMessage *message); + + protected: + ComplexRedisSubscribeTask *task_; + + public: + SubscribeWrapper(ComplexRedisSubscribeTask *task) : + PackageWrapper(task->get_resp()) + { + task_ = task; + } + }; + +protected: + SubscribeWrapper wrapper_; + bool watching_; + bool finished_; + std::function extract_; + +public: + ComplexRedisSubscribeTask(std::function&& extract, + redis_callback_t&& callback) : + ComplexRedisTask(0, std::move(callback)), + wrapper_(this), + extract_(std::move(extract)) + { + watching_ = false; + finished_ = false; + } +}; + +ProtocolMessage * +ComplexRedisSubscribeTask::SubscribeWrapper::next_in(ProtocolMessage *message) +{ + redis_reply_t *reply = ((RedisResponse *)message)->result_ptr(); + + if (reply->type != REDIS_REPLY_TYPE_ARRAY) + { + task_->finished_ = true; + return NULL; + } + + if (reply->elements == 3 && + reply->element[2]->type == REDIS_REPLY_TYPE_INTEGER && + reply->element[2]->integer == 0) + { + task_->finished_ = true; + } + + task_->watching_ = true; + task_->extract_(task_); + + task_->clear_resp(); + return task_->finished_ ? NULL : &task_->resp; +} + +/**********Factory**********/ + +// redis://:password@host:port/db_num +// url = "redis://:admin@192.168.1.101:6001/3" +// url = "redis://127.0.0.1:6379" +WFRedisTask *WFTaskFactory::create_redis_task(const std::string& url, + int retry_max, + redis_callback_t callback) +{ + auto *task = new ComplexRedisTask(retry_max, std::move(callback)); + ParsedURI uri; + + URIParser::parse(url, uri); + task->init(std::move(uri)); + task->set_keep_alive(REDIS_KEEPALIVE_DEFAULT); + return task; +} + +WFRedisTask *WFTaskFactory::create_redis_task(const ParsedURI& uri, + int retry_max, + redis_callback_t callback) +{ + auto *task = new ComplexRedisTask(retry_max, std::move(callback)); + + task->init(uri); + task->set_keep_alive(REDIS_KEEPALIVE_DEFAULT); + return task; +} + +WFRedisTask * +__WFRedisTaskFactory::create_subscribe_task(const std::string& url, + extract_t extract, + redis_callback_t callback) +{ + auto *task = new ComplexRedisSubscribeTask(std::move(extract), + std::move(callback)); + ParsedURI uri; + + URIParser::parse(url, uri); + task->init(std::move(uri)); + return task; +} + +WFRedisTask * +__WFRedisTaskFactory::create_subscribe_task(const ParsedURI& uri, + extract_t extract, + redis_callback_t callback) +{ + auto *task = new ComplexRedisSubscribeTask(std::move(extract), + std::move(callback)); + + task->init(uri); + return task; +} + diff --git a/src/factory/RedisTaskImpl.inl b/src/factory/RedisTaskImpl.inl new file mode 100644 index 0000000..2329ba8 --- /dev/null +++ b/src/factory/RedisTaskImpl.inl @@ -0,0 +1,37 @@ +/* + Copyright (c) 2024 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) +*/ + +#include "WFTaskFactory.h" + +// Internal, for WFRedisSubscribeTask only. + +class __WFRedisTaskFactory +{ +private: + using extract_t = std::function; + +public: + static WFRedisTask *create_subscribe_task(const std::string& url, + extract_t extract, + redis_callback_t callback); + + static WFRedisTask *create_subscribe_task(const ParsedURI& uri, + extract_t extract, + redis_callback_t callback); +}; + diff --git a/src/factory/WFAlgoTaskFactory.h b/src/factory/WFAlgoTaskFactory.h new file mode 100644 index 0000000..301615d --- /dev/null +++ b/src/factory/WFAlgoTaskFactory.h @@ -0,0 +1,259 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFALGOTASKFACTORY_H_ +#define _WFALGOTASKFACTORY_H_ + +#include +#include +#include +#include "WFTask.h" + +namespace algorithm +{ + +template +struct SortInput +{ + T *first; + T *last; +}; + +template +struct SortOutput +{ + T *first; + T *last; +}; + +template +struct MergeInput +{ + T *first1; + T *last1; + T *first2; + T *last2; + T *d_first; +}; + +template +struct MergeOutput +{ + T *first; + T *last; +}; + +template +struct ShuffleInput +{ + T *first; + T *last; +}; + +template +struct ShuffleOutput +{ + T *first; + T *last; +}; + +template +struct RemoveInput +{ + T *first; + T *last; + T value; +}; + +template +struct RemoveOutput +{ + T *first; + T *last; +}; + +template +struct UniqueInput +{ + T *first; + T *last; +}; + +template +struct UniqueOutput +{ + T *first; + T *last; +}; + +template +struct ReverseInput +{ + T *first; + T *last; +}; + +template +struct ReverseOutput +{ + T *first; + T *last; +}; + +template +struct RotateInput +{ + T *first; + T *middle; + T *last; +}; + +template +struct RotateOutput +{ + T *first; + T *last; +}; + +template +using ReduceInput = std::vector>; + +template +using ReduceOutput = std::vector>; + +} /* namespace algorithm */ + +template +using WFSortTask = WFThreadTask, + algorithm::SortOutput>; +template +using sort_callback_t = std::function *)>; + +template +using WFMergeTask = WFThreadTask, + algorithm::MergeOutput>; +template +using merge_callback_t = std::function *)>; + +template +using WFShuffleTask = WFThreadTask, + algorithm::ShuffleOutput>; +template +using shuffle_callback_t = std::function *)>; + +template +using WFRemoveTask = WFThreadTask, + algorithm::RemoveOutput>; +template +using remove_callback_t = std::function *)>; + +template +using WFUniqueTask = WFThreadTask, + algorithm::UniqueOutput>; +template +using unique_callback_t = std::function *)>; + +template +using WFReverseTask = WFThreadTask, + algorithm::ReverseOutput>; +template +using reverse_callback_t = std::function *)>; + +template +using WFRotateTask = WFThreadTask, + algorithm::RotateOutput>; +template +using rotate_callback_t = std::function *)>; + +class WFAlgoTaskFactory +{ +public: + template> + static WFSortTask *create_sort_task(const std::string& queue_name, + T *first, T *last, + CB callback); + + template> + static WFSortTask *create_sort_task(const std::string& queue_name, + T *first, T *last, + CMP compare, + CB callback); + + template> + static WFSortTask *create_psort_task(const std::string& queue_name, + T *first, T *last, + CB callback); + + template> + static WFSortTask *create_psort_task(const std::string& queue_name, + T *first, T *last, + CMP compare, + CB callback); + + template> + static WFMergeTask *create_merge_task(const std::string& queue_name, + T *first1, T *last1, + T *first2, T *last2, + T *d_first, + CB callback); + + template> + static WFMergeTask *create_merge_task(const std::string& queue_name, + T *first1, T *last1, + T *first2, T *last2, + T *d_first, + CMP compare, + CB callback); + + template> + static WFShuffleTask *create_shuffle_task(const std::string& queue_name, + T *first, T *last, + CB callback); + + template> + static WFShuffleTask *create_shuffle_task(const std::string& queue_name, + T *first, T *last, + URBG generator, + CB callback); + + template> + static WFRemoveTask *create_remove_task(const std::string& queue_name, + T *first, T *last, + T value, + CB callback); + + template> + static WFUniqueTask *create_unique_task(const std::string& queue_name, + T *first, T *last, + CB callback); + + template> + static WFReverseTask *create_reverse_task(const std::string& queue_name, + T *first, T *last, + CB callback); + + template> + static WFRotateTask *create_rotate_task(const std::string& queue_name, + T *first, T *middle, T *last, + CB callback); +}; + +#include "WFAlgoTaskFactory.inl" + +#endif + diff --git a/src/factory/WFAlgoTaskFactory.inl b/src/factory/WFAlgoTaskFactory.inl new file mode 100644 index 0000000..746b321 --- /dev/null +++ b/src/factory/WFAlgoTaskFactory.inl @@ -0,0 +1,651 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include "Workflow.h" +#include "WFGlobal.h" + +/********** Classes without CMP **********/ + +template +class __WFSortTask : public WFSortTask +{ +protected: + virtual void execute() + { + std::sort(this->input.first, this->input.last); + this->output.first = this->input.first; + this->output.last = this->input.last; + } + +public: + __WFSortTask(ExecQueue *queue, Executor *executor, + T *first, T *last, + sort_callback_t&& cb) : + WFSortTask(queue, executor, std::move(cb)) + { + this->input.first = first; + this->input.last = last; + this->output.first = NULL; + this->output.last = NULL; + } +}; + +template +class __WFMergeTask : public WFMergeTask +{ +protected: + virtual void execute(); + +public: + __WFMergeTask(ExecQueue *queue, Executor *executor, + T *first1, T *last1, T *first2, T *last2, T *d_first, + merge_callback_t&& cb) : + WFMergeTask(queue, executor, std::move(cb)) + { + this->input.first1 = first1; + this->input.last1 = last1; + this->input.first2 = first2; + this->input.last2 = last2; + this->input.d_first = d_first; + this->output.first = NULL; + this->output.last = NULL; + } +}; + +template +void __WFMergeTask::execute() +{ + auto *input = &this->input; + auto *output = &this->output; + + if (input->first1 == input->d_first && input->last1 == input->first2) + { + std::inplace_merge(input->first1, input->first2, input->last2); + output->last = input->last2; + } + else if (input->first2 == input->d_first && input->last2 == input->first1) + { + std::inplace_merge(input->first2, input->first1, input->last1); + output->last = input->last1; + } + else + { + output->last = std::merge(input->first1, input->last1, + input->first2, input->last2, + input->d_first); + } + + output->first = input->d_first; +} + +template +class __WFParSortTask : public __WFSortTask +{ +public: + virtual void dispatch(); + +protected: + virtual SubTask *done() + { + if (this->flag) + return series_of(this)->pop(); + + return this->WFSortTask::done(); + } + + virtual void execute(); + +protected: + int depth; + int flag; + +public: + __WFParSortTask(ExecQueue *queue, Executor *executor, + T *first, T *last, int depth, + sort_callback_t&& cb) : + __WFSortTask(queue, executor, first, last, std::move(cb)) + { + this->depth = depth; + this->flag = 0; + } +}; + +template +void __WFParSortTask::dispatch() +{ + size_t n = this->input.last - this->input.first; + + if (!this->flag && this->depth < 7 && n >= 32) + { + SeriesWork *series = series_of(this); + T *middle = this->input.first + n / 2; + auto *task1 = + new __WFParSortTask(this->queue, this->executor, + this->input.first, middle, + this->depth + 1, + nullptr); + auto *task2 = + new __WFParSortTask(this->queue, this->executor, + middle, this->input.last, + this->depth + 1, + nullptr); + SeriesWork *sub_series[2] = { + Workflow::create_series_work(task1, nullptr), + Workflow::create_series_work(task2, nullptr) + }; + ParallelWork *parallel = + Workflow::create_parallel_work(sub_series, 2, nullptr); + + series->push_front(this); + series->push_front(parallel); + this->flag = 1; + this->subtask_done(); + } + else + this->__WFSortTask::dispatch(); +} + +template +void __WFParSortTask::execute() +{ + if (this->flag) + { + size_t n = this->input.last - this->input.first; + T *middle = this->input.first + n / 2; + + std::inplace_merge(this->input.first, middle, this->input.last); + this->output.first = this->input.first; + this->output.last = this->input.last; + this->flag = 0; + } + else + this->__WFSortTask::execute(); +} + +/********** Classes with CMP **********/ + +template +class __WFSortTaskCmp : public __WFSortTask +{ +protected: + virtual void execute() + { + std::sort(this->input.first, this->input.last, + std::move(this->compare)); + this->output.first = this->input.first; + this->output.last = this->input.last; + } + +protected: + CMP compare; + +public: + __WFSortTaskCmp(ExecQueue *queue, Executor *executor, + T *first, T *last, CMP&& cmp, + sort_callback_t&& cb) : + __WFSortTask(queue, executor, first, last, std::move(cb)), + compare(std::move(cmp)) + { + } +}; + +template +class __WFMergeTaskCmp : public __WFMergeTask +{ +protected: + virtual void execute(); + +protected: + CMP compare; + +public: + __WFMergeTaskCmp(ExecQueue *queue, Executor *executor, + T *first1, T *last1, T *first2, T *last2, + T *d_first, CMP&& cmp, + merge_callback_t&& cb) : + __WFMergeTask(queue, executor, first1, last1, first2, last2, d_first, + std::move(cb)), + compare(std::move(cmp)) + { + } +}; + +template +void __WFMergeTaskCmp::execute() +{ + auto *input = &this->input; + auto *output = &this->output; + + if (input->first1 == input->d_first && input->last1 == input->first2) + { + std::inplace_merge(input->first1, input->first2, input->last2, + std::move(this->compare)); + output->last = input->last2; + } + else if (input->first2 == input->d_first && input->last2 == input->first1) + { + std::inplace_merge(input->first2, input->first1, input->last1, + std::move(this->compare)); + output->last = input->last1; + } + else + { + output->last = std::merge(input->first1, input->last1, + input->first2, input->last2, + input->d_first, + std::move(this->compare)); + } + + output->first = input->d_first; +} + +template +class __WFParSortTaskCmp : public __WFSortTaskCmp +{ +public: + virtual void dispatch(); + +protected: + virtual SubTask *done() + { + if (this->flag) + return series_of(this)->pop(); + + return this->WFSortTask::done(); + } + + virtual void execute(); + +protected: + int depth; + int flag; + +public: + __WFParSortTaskCmp(ExecQueue *queue, Executor *executor, + T *first, T *last, CMP cmp, int depth, + sort_callback_t&& cb) : + __WFSortTaskCmp(queue, executor, first, last, std::move(cmp), + std::move(cb)) + { + this->depth = depth; + this->flag = 0; + } +}; + +template +void __WFParSortTaskCmp::dispatch() +{ + size_t n = this->input.last - this->input.first; + + if (!this->flag && this->depth < 7 && n >= 32) + { + SeriesWork *series = series_of(this); + T *middle = this->input.first + n / 2; + auto *task1 = + new __WFParSortTaskCmp(this->queue, this->executor, + this->input.first, middle, + this->compare, this->depth + 1, + nullptr); + auto *task2 = + new __WFParSortTaskCmp(this->queue, this->executor, + middle, this->input.last, + this->compare, this->depth + 1, + nullptr); + SeriesWork *sub_series[2] = { + Workflow::create_series_work(task1, nullptr), + Workflow::create_series_work(task2, nullptr) + }; + ParallelWork *parallel = + Workflow::create_parallel_work(sub_series, 2, nullptr); + + series->push_front(this); + series->push_front(parallel); + this->flag = 1; + this->subtask_done(); + } + else + this->__WFSortTaskCmp::dispatch(); +} + +template +void __WFParSortTaskCmp::execute() +{ + if (this->flag) + { + size_t n = this->input.last - this->input.first; + T *middle = this->input.first + n / 2; + + std::inplace_merge(this->input.first, middle, this->input.last, + std::move(this->compare)); + this->output.first = this->input.first; + this->output.last = this->input.last; + this->flag = 0; + } + else + this->__WFSortTaskCmp::execute(); +} + +/********** Factory functions without CMP **********/ + +template +WFSortTask *WFAlgoTaskFactory::create_sort_task(const std::string& name, + T *first, T *last, + CB callback) +{ + return new __WFSortTask(WFGlobal::get_exec_queue(name), + WFGlobal::get_compute_executor(), + first, last, + std::move(callback)); +} + +template +WFMergeTask *WFAlgoTaskFactory::create_merge_task(const std::string& name, + T *first1, T *last1, + T *first2, T *last2, + T *d_first, + CB callback) +{ + return new __WFMergeTask(WFGlobal::get_exec_queue(name), + WFGlobal::get_compute_executor(), + first1, last1, first2, last2, d_first, + std::move(callback)); +} + +template +WFSortTask *WFAlgoTaskFactory::create_psort_task(const std::string& name, + T *first, T *last, + CB callback) +{ + return new __WFParSortTask(WFGlobal::get_exec_queue(name), + WFGlobal::get_compute_executor(), + first, last, 0, + std::move(callback)); +} + +/********** Factory functions with CMP **********/ + +template +WFSortTask *WFAlgoTaskFactory::create_sort_task(const std::string& name, + T *first, T *last, + CMP compare, + CB callback) +{ + return new __WFSortTaskCmp(WFGlobal::get_exec_queue(name), + WFGlobal::get_compute_executor(), + first, last, std::move(compare), + std::move(callback)); +} + +template +WFMergeTask *WFAlgoTaskFactory::create_merge_task(const std::string& name, + T *first1, T *last1, + T *first2, T *last2, + T *d_first, + CMP compare, + CB callback) +{ + return new __WFMergeTaskCmp(WFGlobal::get_exec_queue(name), + WFGlobal::get_compute_executor(), + first1, last1, first2, last2, + d_first, std::move(compare), + std::move(callback)); +} + +template +WFSortTask *WFAlgoTaskFactory::create_psort_task(const std::string& name, + T *first, T *last, + CMP compare, + CB callback) +{ + return new __WFParSortTaskCmp(WFGlobal::get_exec_queue(name), + WFGlobal::get_compute_executor(), + first, last, std::move(compare), 0, + std::move(callback)); +} + +/****************** Shuffle ******************/ + +template +class __WFShuffleTask : public WFShuffleTask +{ +protected: + virtual void execute() + { + std::shuffle(this->input.first, this->input.last, + std::mt19937_64(random())); + this->output.first = this->input.first; + this->output.last = this->input.last; + } + +public: + __WFShuffleTask(ExecQueue *queue, Executor *executor, + T *first, T *last, + shuffle_callback_t&& cb) : + WFShuffleTask(queue, executor, std::move(cb)) + { + this->input.first = first; + this->input.last = last; + this->output.first = NULL; + this->output.last = NULL; + } +}; + +template +class __WFShuffleTaskGen : public __WFShuffleTask +{ +protected: + virtual void execute() + { + std::shuffle(this->input.first, this->input.last, + std::move(this->generator)); + this->output.first = this->input.first; + this->output.last = this->input.last; + } + +protected: + URBG generator; + +public: + __WFShuffleTaskGen(ExecQueue *queue, Executor *executor, + T *first, T *last, URBG&& gen, + shuffle_callback_t&& cb) : + __WFShuffleTask(queue, executor, std::move(cb)), + generator(std::move(gen)) + { + } +}; + +template +WFShuffleTask *WFAlgoTaskFactory::create_shuffle_task(const std::string& name, + T *first, T *last, + CB callback) +{ + return new __WFShuffleTask(WFGlobal::get_exec_queue(name), + WFGlobal::get_compute_executor(), + first, last, + std::move(callback)); +} + +template +WFShuffleTask *WFAlgoTaskFactory::create_shuffle_task(const std::string& name, + T *first, T *last, + URBG generator, + CB callback) +{ + return new __WFShuffleTaskGen(WFGlobal::get_exec_queue(name), + WFGlobal::get_compute_executor(), + first, last, std::move(generator), + std::move(callback)); +} + +/****************** Remove ******************/ + +template +class __WFRemoveTask : public WFRemoveTask +{ +protected: + virtual void execute() + { + this->output.last = std::remove(this->input.first, this->input.last, + this->input.value); + this->output.first = this->input.first; + } + +public: + __WFRemoveTask(ExecQueue *queue, Executor *executor, + T *first, T *last, T&& value, + remove_callback_t&& cb) : + WFRemoveTask(queue, executor, std::move(cb)) + { + this->input.first = first; + this->input.last = last; + this->input.value = std::move(value); + this->output.first = NULL; + this->output.last = NULL; + } +}; + +template +WFRemoveTask *WFAlgoTaskFactory::create_remove_task(const std::string& name, + T *first, T *last, + T value, + CB callback) +{ + return new __WFRemoveTask(WFGlobal::get_exec_queue(name), + WFGlobal::get_compute_executor(), + first, last, std::move(value), + std::move(callback)); +} + +/****************** Unique ******************/ + +template +class __WFUniqueTask : public WFUniqueTask +{ +protected: + virtual void execute() + { + this->output.last = std::unique(this->input.first, this->input.last); + this->output.first = this->input.first; + } + +public: + __WFUniqueTask(ExecQueue *queue, Executor *executor, + T *first, T *last, + unique_callback_t&& cb) : + WFUniqueTask(queue, executor, std::move(cb)) + { + this->input.first = first; + this->input.last = last; + this->output.first = NULL; + this->output.last = NULL; + } +}; + +template +WFUniqueTask *WFAlgoTaskFactory::create_unique_task(const std::string& name, + T *first, T *last, + CB callback) +{ + return new __WFUniqueTask(WFGlobal::get_exec_queue(name), + WFGlobal::get_compute_executor(), + first, last, + std::move(callback)); +} + +/****************** Reverse ******************/ + +template +class __WFReverseTask : public WFReverseTask +{ +protected: + virtual void execute() + { + std::reverse(this->input.first, this->input.last); + this->output.first = this->input.first; + this->output.last = this->input.last; + } + +public: + __WFReverseTask(ExecQueue *queue, Executor *executor, + T *first, T *last, + reverse_callback_t&& cb) : + WFReverseTask(queue, executor, std::move(cb)) + { + this->input.first = first; + this->input.last = last; + this->output.first = NULL; + this->output.last = NULL; + } +}; + +template +WFReverseTask *WFAlgoTaskFactory::create_reverse_task(const std::string& name, + T *first, T *last, + CB callback) +{ + return new __WFReverseTask(WFGlobal::get_exec_queue(name), + WFGlobal::get_compute_executor(), + first, last, + std::move(callback)); +} + +/****************** Rotate ******************/ + +template +class __WFRotateTask : public WFRotateTask +{ +protected: + virtual void execute() + { + std::rotate(this->input.first, this->input.middle, this->input.last); + this->output.first = this->input.first; + this->output.last = this->input.last; + } + +public: + __WFRotateTask(ExecQueue *queue, Executor *executor, + T *first, T* middle, T *last, + rotate_callback_t&& cb) : + WFRotateTask(queue, executor, std::move(cb)) + { + this->input.first = first; + this->input.middle = middle; + this->input.last = last; + this->output.first = NULL; + this->output.last = NULL; + } +}; + +template +WFRotateTask *WFAlgoTaskFactory::create_rotate_task(const std::string& name, + T *first, T *middle, T *last, + CB callback) +{ + return new __WFRotateTask(WFGlobal::get_exec_queue(name), + WFGlobal::get_compute_executor(), + first, middle, last, + std::move(callback)); +} + diff --git a/src/factory/WFConnection.h b/src/factory/WFConnection.h new file mode 100644 index 0000000..68adf05 --- /dev/null +++ b/src/factory/WFConnection.h @@ -0,0 +1,69 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFCONNECTION_H_ +#define _WFCONNECTION_H_ + +#include +#include +#include +#include "Communicator.h" + +class WFConnection : public CommConnection +{ +public: + void *get_context() const + { + return this->context; + } + + void set_context(void *context, std::function deleter) + { + this->context = context; + this->deleter = std::move(deleter); + } + + void *test_set_context(void *test_context, void *new_context, + std::function deleter) + { + if (this->context.compare_exchange_strong(test_context, new_context)) + { + this->deleter = std::move(deleter); + return new_context; + } + + return test_context; + } + +private: + std::atomic context; + std::function deleter; + +public: + WFConnection() : context(NULL) { } + +protected: + virtual ~WFConnection() + { + if (this->deleter) + this->deleter(this->context); + } +}; + +#endif + diff --git a/src/factory/WFGraphTask.cc b/src/factory/WFGraphTask.cc new file mode 100644 index 0000000..3332a7a --- /dev/null +++ b/src/factory/WFGraphTask.cc @@ -0,0 +1,110 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include "Workflow.h" +#include "WFGraphTask.h" + +SubTask *WFGraphNode::done() +{ + SeriesWork *series = series_of(this); + + if (!this->user_data) + { + this->value = 1; + this->user_data = (void *)1; + } + else + delete this; + + return series->pop(); +} + +WFGraphNode::~WFGraphNode() +{ + if (this->user_data) + { + if (series_of(this)->is_canceled()) + { + for (WFGraphNode *node : this->successors) + series_of(node)->SeriesWork::cancel(); + } + + for (WFGraphNode *node : this->successors) + node->WFCounterTask::count(); + } +} + +WFGraphNode& WFGraphTask::create_graph_node(SubTask *task) +{ + WFGraphNode *node = new WFGraphNode; + SeriesWork *series = Workflow::create_series_work(node, node, nullptr); + + series->push_back(task); + this->parallel->add_series(series); + return *node; +} + +void WFGraphTask::dispatch() +{ + SeriesWork *series = series_of(this); + + if (this->parallel) + { + series->push_front(this); + series->push_front(this->parallel); + this->parallel = NULL; + } + else + this->state = WFT_STATE_SUCCESS; + + this->subtask_done(); +} + +SubTask *WFGraphTask::done() +{ + SeriesWork *series = series_of(this); + + if (this->state == WFT_STATE_SUCCESS) + { + if (this->callback) + this->callback(this); + + delete this; + } + + return series->pop(); +} + +WFGraphTask::~WFGraphTask() +{ + SeriesWork *series; + size_t i; + + if (this->parallel) + { + for (i = 0; i < this->parallel->size(); i++) + { + series = this->parallel->series_at(i); + series->unset_last_task(); + } + + this->parallel->dismiss(); + } +} + diff --git a/src/factory/WFGraphTask.h b/src/factory/WFGraphTask.h new file mode 100644 index 0000000..752b2d9 --- /dev/null +++ b/src/factory/WFGraphTask.h @@ -0,0 +1,107 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFGRAPHTASK_H_ +#define _WFGRAPHTASK_H_ + +#include +#include +#include +#include "Workflow.h" +#include "WFTask.h" + +class WFGraphNode : protected WFCounterTask +{ +public: + void precede(WFGraphNode& node) + { + node.value++; + this->successors.push_back(&node); + } + + void succeed(WFGraphNode& node) + { + node.precede(*this); + } + +protected: + virtual SubTask *done(); + +protected: + std::vector successors; + +protected: + WFGraphNode() : WFCounterTask(0, nullptr) { } + virtual ~WFGraphNode(); + friend class WFGraphTask; +}; + +static inline WFGraphNode& operator --(WFGraphNode& node, int) +{ + return node; +} + +static inline WFGraphNode& operator > (WFGraphNode& prec, WFGraphNode& succ) +{ + prec.precede(succ); + return succ; +} + +static inline WFGraphNode& operator < (WFGraphNode& succ, WFGraphNode& prec) +{ + succ.succeed(prec); + return prec; +} + +static inline WFGraphNode& operator --(WFGraphNode& node) +{ + return node; +} + +class WFGraphTask : public WFGenericTask +{ +public: + WFGraphNode& create_graph_node(SubTask *task); + +public: + void set_callback(std::function cb) + { + this->callback = std::move(cb); + } + +protected: + virtual void dispatch(); + virtual SubTask *done(); + +protected: + ParallelWork *parallel; + std::function callback; + +public: + WFGraphTask(std::function&& cb) : + callback(std::move(cb)) + { + this->parallel = Workflow::create_parallel_work(nullptr); + } + +protected: + virtual ~WFGraphTask(); +}; + +#endif + diff --git a/src/factory/WFMessageQueue.cc b/src/factory/WFMessageQueue.cc new file mode 100644 index 0000000..6c13c6c --- /dev/null +++ b/src/factory/WFMessageQueue.cc @@ -0,0 +1,94 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) +*/ + +#include "list.h" +#include "WFTask.h" +#include "WFMessageQueue.h" + +class __MQConditional : public WFConditional +{ +public: + struct list_head list; + struct WFMessageQueue::Data *data; + +public: + virtual void dispatch(); + virtual void signal(void *msg) { } + +public: + __MQConditional(SubTask *task, void **msgbuf, + struct WFMessageQueue::Data *data) : + WFConditional(task, msgbuf) + { + this->data = data; + } + + __MQConditional(SubTask *task, + struct WFMessageQueue::Data *data) : + WFConditional(task) + { + this->data = data; + } +}; + +void __MQConditional::dispatch() +{ + struct WFMessageQueue::Data *data = this->data; + + data->mutex.lock(); + if (!list_empty(&data->msg_list)) + this->WFConditional::signal(data->pop()); + else + list_add_tail(&this->list, &data->wait_list); + + data->mutex.unlock(); + this->WFConditional::dispatch(); +} + +WFConditional *WFMessageQueue::get(SubTask *task, void **msgbuf) +{ + return new __MQConditional(task, msgbuf, &this->data); +} + +WFConditional *WFMessageQueue::get(SubTask *task) +{ + return new __MQConditional(task, &this->data); +} + +void WFMessageQueue::post(void *msg) +{ + struct WFMessageQueue::Data *data = &this->data; + WFConditional *cond; + + data->mutex.lock(); + if (!list_empty(&data->wait_list)) + { + cond = list_entry(data->wait_list.next, __MQConditional, list); + list_del(data->wait_list.next); + } + else + { + cond = NULL; + this->push(msg); + } + + data->mutex.unlock(); + if (cond) + cond->WFConditional::signal(msg); +} + diff --git a/src/factory/WFMessageQueue.h b/src/factory/WFMessageQueue.h new file mode 100644 index 0000000..e7fac0e --- /dev/null +++ b/src/factory/WFMessageQueue.h @@ -0,0 +1,88 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFMESSAGEQUEUE_H_ +#define _WFMESSAGEQUEUE_H_ + +#include +#include "list.h" +#include "WFTask.h" + +class WFMessageQueue +{ +public: + WFConditional *get(SubTask *task, void **msgbuf); + WFConditional *get(SubTask *task); + void post(void *msg); + +public: + struct Data + { + void *pop() { return this->queue->pop(); } + void push(void *msg) { this->queue->push(msg); } + + struct list_head msg_list; + struct list_head wait_list; + std::mutex mutex; + WFMessageQueue *queue; + }; + +protected: + struct MessageEntry + { + struct list_head list; + void *msg; + }; + +protected: + virtual void *pop() + { + struct MessageEntry *entry; + void *msg; + + entry = list_entry(this->data.msg_list.next, struct MessageEntry, list); + list_del(&entry->list); + msg = entry->msg; + delete entry; + + return msg; + } + + virtual void push(void *msg) + { + struct MessageEntry *entry = new struct MessageEntry; + entry->msg = msg; + list_add_tail(&entry->list, &this->data.msg_list); + } + +protected: + struct Data data; + +public: + WFMessageQueue() + { + INIT_LIST_HEAD(&this->data.msg_list); + INIT_LIST_HEAD(&this->data.wait_list); + this->data.queue = this; + } + + virtual ~WFMessageQueue() { } +}; + +#endif + diff --git a/src/factory/WFOperator.h b/src/factory/WFOperator.h new file mode 100644 index 0000000..954d798 --- /dev/null +++ b/src/factory/WFOperator.h @@ -0,0 +1,341 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _WFOPERATOR_H_ +#define _WFOPERATOR_H_ + +#include "Workflow.h" + +/** + * @file WFOperator.h + * @brief Workflow Series/Parallel/Task operator + */ + +/** + * @brief S=S>P + * @note Equivalent to x.push_back(&y) +*/ +static inline SeriesWork& operator>(SeriesWork& x, ParallelWork& y); + +/** + * @brief S=P>S + * @note Equivalent to y.push_front(&x) +*/ +static inline SeriesWork& operator>(ParallelWork& x, SeriesWork& y); + +/** + * @brief S=S>t + * @note Equivalent to x.push_back(&y) +*/ +static inline SeriesWork& operator>(SeriesWork& x, SubTask& y); + +/** + * @brief S=S>t + * @note Equivalent to x.push_back(y) +*/ +static inline SeriesWork& operator>(SeriesWork& x, SubTask *y); + +/** + * @brief S=t>S + * @note Equivalent to y.push_front(&x) +*/ +static inline SeriesWork& operator>(SubTask& x, SeriesWork& y); + +/** + * @brief S=t>S + * @note Equivalent to y.push_front(x) +*/ +static inline SeriesWork& operator>(SubTask *x, SeriesWork& y); + +/** + * @brief S=P>P + * @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(&y) +*/ +static inline SeriesWork& operator>(ParallelWork& x, ParallelWork& y); + +/** + * @brief S=P>t + * @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(&y) +*/ +static inline SeriesWork& operator>(ParallelWork& x, SubTask& y); + +/** + * @brief S=P>t + * @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(y) +*/ +static inline SeriesWork& operator>(ParallelWork& x, SubTask *y); + +/** + * @brief S=t>P + * @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(&y) +*/ +static inline SeriesWork& operator>(SubTask& x, ParallelWork& y); + +/** + * @brief S=t>P + * @note Equivalent to Workflow::create_series_work(x, nullptr)->push_back(&y) +*/ +static inline SeriesWork& operator>(SubTask *x, ParallelWork& y); + +/** + * @brief S=t>t + * @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(&y) +*/ +static inline SeriesWork& operator>(SubTask& x, SubTask& y); + +/** + * @brief S=t>t + * @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(y) +*/ +static inline SeriesWork& operator>(SubTask& x, SubTask *y); + +/** + * @brief S=t>t + * @note Equivalent to Workflow::create_series_work(x, nullptr)->push_back(&y) +*/ +static inline SeriesWork& operator>(SubTask *x, SubTask& y); +//static inline SeriesWork& operator>(SubTask *x, SubTask *y);//compile error! + +/** + * @brief P=P*S + * @note Equivalent to x.add_series(&y) +*/ +static inline ParallelWork& operator*(ParallelWork& x, SeriesWork& y); + +/** + * @brief P=S*P + * @note Equivalent to y.push_back(&x) +*/ +static inline ParallelWork& operator*(SeriesWork& x, ParallelWork& y); + +/** + * @brief P=P*t + * @note Equivalent to x.add_series(Workflow::create_series_work(&y, nullptr)) +*/ +static inline ParallelWork& operator*(ParallelWork& x, SubTask& y); + +/** + * @brief P=P*t + * @note Equivalent to x.add_series(Workflow::create_series_work(y, nullptr)) +*/ +static inline ParallelWork& operator*(ParallelWork& x, SubTask *y); + +/** + * @brief P=t*P + * @note Equivalent to y.add_series(Workflow::create_series_work(&x, nullptr)) +*/ +static inline ParallelWork& operator*(SubTask& x, ParallelWork& y); + +/** + * @brief P=t*P + * @note Equivalent to y.add_series(Workflow::create_series_work(x, nullptr)) +*/ +static inline ParallelWork& operator*(SubTask *x, ParallelWork& y); + +/** + * @brief P=S*S + * @note Equivalent to Workflow::create_parallel_work({&x, &y}, 2, nullptr) +*/ +static inline ParallelWork& operator*(SeriesWork& x, SeriesWork& y); + +/** + * @brief P=S*t + * @note Equivalent to Workflow::create_parallel_work({&y}, 1, nullptr)->add_series(&x) +*/ +static inline ParallelWork& operator*(SeriesWork& x, SubTask& y); + +/** + * @brief P=S*t + * @note Equivalent to Workflow::create_parallel_work({y}, 1, nullptr)->add_series(&x) +*/ +static inline ParallelWork& operator*(SeriesWork& x, SubTask *y); + +/** + * @brief P=t*S + * @note Equivalent to Workflow::create_parallel_work({&x}, 1, nullptr)->add_series(&y) +*/ +static inline ParallelWork& operator*(SubTask& x, SeriesWork& y); + +/** + * @brief P=t*S + * @note Equivalent to Workflow::create_parallel_work({x}, 1, nullptr)->add_series(&y) +*/ +static inline ParallelWork& operator*(SubTask *x, SeriesWork& y); + +/** + * @brief P=t*t + * @note Equivalent to Workflow::create_parallel_work({Workflow::create_series_work(&x, nullptr), Workflow::create_series_work(&y, nullptr)}, 2, nullptr) +*/ +static inline ParallelWork& operator*(SubTask& x, SubTask& y); + +/** + * @brief P=t*t + * @note Equivalent to Workflow::create_parallel_work({Workflow::create_series_work(&x, nullptr), Workflow::create_series_work(y, nullptr)}, 2, nullptr) +*/ +static inline ParallelWork& operator*(SubTask& x, SubTask *y); + +/** + * @brief P=t*t + * @note Equivalent to Workflow::create_parallel_work({Workflow::create_series_work(x, nullptr), Workflow::create_series_work(&y, nullptr)}, 2, nullptr) +*/ +static inline ParallelWork& operator*(SubTask *x, SubTask& y); +//static inline ParallelWork& operator*(SubTask *x, SubTask *y);//compile error! + + +//S=S>t +static inline SeriesWork& operator>(SeriesWork& x, SubTask& y) +{ + x.push_back(&y); + return x; +} +static inline SeriesWork& operator>(SeriesWork& x, SubTask *y) +{ + return x > *y; +} +//S=t>S +static inline SeriesWork& operator>(SubTask& x, SeriesWork& y) +{ + y.push_front(&x); + return y; +} +static inline SeriesWork& operator>(SubTask *x, SeriesWork& y) +{ + return *x > y; +} +//S=t>t +static inline SeriesWork& operator>(SubTask& x, SubTask& y) +{ + SeriesWork *series = Workflow::create_series_work(&x, nullptr); + series->push_back(&y); + return *series; +} +static inline SeriesWork& operator>(SubTask& x, SubTask *y) +{ + return x > *y; +} +static inline SeriesWork& operator>(SubTask *x, SubTask& y) +{ + return *x > y; +} +//S=S>P +static inline SeriesWork& operator>(SeriesWork& x, ParallelWork& y) +{ + return x > (SubTask&)y; +} +//S=P>S +static inline SeriesWork& operator>(ParallelWork& x, SeriesWork& y) +{ + return y > (SubTask&)x; +} +//S=P>P +static inline SeriesWork& operator>(ParallelWork& x, ParallelWork& y) +{ + return (SubTask&)x > (SubTask&)y; +} +//S=P>t +static inline SeriesWork& operator>(ParallelWork& x, SubTask& y) +{ + return (SubTask&)x > y; +} +static inline SeriesWork& operator>(ParallelWork& x, SubTask *y) +{ + return x > *y; +} +//S=t>P +static inline SeriesWork& operator>(SubTask& x, ParallelWork& y) +{ + return x > (SubTask&)y; +} +static inline SeriesWork& operator>(SubTask *x, ParallelWork& y) +{ + return *x > y; +} + +//P=P*S +static inline ParallelWork& operator*(ParallelWork& x, SeriesWork& y) +{ + x.add_series(&y); + return x; +} +//P=S*P +static inline ParallelWork& operator*(SeriesWork& x, ParallelWork& y) +{ + return y * x; +} +//P=P*t +static inline ParallelWork& operator*(ParallelWork& x, SubTask& y) +{ + x.add_series(Workflow::create_series_work(&y, nullptr)); + return x; +} +static inline ParallelWork& operator*(ParallelWork& x, SubTask *y) +{ + return x * (*y); +} +//P=t*P +static inline ParallelWork& operator*(SubTask& x, ParallelWork& y) +{ + return y * x; +} +static inline ParallelWork& operator*(SubTask *x, ParallelWork& y) +{ + return (*x) * y; +} +//P=S*S +static inline ParallelWork& operator*(SeriesWork& x, SeriesWork& y) +{ + SeriesWork *arr[2] = {&x, &y}; + return *Workflow::create_parallel_work(arr, 2, nullptr); +} +//P=S*t +static inline ParallelWork& operator*(SeriesWork& x, SubTask& y) +{ + return x * (*Workflow::create_series_work(&y, nullptr)); +} +static inline ParallelWork& operator*(SeriesWork& x, SubTask *y) +{ + return x * (*y); +} +//P=t*S +static inline ParallelWork& operator*(SubTask& x, SeriesWork& y) +{ + return y * x; +} +static inline ParallelWork& operator*(SubTask *x, SeriesWork& y) +{ + return (*x) * y; +} +//P=t*t +static inline ParallelWork& operator*(SubTask& x, SubTask& y) +{ + SeriesWork *arr[2] = {Workflow::create_series_work(&x, nullptr), + Workflow::create_series_work(&y, nullptr)}; + return *Workflow::create_parallel_work(arr, 2, nullptr); +} +static inline ParallelWork& operator*(SubTask& x, SubTask *y) +{ + return x * (*y); +} +static inline ParallelWork& operator*(SubTask *x, SubTask& y) +{ + return (*x) * y; +} + + +#endif + diff --git a/src/factory/WFResourcePool.cc b/src/factory/WFResourcePool.cc new file mode 100644 index 0000000..865bae3 --- /dev/null +++ b/src/factory/WFResourcePool.cc @@ -0,0 +1,117 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include "list.h" +#include "WFTask.h" +#include "WFResourcePool.h" + +class __RPConditional : public WFConditional +{ +public: + struct list_head list; + struct WFResourcePool::Data *data; + +public: + virtual void dispatch(); + virtual void signal(void *res) { } + +public: + __RPConditional(SubTask *task, void **resbuf, + struct WFResourcePool::Data *data) : + WFConditional(task, resbuf) + { + this->data = data; + } + + __RPConditional(SubTask *task, + struct WFResourcePool::Data *data) : + WFConditional(task) + { + this->data = data; + } +}; + +void __RPConditional::dispatch() +{ + struct WFResourcePool::Data *data = this->data; + + data->mutex.lock(); + if (--data->value >= 0) + this->WFConditional::signal(data->pop()); + else + list_add_tail(&this->list, &data->wait_list); + + data->mutex.unlock(); + this->WFConditional::dispatch(); +} + +WFConditional *WFResourcePool::get(SubTask *task, void **resbuf) +{ + return new __RPConditional(task, resbuf, &this->data); +} + +WFConditional *WFResourcePool::get(SubTask *task) +{ + return new __RPConditional(task, &this->data); +} + +void WFResourcePool::create(size_t n) +{ + this->data.res = new void *[n]; + this->data.value = n; + this->data.index = 0; + INIT_LIST_HEAD(&this->data.wait_list); + this->data.pool = this; +} + +WFResourcePool::WFResourcePool(void *const *res, size_t n) +{ + this->create(n); + memcpy(this->data.res, res, n * sizeof (void *)); +} + +WFResourcePool::WFResourcePool(size_t n) +{ + this->create(n); + memset(this->data.res, 0, n * sizeof (void *)); +} + +void WFResourcePool::post(void *res) +{ + struct WFResourcePool::Data *data = &this->data; + WFConditional *cond; + + data->mutex.lock(); + if (++data->value <= 0) + { + cond = list_entry(data->wait_list.next, __RPConditional, list); + list_del(data->wait_list.next); + } + else + { + cond = NULL; + this->push(res); + } + + data->mutex.unlock(); + if (cond) + cond->WFConditional::signal(res); +} + diff --git a/src/factory/WFResourcePool.h b/src/factory/WFResourcePool.h new file mode 100644 index 0000000..1cdce2f --- /dev/null +++ b/src/factory/WFResourcePool.h @@ -0,0 +1,72 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFRESOURCEPOOL_H_ +#define _WFRESOURCEPOOL_H_ + +#include +#include "list.h" +#include "WFTask.h" + +class WFResourcePool +{ +public: + WFConditional *get(SubTask *task, void **resbuf); + WFConditional *get(SubTask *task); + void post(void *res); + +public: + struct Data + { + void *pop() { return this->pool->pop(); } + void push(void *res) { this->pool->push(res); } + + void **res; + long value; + size_t index; + struct list_head wait_list; + std::mutex mutex; + WFResourcePool *pool; + }; + +protected: + virtual void *pop() + { + return this->data.res[this->data.index++]; + } + + virtual void push(void *res) + { + this->data.res[--this->data.index] = res; + } + +protected: + struct Data data; + +private: + void create(size_t n); + +public: + WFResourcePool(void *const *res, size_t n); + WFResourcePool(size_t n); + virtual ~WFResourcePool() { delete []this->data.res; } +}; + +#endif + diff --git a/src/factory/WFTask.h b/src/factory/WFTask.h new file mode 100644 index 0000000..218f08a --- /dev/null +++ b/src/factory/WFTask.h @@ -0,0 +1,872 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFTASK_H_ +#define _WFTASK_H_ + +#include +#include +#include +#include +#include +#include +#include "Executor.h" +#include "ExecRequest.h" +#include "Communicator.h" +#include "CommScheduler.h" +#include "CommRequest.h" +#include "SleepRequest.h" +#include "IORequest.h" +#include "Workflow.h" +#include "WFConnection.h" + +enum +{ + WFT_STATE_UNDEFINED = -1, + WFT_STATE_SUCCESS = CS_STATE_SUCCESS, + WFT_STATE_TOREPLY = CS_STATE_TOREPLY, /* for server task only */ + WFT_STATE_NOREPLY = CS_STATE_TOREPLY + 1, /* for server task only */ + WFT_STATE_SYS_ERROR = CS_STATE_ERROR, + WFT_STATE_SSL_ERROR = 65, + WFT_STATE_DNS_ERROR = 66, /* for client task only */ + WFT_STATE_TASK_ERROR = 67, + WFT_STATE_ABORTED = CS_STATE_STOPPED +}; + +template +class WFThreadTask : public ExecRequest +{ +public: + void start() + { + assert(!series_of(this)); + Workflow::start_series_work(this, nullptr); + } + + void dismiss() + { + assert(!series_of(this)); + delete this; + } + +public: + INPUT *get_input() { return &this->input; } + OUTPUT *get_output() { return &this->output; } + +public: + void *user_data; + +public: + int get_state() const { return this->state; } + int get_error() const { return this->error; } + +public: + void set_callback(std::function *)> cb) + { + this->callback = std::move(cb); + } + +protected: + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + + if (this->callback) + this->callback(this); + + delete this; + return series->pop(); + } + +protected: + INPUT input; + OUTPUT output; + std::function *)> callback; + +public: + WFThreadTask(ExecQueue *queue, Executor *executor, + std::function *)>&& cb) : + ExecRequest(queue, executor), + callback(std::move(cb)) + { + this->user_data = NULL; + this->state = WFT_STATE_UNDEFINED; + this->error = 0; + } + +protected: + virtual ~WFThreadTask() { } +}; + +template +class WFNetworkTask : public CommRequest +{ +public: + /* start(), dismiss() are for client tasks only. */ + void start() + { + assert(!series_of(this)); + Workflow::start_series_work(this, nullptr); + } + + void dismiss() + { + assert(!series_of(this)); + delete this; + } + +public: + REQ *get_req() { return &this->req; } + RESP *get_resp() { return &this->resp; } + +public: + void *user_data; + +public: + int get_state() const { return this->state; } + int get_error() const { return this->error; } + + /* Call when error is ETIMEDOUT, return values: + * TOR_NOT_TIMEOUT, TOR_WAIT_TIMEOUT, TOR_CONNECT_TIMEOUT, + * TOR_TRANSMIT_TIMEOUT (send or receive). + * SSL connect timeout also returns TOR_CONNECT_TIMEOUT. */ + int get_timeout_reason() const { return this->timeout_reason; } + + /* Call only in callback or server's process. */ + long long get_task_seq() const + { + if (!this->target) + { + errno = ENOTCONN; + return -1; + } + + return this->get_seq(); + } + + int get_peer_addr(struct sockaddr *addr, socklen_t *addrlen) const; + + virtual WFConnection *get_connection() const = 0; + +public: + /* All in milliseconds. timeout == -1 for unlimited. */ + void set_send_timeout(int timeout) { this->send_timeo = timeout; } + void set_receive_timeout(int timeout) { this->receive_timeo = timeout; } + void set_keep_alive(int timeout) { this->keep_alive_timeo = timeout; } + void set_watch_timeout(int timeout) { this->watch_timeo = timeout; } + +public: + /* Do not reply this request. */ + void noreply() + { + if (this->state == WFT_STATE_TOREPLY) + this->state = WFT_STATE_NOREPLY; + } + + /* Push reply data synchronously. */ + virtual int push(const void *buf, size_t size) + { + if (this->state != WFT_STATE_TOREPLY && + this->state != WFT_STATE_NOREPLY) + { + errno = ENOENT; + return -1; + } + + return this->scheduler->push(buf, size, this); + } + + /* To check if the connection was closed before replying. + Always returns 'true' in callback. */ + bool closed() const + { + switch (this->state) + { + case WFT_STATE_UNDEFINED: + return false; + case WFT_STATE_TOREPLY: + case WFT_STATE_NOREPLY: + return !this->target->has_idle_conn(); + default: + return true; + } + } + +public: + void set_callback(std::function *)> cb) + { + this->callback = std::move(cb); + } + +protected: + virtual int send_timeout() { return this->send_timeo; } + virtual int receive_timeout() { return this->receive_timeo; } + virtual int keep_alive_timeout() { return this->keep_alive_timeo; } + virtual int first_timeout() { return this->watch_timeo; } + +protected: + int send_timeo; + int receive_timeo; + int keep_alive_timeo; + int watch_timeo; + REQ req; + RESP resp; + std::function *)> callback; + +protected: + WFNetworkTask(CommSchedObject *object, CommScheduler *scheduler, + std::function *)>&& cb) : + CommRequest(object, scheduler), + callback(std::move(cb)) + { + this->send_timeo = -1; + this->receive_timeo = -1; + this->keep_alive_timeo = 0; + this->watch_timeo = 0; + this->target = NULL; + this->timeout_reason = TOR_NOT_TIMEOUT; + this->user_data = NULL; + this->state = WFT_STATE_UNDEFINED; + this->error = 0; + } + + virtual ~WFNetworkTask() { } +}; + +class WFTimerTask : public SleepRequest +{ +public: + void start() + { + assert(!series_of(this)); + Workflow::start_series_work(this, nullptr); + } + + void dismiss() + { + assert(!series_of(this)); + delete this; + } + +public: + void *user_data; + +public: + int get_state() const { return this->state; } + int get_error() const { return this->error; } + +public: + void set_callback(std::function cb) + { + this->callback = std::move(cb); + } + +protected: + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + + if (this->callback) + this->callback(this); + + delete this; + return series->pop(); + } + +protected: + std::function callback; + +public: + WFTimerTask(CommScheduler *scheduler, + std::function cb) : + SleepRequest(scheduler), + callback(std::move(cb)) + { + this->user_data = NULL; + this->state = WFT_STATE_UNDEFINED; + this->error = 0; + } + +protected: + virtual ~WFTimerTask() { } +}; + +template +class WFFileTask : public IORequest +{ +public: + void start() + { + assert(!series_of(this)); + Workflow::start_series_work(this, nullptr); + } + + void dismiss() + { + assert(!series_of(this)); + delete this; + } + +public: + ARGS *get_args() { return &this->args; } + + long get_retval() const + { + if (this->state == WFT_STATE_SUCCESS) + return this->get_res(); + else + return -1; + } + +public: + void *user_data; + +public: + int get_state() const { return this->state; } + int get_error() const { return this->error; } + +public: + void set_callback(std::function *)> cb) + { + this->callback = std::move(cb); + } + +protected: + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + + if (this->callback) + this->callback(this); + + delete this; + return series->pop(); + } + +protected: + ARGS args; + std::function *)> callback; + +public: + WFFileTask(IOService *service, + std::function *)>&& cb) : + IORequest(service), + callback(std::move(cb)) + { + this->user_data = NULL; + this->state = WFT_STATE_UNDEFINED; + this->error = 0; + } + +protected: + virtual ~WFFileTask() { } +}; + +class WFGenericTask : public SubTask +{ +public: + void start() + { + assert(!series_of(this)); + Workflow::start_series_work(this, nullptr); + } + + void dismiss() + { + assert(!series_of(this)); + delete this; + } + +public: + void *user_data; + +public: + int get_state() const { return this->state; } + int get_error() const { return this->error; } + +protected: + virtual void dispatch() + { + this->subtask_done(); + } + + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + delete this; + return series->pop(); + } + +protected: + int state; + int error; + +public: + WFGenericTask() + { + this->user_data = NULL; + this->state = WFT_STATE_UNDEFINED; + this->error = 0; + } + +protected: + virtual ~WFGenericTask() { } +}; + +class WFCounterTask : public WFGenericTask +{ +public: + virtual void count() + { + if (--this->value == 0) + { + this->state = WFT_STATE_SUCCESS; + this->subtask_done(); + } + } + +public: + void set_callback(std::function cb) + { + this->callback = std::move(cb); + } + +protected: + virtual void dispatch() + { + this->WFCounterTask::count(); + } + + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + + if (this->callback) + this->callback(this); + + delete this; + return series->pop(); + } + +protected: + std::atomic value; + std::function callback; + +public: + WFCounterTask(unsigned int target_value, + std::function&& cb) : + value(target_value + 1), + callback(std::move(cb)) + { + } + +protected: + virtual ~WFCounterTask() { } +}; + +class WFMailboxTask : public WFGenericTask +{ +public: + virtual void send(void *msg) + { + *this->mailbox = msg; + if (this->flag.exchange(true)) + { + this->state = WFT_STATE_SUCCESS; + this->subtask_done(); + } + } + + void **get_mailbox() const { return this->mailbox; } + +public: + void set_callback(std::function cb) + { + this->callback = std::move(cb); + } + +protected: + virtual void dispatch() + { + if (this->flag.exchange(true)) + { + this->state = WFT_STATE_SUCCESS; + this->subtask_done(); + } + } + + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + + if (this->callback) + this->callback(this); + + delete this; + return series->pop(); + } + +protected: + void **mailbox; + std::atomic flag; + std::function callback; + +public: + WFMailboxTask(void **mailbox, + std::function&& cb) : + flag(false), + callback(std::move(cb)) + { + this->mailbox = mailbox; + } + + WFMailboxTask(std::function&& cb) : + flag(false), + callback(std::move(cb)) + { + this->mailbox = &this->user_data; + } + +protected: + virtual ~WFMailboxTask() { } +}; + +class WFSelectorTask : public WFGenericTask +{ +public: + virtual int submit(void *msg) + { + void *tmp = NULL; + int ret = 0; + + if (this->message.compare_exchange_strong(tmp, msg) && msg) + { + ret = 1; + if (this->flag.exchange(true)) + { + this->state = WFT_STATE_SUCCESS; + this->subtask_done(); + } + } + + if (--this->nleft == 0) + { + if (!this->message) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = ENOMSG; + this->subtask_done(); + } + + delete this; + } + + return ret; + } + + void *get_message() const { return this->message; } + +public: + void set_callback(std::function cb) + { + this->callback = std::move(cb); + } + +protected: + virtual void dispatch() + { + if (this->flag.exchange(true)) + { + this->state = WFT_STATE_SUCCESS; + this->subtask_done(); + } + + if (--this->nleft == 0) + { + if (!this->message) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = ENOMSG; + this->subtask_done(); + } + + delete this; + } + } + + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + + if (this->callback) + this->callback(this); + + return series->pop(); + } + +protected: + std::atomic message; + std::atomic flag; + std::atomic nleft; + std::function callback; + +public: + WFSelectorTask(size_t candidates, + std::function&& cb) : + message(NULL), + flag(false), + nleft(candidates + 1), + callback(std::move(cb)) + { + } + +protected: + virtual ~WFSelectorTask() { } +}; + +class WFConditional : public WFGenericTask +{ +public: + virtual void signal(void *msg) + { + *this->msgbuf = msg; + if (this->flag.exchange(true)) + this->subtask_done(); + } + +protected: + virtual void dispatch() + { + series_of(this)->push_front(this->task); + this->task = NULL; + if (this->flag.exchange(true)) + this->subtask_done(); + } + +protected: + std::atomic flag; + SubTask *task; + void **msgbuf; + +public: + WFConditional(SubTask *task, void **msgbuf) : + flag(false) + { + this->task = task; + this->msgbuf = msgbuf; + } + + WFConditional(SubTask *task) : + flag(false) + { + this->task = task; + this->msgbuf = &this->user_data; + } + +protected: + virtual ~WFConditional() + { + delete this->task; + } +}; + +class WFGoTask : public ExecRequest +{ +public: + void start() + { + assert(!series_of(this)); + Workflow::start_series_work(this, nullptr); + } + + void dismiss() + { + assert(!series_of(this)); + delete this; + } + +public: + void *user_data; + +public: + int get_state() const { return this->state; } + int get_error() const { return this->error; } + +public: + void set_callback(std::function cb) + { + this->callback = std::move(cb); + } + +protected: + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + + if (this->callback) + this->callback(this); + + delete this; + return series->pop(); + } + +protected: + std::function callback; + +public: + WFGoTask(ExecQueue *queue, Executor *executor) : + ExecRequest(queue, executor) + { + this->user_data = NULL; + this->state = WFT_STATE_UNDEFINED; + this->error = 0; + } + +protected: + virtual ~WFGoTask() { } +}; + +class WFRepeaterTask : public WFGenericTask +{ +public: + void set_create(std::function create) + { + this->create = std::move(create); + } + + void set_callback(std::function cb) + { + this->callback = std::move(cb); + } + +protected: + virtual void dispatch() + { + SubTask *task = this->create(this); + + if (task) + { + series_of(this)->push_front(this); + series_of(this)->push_front(task); + } + else + this->state = WFT_STATE_SUCCESS; + + this->subtask_done(); + } + + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + + if (this->state != WFT_STATE_UNDEFINED) + { + if (this->callback) + this->callback(this); + + delete this; + } + + return series->pop(); + } + +protected: + std::function create; + std::function callback; + +public: + WFRepeaterTask(std::function&& create, + std::function&& cb) : + create(std::move(create)), + callback(std::move(cb)) + { + } + +protected: + virtual ~WFRepeaterTask() { } +}; + +class WFModuleTask : public ParallelTask, protected SeriesWork +{ +public: + void start() + { + assert(!series_of(this)); + Workflow::start_series_work(this, nullptr); + } + + void dismiss() + { + assert(!series_of(this)); + delete this; + } + +public: + SeriesWork *sub_series() { return this; } + + const SeriesWork *sub_series() const { return this; } + +public: + void *user_data; + +public: + void set_callback(std::function cb) + { + this->callback = std::move(cb); + } + +protected: + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + + if (this->callback) + this->callback(this); + + delete this; + return series->pop(); + } + +protected: + SubTask *first; + std::function callback; + +public: + WFModuleTask(SubTask *first, + std::function&& cb) : + ParallelTask(&this->first, 1), + SeriesWork(first, nullptr), + callback(std::move(cb)) + { + this->first = first; + this->set_in_parallel(this); + this->user_data = NULL; + } + +protected: + virtual ~WFModuleTask() + { + if (!this->is_finished()) + this->dismiss_recursive(); + } +}; + +#include "WFTask.inl" + +#endif + diff --git a/src/factory/WFTask.inl b/src/factory/WFTask.inl new file mode 100644 index 0000000..3ce9b53 --- /dev/null +++ b/src/factory/WFTask.inl @@ -0,0 +1,256 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +template +int WFNetworkTask::get_peer_addr(struct sockaddr *addr, + socklen_t *addrlen) const +{ + const struct sockaddr *p; + socklen_t len; + + if (this->target) + { + this->target->get_addr(&p, &len); + if (*addrlen >= len) + { + memcpy(addr, p, len); + *addrlen = len; + return 0; + } + + errno = ENOBUFS; + } + else + errno = ENOTCONN; + + return -1; +} + +template +class WFClientTask : public WFNetworkTask +{ +protected: + virtual CommMessageOut *message_out() + { + /* By using prepare function, users can modify request after + * the connection is established. */ + if (this->prepare) + this->prepare(this); + + return &this->req; + } + + virtual CommMessageIn *message_in() { return &this->resp; } + +protected: + virtual WFConnection *get_connection() const + { + CommConnection *conn; + + if (this->target) + { + conn = this->CommSession::get_connection(); + if (conn) + return (WFConnection *)conn; + } + + errno = ENOTCONN; + return NULL; + } + +protected: + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + + if (this->state == WFT_STATE_SYS_ERROR && this->error < 0) + { + this->state = WFT_STATE_SSL_ERROR; + this->error = -this->error; + } + + if (this->callback) + this->callback(this); + + delete this; + return series->pop(); + } + +public: + void set_prepare(std::function *)> prep) + { + this->prepare = std::move(prep); + } + +protected: + std::function *)> prepare; + +public: + WFClientTask(CommSchedObject *object, CommScheduler *scheduler, + std::function *)>&& cb) : + WFNetworkTask(object, scheduler, std::move(cb)) + { + } + +protected: + virtual ~WFClientTask() { } +}; + +template +class WFServerTask : public WFNetworkTask +{ +protected: + virtual CommMessageOut *message_out() { return &this->resp; } + virtual CommMessageIn *message_in() { return &this->req; } + virtual void handle(int state, int error); + +protected: + /* CommSession::get_connection() is supposed to be called only in the + * implementations of it's virtual functions. As a server task, to call + * this function after process() and before callback() is very dangerous + * and should be blocked. */ + virtual WFConnection *get_connection() const + { + if (this->processor.task) + return (WFConnection *)this->CommSession::get_connection(); + + errno = EPERM; + return NULL; + } + +protected: + virtual void dispatch() + { + if (this->state == WFT_STATE_TOREPLY) + { + /* Enable get_connection() again if the reply() call is success. */ + this->processor.task = this; + if (this->scheduler->reply(this) >= 0) + return; + + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + this->processor.task = NULL; + } + else + this->scheduler->shutdown(this); + + this->subtask_done(); + } + + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + + if (this->state == WFT_STATE_SYS_ERROR && this->error < 0) + { + this->state = WFT_STATE_SSL_ERROR; + this->error = -this->error; + } + + if (this->callback) + this->callback(this); + + /* Defer deleting the task. */ + return series->pop(); + } + +protected: + class Processor : public SubTask + { + public: + Processor(WFServerTask *task, + std::function *)>& proc) : + process(proc) + { + this->task = task; + } + + virtual void dispatch() + { + this->process(this->task); + this->task = NULL; /* As a flag. get_conneciton() disabled. */ + this->subtask_done(); + } + + virtual SubTask *done() + { + return series_of(this)->pop(); + } + + std::function *)>& process; + WFServerTask *task; + } processor; + + class Series : public SeriesWork + { + public: + Series(WFServerTask *task) : + SeriesWork(&task->processor, nullptr) + { + this->set_last_task(task); + this->task = task; + } + + virtual ~Series() + { + delete this->task; + } + + WFServerTask *task; + }; + +public: + WFServerTask(CommService *service, CommScheduler *scheduler, + std::function *)>& proc) : + WFNetworkTask(NULL, scheduler, nullptr), + processor(this, proc) + { + } + +protected: + virtual ~WFServerTask() + { + if (this->target) + ((Series *)series_of(this))->task = NULL; + } +}; + +template +void WFServerTask::handle(int state, int error) +{ + if (state == WFT_STATE_TOREPLY) + { + this->state = WFT_STATE_TOREPLY; + this->target = this->get_target(); + new Series(this); + this->processor.dispatch(); + } + else if (this->state == WFT_STATE_TOREPLY) + { + this->state = state; + this->error = error; + if (error == ETIMEDOUT) + this->timeout_reason = TOR_TRANSMIT_TIMEOUT; + + this->subtask_done(); + } + else + delete this; +} + diff --git a/src/factory/WFTaskError.h b/src/factory/WFTaskError.h new file mode 100644 index 0000000..6f5613d --- /dev/null +++ b/src/factory/WFTaskError.h @@ -0,0 +1,75 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _WFTASKERROR_H_ +#define _WFTASKERROR_H_ + +/** + * @file WFTaskError.h + * @brief Workflow Task Error Code List + */ + +/** + * @brief Defination of task state code + * @note Only for WFNetworkTask and only when get_state()==WFT_STATE_TASK_ERROR + */ +enum +{ + //COMMON + WFT_ERR_URI_PARSE_FAILED = 1001, ///< URI, parse failed + WFT_ERR_URI_SCHEME_INVALID = 1002, ///< URI, invalid scheme + WFT_ERR_URI_PORT_INVALID = 1003, ///< URI, invalid port + WFT_ERR_UPSTREAM_UNAVAILABLE = 1004, ///< Upstream, all target server down + + //HTTP + WFT_ERR_HTTP_BAD_REDIRECT_HEADER = 2001, ///< Http, 301/302/303/307/308 Location header value is NULL + WFT_ERR_HTTP_PROXY_CONNECT_FAILED = 2002, ///< Http, proxy CONNECT return non 200 + + //REDIS + WFT_ERR_REDIS_ACCESS_DENIED = 3001, ///< Redis, invalid password + WFT_ERR_REDIS_COMMAND_DISALLOWED = 3002, ///< Redis, command disabled, cannot be "AUTH"/"SELECT" + + //MYSQL + WFT_ERR_MYSQL_HOST_NOT_ALLOWED = 4001, ///< MySQL + WFT_ERR_MYSQL_ACCESS_DENIED = 4002, ///< MySQL, authentication failed + WFT_ERR_MYSQL_INVALID_CHARACTER_SET = 4003, ///< MySQL, invalid charset, not found in MySQL-Documentation + WFT_ERR_MYSQL_COMMAND_DISALLOWED = 4004, ///< MySQL, sql command disabled, cannot be "USE"/"SET NAMES"/"SET CHARSET"/"SET CHARACTER SET" + WFT_ERR_MYSQL_QUERY_NOT_SET = 4005, ///< MySQL, query not set sql, maybe forget please check + WFT_ERR_MYSQL_SSL_NOT_SUPPORTED = 4006, ///< MySQL, SSL not supported by the server + + //KAFKA + WFT_ERR_KAFKA_PARSE_RESPONSE_FAILED = 5001, ///< Kafka parse response failed + WFT_ERR_KAFKA_PRODUCE_FAILED = 5002, + WFT_ERR_KAFKA_FETCH_FAILED = 5003, + WFT_ERR_KAFKA_CGROUP_FAILED = 5004, + WFT_ERR_KAFKA_COMMIT_FAILED = 5005, + WFT_ERR_KAFKA_META_FAILED = 5006, + WFT_ERR_KAFKA_LEAVEGROUP_FAILED = 5007, + WFT_ERR_KAFKA_API_UNKNOWN = 5008, ///< api type not supported + WFT_ERR_KAFKA_VERSION_DISALLOWED = 5009, ///< broker version not supported + WFT_ERR_KAFKA_SASL_DISALLOWED = 5010, ///< sasl not supported + WFT_ERR_KAFKA_ARRANGE_FAILED = 5011, ///< arrange toppar failed + WFT_ERR_KAFKA_LIST_OFFSETS_FAILED = 5012, + WFT_ERR_KAFKA_CGROUP_ASSIGN_FAILED = 5013, + + //CONSUL + WFT_ERR_CONSUL_API_UNKNOWN = 6001, ///< api type not supported + WFT_ERR_CONSUL_CHECK_RESPONSE_FAILED = 6002, ///< Consul http code failed +}; + +#endif diff --git a/src/factory/WFTaskFactory.cc b/src/factory/WFTaskFactory.cc new file mode 100644 index 0000000..79844e1 --- /dev/null +++ b/src/factory/WFTaskFactory.cc @@ -0,0 +1,1168 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include "list.h" +#include "rbtree.h" +#include "WFGlobal.h" +#include "WFTaskFactory.h" + +class __WFTimerTask : public WFTimerTask +{ +protected: + virtual int duration(struct timespec *value) + { + value->tv_sec = this->seconds; + value->tv_nsec = this->nanoseconds; + return 0; + } + +protected: + time_t seconds; + long nanoseconds; + +public: + __WFTimerTask(time_t seconds, long nanoseconds, CommScheduler *scheduler, + timer_callback_t&& cb) : + WFTimerTask(scheduler, std::move(cb)) + { + this->seconds = seconds; + this->nanoseconds = nanoseconds; + } +}; + +WFTimerTask *WFTaskFactory::create_timer_task(time_t seconds, long nanoseconds, + timer_callback_t callback) +{ + return new __WFTimerTask(seconds, nanoseconds, WFGlobal::get_scheduler(), + std::move(callback)); +} + +/* Deprecated. */ +WFTimerTask *WFTaskFactory::create_timer_task(unsigned int microseconds, + timer_callback_t callback) +{ + return WFTaskFactory::create_timer_task(microseconds / 1000000, + microseconds % 1000000 * 1000, + std::move(callback)); +} + +/****************** Named Tasks ******************/ + +template +struct __NamedObjectList +{ + __NamedObjectList(const std::string& str): + name(str) + { + INIT_LIST_HEAD(&this->head); + } + + void push_back(T *node) + { + list_add_tail(&node->list, &this->head); + } + + bool empty() const + { + return list_empty(&this->head); + } + + bool del(T *node, rb_root *root) + { + list_del(&node->list); + if (this->empty()) + { + rb_erase(&this->rb, root); + return true; + } + else + return false; + } + + struct rb_node rb; + struct list_head head; + std::string name; +}; + +template +static T *__get_object_list(const std::string& name, struct rb_root *root, + bool insert) +{ + struct rb_node **p = &root->rb_node; + struct rb_node *parent = NULL; + T *objs; + int n; + + while (*p) + { + parent = *p; + objs = rb_entry(*p, T, rb); + n = name.compare(objs->name); + if (n < 0) + p = &(*p)->rb_left; + else if (n > 0) + p = &(*p)->rb_right; + else + return objs; + } + + if (insert) + { + objs = new T(name); + rb_link_node(&objs->rb, parent, p); + rb_insert_color(&objs->rb, root); + return objs; + } + + return NULL; +} + +/****************** Named Timer ******************/ + +class __WFNamedTimerTask; + +struct __timer_node +{ + struct list_head list; + __WFNamedTimerTask *task; +}; + +static class __NamedTimerMap +{ +public: + using TimerList = __NamedObjectList; + +public: + WFTimerTask *create(const std::string& name, + time_t seconds, long nanoseconds, + CommScheduler *scheduler, + timer_callback_t&& cb); + +public: + int cancel(const std::string& name, size_t max); + +private: + struct rb_root root_; + std::mutex mutex_; + +public: + __NamedTimerMap() + { + root_.rb_node = NULL; + } + + friend class __WFNamedTimerTask; +} __timer_map; + +class __WFNamedTimerTask : public __WFTimerTask +{ +public: + __WFNamedTimerTask(time_t seconds, long nanoseconds, + CommScheduler *scheduler, + timer_callback_t&& cb) : + __WFTimerTask(seconds, nanoseconds, scheduler, std::move(cb)), + flag_(false) + { + node_.task = this; + } + + void push_to(__NamedTimerMap::TimerList *timers) + { + timers->push_back(&node_); + timers_ = timers; + } + + virtual ~__WFNamedTimerTask() + { + if (node_.task) + { + bool erased = false; + + __timer_map.mutex_.lock(); + if (node_.task) + erased = timers_->del(&node_, &__timer_map.root_); + + __timer_map.mutex_.unlock(); + if (erased) + delete timers_; + } + } + +protected: + virtual void dispatch(); + virtual void handle(int state, int error); + +private: + struct __timer_node node_; + __NamedTimerMap::TimerList *timers_; + std::atomic flag_; + std::mutex mutex_; + friend class __NamedTimerMap; +}; + +void __WFNamedTimerTask::dispatch() +{ + int ret; + + mutex_.lock(); + ret = this->scheduler->sleep(this); + if (ret >= 0 && flag_.exchange(true)) + this->cancel(); + + mutex_.unlock(); + if (ret < 0) + this->handle(WFT_STATE_SYS_ERROR, errno); +} + +void __WFNamedTimerTask::handle(int state, int error) +{ + bool canceled = true; + + if (node_.task) + { + bool erased = false; + + __timer_map.mutex_.lock(); + if (node_.task) + { + canceled = false; + erased = timers_->del(&node_, &__timer_map.root_); + node_.task = NULL; + } + + __timer_map.mutex_.unlock(); + if (erased) + delete timers_; + } + + if (canceled) + { + state = WFT_STATE_SYS_ERROR; + error = ECANCELED; + } + + mutex_.lock(); + mutex_.unlock(); + this->__WFTimerTask::handle(state, error); +} + +WFTimerTask *__NamedTimerMap::create(const std::string& name, + time_t seconds, long nanoseconds, + CommScheduler *scheduler, + timer_callback_t&& cb) +{ + auto *task = new __WFNamedTimerTask(seconds, nanoseconds, scheduler, + std::move(cb)); + mutex_.lock(); + task->push_to(__get_object_list(name, &root_, true)); + mutex_.unlock(); + return task; +} + +int __NamedTimerMap::cancel(const std::string& name, size_t max) +{ + struct __timer_node *node; + TimerList *timers; + int ret = 0; + + mutex_.lock(); + timers = __get_object_list(name, &root_, false); + if (timers) + { + while (1) + { + if (max == 0) + { + timers = NULL; + break; + } + + node = list_entry(timers->head.next, struct __timer_node, list); + list_del(&node->list); + if (node->task->flag_.exchange(true)) + node->task->cancel(); + + node->task = NULL; + max--; + ret++; + if (timers->empty()) + { + rb_erase(&timers->rb, &root_); + break; + } + } + } + + mutex_.unlock(); + delete timers; + return ret; +} + +WFTimerTask *WFTaskFactory::create_timer_task(const std::string& name, + time_t seconds, long nanoseconds, + timer_callback_t callback) +{ + return __timer_map.create(name, seconds, nanoseconds, + WFGlobal::get_scheduler(), + std::move(callback)); +} + +int WFTaskFactory::cancel_by_name(const std::string& name, size_t max) +{ + return __timer_map.cancel(name, max); +} + +/****************** Named Counter ******************/ + +class __WFNamedCounterTask; + +struct __counter_node +{ + struct list_head list; + unsigned int target_value; + __WFNamedCounterTask *task; +}; + +static class __NamedCounterMap +{ +public: + using CounterList = __NamedObjectList; + +public: + WFCounterTask *create(const std::string& name, unsigned int target_value, + counter_callback_t&& cb); + + int count_n(const std::string& name, unsigned int n); + void count(CounterList *counters, struct __counter_node *node); + + void remove(CounterList *counters, struct __counter_node *node) + { + bool erased; + + mutex_.lock(); + erased = counters->del(node, &root_); + mutex_.unlock(); + if (erased) + delete counters; + } + +private: + bool count_n_locked(CounterList *counters, unsigned int n, + struct list_head *task_list); + struct rb_root root_; + std::mutex mutex_; + +public: + __NamedCounterMap() + { + root_.rb_node = NULL; + } +} __counter_map; + +class __WFNamedCounterTask : public WFCounterTask +{ +public: + __WFNamedCounterTask(unsigned int target_value, counter_callback_t&& cb) : + WFCounterTask(1, std::move(cb)) + { + node_.target_value = target_value; + node_.task = this; + } + + void push_to(__NamedCounterMap::CounterList *counters) + { + counters->push_back(&node_); + counters_ = counters; + } + + virtual void count() + { + __counter_map.count(counters_, &node_); + } + + virtual ~__WFNamedCounterTask() + { + if (this->value != 0) + __counter_map.remove(counters_, &node_); + } + +private: + struct __counter_node node_; + __NamedCounterMap::CounterList *counters_; +}; + +WFCounterTask *__NamedCounterMap::create(const std::string& name, + unsigned int target_value, + counter_callback_t&& cb) +{ + if (target_value == 0) + return new WFCounterTask(0, std::move(cb)); + + auto *task = new __WFNamedCounterTask(target_value, std::move(cb)); + mutex_.lock(); + task->push_to(__get_object_list(name, &root_, true)); + mutex_.unlock(); + return task; +} + +bool __NamedCounterMap::count_n_locked(CounterList *counters, unsigned int n, + struct list_head *task_list) +{ + struct __counter_node *node; + + while (n > 0) + { + node = list_entry(counters->head.next, struct __counter_node, list); + if (n >= node->target_value) + { + n -= node->target_value; + node->target_value = 0; + list_move_tail(&node->list, task_list); + if (counters->empty()) + { + rb_erase(&counters->rb, &root_); + return true; + } + } + else + { + node->target_value -= n; + break; + } + } + + return false; +} + +int __NamedCounterMap::count_n(const std::string& name, unsigned int n) +{ + LIST_HEAD(task_list); + struct __counter_node *node; + CounterList *counters; + bool erased = false; + int ret = 0; + + mutex_.lock(); + counters = __get_object_list(name, &root_, false); + if (counters) + erased = count_n_locked(counters, n, &task_list); + + mutex_.unlock(); + if (erased) + delete counters; + + while (!list_empty(&task_list)) + { + node = list_entry(task_list.next, struct __counter_node, list); + list_del(&node->list); + node->task->WFCounterTask::count(); + ret++; + } + + return ret; +} + +void __NamedCounterMap::count(CounterList *counters, + struct __counter_node *node) +{ + __WFNamedCounterTask *task = NULL; + bool erased = false; + + mutex_.lock(); + if (--node->target_value == 0) + { + task = node->task; + erased = counters->del(node, &root_); + } + + mutex_.unlock(); + if (erased) + delete counters; + + if (task) + task->WFCounterTask::count(); +} + +WFCounterTask *WFTaskFactory::create_counter_task(const std::string& name, + unsigned int target_value, + counter_callback_t callback) +{ + return __counter_map.create(name, target_value, std::move(callback)); +} + +int WFTaskFactory::count_by_name(const std::string& name, unsigned int n) +{ + return __counter_map.count_n(name, n); +} + +/****************** Named Mailbox ******************/ + +class __WFNamedMailboxTask; + +struct __mailbox_node +{ + struct list_head list; + __WFNamedMailboxTask *task; +}; + +static class __NamedMailboxMap +{ +public: + using MailboxList = __NamedObjectList; + +public: + WFMailboxTask *create(const std::string& name, void **mailbox, + mailbox_callback_t&& cb); + WFMailboxTask *create(const std::string& name, mailbox_callback_t&& cb); + + int send(const std::string& name, void *const msg[], size_t max, int inc); + void send(MailboxList *mailboxes, struct __mailbox_node *node, void *msg); + + void remove(MailboxList *mailboxes, struct __mailbox_node *node) + { + bool erased; + + mutex_.lock(); + erased = mailboxes->del(node, &root_); + mutex_.unlock(); + if (erased) + delete mailboxes; + } + +private: + bool send_max_locked(MailboxList *mailboxes, size_t max, + struct list_head *task_list); + struct rb_root root_; + std::mutex mutex_; + +public: + __NamedMailboxMap() + { + root_.rb_node = NULL; + } +} __mailbox_map; + +class __WFNamedMailboxTask : public WFMailboxTask +{ +public: + __WFNamedMailboxTask(void **mailbox, mailbox_callback_t&& cb) : + WFMailboxTask(mailbox, std::move(cb)) + { + node_.task = this; + } + + __WFNamedMailboxTask(mailbox_callback_t&& cb) : + WFMailboxTask(std::move(cb)) + { + node_.task = this; + } + + void push_to(__NamedMailboxMap::MailboxList *mailboxes) + { + mailboxes->push_back(&node_); + mailboxes_ = mailboxes; + } + + virtual void send(void *msg) + { + __mailbox_map.send(mailboxes_, &node_, msg); + } + + virtual ~__WFNamedMailboxTask() + { + if (!this->flag) + __mailbox_map.remove(mailboxes_, &node_); + } + +private: + struct __mailbox_node node_; + __NamedMailboxMap::MailboxList *mailboxes_; +}; + +WFMailboxTask *__NamedMailboxMap::create(const std::string& name, + void **mailbox, + mailbox_callback_t&& cb) +{ + auto *task = new __WFNamedMailboxTask(mailbox, std::move(cb)); + mutex_.lock(); + task->push_to(__get_object_list(name, &root_, true)); + mutex_.unlock(); + return task; +} + +WFMailboxTask *__NamedMailboxMap::create(const std::string& name, + mailbox_callback_t&& cb) +{ + auto *task = new __WFNamedMailboxTask(std::move(cb)); + mutex_.lock(); + task->push_to(__get_object_list(name, &root_, true)); + mutex_.unlock(); + return task; +} + +bool __NamedMailboxMap::send_max_locked(MailboxList *mailboxes, size_t max, + struct list_head *task_list) +{ + if (max == (size_t)-1) + list_splice(&mailboxes->head, task_list); + else + { + do + { + if (max == 0) + return false; + + list_move_tail(mailboxes->head.next, task_list); + max--; + } while (!mailboxes->empty()); + } + + rb_erase(&mailboxes->rb, &root_); + return true; +} + +int __NamedMailboxMap::send(const std::string& name, void *const msg[], + size_t max, int inc) +{ + LIST_HEAD(task_list); + struct __mailbox_node *node; + MailboxList *mailboxes; + bool erased = false; + int ret = 0; + + mutex_.lock(); + mailboxes = __get_object_list(name, &root_, false); + if (mailboxes) + erased = send_max_locked(mailboxes, max, &task_list); + + mutex_.unlock(); + if (erased) + delete mailboxes; + + while (!list_empty(&task_list)) + { + node = list_entry(task_list.next, struct __mailbox_node, list); + list_del(&node->list); + node->task->WFMailboxTask::send(*msg); + msg += inc; + ret++; + } + + return ret; +} + +void __NamedMailboxMap::send(MailboxList *mailboxes, + struct __mailbox_node *node, + void *msg) +{ + bool erased; + + mutex_.lock(); + erased = mailboxes->del(node, &root_); + mutex_.unlock(); + if (erased) + delete mailboxes; + + node->task->WFMailboxTask::send(msg); +} + +WFMailboxTask *WFTaskFactory::create_mailbox_task(const std::string& name, + void **mailbox, + mailbox_callback_t callback) +{ + return __mailbox_map.create(name, mailbox, std::move(callback)); +} + +WFMailboxTask *WFTaskFactory::create_mailbox_task(const std::string& name, + mailbox_callback_t callback) +{ + return __mailbox_map.create(name, std::move(callback)); +} + +int WFTaskFactory::send_by_name(const std::string& name, void *msg, + size_t max) +{ + return __mailbox_map.send(name, &msg, max, 0); +} + +int WFTaskFactory::send_by_name(const std::string& name, void *const msg[], + size_t max) +{ + return __mailbox_map.send(name, msg, max, 1); +} + +/****************** Named Conditional ******************/ + +class __WFNamedConditional; + +struct __conditional_node +{ + struct list_head list; + __WFNamedConditional *cond; +}; + +static class __NamedConditionalMap +{ +public: + using ConditionalList = __NamedObjectList; + +public: + WFConditional *create(const std::string& name, SubTask *task, + void **msgbuf); + WFConditional *create(const std::string& name, SubTask *task); + + int signal(const std::string& name, void *const msg[], size_t max, int inc); + void signal(ConditionalList *conds, struct __conditional_node *node, + void *msg); + + void remove(ConditionalList *conds, struct __conditional_node *node) + { + bool erased; + + mutex_.lock(); + erased = conds->del(node, &root_); + mutex_.unlock(); + if (erased) + delete conds; + } + +private: + bool signal_max_locked(ConditionalList *conds, size_t max, + struct list_head *cond_list); + struct rb_root root_; + std::mutex mutex_; + +public: + __NamedConditionalMap() + { + root_.rb_node = NULL; + } +} __conditional_map; + +class __WFNamedConditional : public WFConditional +{ +public: + __WFNamedConditional(SubTask *task, void **msgbuf) : + WFConditional(task, msgbuf) + { + node_.cond = this; + } + + __WFNamedConditional(SubTask *task) : + WFConditional(task) + { + node_.cond = this; + } + + void push_to(__NamedConditionalMap::ConditionalList *conds) + { + conds->push_back(&node_); + conds_ = conds; + } + + virtual void signal(void *msg) + { + __conditional_map.signal(conds_, &node_, msg); + } + + virtual ~__WFNamedConditional() + { + if (!this->flag) + __conditional_map.remove(conds_, &node_); + } + +private: + struct __conditional_node node_; + __NamedConditionalMap::ConditionalList *conds_; +}; + +WFConditional *__NamedConditionalMap::create(const std::string& name, + SubTask *task, void **msgbuf) +{ + auto *cond = new __WFNamedConditional(task, msgbuf); + mutex_.lock(); + cond->push_to(__get_object_list(name, &root_, true)); + mutex_.unlock(); + return cond; +} + +WFConditional *__NamedConditionalMap::create(const std::string& name, + SubTask *task) +{ + auto *cond = new __WFNamedConditional(task); + mutex_.lock(); + cond->push_to(__get_object_list(name, &root_, true)); + mutex_.unlock(); + return cond; +} + +bool __NamedConditionalMap::signal_max_locked(ConditionalList *conds, + size_t max, + struct list_head *cond_list) +{ + if (max == (size_t)-1) + list_splice(&conds->head, cond_list); + else + { + do + { + if (max == 0) + return false; + + list_move_tail(conds->head.next, cond_list); + max--; + } while (!conds->empty()); + } + + rb_erase(&conds->rb, &root_); + return true; +} + +int __NamedConditionalMap::signal(const std::string& name, void *const msg[], + size_t max, int inc) +{ + LIST_HEAD(cond_list); + struct __conditional_node *node; + ConditionalList *conds; + bool erased = false; + int ret = 0; + + mutex_.lock(); + conds = __get_object_list(name, &root_, false); + if (conds) + erased = signal_max_locked(conds, max, &cond_list); + + mutex_.unlock(); + if (erased) + delete conds; + + while (!list_empty(&cond_list)) + { + node = list_entry(cond_list.next, struct __conditional_node, list); + list_del(&node->list); + node->cond->WFConditional::signal(*msg); + msg += inc; + ret++; + } + + return ret; +} + +void __NamedConditionalMap::signal(ConditionalList *conds, + struct __conditional_node *node, + void *msg) +{ + bool erased; + + mutex_.lock(); + erased = conds->del(node, &root_); + mutex_.unlock(); + if (erased) + delete conds; + + node->cond->WFConditional::signal(msg); +} + +WFConditional *WFTaskFactory::create_conditional(const std::string& name, + SubTask *task, void **msgbuf) +{ + return __conditional_map.create(name, task, msgbuf); +} + +WFConditional *WFTaskFactory::create_conditional(const std::string& name, + SubTask *task) +{ + return __conditional_map.create(name, task); +} + +int WFTaskFactory::signal_by_name(const std::string& name, void *msg, + size_t max) +{ + return __conditional_map.signal(name, &msg, max, 0); +} + +int WFTaskFactory::signal_by_name(const std::string& name, void *const msg[], + size_t max) +{ + return __conditional_map.signal(name, msg, max, 1); +} + +/****************** Named Guard ******************/ + +class __WFNamedGuard; + +struct __guard_node +{ + struct list_head list; + __WFNamedGuard *guard; +}; + +static class __NamedGuardMap +{ +public: + struct GuardList : public __NamedObjectList + { + GuardList(const std::string& name) : __NamedObjectList(name) + { + acquired = false; + refcnt = 0; + } + + bool acquired; + size_t refcnt; + std::mutex mutex; + }; + +public: + WFConditional *create(const std::string& name, SubTask *task); + WFConditional *create(const std::string& name, SubTask *task, + void **msgbuf); + + struct __guard_node *release(const std::string& name); + + void unref(GuardList *guards) + { + mutex_.lock(); + if (--guards->refcnt == 0) + rb_erase(&guards->rb, &root_); + else + guards = NULL; + + mutex_.unlock(); + delete guards; + } + +private: + struct rb_root root_; + std::mutex mutex_; + +public: + __NamedGuardMap() + { + root_.rb_node = NULL; + } +} __guard_map; + +class __WFNamedGuard : public WFConditional +{ +public: + __WFNamedGuard(SubTask *task) : WFConditional(task) + { + node_.guard = this; + } + + __WFNamedGuard(SubTask *task, void **msgbuf) : WFConditional(task, msgbuf) + { + node_.guard = this; + } + + virtual ~__WFNamedGuard() + { + if (!this->flag) + __guard_map.unref(guards_); + } + + SubTask *get_task() const { return this->task; } + + void set_task(SubTask *task) { this->task = task; } + +protected: + virtual void dispatch(); + virtual void signal(void *msg) { } + +private: + struct __guard_node node_; + __NamedGuardMap::GuardList *guards_; + friend __NamedGuardMap; +}; + +void __WFNamedGuard::dispatch() +{ + guards_->mutex.lock(); + if (guards_->acquired) + guards_->push_back(&node_); + else + { + guards_->acquired = true; + this->WFConditional::signal(NULL); + } + + guards_->mutex.unlock(); + this->WFConditional::dispatch(); +} + +WFConditional *__NamedGuardMap::create(const std::string& name, + SubTask *task) +{ + auto *guard = new __WFNamedGuard(task); + mutex_.lock(); + guard->guards_ = __get_object_list(name, &root_, true); + guard->guards_->refcnt++; + mutex_.unlock(); + return guard; +} + +WFConditional *__NamedGuardMap::create(const std::string& name, + SubTask *task, void **msgbuf) +{ + auto *guard = new __WFNamedGuard(task, msgbuf); + mutex_.lock(); + guard->guards_ = __get_object_list(name, &root_, true); + guard->guards_->refcnt++; + mutex_.unlock(); + return guard; +} + +struct __guard_node *__NamedGuardMap::release(const std::string& name) +{ + struct __guard_node *node = NULL; + GuardList *guards; + + mutex_.lock(); + guards = __get_object_list(name, &root_, false); + if (guards) + { + if (--guards->refcnt == 0) + rb_erase(&guards->rb, &root_); + else + { + guards->mutex.lock(); + if (!guards->empty()) + { + node = list_entry(guards->head.next, struct __guard_node, list); + list_del(&node->list); + } + else + guards->acquired = false; + + guards->mutex.unlock(); + guards = NULL; + } + } + + mutex_.unlock(); + delete guards; + return node; +} + +WFConditional *WFTaskFactory::create_guard(const std::string& name, + SubTask *task) +{ + return __guard_map.create(name, task); +} + +WFConditional *WFTaskFactory::create_guard(const std::string& name, + SubTask *task, void **msgbuf) +{ + return __guard_map.create(name, task, msgbuf); +} + +int WFTaskFactory::release_guard(const std::string& name, void *msg) +{ + struct __guard_node *node = __guard_map.release(name); + + if (!node) + return 0; + + node->guard->WFConditional::signal(msg); + return 1; +} + +int WFTaskFactory::release_guard_safe(const std::string& name, void *msg) +{ + struct __guard_node *node = __guard_map.release(name); + WFTimerTask *timer; + + if (!node) + return 0; + + timer = WFTaskFactory::create_timer_task(0, 0, [](WFTimerTask *timer) { + series_of(timer)->push_front((SubTask *)timer->user_data); + }); + timer->user_data = node->guard->get_task(); + node->guard->set_task(timer); + node->guard->WFConditional::signal(msg); + return 1; +} + +/**************** Timed Go Task *****************/ + +void __WFTimedGoTask::dispatch() +{ + WFTimerTask *timer; + + timer = WFTaskFactory::create_timer_task(this->seconds, this->nanoseconds, + __WFTimedGoTask::timer_callback); + timer->user_data = this; + + this->__WFGoTask::dispatch(); + timer->start(); +} + +SubTask *__WFTimedGoTask::done() +{ + if (this->callback) + this->callback(this); + + return series_of(this)->pop(); +} + +void __WFTimedGoTask::handle(int state, int error) +{ + if (--this->ref == 3) + { + this->state = state; + this->error = error; + this->subtask_done(); + } + + if (--this->ref == 0) + delete this; +} + +void __WFTimedGoTask::timer_callback(WFTimerTask *timer) +{ + __WFTimedGoTask *task = (__WFTimedGoTask *)timer->user_data; + + if (--task->ref == 3) + { + if (timer->get_state() == WFT_STATE_SUCCESS) + { + task->state = WFT_STATE_SYS_ERROR; + task->error = ETIMEDOUT; + } + else + { + task->state = timer->get_state(); + task->error = timer->get_error(); + } + + task->subtask_done(); + } + + if (--task->ref == 0) + delete task; +} + diff --git a/src/factory/WFTaskFactory.h b/src/factory/WFTaskFactory.h new file mode 100644 index 0000000..3834567 --- /dev/null +++ b/src/factory/WFTaskFactory.h @@ -0,0 +1,499 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _WFTASKFACTORY_H_ +#define _WFTASKFACTORY_H_ + +#include +#include +#include +#include +#include +#include +#include "URIParser.h" +#include "RedisMessage.h" +#include "HttpMessage.h" +#include "MySQLMessage.h" +#include "DnsMessage.h" +#include "Workflow.h" +#include "WFTask.h" +#include "WFGraphTask.h" +#include "EndpointParams.h" + +// Network Client/Server tasks + +using WFHttpTask = WFNetworkTask; +using http_callback_t = std::function; + +using WFRedisTask = WFNetworkTask; +using redis_callback_t = std::function; + +using WFMySQLTask = WFNetworkTask; +using mysql_callback_t = std::function; + +using WFDnsTask = WFNetworkTask; +using dns_callback_t = std::function; + +// File IO tasks + +struct FileIOArgs +{ + int fd; + void *buf; + size_t count; + off_t offset; +}; + +struct FileVIOArgs +{ + int fd; + const struct iovec *iov; + int iovcnt; + off_t offset; +}; + +struct FileSyncArgs +{ + int fd; +}; + +using WFFileIOTask = WFFileTask; +using fio_callback_t = std::function; + +using WFFileVIOTask = WFFileTask; +using fvio_callback_t = std::function; + +using WFFileSyncTask = WFFileTask; +using fsync_callback_t = std::function; + +// Timer and counter +using timer_callback_t = std::function; +using counter_callback_t = std::function; + +using mailbox_callback_t = std::function; + +using selector_callback_t = std::function; + +// Graph (DAG) task. +using graph_callback_t = std::function; + +using WFEmptyTask = WFGenericTask; + +using WFDynamicTask = WFGenericTask; +using dynamic_create_t = std::function; + +using repeated_create_t = std::function; +using repeater_callback_t = std::function; + +using module_callback_t = std::function; + +class WFTaskFactory +{ +public: + static WFHttpTask *create_http_task(const std::string& url, + int redirect_max, + int retry_max, + http_callback_t callback); + + static WFHttpTask *create_http_task(const ParsedURI& uri, + int redirect_max, + int retry_max, + http_callback_t callback); + + static WFHttpTask *create_http_task(const std::string& url, + const std::string& proxy_url, + int redirect_max, + int retry_max, + http_callback_t callback); + + static WFHttpTask *create_http_task(const ParsedURI& uri, + const ParsedURI& proxy_uri, + int redirect_max, + int retry_max, + http_callback_t callback); + + static WFRedisTask *create_redis_task(const std::string& url, + int retry_max, + redis_callback_t callback); + + static WFRedisTask *create_redis_task(const ParsedURI& uri, + int retry_max, + redis_callback_t callback); + + static WFMySQLTask *create_mysql_task(const std::string& url, + int retry_max, + mysql_callback_t callback); + + static WFMySQLTask *create_mysql_task(const ParsedURI& uri, + int retry_max, + mysql_callback_t callback); + + static WFDnsTask *create_dns_task(const std::string& url, + int retry_max, + dns_callback_t callback); + + static WFDnsTask *create_dns_task(const ParsedURI& uri, + int retry_max, + dns_callback_t callback); + +public: + static WFFileIOTask *create_pread_task(int fd, + void *buf, + size_t count, + off_t offset, + fio_callback_t callback); + + static WFFileIOTask *create_pwrite_task(int fd, + const void *buf, + size_t count, + off_t offset, + fio_callback_t callback); + + /* preadv and pwritev tasks are supported by Linux aio only. + * On macOS or others, you will get an ENOSYS error in callback. */ + + static WFFileVIOTask *create_preadv_task(int fd, + const struct iovec *iov, + int iovcnt, + off_t offset, + fvio_callback_t callback); + + static WFFileVIOTask *create_pwritev_task(int fd, + const struct iovec *iov, + int iovcnt, + off_t offset, + fvio_callback_t callback); + + static WFFileSyncTask *create_fsync_task(int fd, + fsync_callback_t callback); + + /* On systems that do not support fdatasync(), like macOS, + * fdsync task is equal to fsync task. */ + static WFFileSyncTask *create_fdsync_task(int fd, + fsync_callback_t callback); + + /* File tasks with path name. */ +public: + static WFFileIOTask *create_pread_task(const std::string& path, + void *buf, + size_t count, + off_t offset, + fio_callback_t callback); + + static WFFileIOTask *create_pwrite_task(const std::string& path, + const void *buf, + size_t count, + off_t offset, + fio_callback_t callback); + + static WFFileVIOTask *create_preadv_task(const std::string& path, + const struct iovec *iov, + int iovcnt, + off_t offset, + fvio_callback_t callback); + + static WFFileVIOTask *create_pwritev_task(const std::string& path, + const struct iovec *iov, + int iovcnt, + off_t offset, + fvio_callback_t callback); + +public: + static WFTimerTask *create_timer_task(time_t seconds, long nanoseconds, + timer_callback_t callback); + + /* create a named timer. */ + static WFTimerTask *create_timer_task(const std::string& timer_name, + time_t seconds, long nanoseconds, + timer_callback_t callback); + + /* cancel all timers under the name. */ + static int cancel_by_name(const std::string& timer_name) + { + return WFTaskFactory::cancel_by_name(timer_name, (size_t)-1); + } + + /* cancel at most 'max' timers under the name. */ + static int cancel_by_name(const std::string& timer_name, size_t max); + + /* timer in microseconds (deprecated) */ + static WFTimerTask *create_timer_task(unsigned int microseconds, + timer_callback_t callback); + +public: + /* Create an unnamed counter. Call counter->count() directly. + * NOTE: never call count() exceeding target_value. */ + static WFCounterTask *create_counter_task(unsigned int target_value, + counter_callback_t callback) + { + return new WFCounterTask(target_value, std::move(callback)); + } + + /* Create a named counter. */ + static WFCounterTask *create_counter_task(const std::string& counter_name, + unsigned int target_value, + counter_callback_t callback); + + /* Count by a counter's name. When count_by_name(), it's safe to count + * exceeding target_value. When multiple counters share a same name, + * this operation will be performed on the first created. If no counter + * matches the name, nothing is performed. */ + static int count_by_name(const std::string& counter_name) + { + return WFTaskFactory::count_by_name(counter_name, 1); + } + + /* Count by name with a value n. When multiple counters share this name, + * the operation is performed on the counters in the sequence of its + * creation, and more than one counter may reach target value. */ + static int count_by_name(const std::string& counter_name, unsigned int n); + +public: + static WFMailboxTask *create_mailbox_task(void **mailbox, + mailbox_callback_t callback) + { + return new WFMailboxTask(mailbox, std::move(callback)); + } + + /* Use 'user_data' as mailbox. */ + static WFMailboxTask *create_mailbox_task(mailbox_callback_t callback) + { + return new WFMailboxTask(std::move(callback)); + } + + static WFMailboxTask *create_mailbox_task(const std::string& mailbox_name, + void **mailbox, + mailbox_callback_t callback); + + static WFMailboxTask *create_mailbox_task(const std::string& mailbox_name, + mailbox_callback_t callback); + + /* The 'msg' will be sent to the all mailbox tasks under the name, and + * would be lost if no task matched. */ + static int send_by_name(const std::string& mailbox_name, void *msg) + { + return WFTaskFactory::send_by_name(mailbox_name, msg, (size_t)-1); + } + + static int send_by_name(const std::string& mailbox_name, void *msg, + size_t max); + + static int send_by_name(const std::string& mailbox_name, void *const msg[], + size_t max); + +public: + static WFSelectorTask *create_selector_task(size_t candidates, + selector_callback_t callback) + { + return new WFSelectorTask(candidates, std::move(callback)); + } + +public: + static WFConditional *create_conditional(SubTask *task, void **msgbuf) + { + return new WFConditional(task, msgbuf); + } + + static WFConditional *create_conditional(SubTask *task) + { + return new WFConditional(task); + } + + static WFConditional *create_conditional(const std::string& cond_name, + SubTask *task, void **msgbuf); + + static WFConditional *create_conditional(const std::string& cond_name, + SubTask *task); + + static int signal_by_name(const std::string& cond_name, void *msg) + { + return WFTaskFactory::signal_by_name(cond_name, msg, (size_t)-1); + } + + static int signal_by_name(const std::string& cond_name, void *msg, + size_t max); + + static int signal_by_name(const std::string& cond_name, void *const msg[], + size_t max); + +public: + static WFConditional *create_guard(const std::string& resource_name, + SubTask *task); + + static WFConditional *create_guard(const std::string& resource_name, + SubTask *task, void **msgbuf); + + /* The 'guard' is acquired after started, so call 'release_guard' after + and only after the task is finished, typically in its callback. + The function returns 1 if another is signaled, otherwise returns 0. */ + static int release_guard(const std::string& resource_name) + { + return WFTaskFactory::release_guard(resource_name, NULL); + } + + static int release_guard(const std::string& resource_name, void *msg); + + static int release_guard_safe(const std::string& resource_name) + { + return WFTaskFactory::release_guard_safe(resource_name, NULL); + } + + static int release_guard_safe(const std::string& resource_name, void *msg); + +public: + template + static WFGoTask *create_go_task(const std::string& queue_name, + FUNC&& func, ARGS&&... args); + + /* Create 'Go' task with running time limit in seconds plus nanoseconds. + * If time exceeded, state WFT_STATE_SYS_ERROR and error ETIMEDOUT + * will be got in callback. */ + template + static WFGoTask *create_timedgo_task(time_t seconds, long nanoseconds, + const std::string& queue_name, + FUNC&& func, ARGS&&... args); + + /* Create 'Go' task on user's executor and execution queue. */ + template + static WFGoTask *create_go_task(ExecQueue *queue, Executor *executor, + FUNC&& func, ARGS&&... args); + + template + static WFGoTask *create_timedgo_task(time_t seconds, long nanoseconds, + ExecQueue *queue, Executor *executor, + FUNC&& func, ARGS&&... args); + + /* For capturing 'task' itself in go task's running function. */ + template + static void reset_go_task(WFGoTask *task, FUNC&& func, ARGS&&... args); + +public: + static WFGraphTask *create_graph_task(graph_callback_t callback) + { + return new WFGraphTask(std::move(callback)); + } + +public: + static WFEmptyTask *create_empty_task() + { + return new WFEmptyTask; + } + + static WFDynamicTask *create_dynamic_task(dynamic_create_t create); + + static WFRepeaterTask *create_repeater_task(repeated_create_t create, + repeater_callback_t callback) + { + return new WFRepeaterTask(std::move(create), std::move(callback)); + } + +public: + static WFModuleTask *create_module_task(SubTask *first, + module_callback_t callback) + { + return new WFModuleTask(first, std::move(callback)); + } + + static WFModuleTask *create_module_task(SubTask *first, SubTask *last, + module_callback_t callback) + { + WFModuleTask *task = new WFModuleTask(first, std::move(callback)); + task->sub_series()->set_last_task(last); + return task; + } +}; + +template +class WFNetworkTaskFactory +{ +private: + using T = WFNetworkTask; + +public: + static T *create_client_task(enum TransportType type, + const std::string& host, + unsigned short port, + int retry_max, + std::function callback); + + static T *create_client_task(enum TransportType type, + const std::string& url, + int retry_max, + std::function callback); + + static T *create_client_task(enum TransportType type, + const ParsedURI& uri, + int retry_max, + std::function callback); + + static T *create_client_task(enum TransportType type, + const struct sockaddr *addr, + socklen_t addrlen, + int retry_max, + std::function callback); + + static T *create_client_task(enum TransportType type, + const struct sockaddr *addr, + socklen_t addrlen, + SSL_CTX *ssl_ctx, + int retry_max, + std::function callback); +public: + static T *create_server_task(CommService *service, + std::function& process); +}; + +template +class WFThreadTaskFactory +{ +private: + using T = WFThreadTask; + +public: + static T *create_thread_task(const std::string& queue_name, + std::function routine, + std::function callback); + + /* Create thread task with running time limit. */ + static T *create_thread_task(time_t seconds, long nanoseconds, + const std::string& queue_name, + std::function routine, + std::function callback); + +public: + /* Create thread task on user's executor and execution queue. */ + static T *create_thread_task(ExecQueue *queue, Executor *executor, + std::function routine, + std::function callback); + + /* With running time limit. */ + static T *create_thread_task(time_t seconds, long nanoseconds, + ExecQueue *queue, Executor *executor, + std::function routine, + std::function callback); +}; + +#include "WFTaskFactory.inl" + +#endif + diff --git a/src/factory/WFTaskFactory.inl b/src/factory/WFTaskFactory.inl new file mode 100644 index 0000000..07bd443 --- /dev/null +++ b/src/factory/WFTaskFactory.inl @@ -0,0 +1,916 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) + Li Yingxin (liyingxin@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "WFGlobal.h" +#include "Workflow.h" +#include "WFTask.h" +#include "RouteManager.h" +#include "URIParser.h" +#include "WFTaskError.h" +#include "EndpointParams.h" +#include "WFNameService.h" + +class __WFDynamicTask : public WFDynamicTask +{ +protected: + virtual void dispatch() + { + series_of(this)->push_front(this->create(this)); + this->WFDynamicTask::dispatch(); + } + +protected: + std::function create; + +public: + __WFDynamicTask(std::function&& create) : + create(std::move(create)) + { + } +}; + +inline WFDynamicTask * +WFTaskFactory::create_dynamic_task(dynamic_create_t create) +{ + return new __WFDynamicTask(std::move(create)); +} + +template +class WFComplexClientTask : public WFClientTask +{ +protected: + using task_callback_t = std::function *)>; + +public: + WFComplexClientTask(int retry_max, task_callback_t&& cb): + WFClientTask(NULL, WFGlobal::get_scheduler(), std::move(cb)) + { + type_ = TT_TCP; + ssl_ctx_ = NULL; + fixed_addr_ = false; + fixed_conn_ = false; + retry_max_ = retry_max; + retry_times_ = 0; + redirect_ = false; + ns_policy_ = NULL; + router_task_ = NULL; + } + +protected: + // new api for children + virtual bool init_success() { return true; } + virtual void init_failed() {} + virtual bool check_request() { return true; } + virtual WFRouterTask *route(); + virtual bool finish_once() { return true; } + +public: + void init(const ParsedURI& uri) + { + uri_ = uri; + init_with_uri(); + } + + void init(ParsedURI&& uri) + { + uri_ = std::move(uri); + init_with_uri(); + } + + void init(enum TransportType type, + const struct sockaddr *addr, + socklen_t addrlen, + const std::string& info); + + void set_transport_type(enum TransportType type) + { + type_ = type; + } + + enum TransportType get_transport_type() const { return type_; } + + void set_ssl_ctx(SSL_CTX *ssl_ctx) { ssl_ctx_ = ssl_ctx; } + + virtual const ParsedURI *get_current_uri() const { return &uri_; } + + void set_redirect(const ParsedURI& uri) + { + redirect_ = true; + init(uri); + } + + void set_redirect(enum TransportType type, const struct sockaddr *addr, + socklen_t addrlen, const std::string& info) + { + redirect_ = true; + init(type, addr, addrlen, info); + } + + bool is_fixed_addr() const { return this->fixed_addr_; } + + bool is_fixed_conn() const { return this->fixed_conn_; } + +protected: + void set_fixed_addr(int fixed) { this->fixed_addr_ = fixed; } + + void set_fixed_conn(int fixed) { this->fixed_conn_ = fixed; } + + void set_info(const std::string& info) + { + info_.assign(info); + } + + void set_info(const char *info) + { + info_.assign(info); + } + +protected: + virtual void dispatch(); + virtual SubTask *done(); + + void clear_resp() + { + RESP resp; + *(protocol::ProtocolMessage *)&resp = std::move(this->resp); + this->resp = std::move(resp); + } + + void disable_retry() + { + retry_times_ = retry_max_; + } + +protected: + enum TransportType type_; + ParsedURI uri_; + std::string info_; + SSL_CTX *ssl_ctx_; + bool fixed_addr_; + bool fixed_conn_; + bool redirect_; + CTX ctx_; + int retry_max_; + int retry_times_; + WFNSPolicy *ns_policy_; + WFRouterTask *router_task_; + RouteManager::RouteResult route_result_; + WFNSTracing tracing_; + +public: + CTX *get_mutable_ctx() { return &ctx_; } + +private: + void clear_prev_state(); + void init_with_uri(); + bool set_port(); + void router_callback(void *t); + void switch_callback(void *t); +}; + +template +void WFComplexClientTask::clear_prev_state() +{ + ns_policy_ = NULL; + route_result_.clear(); + if (tracing_.deleter) + { + tracing_.deleter(tracing_.data); + tracing_.deleter = NULL; + } + tracing_.data = NULL; + retry_times_ = 0; + this->state = WFT_STATE_UNDEFINED; + this->error = 0; + this->timeout_reason = TOR_NOT_TIMEOUT; +} + +template +void WFComplexClientTask::init(enum TransportType type, + const struct sockaddr *addr, + socklen_t addrlen, + const std::string& info) +{ + if (redirect_) + clear_prev_state(); + + auto params = WFGlobal::get_global_settings()->endpoint_params; + struct addrinfo addrinfo = { }; + addrinfo.ai_family = addr->sa_family; + addrinfo.ai_addr = (struct sockaddr *)addr; + addrinfo.ai_addrlen = addrlen; + + type_ = type; + info_.assign(info); + params.use_tls_sni = false; + if (WFGlobal::get_route_manager()->get(type, &addrinfo, info_, ¶ms, + "", ssl_ctx_, route_result_) < 0) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + } + else if (this->init_success()) + return; + + this->init_failed(); +} + +template +bool WFComplexClientTask::set_port() +{ + if (uri_.port) + { + int port = atoi(uri_.port); + + if (port <= 0 || port > 65535) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_PORT_INVALID; + return false; + } + + return true; + } + + if (uri_.scheme) + { + const char *port_str = WFGlobal::get_default_port(uri_.scheme); + + if (port_str) + { + uri_.port = strdup(port_str); + if (uri_.port) + return true; + + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + return false; + } + } + + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_SCHEME_INVALID; + return false; +} + +template +void WFComplexClientTask::init_with_uri() +{ + if (redirect_) + { + clear_prev_state(); + ns_policy_ = WFGlobal::get_dns_resolver(); + } + + if (uri_.state == URI_STATE_SUCCESS) + { + if (this->set_port()) + { + if (this->init_success()) + return; + } + } + else if (uri_.state == URI_STATE_ERROR) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = uri_.error; + } + else + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_URI_PARSE_FAILED; + } + + this->init_failed(); +} + +template +WFRouterTask *WFComplexClientTask::route() +{ + auto&& cb = std::bind(&WFComplexClientTask::router_callback, + this, + std::placeholders::_1); + struct WFNSParams params = { + .type = type_, + .uri = uri_, + .info = info_.c_str(), + .ssl_ctx = ssl_ctx_, + .fixed_addr = fixed_addr_, + .fixed_conn = fixed_conn_, + .retry_times = retry_times_, + .tracing = &tracing_, + }; + + if (!ns_policy_) + { + WFNameService *ns = WFGlobal::get_name_service(); + ns_policy_ = ns->get_policy(uri_.host ? uri_.host : ""); + } + + return ns_policy_->create_router_task(¶ms, std::move(cb)); +} + +template +void WFComplexClientTask::router_callback(void *t) +{ + WFRouterTask *task = (WFRouterTask *)t; + + this->state = task->get_state(); + if (this->state == WFT_STATE_SUCCESS) + route_result_ = std::move(*task->get_result()); + else if (this->state == WFT_STATE_UNDEFINED) + { + /* should not happend */ + this->state = WFT_STATE_SYS_ERROR; + this->error = ENOSYS; + } + else + this->error = task->get_error(); +} + +template +void WFComplexClientTask::dispatch() +{ + switch (this->state) + { + case WFT_STATE_UNDEFINED: + if (this->check_request()) + { + if (this->route_result_.request_object) + { + case WFT_STATE_SUCCESS: + this->set_request_object(route_result_.request_object); + this->WFClientTask::dispatch(); + return; + } + + router_task_ = this->route(); + series_of(this)->push_front(this); + series_of(this)->push_front(router_task_); + } + + default: + break; + } + + this->subtask_done(); +} + +template +void WFComplexClientTask::switch_callback(void *t) +{ + if (!redirect_) + { + if (this->state == WFT_STATE_SYS_ERROR && this->error < 0) + { + this->state = WFT_STATE_SSL_ERROR; + this->error = -this->error; + } + + if (tracing_.deleter) + { + tracing_.deleter(tracing_.data); + tracing_.deleter = NULL; + } + + if (this->callback) + this->callback(this); + } + + if (redirect_) + { + redirect_ = false; + clear_resp(); + this->target = NULL; + series_of(this)->push_front(this); + } + else + delete this; +} + +template +SubTask *WFComplexClientTask::done() +{ + SeriesWork *series = series_of(this); + + if (router_task_) + { + router_task_ = NULL; + return series->pop(); + } + + bool is_user_request = this->finish_once(); + + if (ns_policy_) + { + if (this->state == WFT_STATE_SYS_ERROR || + this->state == WFT_STATE_DNS_ERROR) + { + ns_policy_->failed(&route_result_, &tracing_, this->target); + } + else if (route_result_.request_object) + { + ns_policy_->success(&route_result_, &tracing_, this->target); + } + } + + if (this->state == WFT_STATE_SUCCESS) + { + if (!is_user_request) + return this; + } + else if (this->state == WFT_STATE_SYS_ERROR) + { + if (retry_times_ < retry_max_) + { + redirect_ = true; + if (ns_policy_) + route_result_.clear(); + + this->state = WFT_STATE_UNDEFINED; + this->error = 0; + this->timeout_reason = 0; + retry_times_++; + } + } + + /* When the target or the connection is NULL, it's very likely that we are + * in the caller's thread. Running a timer will switch callback function to + * a handler thread, and this can prevent stack overflow. */ + if (!this->target || !this->CommSession::get_connection()) + { + auto&& cb = std::bind(&WFComplexClientTask::switch_callback, + this, + std::placeholders::_1); + WFTimerTask *timer; + + timer = WFTaskFactory::create_timer_task(0, 0, std::move(cb)); + series->push_front(timer); + } + else + this->switch_callback(NULL); + + return series->pop(); +} + +/**********Template Network Factory**********/ + +template +WFNetworkTask * +WFNetworkTaskFactory::create_client_task(enum TransportType type, + const std::string& host, + unsigned short port, + int retry_max, + std::function *)> callback) +{ + auto *task = new WFComplexClientTask(retry_max, std::move(callback)); + ParsedURI uri; + char buf[32]; + + sprintf(buf, "%u", port); + uri.scheme = strdup("scheme"); + uri.host = strdup(host.c_str()); + uri.port = strdup(buf); + if (!uri.scheme || !uri.host || !uri.port) + { + uri.state = URI_STATE_ERROR; + uri.error = errno; + } + else + uri.state = URI_STATE_SUCCESS; + + task->init(std::move(uri)); + task->set_transport_type(type); + return task; +} + +template +WFNetworkTask * +WFNetworkTaskFactory::create_client_task(enum TransportType type, + const std::string& url, + int retry_max, + std::function *)> callback) +{ + auto *task = new WFComplexClientTask(retry_max, std::move(callback)); + ParsedURI uri; + + URIParser::parse(url, uri); + task->init(std::move(uri)); + task->set_transport_type(type); + return task; +} + +template +WFNetworkTask * +WFNetworkTaskFactory::create_client_task(enum TransportType type, + const ParsedURI& uri, + int retry_max, + std::function *)> callback) +{ + auto *task = new WFComplexClientTask(retry_max, std::move(callback)); + + task->init(uri); + task->set_transport_type(type); + return task; +} + +template +WFNetworkTask * +WFNetworkTaskFactory::create_client_task(enum TransportType type, + const struct sockaddr *addr, + socklen_t addrlen, + int retry_max, + std::function *)> callback) +{ + auto *task = new WFComplexClientTask(retry_max, std::move(callback)); + + task->init(type, addr, addrlen, ""); + return task; +} + +template +WFNetworkTask * +WFNetworkTaskFactory::create_client_task(enum TransportType type, + const struct sockaddr *addr, + socklen_t addrlen, + SSL_CTX *ssl_ctx, + int retry_max, + std::function *)> callback) +{ + auto *task = new WFComplexClientTask(retry_max, std::move(callback)); + + task->set_ssl_ctx(ssl_ctx); + task->init(type, addr, addrlen, ""); + return task; +} + +template +WFNetworkTask * +WFNetworkTaskFactory::create_server_task(CommService *service, + std::function *)>& process) +{ + return new WFServerTask(service, WFGlobal::get_scheduler(), process); +} + +/**********Server Factory**********/ + +class WFServerTaskFactory +{ +public: + static WFDnsTask *create_dns_task(CommService *service, + std::function& process); + + static WFHttpTask *create_http_task(CommService *service, + std::function& process); + + static WFMySQLTask *create_mysql_task(CommService *service, + std::function& process); +}; + +/************Go Task Factory************/ + +class __WFGoTask : public WFGoTask +{ +public: + void set_go_func(std::function func) + { + this->go = std::move(func); + } + +protected: + virtual void execute() + { + this->go(); + } + +protected: + std::function go; + +public: + __WFGoTask(ExecQueue *queue, Executor *executor, + std::function&& func) : + WFGoTask(queue, executor), + go(std::move(func)) + { + } +}; + +class __WFTimedGoTask : public __WFGoTask +{ +protected: + virtual void dispatch(); + virtual SubTask *done(); + +protected: + virtual void handle(int state, int error); + +protected: + static void timer_callback(WFTimerTask *timer); + +protected: + time_t seconds; + long nanoseconds; + std::atomic ref; + +public: + __WFTimedGoTask(time_t seconds, long nanoseconds, + ExecQueue *queue, Executor *executor, + std::function&& func) : + __WFGoTask(queue, executor, std::move(func)), + ref(4) + { + this->seconds = seconds; + this->nanoseconds = nanoseconds; + } +}; + +template +WFGoTask *WFTaskFactory::create_go_task(const std::string& queue_name, + FUNC&& func, ARGS&&... args) +{ + auto&& tmp = std::bind(std::forward(func), + std::forward(args)...); + return new __WFGoTask(WFGlobal::get_exec_queue(queue_name), + WFGlobal::get_compute_executor(), + std::move(tmp)); +} + +template +WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds, + const std::string& queue_name, + FUNC&& func, ARGS&&... args) +{ + auto&& tmp = std::bind(std::forward(func), + std::forward(args)...); + return new __WFTimedGoTask(seconds, nanoseconds, + WFGlobal::get_exec_queue(queue_name), + WFGlobal::get_compute_executor(), + std::move(tmp)); +} + +template +WFGoTask *WFTaskFactory::create_go_task(ExecQueue *queue, Executor *executor, + FUNC&& func, ARGS&&... args) +{ + auto&& tmp = std::bind(std::forward(func), + std::forward(args)...); + return new __WFGoTask(queue, executor, std::move(tmp)); +} + +template +WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds, + ExecQueue *queue, Executor *executor, + FUNC&& func, ARGS&&... args) +{ + auto&& tmp = std::bind(std::forward(func), + std::forward(args)...); + return new __WFTimedGoTask(seconds, nanoseconds, + queue, executor, + std::move(tmp)); +} + +template +void WFTaskFactory::reset_go_task(WFGoTask *task, FUNC&& func, ARGS&&... args) +{ + auto&& tmp = std::bind(std::forward(func), + std::forward(args)...); + ((__WFGoTask *)task)->set_go_func(std::move(tmp)); +} + +/**********Create go task with nullptr func**********/ + +template<> inline +WFGoTask *WFTaskFactory::create_go_task(const std::string& queue_name, + std::nullptr_t&& func) +{ + return new __WFGoTask(WFGlobal::get_exec_queue(queue_name), + WFGlobal::get_compute_executor(), + nullptr); +} + +template<> inline +WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds, + const std::string& queue_name, + std::nullptr_t&& func) +{ + return new __WFTimedGoTask(seconds, nanoseconds, + WFGlobal::get_exec_queue(queue_name), + WFGlobal::get_compute_executor(), + nullptr); +} + +template<> inline +WFGoTask *WFTaskFactory::create_go_task(ExecQueue *queue, Executor *executor, + std::nullptr_t&& func) +{ + return new __WFGoTask(queue, executor, nullptr); +} + +template<> inline +WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds, + ExecQueue *queue, Executor *executor, + std::nullptr_t&& func) +{ + return new __WFTimedGoTask(seconds, nanoseconds, queue, executor, nullptr); +} + +template<> inline +void WFTaskFactory::reset_go_task(WFGoTask *task, std::nullptr_t&& func) +{ + ((__WFGoTask *)task)->set_go_func(nullptr); +} + +/**********Template Thread Task Factory**********/ + +template +class __WFThreadTask : public WFThreadTask +{ +protected: + virtual void execute() + { + this->routine(&this->input, &this->output); + } + +protected: + std::function routine; + +public: + __WFThreadTask(ExecQueue *queue, Executor *executor, + std::function&& rt, + std::function *)>&& cb) : + WFThreadTask(queue, executor, std::move(cb)), + routine(std::move(rt)) + { + } +}; + +template +class __WFTimedThreadTask : public __WFThreadTask +{ +protected: + virtual void dispatch(); + virtual SubTask *done(); + +protected: + virtual void handle(int state, int error); + +protected: + static void timer_callback(WFTimerTask *timer); + +protected: + time_t seconds; + long nanoseconds; + std::atomic ref; + +public: + __WFTimedThreadTask(time_t seconds, long nanoseconds, + ExecQueue *queue, Executor *executor, + std::function&& rt, + std::function *)>&& cb) : + __WFThreadTask(queue, executor, std::move(rt), std::move(cb)), + ref(4) + { + this->seconds = seconds; + this->nanoseconds = nanoseconds; + } +}; + +template +void __WFTimedThreadTask::dispatch() +{ + WFTimerTask *timer; + + timer = WFTaskFactory::create_timer_task(this->seconds, this->nanoseconds, + __WFTimedThreadTask::timer_callback); + timer->user_data = this; + + this->__WFThreadTask::dispatch(); + timer->start(); +} + +template +SubTask *__WFTimedThreadTask::done() +{ + if (this->callback) + this->callback(this); + + return series_of(this)->pop(); +} + +template +void __WFTimedThreadTask::handle(int state, int error) +{ + if (--this->ref == 3) + { + this->state = state; + this->error = error; + this->subtask_done(); + } + + if (--this->ref == 0) + delete this; +} + +template +void __WFTimedThreadTask::timer_callback(WFTimerTask *timer) +{ + auto *task = (__WFTimedThreadTask *)timer->user_data; + + if (--task->ref == 3) + { + if (timer->get_state() == WFT_STATE_SUCCESS) + { + task->state = WFT_STATE_SYS_ERROR; + task->error = ETIMEDOUT; + } + else + { + task->state = timer->get_state(); + task->error = timer->get_error(); + } + + task->subtask_done(); + } + + if (--task->ref == 0) + delete task; +} + +template +WFThreadTask * +WFThreadTaskFactory::create_thread_task(const std::string& queue_name, + std::function routine, + std::function *)> callback) +{ + return new __WFThreadTask(WFGlobal::get_exec_queue(queue_name), + WFGlobal::get_compute_executor(), + std::move(routine), + std::move(callback)); +} + +template +WFThreadTask * +WFThreadTaskFactory::create_thread_task(time_t seconds, long nanoseconds, + const std::string& queue_name, + std::function routine, + std::function *)> callback) +{ + return new __WFTimedThreadTask(seconds, nanoseconds, + WFGlobal::get_exec_queue(queue_name), + WFGlobal::get_compute_executor(), + std::move(routine), + std::move(callback)); +} + +template +WFThreadTask * +WFThreadTaskFactory::create_thread_task(ExecQueue *queue, Executor *executor, + std::function routine, + std::function *)> callback) +{ + return new __WFThreadTask(queue, executor, + std::move(routine), + std::move(callback)); +} + +template +WFThreadTask * +WFThreadTaskFactory::create_thread_task(time_t seconds, long nanoseconds, + ExecQueue *queue, Executor *executor, + std::function routine, + std::function *)> callback) +{ + return new __WFTimedThreadTask(seconds, nanoseconds, + queue, executor, + std::move(routine), + std::move(callback)); +} + diff --git a/src/factory/Workflow.cc b/src/factory/Workflow.cc new file mode 100644 index 0000000..fe9926e --- /dev/null +++ b/src/factory/Workflow.cc @@ -0,0 +1,248 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include "Workflow.h" + +SeriesWork::SeriesWork(SubTask *first, series_callback_t&& cb) : + callback(std::move(cb)) +{ + this->queue = this->buf; + this->queue_size = sizeof this->buf / sizeof *this->buf; + this->front = 0; + this->back = 0; + this->canceled = false; + this->finished = false; + assert(!series_of(first)); + first->set_pointer(this); + this->first = first; + this->last = NULL; + this->context = NULL; + this->in_parallel = NULL; +} + +SeriesWork::~SeriesWork() +{ + if (this->queue != this->buf) + delete []this->queue; +} + +void SeriesWork::dismiss_recursive() +{ + SubTask *task = first; + + this->callback = nullptr; + do + { + delete task; + task = this->pop_task(); + } while (task); +} + +void SeriesWork::expand_queue() +{ + int size = 2 * this->queue_size; + SubTask **queue = new SubTask *[size]; + int i, j; + + i = 0; + j = this->front; + do + { + queue[i++] = this->queue[j++]; + if (j == this->queue_size) + j = 0; + } while (j != this->back); + + if (this->queue != this->buf) + delete []this->queue; + + this->queue = queue; + this->queue_size = size; + this->front = 0; + this->back = i; +} + +void SeriesWork::push_front(SubTask *task) +{ + this->mutex.lock(); + if (--this->front == -1) + this->front = this->queue_size - 1; + + task->set_pointer(this); + this->queue[this->front] = task; + if (this->front == this->back) + this->expand_queue(); + + this->mutex.unlock(); +} + +void SeriesWork::push_back(SubTask *task) +{ + this->mutex.lock(); + task->set_pointer(this); + this->queue[this->back] = task; + if (++this->back == this->queue_size) + this->back = 0; + + if (this->front == this->back) + this->expand_queue(); + + this->mutex.unlock(); +} + +SubTask *SeriesWork::pop() +{ + bool canceled = this->canceled; + SubTask *task = this->pop_task(); + + if (!canceled) + return task; + + while (task) + { + delete task; + task = this->pop_task(); + } + + return NULL; +} + +SubTask *SeriesWork::pop_task() +{ + SubTask *task; + + this->mutex.lock(); + if (this->front != this->back) + { + task = this->queue[this->front]; + if (++this->front == this->queue_size) + this->front = 0; + } + else + { + task = this->last; + this->last = NULL; + } + + this->mutex.unlock(); + if (!task) + { + this->finished = true; + + if (this->callback) + this->callback(this); + + if (!this->in_parallel) + delete this; + } + + return task; +} + +ParallelWork::ParallelWork(parallel_callback_t&& cb) : + ParallelTask(new SubTask *[2 * 4], 0), + callback(std::move(cb)) +{ + this->buf_size = 4; + this->all_series = (SeriesWork **)&this->subtasks[this->buf_size]; + this->context = NULL; +} + +ParallelWork::ParallelWork(SeriesWork *const all_series[], size_t n, + parallel_callback_t&& cb) : + ParallelTask(new SubTask *[2 * (n > 4 ? n : 4)], n), + callback(std::move(cb)) +{ + size_t i; + + this->buf_size = (n > 4 ? n : 4); + this->all_series = (SeriesWork **)&this->subtasks[this->buf_size]; + for (i = 0; i < n; i++) + { + assert(!all_series[i]->in_parallel); + all_series[i]->in_parallel = this; + this->all_series[i] = all_series[i]; + this->subtasks[i] = all_series[i]->first; + } + + this->context = NULL; +} + +void ParallelWork::expand_buf() +{ + SubTask **buf; + size_t size; + + this->buf_size *= 2; + buf = new SubTask *[2 * this->buf_size]; + size = this->subtasks_nr * sizeof (void *); + memcpy(buf, this->subtasks, size); + memcpy(buf + this->buf_size, this->all_series, size); + + delete []this->subtasks; + this->subtasks = buf; + this->all_series = (SeriesWork **)&buf[this->buf_size]; +} + +void ParallelWork::add_series(SeriesWork *series) +{ + if (this->subtasks_nr == this->buf_size) + this->expand_buf(); + + assert(!series->in_parallel); + series->in_parallel = this; + this->all_series[this->subtasks_nr] = series; + this->subtasks[this->subtasks_nr] = series->first; + this->subtasks_nr++; +} + +SubTask *ParallelWork::done() +{ + SeriesWork *series = series_of(this); + size_t i; + + if (this->callback) + this->callback(this); + + for (i = 0; i < this->subtasks_nr; i++) + delete this->all_series[i]; + + this->subtasks_nr = 0; + delete this; + return series->pop(); +} + +ParallelWork::~ParallelWork() +{ + size_t i; + + for (i = 0; i < this->subtasks_nr; i++) + { + this->all_series[i]->in_parallel = NULL; + this->all_series[i]->dismiss_recursive(); + } + + delete []this->subtasks; +} + diff --git a/src/factory/Workflow.h b/src/factory/Workflow.h new file mode 100644 index 0000000..377c888 --- /dev/null +++ b/src/factory/Workflow.h @@ -0,0 +1,310 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WORKFLOW_H_ +#define _WORKFLOW_H_ + +#include +#include +#include +#include +#include +#include "SubTask.h" + +class SeriesWork; +class ParallelWork; + +using series_callback_t = std::function; +using parallel_callback_t = std::function; + +class Workflow +{ +public: + static SeriesWork * + create_series_work(SubTask *first, series_callback_t callback); + + static void + start_series_work(SubTask *first, series_callback_t callback); + + static ParallelWork * + create_parallel_work(parallel_callback_t callback); + + static ParallelWork * + create_parallel_work(SeriesWork *const all_series[], size_t n, + parallel_callback_t callback); + + static void + start_parallel_work(SeriesWork *const all_series[], size_t n, + parallel_callback_t callback); + +public: + static SeriesWork * + create_series_work(SubTask *first, SubTask *last, + series_callback_t callback); + + static void + start_series_work(SubTask *first, SubTask *last, + series_callback_t callback); +}; + +class SeriesWork +{ +public: + void start() + { + assert(!this->in_parallel); + this->first->dispatch(); + } + + /* Call dismiss() only when you don't want to start a created series. + * This operation is recursive, so only call on the "root". */ + void dismiss() + { + assert(!this->in_parallel); + this->dismiss_recursive(); + } + +public: + void push_back(SubTask *task); + void push_front(SubTask *task); + +public: + void *get_context() const { return this->context; } + void set_context(void *context) { this->context = context; } + +public: + /* Cancel a running series. Typically, called in the callback of a task + * that belongs to the series. All subsequent tasks in the series will be + * destroyed immediately and recursively (ParallelWork), without callback. + * But the callback of this canceled series will still be called. */ + virtual void cancel() { this->canceled = true; } + + /* Parallel work's callback may check the cancellation state of each + * sub-series, and cancel it's super-series recursively. */ + bool is_canceled() const { return this->canceled; } + + /* 'false' until the time of callback. Mainly for sub-class. */ + bool is_finished() const { return this->finished; } + +public: + void set_callback(series_callback_t callback) + { + this->callback = std::move(callback); + } + +public: + virtual void *get_specific(const char *key) { return NULL; } + +public: + /* The following functions are intended for task implementations only. */ + SubTask *pop(); + + SubTask *get_last_task() const { return this->last; } + + void set_last_task(SubTask *last) + { + last->set_pointer(this); + this->last = last; + } + + void unset_last_task() { this->last = NULL; } + + const ParallelTask *get_in_parallel() const { return this->in_parallel; } + +protected: + void set_in_parallel(const ParallelTask *task) { this->in_parallel = task; } + + void dismiss_recursive(); + +protected: + void *context; + series_callback_t callback; + +private: + SubTask *pop_task(); + void expand_queue(); + +private: + SubTask *buf[4]; + SubTask *first; + SubTask *last; + SubTask **queue; + int queue_size; + int front; + int back; + bool canceled; + bool finished; + const ParallelTask *in_parallel; + std::mutex mutex; + +protected: + SeriesWork(SubTask *first, series_callback_t&& callback); + virtual ~SeriesWork(); + friend class ParallelWork; + friend class Workflow; +}; + +static inline SeriesWork *series_of(const SubTask *task) +{ + return (SeriesWork *)task->get_pointer(); +} + +static inline SeriesWork& operator *(const SubTask& task) +{ + return *series_of(&task); +} + +static inline SeriesWork& operator << (SeriesWork& series, SubTask *task) +{ + series.push_back(task); + return series; +} + +inline SeriesWork * +Workflow::create_series_work(SubTask *first, series_callback_t callback) +{ + return new SeriesWork(first, std::move(callback)); +} + +inline void +Workflow::start_series_work(SubTask *first, series_callback_t callback) +{ + new SeriesWork(first, std::move(callback)); + first->dispatch(); +} + +inline SeriesWork * +Workflow::create_series_work(SubTask *first, SubTask *last, + series_callback_t callback) +{ + SeriesWork *series = new SeriesWork(first, std::move(callback)); + series->set_last_task(last); + return series; +} + +inline void +Workflow::start_series_work(SubTask *first, SubTask *last, + series_callback_t callback) +{ + SeriesWork *series = new SeriesWork(first, std::move(callback)); + series->set_last_task(last); + first->dispatch(); +} + +class ParallelWork : public ParallelTask +{ +public: + void start() + { + assert(!series_of(this)); + Workflow::start_series_work(this, nullptr); + } + + void dismiss() + { + assert(!series_of(this)); + delete this; + } + +public: + void add_series(SeriesWork *series); + +public: + void *get_context() const { return this->context; } + void set_context(void *context) { this->context = context; } + +public: + SeriesWork *series_at(size_t index) + { + if (index < this->subtasks_nr) + return this->all_series[index]; + else + return NULL; + } + + const SeriesWork *series_at(size_t index) const + { + if (index < this->subtasks_nr) + return this->all_series[index]; + else + return NULL; + } + + SeriesWork& operator[] (size_t index) + { + return *this->series_at(index); + } + + const SeriesWork& operator[] (size_t index) const + { + return *this->series_at(index); + } + + size_t size() const { return this->subtasks_nr; } + +public: + void set_callback(parallel_callback_t callback) + { + this->callback = std::move(callback); + } + +protected: + virtual SubTask *done(); + +protected: + void *context; + parallel_callback_t callback; + +private: + void expand_buf(); + +private: + size_t buf_size; + SeriesWork **all_series; + +protected: + ParallelWork(parallel_callback_t&& callback); + ParallelWork(SeriesWork *const all_series[], size_t n, + parallel_callback_t&& callback); + virtual ~ParallelWork(); + friend class Workflow; +}; + +inline ParallelWork * +Workflow::create_parallel_work(parallel_callback_t callback) +{ + return new ParallelWork(std::move(callback)); +} + +inline ParallelWork * +Workflow::create_parallel_work(SeriesWork *const all_series[], size_t n, + parallel_callback_t callback) +{ + return new ParallelWork(all_series, n, std::move(callback)); +} + +inline void +Workflow::start_parallel_work(SeriesWork *const all_series[], size_t n, + parallel_callback_t callback) +{ + ParallelWork *p = new ParallelWork(all_series, n, std::move(callback)); + Workflow::start_series_work(p, nullptr); +} + +#endif + diff --git a/src/factory/xmake.lua b/src/factory/xmake.lua new file mode 100644 index 0000000..e1e76b2 --- /dev/null +++ b/src/factory/xmake.lua @@ -0,0 +1,21 @@ +target("factory") + add_files("*.cc") + set_kind("object") + if not has_config("mysql") then + remove_files("MySQLTaskImpl.cc") + end + if not has_config("redis") then + remove_files("RedisTaskImpl.cc") + end + remove_files("KafkaTaskImpl.cc") + +target("kafka_factory") + if has_config("kafka") then + add_files("KafkaTaskImpl.cc") + set_kind("object") + add_cxxflags("-fno-rtti") + add_deps("factory") + add_packages("zlib", "snappy", "zstd", "lz4") + else + set_kind("phony") + end diff --git a/src/include/workflow/CommRequest.h b/src/include/workflow/CommRequest.h new file mode 100644 index 0000000..e861e0c --- /dev/null +++ b/src/include/workflow/CommRequest.h @@ -0,0 +1 @@ +../../kernel/CommRequest.h \ No newline at end of file diff --git a/src/include/workflow/CommScheduler.h b/src/include/workflow/CommScheduler.h new file mode 100644 index 0000000..fc6dffe --- /dev/null +++ b/src/include/workflow/CommScheduler.h @@ -0,0 +1 @@ +../../kernel/CommScheduler.h \ No newline at end of file diff --git a/src/include/workflow/Communicator.h b/src/include/workflow/Communicator.h new file mode 100644 index 0000000..5c8c200 --- /dev/null +++ b/src/include/workflow/Communicator.h @@ -0,0 +1 @@ +../../kernel/Communicator.h \ No newline at end of file diff --git a/src/include/workflow/ConsulDataTypes.h b/src/include/workflow/ConsulDataTypes.h new file mode 100644 index 0000000..b411253 --- /dev/null +++ b/src/include/workflow/ConsulDataTypes.h @@ -0,0 +1 @@ +../../protocol/ConsulDataTypes.h \ No newline at end of file diff --git a/src/include/workflow/DnsCache.h b/src/include/workflow/DnsCache.h new file mode 100644 index 0000000..d42b69e --- /dev/null +++ b/src/include/workflow/DnsCache.h @@ -0,0 +1 @@ +../../manager/DnsCache.h \ No newline at end of file diff --git a/src/include/workflow/DnsMessage.h b/src/include/workflow/DnsMessage.h new file mode 100644 index 0000000..55a0843 --- /dev/null +++ b/src/include/workflow/DnsMessage.h @@ -0,0 +1 @@ +../../protocol/DnsMessage.h \ No newline at end of file diff --git a/src/include/workflow/DnsUtil.h b/src/include/workflow/DnsUtil.h new file mode 100644 index 0000000..246b2f2 --- /dev/null +++ b/src/include/workflow/DnsUtil.h @@ -0,0 +1 @@ +../../protocol/DnsUtil.h \ No newline at end of file diff --git a/src/include/workflow/EncodeStream.h b/src/include/workflow/EncodeStream.h new file mode 100644 index 0000000..bf259c8 --- /dev/null +++ b/src/include/workflow/EncodeStream.h @@ -0,0 +1 @@ +../../util/EncodeStream.h \ No newline at end of file diff --git a/src/include/workflow/EndpointParams.h b/src/include/workflow/EndpointParams.h new file mode 100644 index 0000000..2cfce30 --- /dev/null +++ b/src/include/workflow/EndpointParams.h @@ -0,0 +1 @@ +../../manager/EndpointParams.h \ No newline at end of file diff --git a/src/include/workflow/ExecRequest.h b/src/include/workflow/ExecRequest.h new file mode 100644 index 0000000..d8dfd3f --- /dev/null +++ b/src/include/workflow/ExecRequest.h @@ -0,0 +1 @@ +../../kernel/ExecRequest.h \ No newline at end of file diff --git a/src/include/workflow/Executor.h b/src/include/workflow/Executor.h new file mode 100644 index 0000000..cce19d6 --- /dev/null +++ b/src/include/workflow/Executor.h @@ -0,0 +1 @@ +../../kernel/Executor.h \ No newline at end of file diff --git a/src/include/workflow/HttpMessage.h b/src/include/workflow/HttpMessage.h new file mode 100644 index 0000000..544f9ac --- /dev/null +++ b/src/include/workflow/HttpMessage.h @@ -0,0 +1 @@ +../../protocol/HttpMessage.h \ No newline at end of file diff --git a/src/include/workflow/HttpUtil.h b/src/include/workflow/HttpUtil.h new file mode 100644 index 0000000..01a1472 --- /dev/null +++ b/src/include/workflow/HttpUtil.h @@ -0,0 +1 @@ +../../protocol/HttpUtil.h \ No newline at end of file diff --git a/src/include/workflow/IORequest.h b/src/include/workflow/IORequest.h new file mode 100644 index 0000000..dca002b --- /dev/null +++ b/src/include/workflow/IORequest.h @@ -0,0 +1 @@ +../../kernel/IORequest.h \ No newline at end of file diff --git a/src/include/workflow/IOService_linux.h b/src/include/workflow/IOService_linux.h new file mode 100644 index 0000000..3601e74 --- /dev/null +++ b/src/include/workflow/IOService_linux.h @@ -0,0 +1 @@ +../../kernel/IOService_linux.h \ No newline at end of file diff --git a/src/include/workflow/IOService_thread.h b/src/include/workflow/IOService_thread.h new file mode 100644 index 0000000..6bbf062 --- /dev/null +++ b/src/include/workflow/IOService_thread.h @@ -0,0 +1 @@ +../../kernel/IOService_thread.h \ No newline at end of file diff --git a/src/include/workflow/KafkaDataTypes.h b/src/include/workflow/KafkaDataTypes.h new file mode 100644 index 0000000..f2cb305 --- /dev/null +++ b/src/include/workflow/KafkaDataTypes.h @@ -0,0 +1 @@ +../../protocol/KafkaDataTypes.h \ No newline at end of file diff --git a/src/include/workflow/KafkaMessage.h b/src/include/workflow/KafkaMessage.h new file mode 100644 index 0000000..4586291 --- /dev/null +++ b/src/include/workflow/KafkaMessage.h @@ -0,0 +1 @@ +../../protocol/KafkaMessage.h \ No newline at end of file diff --git a/src/include/workflow/KafkaResult.h b/src/include/workflow/KafkaResult.h new file mode 100644 index 0000000..87419ff --- /dev/null +++ b/src/include/workflow/KafkaResult.h @@ -0,0 +1 @@ +../../protocol/KafkaResult.h \ No newline at end of file diff --git a/src/include/workflow/KafkaTaskImpl.inl b/src/include/workflow/KafkaTaskImpl.inl new file mode 100644 index 0000000..16723a4 --- /dev/null +++ b/src/include/workflow/KafkaTaskImpl.inl @@ -0,0 +1 @@ +../../factory/KafkaTaskImpl.inl \ No newline at end of file diff --git a/src/include/workflow/LRUCache.h b/src/include/workflow/LRUCache.h new file mode 100644 index 0000000..fd61118 --- /dev/null +++ b/src/include/workflow/LRUCache.h @@ -0,0 +1 @@ +../../util/LRUCache.h \ No newline at end of file diff --git a/src/include/workflow/MySQLMessage.h b/src/include/workflow/MySQLMessage.h new file mode 100644 index 0000000..c9afe23 --- /dev/null +++ b/src/include/workflow/MySQLMessage.h @@ -0,0 +1 @@ +../../protocol/MySQLMessage.h \ No newline at end of file diff --git a/src/include/workflow/MySQLMessage.inl b/src/include/workflow/MySQLMessage.inl new file mode 100644 index 0000000..c0db157 --- /dev/null +++ b/src/include/workflow/MySQLMessage.inl @@ -0,0 +1 @@ +../../protocol/MySQLMessage.inl \ No newline at end of file diff --git a/src/include/workflow/MySQLResult.h b/src/include/workflow/MySQLResult.h new file mode 100644 index 0000000..0feb4d8 --- /dev/null +++ b/src/include/workflow/MySQLResult.h @@ -0,0 +1 @@ +../../protocol/MySQLResult.h \ No newline at end of file diff --git a/src/include/workflow/MySQLResult.inl b/src/include/workflow/MySQLResult.inl new file mode 100644 index 0000000..f32aaa3 --- /dev/null +++ b/src/include/workflow/MySQLResult.inl @@ -0,0 +1 @@ +../../protocol/MySQLResult.inl \ No newline at end of file diff --git a/src/include/workflow/MySQLUtil.h b/src/include/workflow/MySQLUtil.h new file mode 100644 index 0000000..9bda22d --- /dev/null +++ b/src/include/workflow/MySQLUtil.h @@ -0,0 +1 @@ +../../protocol/MySQLUtil.h \ No newline at end of file diff --git a/src/include/workflow/PackageWrapper.h b/src/include/workflow/PackageWrapper.h new file mode 100644 index 0000000..f4031f3 --- /dev/null +++ b/src/include/workflow/PackageWrapper.h @@ -0,0 +1 @@ +../../protocol/PackageWrapper.h \ No newline at end of file diff --git a/src/include/workflow/ProtocolMessage.h b/src/include/workflow/ProtocolMessage.h new file mode 100644 index 0000000..8038c75 --- /dev/null +++ b/src/include/workflow/ProtocolMessage.h @@ -0,0 +1 @@ +../../protocol/ProtocolMessage.h \ No newline at end of file diff --git a/src/include/workflow/RedisMessage.h b/src/include/workflow/RedisMessage.h new file mode 100644 index 0000000..d691425 --- /dev/null +++ b/src/include/workflow/RedisMessage.h @@ -0,0 +1 @@ +../../protocol/RedisMessage.h \ No newline at end of file diff --git a/src/include/workflow/RedisTaskImpl.inl b/src/include/workflow/RedisTaskImpl.inl new file mode 100644 index 0000000..04a7d9a --- /dev/null +++ b/src/include/workflow/RedisTaskImpl.inl @@ -0,0 +1 @@ +../../factory/RedisTaskImpl.inl \ No newline at end of file diff --git a/src/include/workflow/RouteManager.h b/src/include/workflow/RouteManager.h new file mode 100644 index 0000000..0b31f23 --- /dev/null +++ b/src/include/workflow/RouteManager.h @@ -0,0 +1 @@ +../../manager/RouteManager.h \ No newline at end of file diff --git a/src/include/workflow/SSLWrapper.h b/src/include/workflow/SSLWrapper.h new file mode 100644 index 0000000..ed559cc --- /dev/null +++ b/src/include/workflow/SSLWrapper.h @@ -0,0 +1 @@ +../../protocol/SSLWrapper.h \ No newline at end of file diff --git a/src/include/workflow/SleepRequest.h b/src/include/workflow/SleepRequest.h new file mode 100644 index 0000000..463d8de --- /dev/null +++ b/src/include/workflow/SleepRequest.h @@ -0,0 +1 @@ +../../kernel/SleepRequest.h \ No newline at end of file diff --git a/src/include/workflow/StringUtil.h b/src/include/workflow/StringUtil.h new file mode 100644 index 0000000..59e2c62 --- /dev/null +++ b/src/include/workflow/StringUtil.h @@ -0,0 +1 @@ +../../util/StringUtil.h \ No newline at end of file diff --git a/src/include/workflow/SubTask.h b/src/include/workflow/SubTask.h new file mode 100644 index 0000000..0f82645 --- /dev/null +++ b/src/include/workflow/SubTask.h @@ -0,0 +1 @@ +../../kernel/SubTask.h \ No newline at end of file diff --git a/src/include/workflow/TLVMessage.h b/src/include/workflow/TLVMessage.h new file mode 100644 index 0000000..a9cf834 --- /dev/null +++ b/src/include/workflow/TLVMessage.h @@ -0,0 +1 @@ +../../protocol/TLVMessage.h \ No newline at end of file diff --git a/src/include/workflow/URIParser.h b/src/include/workflow/URIParser.h new file mode 100644 index 0000000..cb28df9 --- /dev/null +++ b/src/include/workflow/URIParser.h @@ -0,0 +1 @@ +../../util/URIParser.h \ No newline at end of file diff --git a/src/include/workflow/UpstreamManager.h b/src/include/workflow/UpstreamManager.h new file mode 100644 index 0000000..7d4eda5 --- /dev/null +++ b/src/include/workflow/UpstreamManager.h @@ -0,0 +1 @@ +../../manager/UpstreamManager.h \ No newline at end of file diff --git a/src/include/workflow/UpstreamPolicies.h b/src/include/workflow/UpstreamPolicies.h new file mode 100644 index 0000000..71326f2 --- /dev/null +++ b/src/include/workflow/UpstreamPolicies.h @@ -0,0 +1 @@ +../../nameservice/UpstreamPolicies.h \ No newline at end of file diff --git a/src/include/workflow/WFAlgoTaskFactory.h b/src/include/workflow/WFAlgoTaskFactory.h new file mode 100644 index 0000000..551dbd3 --- /dev/null +++ b/src/include/workflow/WFAlgoTaskFactory.h @@ -0,0 +1 @@ +../../factory/WFAlgoTaskFactory.h \ No newline at end of file diff --git a/src/include/workflow/WFAlgoTaskFactory.inl b/src/include/workflow/WFAlgoTaskFactory.inl new file mode 100644 index 0000000..5947a6f --- /dev/null +++ b/src/include/workflow/WFAlgoTaskFactory.inl @@ -0,0 +1 @@ +../../factory/WFAlgoTaskFactory.inl \ No newline at end of file diff --git a/src/include/workflow/WFConnection.h b/src/include/workflow/WFConnection.h new file mode 100644 index 0000000..46190b6 --- /dev/null +++ b/src/include/workflow/WFConnection.h @@ -0,0 +1 @@ +../../factory/WFConnection.h \ No newline at end of file diff --git a/src/include/workflow/WFConsulClient.h b/src/include/workflow/WFConsulClient.h new file mode 100644 index 0000000..1bb8326 --- /dev/null +++ b/src/include/workflow/WFConsulClient.h @@ -0,0 +1 @@ +../../client/WFConsulClient.h \ No newline at end of file diff --git a/src/include/workflow/WFDnsClient.h b/src/include/workflow/WFDnsClient.h new file mode 100644 index 0000000..b305415 --- /dev/null +++ b/src/include/workflow/WFDnsClient.h @@ -0,0 +1 @@ +../../client/WFDnsClient.h \ No newline at end of file diff --git a/src/include/workflow/WFDnsResolver.h b/src/include/workflow/WFDnsResolver.h new file mode 100644 index 0000000..b94b16a --- /dev/null +++ b/src/include/workflow/WFDnsResolver.h @@ -0,0 +1 @@ +../../nameservice/WFDnsResolver.h \ No newline at end of file diff --git a/src/include/workflow/WFDnsServer.h b/src/include/workflow/WFDnsServer.h new file mode 100644 index 0000000..820439c --- /dev/null +++ b/src/include/workflow/WFDnsServer.h @@ -0,0 +1 @@ +../../server/WFDnsServer.h \ No newline at end of file diff --git a/src/include/workflow/WFFacilities.h b/src/include/workflow/WFFacilities.h new file mode 100644 index 0000000..d780786 --- /dev/null +++ b/src/include/workflow/WFFacilities.h @@ -0,0 +1 @@ +../../manager/WFFacilities.h \ No newline at end of file diff --git a/src/include/workflow/WFFacilities.inl b/src/include/workflow/WFFacilities.inl new file mode 100644 index 0000000..7497e13 --- /dev/null +++ b/src/include/workflow/WFFacilities.inl @@ -0,0 +1 @@ +../../manager/WFFacilities.inl \ No newline at end of file diff --git a/src/include/workflow/WFFuture.h b/src/include/workflow/WFFuture.h new file mode 100644 index 0000000..814fbd9 --- /dev/null +++ b/src/include/workflow/WFFuture.h @@ -0,0 +1 @@ +../../manager/WFFuture.h \ No newline at end of file diff --git a/src/include/workflow/WFGlobal.h b/src/include/workflow/WFGlobal.h new file mode 100644 index 0000000..162a3d9 --- /dev/null +++ b/src/include/workflow/WFGlobal.h @@ -0,0 +1 @@ +../../manager/WFGlobal.h \ No newline at end of file diff --git a/src/include/workflow/WFGraphTask.h b/src/include/workflow/WFGraphTask.h new file mode 100644 index 0000000..fb0dc82 --- /dev/null +++ b/src/include/workflow/WFGraphTask.h @@ -0,0 +1 @@ +../../factory/WFGraphTask.h \ No newline at end of file diff --git a/src/include/workflow/WFHttpServer.h b/src/include/workflow/WFHttpServer.h new file mode 100644 index 0000000..bc262a0 --- /dev/null +++ b/src/include/workflow/WFHttpServer.h @@ -0,0 +1 @@ +../../server/WFHttpServer.h \ No newline at end of file diff --git a/src/include/workflow/WFKafkaClient.h b/src/include/workflow/WFKafkaClient.h new file mode 100644 index 0000000..86453de --- /dev/null +++ b/src/include/workflow/WFKafkaClient.h @@ -0,0 +1 @@ +../../client/WFKafkaClient.h \ No newline at end of file diff --git a/src/include/workflow/WFMessageQueue.h b/src/include/workflow/WFMessageQueue.h new file mode 100644 index 0000000..f375696 --- /dev/null +++ b/src/include/workflow/WFMessageQueue.h @@ -0,0 +1 @@ +../../factory/WFMessageQueue.h \ No newline at end of file diff --git a/src/include/workflow/WFMySQLConnection.h b/src/include/workflow/WFMySQLConnection.h new file mode 100644 index 0000000..928480f --- /dev/null +++ b/src/include/workflow/WFMySQLConnection.h @@ -0,0 +1 @@ +../../client/WFMySQLConnection.h \ No newline at end of file diff --git a/src/include/workflow/WFMySQLServer.h b/src/include/workflow/WFMySQLServer.h new file mode 100644 index 0000000..4ab0662 --- /dev/null +++ b/src/include/workflow/WFMySQLServer.h @@ -0,0 +1 @@ +../../server/WFMySQLServer.h \ No newline at end of file diff --git a/src/include/workflow/WFNameService.h b/src/include/workflow/WFNameService.h new file mode 100644 index 0000000..3843497 --- /dev/null +++ b/src/include/workflow/WFNameService.h @@ -0,0 +1 @@ +../../nameservice/WFNameService.h \ No newline at end of file diff --git a/src/include/workflow/WFOperator.h b/src/include/workflow/WFOperator.h new file mode 100644 index 0000000..5db73dc --- /dev/null +++ b/src/include/workflow/WFOperator.h @@ -0,0 +1 @@ +../../factory/WFOperator.h \ No newline at end of file diff --git a/src/include/workflow/WFRedisServer.h b/src/include/workflow/WFRedisServer.h new file mode 100644 index 0000000..b6648f9 --- /dev/null +++ b/src/include/workflow/WFRedisServer.h @@ -0,0 +1 @@ +../../server/WFRedisServer.h \ No newline at end of file diff --git a/src/include/workflow/WFRedisSubscriber.h b/src/include/workflow/WFRedisSubscriber.h new file mode 100644 index 0000000..4687bd3 --- /dev/null +++ b/src/include/workflow/WFRedisSubscriber.h @@ -0,0 +1 @@ +../../client/WFRedisSubscriber.h \ No newline at end of file diff --git a/src/include/workflow/WFResourcePool.h b/src/include/workflow/WFResourcePool.h new file mode 100644 index 0000000..8b3f9c0 --- /dev/null +++ b/src/include/workflow/WFResourcePool.h @@ -0,0 +1 @@ +../../factory/WFResourcePool.h \ No newline at end of file diff --git a/src/include/workflow/WFServer.h b/src/include/workflow/WFServer.h new file mode 100644 index 0000000..78d76b9 --- /dev/null +++ b/src/include/workflow/WFServer.h @@ -0,0 +1 @@ +../../server/WFServer.h \ No newline at end of file diff --git a/src/include/workflow/WFServiceGovernance.h b/src/include/workflow/WFServiceGovernance.h new file mode 100644 index 0000000..5b5056a --- /dev/null +++ b/src/include/workflow/WFServiceGovernance.h @@ -0,0 +1 @@ +../../nameservice/WFServiceGovernance.h \ No newline at end of file diff --git a/src/include/workflow/WFTask.h b/src/include/workflow/WFTask.h new file mode 100644 index 0000000..7cb6392 --- /dev/null +++ b/src/include/workflow/WFTask.h @@ -0,0 +1 @@ +../../factory/WFTask.h \ No newline at end of file diff --git a/src/include/workflow/WFTask.inl b/src/include/workflow/WFTask.inl new file mode 100644 index 0000000..940fed8 --- /dev/null +++ b/src/include/workflow/WFTask.inl @@ -0,0 +1 @@ +../../factory/WFTask.inl \ No newline at end of file diff --git a/src/include/workflow/WFTaskError.h b/src/include/workflow/WFTaskError.h new file mode 100644 index 0000000..034f2ac --- /dev/null +++ b/src/include/workflow/WFTaskError.h @@ -0,0 +1 @@ +../../factory/WFTaskError.h \ No newline at end of file diff --git a/src/include/workflow/WFTaskFactory.h b/src/include/workflow/WFTaskFactory.h new file mode 100644 index 0000000..ff3dd6a --- /dev/null +++ b/src/include/workflow/WFTaskFactory.h @@ -0,0 +1 @@ +../../factory/WFTaskFactory.h \ No newline at end of file diff --git a/src/include/workflow/WFTaskFactory.inl b/src/include/workflow/WFTaskFactory.inl new file mode 100644 index 0000000..0ca4f01 --- /dev/null +++ b/src/include/workflow/WFTaskFactory.inl @@ -0,0 +1 @@ +../../factory/WFTaskFactory.inl \ No newline at end of file diff --git a/src/include/workflow/Workflow.h b/src/include/workflow/Workflow.h new file mode 100644 index 0000000..4ac3d2c --- /dev/null +++ b/src/include/workflow/Workflow.h @@ -0,0 +1 @@ +../../factory/Workflow.h \ No newline at end of file diff --git a/src/include/workflow/crc32c.h b/src/include/workflow/crc32c.h new file mode 100644 index 0000000..117f928 --- /dev/null +++ b/src/include/workflow/crc32c.h @@ -0,0 +1 @@ +../../util/crc32c.h \ No newline at end of file diff --git a/src/include/workflow/dns_parser.h b/src/include/workflow/dns_parser.h new file mode 100644 index 0000000..f2fd51f --- /dev/null +++ b/src/include/workflow/dns_parser.h @@ -0,0 +1 @@ +../../protocol/dns_parser.h \ No newline at end of file diff --git a/src/include/workflow/http_parser.h b/src/include/workflow/http_parser.h new file mode 100644 index 0000000..6c4abe7 --- /dev/null +++ b/src/include/workflow/http_parser.h @@ -0,0 +1 @@ +../../protocol/http_parser.h \ No newline at end of file diff --git a/src/include/workflow/json_parser.h b/src/include/workflow/json_parser.h new file mode 100644 index 0000000..affe8ca --- /dev/null +++ b/src/include/workflow/json_parser.h @@ -0,0 +1 @@ +../../util/json_parser.h \ No newline at end of file diff --git a/src/include/workflow/kafka_parser.h b/src/include/workflow/kafka_parser.h new file mode 100644 index 0000000..f2de99c --- /dev/null +++ b/src/include/workflow/kafka_parser.h @@ -0,0 +1 @@ +../../protocol/kafka_parser.h \ No newline at end of file diff --git a/src/include/workflow/list.h b/src/include/workflow/list.h new file mode 100644 index 0000000..b28d68b --- /dev/null +++ b/src/include/workflow/list.h @@ -0,0 +1 @@ +../../kernel/list.h \ No newline at end of file diff --git a/src/include/workflow/mpoller.h b/src/include/workflow/mpoller.h new file mode 100644 index 0000000..5900cda --- /dev/null +++ b/src/include/workflow/mpoller.h @@ -0,0 +1 @@ +../../kernel/mpoller.h \ No newline at end of file diff --git a/src/include/workflow/msgqueue.h b/src/include/workflow/msgqueue.h new file mode 100644 index 0000000..c443c9b --- /dev/null +++ b/src/include/workflow/msgqueue.h @@ -0,0 +1 @@ +../../kernel/msgqueue.h \ No newline at end of file diff --git a/src/include/workflow/mysql_byteorder.h b/src/include/workflow/mysql_byteorder.h new file mode 100644 index 0000000..e3a4c15 --- /dev/null +++ b/src/include/workflow/mysql_byteorder.h @@ -0,0 +1 @@ +../../protocol/mysql_byteorder.h \ No newline at end of file diff --git a/src/include/workflow/mysql_parser.h b/src/include/workflow/mysql_parser.h new file mode 100644 index 0000000..31ae084 --- /dev/null +++ b/src/include/workflow/mysql_parser.h @@ -0,0 +1 @@ +../../protocol/mysql_parser.h \ No newline at end of file diff --git a/src/include/workflow/mysql_stream.h b/src/include/workflow/mysql_stream.h new file mode 100644 index 0000000..fed62e7 --- /dev/null +++ b/src/include/workflow/mysql_stream.h @@ -0,0 +1 @@ +../../protocol/mysql_stream.h \ No newline at end of file diff --git a/src/include/workflow/mysql_types.h b/src/include/workflow/mysql_types.h new file mode 100644 index 0000000..8c331cb --- /dev/null +++ b/src/include/workflow/mysql_types.h @@ -0,0 +1 @@ +../../protocol/mysql_types.h \ No newline at end of file diff --git a/src/include/workflow/poller.h b/src/include/workflow/poller.h new file mode 100644 index 0000000..180e72b --- /dev/null +++ b/src/include/workflow/poller.h @@ -0,0 +1 @@ +../../kernel/poller.h \ No newline at end of file diff --git a/src/include/workflow/rbtree.h b/src/include/workflow/rbtree.h new file mode 100644 index 0000000..298b1fe --- /dev/null +++ b/src/include/workflow/rbtree.h @@ -0,0 +1 @@ +../../kernel/rbtree.h \ No newline at end of file diff --git a/src/include/workflow/redis_parser.h b/src/include/workflow/redis_parser.h new file mode 100644 index 0000000..2e5bca5 --- /dev/null +++ b/src/include/workflow/redis_parser.h @@ -0,0 +1 @@ +../../protocol/redis_parser.h \ No newline at end of file diff --git a/src/include/workflow/thrdpool.h b/src/include/workflow/thrdpool.h new file mode 100644 index 0000000..cbd4b59 --- /dev/null +++ b/src/include/workflow/thrdpool.h @@ -0,0 +1 @@ +../../kernel/thrdpool.h \ No newline at end of file diff --git a/src/kernel/CMakeLists.txt b/src/kernel/CMakeLists.txt new file mode 100644 index 0000000..209f0ec --- /dev/null +++ b/src/kernel/CMakeLists.txt @@ -0,0 +1,27 @@ +cmake_minimum_required(VERSION 3.6) +project(kernel) + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") + set(IOSERVICE_FILE IOService_linux.cc) +elseif (UNIX) + set(IOSERVICE_FILE IOService_thread.cc) +else () + message(FATAL_ERROR "IOService unsupported.") +endif () + +set(SRC + ${IOSERVICE_FILE} + mpoller.c + poller.c + rbtree.c + msgqueue.c + thrdpool.c + CommRequest.cc + CommScheduler.cc + Communicator.cc + Executor.cc + SubTask.cc +) + +add_library(${PROJECT_NAME} OBJECT ${SRC}) + diff --git a/src/kernel/CommRequest.cc b/src/kernel/CommRequest.cc new file mode 100644 index 0000000..ccf9d60 --- /dev/null +++ b/src/kernel/CommRequest.cc @@ -0,0 +1,38 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include "CommScheduler.h" +#include "CommRequest.h" + +void CommRequest::handle(int state, int error) +{ + this->state = state; + this->error = error; + if (error != ETIMEDOUT) + this->timeout_reason = TOR_NOT_TIMEOUT; + else if (!this->target) + this->timeout_reason = TOR_WAIT_TIMEOUT; + else if (!this->get_message_out()) + this->timeout_reason = TOR_CONNECT_TIMEOUT; + else + this->timeout_reason = TOR_TRANSMIT_TIMEOUT; + + this->subtask_done(); +} + diff --git a/src/kernel/CommRequest.h b/src/kernel/CommRequest.h new file mode 100644 index 0000000..4413231 --- /dev/null +++ b/src/kernel/CommRequest.h @@ -0,0 +1,75 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _COMMREQUEST_H_ +#define _COMMREQUEST_H_ + +#include +#include +#include "SubTask.h" +#include "Communicator.h" +#include "CommScheduler.h" + +class CommRequest : public SubTask, public CommSession +{ +public: + CommRequest(CommSchedObject *object, CommScheduler *scheduler) + { + this->scheduler = scheduler; + this->object = object; + this->wait_timeout = 0; + } + + CommSchedObject *get_request_object() const { return this->object; } + void set_request_object(CommSchedObject *object) { this->object = object; } + int get_wait_timeout() const { return this->wait_timeout; } + void set_wait_timeout(int timeout) { this->wait_timeout = timeout; } + +public: + virtual void dispatch() + { + if (this->scheduler->request(this, this->object, this->wait_timeout, + &this->target) < 0) + { + this->handle(CS_STATE_ERROR, errno); + } + } + +protected: + int state; + int error; + +protected: + CommTarget *target; +#define TOR_NOT_TIMEOUT 0 +#define TOR_WAIT_TIMEOUT 1 +#define TOR_CONNECT_TIMEOUT 2 +#define TOR_TRANSMIT_TIMEOUT 3 + int timeout_reason; + +protected: + int wait_timeout; + CommSchedObject *object; + CommScheduler *scheduler; + +protected: + virtual void handle(int state, int error); +}; + +#endif + diff --git a/src/kernel/CommScheduler.cc b/src/kernel/CommScheduler.cc new file mode 100644 index 0000000..2f668fc --- /dev/null +++ b/src/kernel/CommScheduler.cc @@ -0,0 +1,444 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include "CommScheduler.h" + +#define PTHREAD_COND_TIMEDWAIT(cond, mutex, abstime) \ + ((abstime) ? pthread_cond_timedwait(cond, mutex, abstime) : \ + pthread_cond_wait(cond, mutex)) + +static struct timespec *__get_abstime(int timeout, struct timespec *ts) +{ + if (timeout < 0) + return NULL; + + clock_gettime(CLOCK_REALTIME, ts); + ts->tv_sec += timeout / 1000; + ts->tv_nsec += timeout % 1000 * 1000000; + if (ts->tv_nsec >= 1000000000) + { + ts->tv_nsec -= 1000000000; + ts->tv_sec++; + } + + return ts; +} + +int CommSchedTarget::init(const struct sockaddr *addr, socklen_t addrlen, + int connect_timeout, int response_timeout, + size_t max_connections) +{ + int ret; + + if (max_connections == 0) + { + errno = EINVAL; + return -1; + } + + if (this->CommTarget::init(addr, addrlen, connect_timeout, + response_timeout) >= 0) + { + ret = pthread_mutex_init(&this->mutex, NULL); + if (ret == 0) + { + ret = pthread_cond_init(&this->cond, NULL); + if (ret == 0) + { + this->max_load = max_connections; + this->cur_load = 0; + this->wait_cnt = 0; + this->group = NULL; + return 0; + } + + pthread_mutex_destroy(&this->mutex); + } + + errno = ret; + this->CommTarget::deinit(); + } + + return -1; +} + +void CommSchedTarget::deinit() +{ + pthread_cond_destroy(&this->cond); + pthread_mutex_destroy(&this->mutex); + this->CommTarget::deinit(); +} + +CommTarget *CommSchedTarget::acquire(int wait_timeout) +{ + pthread_mutex_t *mutex = &this->mutex; + int ret; + + pthread_mutex_lock(mutex); + if (this->group) + { + mutex = &this->group->mutex; + pthread_mutex_lock(mutex); + pthread_mutex_unlock(&this->mutex); + } + + if (this->cur_load >= this->max_load) + { + if (wait_timeout != 0) + { + struct timespec ts; + struct timespec *abstime = __get_abstime(wait_timeout, &ts); + + do + { + this->wait_cnt++; + ret = PTHREAD_COND_TIMEDWAIT(&this->cond, mutex, abstime); + this->wait_cnt--; + } while (this->cur_load >= this->max_load && ret == 0); + } + else + ret = EAGAIN; + } + + if (this->cur_load < this->max_load) + { + this->cur_load++; + if (this->group) + { + this->group->cur_load++; + this->group->heapify(this->index); + } + + ret = 0; + } + + pthread_mutex_unlock(mutex); + if (ret) + { + errno = ret; + return NULL; + } + + return this; +} + +void CommSchedTarget::release() +{ + pthread_mutex_t *mutex = &this->mutex; + + pthread_mutex_lock(mutex); + if (this->group) + { + mutex = &this->group->mutex; + pthread_mutex_lock(mutex); + pthread_mutex_unlock(&this->mutex); + } + + this->cur_load--; + if (this->wait_cnt > 0) + pthread_cond_signal(&this->cond); + + if (this->group) + { + this->group->cur_load--; + if (this->wait_cnt == 0 && this->group->wait_cnt > 0) + pthread_cond_signal(&this->group->cond); + + this->group->heap_adjust(this->index, this->has_idle_conn()); + } + + pthread_mutex_unlock(mutex); +} + +int CommSchedGroup::target_cmp(CommSchedTarget *target1, + CommSchedTarget *target2) +{ + size_t load1 = target1->cur_load * target2->max_load; + size_t load2 = target2->cur_load * target1->max_load; + + if (load1 < load2) + return -1; + else if (load1 > load2) + return 1; + else + return 0; +} + +void CommSchedGroup::heap_adjust(int index, int swap_on_equal) +{ + CommSchedTarget *target = this->tg_heap[index]; + CommSchedTarget *parent; + + while (index > 0) + { + parent = this->tg_heap[(index - 1) / 2]; + if (CommSchedGroup::target_cmp(target, parent) < swap_on_equal) + { + this->tg_heap[index] = parent; + parent->index = index; + index = (index - 1) / 2; + } + else + break; + } + + this->tg_heap[index] = target; + target->index = index; +} + +/* Fastest heapify ever. */ +void CommSchedGroup::heapify(int top) +{ + CommSchedTarget *target = this->tg_heap[top]; + int last = this->heap_size - 1; + CommSchedTarget **child; + int i; + + while (i = 2 * top + 1, i < last) + { + child = &this->tg_heap[i]; + if (CommSchedGroup::target_cmp(child[0], target) < 0) + { + if (CommSchedGroup::target_cmp(child[1], child[0]) < 0) + { + this->tg_heap[top] = child[1]; + child[1]->index = top; + top = i + 1; + } + else + { + this->tg_heap[top] = child[0]; + child[0]->index = top; + top = i; + } + } + else + { + if (CommSchedGroup::target_cmp(child[1], target) < 0) + { + this->tg_heap[top] = child[1]; + child[1]->index = top; + top = i + 1; + } + else + { + this->tg_heap[top] = target; + target->index = top; + return; + } + } + } + + if (i == last) + { + child = &this->tg_heap[i]; + if (CommSchedGroup::target_cmp(child[0], target) < 0) + { + this->tg_heap[top] = child[0]; + child[0]->index = top; + top = i; + } + } + + this->tg_heap[top] = target; + target->index = top; +} + +int CommSchedGroup::heap_insert(CommSchedTarget *target) +{ + if (this->heap_size == this->heap_buf_size) + { + int new_size = 2 * this->heap_buf_size; + void *new_base = realloc(this->tg_heap, new_size * sizeof (void *)); + + if (new_base) + { + this->tg_heap = (CommSchedTarget **)new_base; + this->heap_buf_size = new_size; + } + else + return -1; + } + + this->tg_heap[this->heap_size] = target; + target->index = this->heap_size; + this->heap_adjust(this->heap_size, 0); + this->heap_size++; + return 0; +} + +void CommSchedGroup::heap_remove(int index) +{ + CommSchedTarget *target; + + this->heap_size--; + if (index != this->heap_size) + { + target = this->tg_heap[this->heap_size]; + this->tg_heap[index] = target; + target->index = index; + this->heap_adjust(index, 0); + this->heapify(target->index); + } +} + +#define COMMGROUP_INIT_SIZE 4 + +int CommSchedGroup::init() +{ + size_t size = COMMGROUP_INIT_SIZE * sizeof (void *); + int ret; + + this->tg_heap = (CommSchedTarget **)malloc(size); + if (this->tg_heap) + { + ret = pthread_mutex_init(&this->mutex, NULL); + if (ret == 0) + { + ret = pthread_cond_init(&this->cond, NULL); + if (ret == 0) + { + this->heap_buf_size = COMMGROUP_INIT_SIZE; + this->heap_size = 0; + this->max_load = 0; + this->cur_load = 0; + this->wait_cnt = 0; + return 0; + } + + pthread_mutex_destroy(&this->mutex); + } + + errno = ret; + free(this->tg_heap); + } + + return -1; +} + +void CommSchedGroup::deinit() +{ + pthread_cond_destroy(&this->cond); + pthread_mutex_destroy(&this->mutex); + free(this->tg_heap); +} + +int CommSchedGroup::add(CommSchedTarget *target) +{ + int ret = -1; + + pthread_mutex_lock(&target->mutex); + pthread_mutex_lock(&this->mutex); + if (target->group == NULL && target->wait_cnt == 0) + { + if (this->heap_insert(target) >= 0) + { + target->group = this; + this->max_load += target->max_load; + this->cur_load += target->cur_load; + if (this->wait_cnt > 0 && this->cur_load < this->max_load) + pthread_cond_signal(&this->cond); + + ret = 0; + } + } + else if (target->group == this) + errno = EEXIST; + else if (target->group) + errno = EINVAL; + else + errno = EBUSY; + + pthread_mutex_unlock(&this->mutex); + pthread_mutex_unlock(&target->mutex); + return ret; +} + +int CommSchedGroup::remove(CommSchedTarget *target) +{ + int ret = -1; + + pthread_mutex_lock(&target->mutex); + pthread_mutex_lock(&this->mutex); + if (target->group == this && target->wait_cnt == 0) + { + this->heap_remove(target->index); + this->max_load -= target->max_load; + this->cur_load -= target->cur_load; + target->group = NULL; + ret = 0; + } + else if (target->group != this) + errno = ENOENT; + else + errno = EBUSY; + + pthread_mutex_unlock(&this->mutex); + pthread_mutex_unlock(&target->mutex); + return ret; +} + +CommTarget *CommSchedGroup::acquire(int wait_timeout) +{ + pthread_mutex_t *mutex = &this->mutex; + CommSchedTarget *target; + int ret; + + pthread_mutex_lock(mutex); + if (this->cur_load >= this->max_load) + { + if (wait_timeout != 0) + { + struct timespec ts; + struct timespec *abstime = __get_abstime(wait_timeout, &ts); + + do + { + this->wait_cnt++; + ret = PTHREAD_COND_TIMEDWAIT(&this->cond, mutex, abstime); + this->wait_cnt--; + } while (this->cur_load >= this->max_load && ret == 0); + } + else + ret = EAGAIN; + } + + if (this->cur_load < this->max_load) + { + target = this->tg_heap[0]; + target->cur_load++; + this->cur_load++; + this->heapify(0); + ret = 0; + } + + pthread_mutex_unlock(mutex); + if (ret) + { + errno = ret; + return NULL; + } + + return target; +} + diff --git a/src/kernel/CommScheduler.h b/src/kernel/CommScheduler.h new file mode 100644 index 0000000..f0fbdcb --- /dev/null +++ b/src/kernel/CommScheduler.h @@ -0,0 +1,214 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _COMMSCHEDULER_H_ +#define _COMMSCHEDULER_H_ + +#include +#include +#include +#include +#include "Communicator.h" + +class CommSchedObject +{ +public: + size_t get_max_load() const { return this->max_load; } + size_t get_cur_load() const { return this->cur_load; } + +private: + virtual CommTarget *acquire(int wait_timeout) = 0; + +protected: + size_t max_load; + size_t cur_load; + +public: + virtual ~CommSchedObject() { } + friend class CommScheduler; +}; + +class CommSchedGroup; + +class CommSchedTarget : public CommSchedObject, public CommTarget +{ +public: + int init(const struct sockaddr *addr, socklen_t addrlen, + int connect_timeout, int response_timeout, + size_t max_connections); + void deinit(); + +public: + int init(const struct sockaddr *addr, socklen_t addrlen, SSL_CTX *ssl_ctx, + int connect_timeout, int ssl_connect_timeout, int response_timeout, + size_t max_connections) + { + int ret = this->init(addr, addrlen, connect_timeout, response_timeout, + max_connections); + + if (ret >= 0) + this->set_ssl(ssl_ctx, ssl_connect_timeout); + + return ret; + } + +private: + virtual CommTarget *acquire(int wait_timeout); /* final */ + virtual void release(); /* final */ + +private: + CommSchedGroup *group; + int index; + int wait_cnt; + pthread_mutex_t mutex; + pthread_cond_t cond; + friend class CommSchedGroup; +}; + +class CommSchedGroup : public CommSchedObject +{ +public: + int init(); + void deinit(); + int add(CommSchedTarget *target); + int remove(CommSchedTarget *target); + +private: + virtual CommTarget *acquire(int wait_timeout); /* final */ + +private: + CommSchedTarget **tg_heap; + int heap_size; + int heap_buf_size; + int wait_cnt; + pthread_mutex_t mutex; + pthread_cond_t cond; + +private: + static int target_cmp(CommSchedTarget *target1, CommSchedTarget *target2); + void heapify(int top); + void heap_adjust(int index, int swap_on_equal); + int heap_insert(CommSchedTarget *target); + void heap_remove(int index); + friend class CommSchedTarget; +}; + +class CommScheduler +{ +public: + int init(size_t poller_threads, size_t handler_threads) + { + return this->comm.init(poller_threads, handler_threads); + } + + void deinit() + { + this->comm.deinit(); + } + + /* wait_timeout in milliseconds, -1 for no timeout. */ + int request(CommSession *session, CommSchedObject *object, + int wait_timeout, CommTarget **target) + { + int ret = -1; + + *target = object->acquire(wait_timeout); + if (*target) + { + ret = this->comm.request(session, *target); + if (ret < 0) + (*target)->release(); + } + + return ret; + } + + /* for services. */ + int reply(CommSession *session) + { + return this->comm.reply(session); + } + + int shutdown(CommSession *session) + { + return this->comm.shutdown(session); + } + + int push(const void *buf, size_t size, CommSession *session) + { + return this->comm.push(buf, size, session); + } + + int bind(CommService *service) + { + return this->comm.bind(service); + } + + void unbind(CommService *service) + { + this->comm.unbind(service); + } + + /* for sleepers. */ + int sleep(SleepSession *session) + { + return this->comm.sleep(session); + } + + /* Call 'unsleep' only before 'handle()' returns. */ + int unsleep(SleepSession *session) + { + return this->comm.unsleep(session); + } + + /* for file aio services. */ + int io_bind(IOService *service) + { + return this->comm.io_bind(service); + } + + void io_unbind(IOService *service) + { + this->comm.io_unbind(service); + } + +public: + int is_handler_thread() const + { + return this->comm.is_handler_thread(); + } + + int increase_handler_thread() + { + return this->comm.increase_handler_thread(); + } + + int decrease_handler_thread() + { + return this->comm.decrease_handler_thread(); + } + +private: + Communicator comm; + +public: + virtual ~CommScheduler() { } +}; + +#endif + diff --git a/src/kernel/Communicator.cc b/src/kernel/Communicator.cc new file mode 100644 index 0000000..5d13cae --- /dev/null +++ b/src/kernel/Communicator.cc @@ -0,0 +1,2232 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "list.h" +#include "msgqueue.h" +#include "thrdpool.h" +#include "poller.h" +#include "mpoller.h" +#include "Communicator.h" + +struct CommConnEntry +{ + struct list_head list; + CommConnection *conn; + long long seq; + int sockfd; +#define CONN_STATE_CONNECTING 0 +#define CONN_STATE_CONNECTED 1 +#define CONN_STATE_RECEIVING 2 +#define CONN_STATE_SUCCESS 3 +#define CONN_STATE_IDLE 4 +#define CONN_STATE_KEEPALIVE 5 +#define CONN_STATE_CLOSING 6 +#define CONN_STATE_ERROR 7 + int state; + int error; + int ref; + struct iovec *write_iov; + SSL *ssl; + CommSession *session; + CommTarget *target; + CommService *service; + mpoller_t *mpoller; + /* Connection entry's mutex is for client session only. */ + pthread_mutex_t mutex; +}; + +static inline int __set_fd_nonblock(int fd) +{ + int flags = fcntl(fd, F_GETFL); + + if (flags >= 0) + flags = fcntl(fd, F_SETFL, flags | O_NONBLOCK); + + return flags; +} + +static int __bind_sockaddr(int sockfd, const struct sockaddr *addr, + socklen_t addrlen) +{ + struct sockaddr_storage ss; + socklen_t len; + + len = sizeof (struct sockaddr_storage); + if (getsockname(sockfd, (struct sockaddr *)&ss, &len) < 0) + return -1; + + ss.ss_family = 0; + while (len != 0) + { + if (((char *)&ss)[--len] != 0) + break; + } + + if (len == 0) + { + if (bind(sockfd, addr, addrlen) < 0) + return -1; + } + + return 0; +} + +static int __create_ssl(SSL_CTX *ssl_ctx, struct CommConnEntry *entry) +{ + BIO *bio = BIO_new_socket(entry->sockfd, BIO_NOCLOSE); + + if (bio) + { + entry->ssl = SSL_new(ssl_ctx); + if (entry->ssl) + { + SSL_set_bio(entry->ssl, bio, bio); + return 0; + } + + BIO_free(bio); + } + + return -1; +} + +static void __release_conn(struct CommConnEntry *entry) +{ + delete entry->conn; + if (!entry->service) + pthread_mutex_destroy(&entry->mutex); + + if (entry->ssl) + SSL_free(entry->ssl); + + close(entry->sockfd); + free(entry); +} + +int CommTarget::init(const struct sockaddr *addr, socklen_t addrlen, + int connect_timeout, int response_timeout) +{ + int ret; + + this->addr = (struct sockaddr *)malloc(addrlen); + if (this->addr) + { + ret = pthread_mutex_init(&this->mutex, NULL); + if (ret == 0) + { + memcpy(this->addr, addr, addrlen); + this->addrlen = addrlen; + this->connect_timeout = connect_timeout; + this->response_timeout = response_timeout; + INIT_LIST_HEAD(&this->idle_list); + + this->ssl_ctx = NULL; + this->ssl_connect_timeout = 0; + return 0; + } + + errno = ret; + free(this->addr); + } + + return -1; +} + +void CommTarget::deinit() +{ + pthread_mutex_destroy(&this->mutex); + free(this->addr); +} + +int CommMessageIn::feedback(const void *buf, size_t size) +{ + struct CommConnEntry *entry = this->entry; + const struct sockaddr *addr; + socklen_t addrlen; + int ret; + + if (!entry->ssl) + { + if (entry->service) + { + entry->target->get_addr(&addr, &addrlen); + return sendto(entry->sockfd, buf, size, 0, addr, addrlen); + } + else + return write(entry->sockfd, buf, size); + } + + if (size == 0) + return 0; + + ret = SSL_write(entry->ssl, buf, size); + if (ret <= 0) + { + ret = SSL_get_error(entry->ssl, ret); + if (ret != SSL_ERROR_SYSCALL) + errno = -ret; + + ret = -1; + } + + return ret; +} + +void CommMessageIn::renew() +{ + CommSession *session = this->entry->session; + session->timeout = -1; + session->begin_time.tv_sec = -1; + session->begin_time.tv_nsec = -1; +} + +int CommService::init(const struct sockaddr *bind_addr, socklen_t addrlen, + int listen_timeout, int response_timeout) +{ + int ret; + + this->bind_addr = (struct sockaddr *)malloc(addrlen); + if (this->bind_addr) + { + ret = pthread_mutex_init(&this->mutex, NULL); + if (ret == 0) + { + memcpy(this->bind_addr, bind_addr, addrlen); + this->addrlen = addrlen; + this->listen_timeout = listen_timeout; + this->response_timeout = response_timeout; + INIT_LIST_HEAD(&this->alive_list); + + this->ssl_ctx = NULL; + this->ssl_accept_timeout = 0; + return 0; + } + + errno = ret; + free(this->bind_addr); + } + + return -1; +} + +void CommService::deinit() +{ + pthread_mutex_destroy(&this->mutex); + free(this->bind_addr); +} + +int CommService::drain(int max) +{ + struct CommConnEntry *entry; + struct list_head *pos; + int errno_bak; + int cnt = 0; + + errno_bak = errno; + pthread_mutex_lock(&this->mutex); + while (cnt != max && !list_empty(&this->alive_list)) + { + pos = this->alive_list.next; + entry = list_entry(pos, struct CommConnEntry, list); + list_del(pos); + cnt++; + + /* Cannot change the sequence of next two lines. */ + mpoller_del(entry->sockfd, entry->mpoller); + entry->state = CONN_STATE_CLOSING; + } + + pthread_mutex_unlock(&this->mutex); + errno = errno_bak; + return cnt; +} + +inline void CommService::incref() +{ + __sync_add_and_fetch(&this->ref, 1); +} + +inline void CommService::decref() +{ + if (__sync_sub_and_fetch(&this->ref, 1) == 0) + this->handle_unbound(); +} + +class CommServiceTarget : public CommTarget +{ +public: + void incref() + { + __sync_add_and_fetch(&this->ref, 1); + } + + void decref() + { + if (__sync_sub_and_fetch(&this->ref, 1) == 0) + { + this->service->decref(); + this->deinit(); + delete this; + } + } + +public: + int shutdown(); + +private: + int sockfd; + int ref; + +private: + CommService *service; + +private: + virtual int create_connect_fd() + { + errno = EPERM; + return -1; + } + + friend class Communicator; +}; + +int CommServiceTarget::shutdown() +{ + struct CommConnEntry *entry; + int errno_bak; + int ret = 0; + + pthread_mutex_lock(&this->mutex); + if (!list_empty(&this->idle_list)) + { + entry = list_entry(this->idle_list.next, struct CommConnEntry, list); + list_del(&entry->list); + + if (this->service->reliable) + { + errno_bak = errno; + mpoller_del(entry->sockfd, entry->mpoller); + entry->state = CONN_STATE_CLOSING; + errno = errno_bak; + } + else + { + __release_conn(entry); + this->decref(); + } + + ret = 1; + } + + pthread_mutex_unlock(&this->mutex); + return ret; +} + +CommSession::~CommSession() +{ + CommServiceTarget *target; + + if (!this->passive) + return; + + target = (CommServiceTarget *)this->target; + if (this->passive == 2) + target->shutdown(); + + target->decref(); +} + +inline int Communicator::first_timeout(CommSession *session) +{ + int timeout = session->target->response_timeout; + + if (timeout < 0 || (unsigned int)session->timeout <= (unsigned int)timeout) + { + timeout = session->timeout; + session->timeout = 0; + } + else + clock_gettime(CLOCK_MONOTONIC, &session->begin_time); + + return timeout; +} + +int Communicator::next_timeout(CommSession *session) +{ + int timeout = session->target->response_timeout; + struct timespec cur_time; + int time_used, time_left; + + if (session->timeout > 0) + { + clock_gettime(CLOCK_MONOTONIC, &cur_time); + time_used = 1000 * (cur_time.tv_sec - session->begin_time.tv_sec) + + (cur_time.tv_nsec - session->begin_time.tv_nsec) / 1000000; + time_left = session->timeout - time_used; + if (time_left <= timeout) /* here timeout >= 0 */ + { + timeout = time_left < 0 ? 0 : time_left; + session->timeout = 0; + } + } + + return timeout; +} + +int Communicator::first_timeout_send(CommSession *session) +{ + session->timeout = session->send_timeout(); + return Communicator::first_timeout(session); +} + +int Communicator::first_timeout_recv(CommSession *session) +{ + session->timeout = session->receive_timeout(); + return Communicator::first_timeout(session); +} + +void Communicator::shutdown_service(CommService *service) +{ + close(service->listen_fd); + service->listen_fd = -1; + service->drain(-1); + service->decref(); +} + +#ifndef IOV_MAX +# ifdef UIO_MAXIOV +# define IOV_MAX UIO_MAXIOV +# else +# define IOV_MAX 1024 +# endif +#endif + +int Communicator::send_message_sync(struct iovec vectors[], int cnt, + struct CommConnEntry *entry) +{ + CommSession *session = entry->session; + CommService *service; + int timeout; + ssize_t n; + int i; + + while (cnt > 0) + { + if (!entry->ssl) + { + n = writev(entry->sockfd, vectors, cnt <= IOV_MAX ? cnt : IOV_MAX); + if (n < 0) + return errno == EAGAIN ? cnt : -1; + } + else if (vectors->iov_len > 0) + { + n = SSL_write(entry->ssl, vectors->iov_base, vectors->iov_len); + if (n <= 0) + return cnt; + } + else + n = 0; + + for (i = 0; i < cnt; i++) + { + if ((size_t)n >= vectors[i].iov_len) + n -= vectors[i].iov_len; + else + { + vectors[i].iov_base = (char *)vectors[i].iov_base + n; + vectors[i].iov_len -= n; + break; + } + } + + vectors += i; + cnt -= i; + } + + service = entry->service; + if (service) + { + __sync_add_and_fetch(&entry->ref, 1); + timeout = session->keep_alive_timeout(); + switch (timeout) + { + default: + mpoller_set_timeout(entry->sockfd, timeout, this->mpoller); + pthread_mutex_lock(&service->mutex); + if (service->listen_fd >= 0) + { + entry->state = CONN_STATE_KEEPALIVE; + list_add_tail(&entry->list, &service->alive_list); + entry = NULL; + } + + pthread_mutex_unlock(&service->mutex); + if (entry) + { + case 0: + mpoller_del(entry->sockfd, this->mpoller); + entry->state = CONN_STATE_CLOSING; + } + } + } + else + { + if (entry->state == CONN_STATE_IDLE) + { + timeout = session->first_timeout(); + if (timeout == 0) + timeout = Communicator::first_timeout_recv(session); + else + { + session->timeout = -1; + session->begin_time.tv_sec = -1; + session->begin_time.tv_nsec = 0; + } + + mpoller_set_timeout(entry->sockfd, timeout, this->mpoller); + } + + entry->state = CONN_STATE_RECEIVING; + } + + return 0; +} + +int Communicator::send_message_async(struct iovec vectors[], int cnt, + struct CommConnEntry *entry) +{ + struct poller_data data; + int timeout; + int ret; + int i; + + entry->write_iov = (struct iovec *)malloc(cnt * sizeof (struct iovec)); + if (entry->write_iov) + { + for (i = 0; i < cnt; i++) + entry->write_iov[i] = vectors[i]; + } + else + return -1; + + data.operation = PD_OP_WRITE; + data.fd = entry->sockfd; + data.ssl = entry->ssl; + data.partial_written = Communicator::partial_written; + data.context = entry; + data.write_iov = entry->write_iov; + data.iovcnt = cnt; + timeout = Communicator::first_timeout_send(entry->session); + if (entry->state == CONN_STATE_IDLE) + { + ret = mpoller_mod(&data, timeout, this->mpoller); + if (ret < 0 && errno == ENOENT) + entry->state = CONN_STATE_RECEIVING; + } + else + { + ret = mpoller_add(&data, timeout, this->mpoller); + if (ret >= 0) + { + if (this->stop_flag) + mpoller_del(data.fd, this->mpoller); + } + } + + if (ret < 0) + { + free(entry->write_iov); + if (entry->state != CONN_STATE_RECEIVING) + return -1; + } + + return 1; +} + +#define ENCODE_IOV_MAX 2048 + +int Communicator::send_message(struct CommConnEntry *entry) +{ + struct iovec vectors[ENCODE_IOV_MAX]; + struct iovec *end; + int cnt; + + cnt = entry->session->out->encode(vectors, ENCODE_IOV_MAX); + if ((unsigned int)cnt > ENCODE_IOV_MAX) + { + if (cnt > ENCODE_IOV_MAX) + errno = EOVERFLOW; + return -1; + } + + end = vectors + cnt; + cnt = this->send_message_sync(vectors, cnt, entry); + if (cnt <= 0) + return cnt; + + return this->send_message_async(end - cnt, cnt, entry); +} + +void Communicator::handle_incoming_request(struct poller_result *res) +{ + struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; + CommTarget *target = entry->target; + CommSession *session = NULL; + int state; + + switch (res->state) + { + case PR_ST_SUCCESS: + session = entry->session; + state = CS_STATE_TOREPLY; + pthread_mutex_lock(&target->mutex); + if (entry->state == CONN_STATE_SUCCESS) + { + __sync_add_and_fetch(&entry->ref, 1); + entry->state = CONN_STATE_IDLE; + list_add(&entry->list, &target->idle_list); + } + + session->passive = 2; + pthread_mutex_unlock(&target->mutex); + break; + + case PR_ST_FINISHED: + res->error = ECONNRESET; + if (1) + case PR_ST_ERROR: + state = CS_STATE_ERROR; + else + case PR_ST_DELETED: + case PR_ST_STOPPED: + state = CS_STATE_STOPPED; + + pthread_mutex_lock(&target->mutex); + switch (entry->state) + { + case CONN_STATE_KEEPALIVE: + pthread_mutex_lock(&entry->service->mutex); + if (entry->state == CONN_STATE_KEEPALIVE) + list_del(&entry->list); + pthread_mutex_unlock(&entry->service->mutex); + break; + + case CONN_STATE_IDLE: + list_del(&entry->list); + break; + + case CONN_STATE_ERROR: + res->error = entry->error; + state = CS_STATE_ERROR; + case CONN_STATE_RECEIVING: + session = entry->session; + break; + + case CONN_STATE_SUCCESS: + /* This may happen only if handler_threads > 1. */ + entry->state = CONN_STATE_CLOSING; + entry = NULL; + break; + } + + pthread_mutex_unlock(&target->mutex); + break; + } + + if (entry) + { + if (session) + session->handle(state, res->error); + + if (__sync_sub_and_fetch(&entry->ref, 1) == 0) + { + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + } + } +} + +void Communicator::handle_incoming_reply(struct poller_result *res) +{ + struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; + CommTarget *target = entry->target; + CommSession *session = NULL; + pthread_mutex_t *mutex; + int state; + + switch (res->state) + { + case PR_ST_SUCCESS: + session = entry->session; + state = CS_STATE_SUCCESS; + pthread_mutex_lock(&target->mutex); + if (entry->state == CONN_STATE_SUCCESS) + { + __sync_add_and_fetch(&entry->ref, 1); + if (session->timeout != 0) /* This is keep-alive timeout. */ + { + entry->state = CONN_STATE_IDLE; + list_add(&entry->list, &target->idle_list); + } + else + entry->state = CONN_STATE_CLOSING; + } + + pthread_mutex_unlock(&target->mutex); + break; + + case PR_ST_FINISHED: + res->error = ECONNRESET; + if (1) + case PR_ST_ERROR: + state = CS_STATE_ERROR; + else + case PR_ST_DELETED: + case PR_ST_STOPPED: + state = CS_STATE_STOPPED; + + mutex = &entry->mutex; + pthread_mutex_lock(&target->mutex); + pthread_mutex_lock(mutex); + switch (entry->state) + { + case CONN_STATE_IDLE: + list_del(&entry->list); + break; + + case CONN_STATE_ERROR: + res->error = entry->error; + state = CS_STATE_ERROR; + case CONN_STATE_RECEIVING: + session = entry->session; + break; + + case CONN_STATE_SUCCESS: + /* This may happen only if handler_threads > 1. */ + entry->state = CONN_STATE_CLOSING; + entry = NULL; + break; + } + + pthread_mutex_unlock(&target->mutex); + pthread_mutex_unlock(mutex); + break; + } + + if (entry) + { + if (session) + { + target->release(); + session->handle(state, res->error); + } + + if (__sync_sub_and_fetch(&entry->ref, 1) == 0) + __release_conn(entry); + } +} + +void Communicator::handle_read_result(struct poller_result *res) +{ + struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; + + if (res->state != PR_ST_MODIFIED) + { + if (entry->service) + this->handle_incoming_request(res); + else + this->handle_incoming_reply(res); + } +} + +void Communicator::handle_reply_result(struct poller_result *res) +{ + struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; + CommService *service = entry->service; + CommSession *session = entry->session; + CommTarget *target = entry->target; + int timeout; + int state; + + switch (res->state) + { + case PR_ST_FINISHED: + timeout = session->keep_alive_timeout(); + if (timeout != 0) + { + __sync_add_and_fetch(&entry->ref, 1); + res->data.operation = PD_OP_READ; + res->data.create_message = Communicator::create_request; + res->data.message = NULL; + pthread_mutex_lock(&target->mutex); + if (mpoller_add(&res->data, timeout, this->mpoller) >= 0) + { + pthread_mutex_lock(&service->mutex); + if (!this->stop_flag && service->listen_fd >= 0) + { + entry->state = CONN_STATE_KEEPALIVE; + list_add_tail(&entry->list, &service->alive_list); + } + else + { + mpoller_del(res->data.fd, this->mpoller); + entry->state = CONN_STATE_CLOSING; + } + + pthread_mutex_unlock(&service->mutex); + } + else + __sync_sub_and_fetch(&entry->ref, 1); + + pthread_mutex_unlock(&target->mutex); + } + + if (1) + state = CS_STATE_SUCCESS; + else if (1) + case PR_ST_ERROR: + state = CS_STATE_ERROR; + else + case PR_ST_DELETED: /* DELETED seems not possible. */ + case PR_ST_STOPPED: + state = CS_STATE_STOPPED; + + session->handle(state, res->error); + if (__sync_sub_and_fetch(&entry->ref, 1) == 0) + { + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + } + + break; + } +} + +void Communicator::handle_request_result(struct poller_result *res) +{ + struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; + CommSession *session = entry->session; + int timeout; + int state; + + switch (res->state) + { + case PR_ST_FINISHED: + entry->state = CONN_STATE_RECEIVING; + res->data.operation = PD_OP_READ; + res->data.create_message = Communicator::create_reply; + res->data.message = NULL; + timeout = session->first_timeout(); + if (timeout == 0) + timeout = Communicator::first_timeout_recv(session); + else + { + session->timeout = -1; + session->begin_time.tv_sec = -1; + session->begin_time.tv_nsec = 0; + } + + if (mpoller_add(&res->data, timeout, this->mpoller) >= 0) + { + if (this->stop_flag) + mpoller_del(res->data.fd, this->mpoller); + break; + } + + res->error = errno; + if (1) + case PR_ST_ERROR: + state = CS_STATE_ERROR; + else + case PR_ST_DELETED: + case PR_ST_STOPPED: + state = CS_STATE_STOPPED; + + entry->target->release(); + session->handle(state, res->error); + pthread_mutex_lock(&entry->mutex); + /* do nothing */ + pthread_mutex_unlock(&entry->mutex); + if (__sync_sub_and_fetch(&entry->ref, 1) == 0) + __release_conn(entry); + + break; + } +} + +void Communicator::handle_write_result(struct poller_result *res) +{ + struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; + + free(entry->write_iov); + if (entry->service) + this->handle_reply_result(res); + else + this->handle_request_result(res); +} + +struct CommConnEntry *Communicator::accept_conn(CommServiceTarget *target, + CommService *service) +{ + struct CommConnEntry *entry; + size_t size; + + if (__set_fd_nonblock(target->sockfd) >= 0) + { + size = offsetof(struct CommConnEntry, mutex); + entry = (struct CommConnEntry *)malloc(size); + if (entry) + { + entry->conn = service->new_connection(target->sockfd); + if (entry->conn) + { + entry->seq = 0; + entry->mpoller = NULL; + entry->service = service; + entry->target = target; + entry->ssl = NULL; + entry->sockfd = target->sockfd; + entry->state = CONN_STATE_CONNECTED; + entry->ref = 1; + return entry; + } + + free(entry); + } + } + + return NULL; +} + +void Communicator::handle_connect_result(struct poller_result *res) +{ + struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; + CommSession *session = entry->session; + CommTarget *target = entry->target; + int timeout; + int state; + int ret; + + switch (res->state) + { + case PR_ST_FINISHED: + if (target->ssl_ctx && !entry->ssl) + { + if (__create_ssl(target->ssl_ctx, entry) >= 0 && + target->init_ssl(entry->ssl) >= 0) + { + ret = 0; + res->data.operation = PD_OP_SSL_CONNECT; + res->data.ssl = entry->ssl; + timeout = target->ssl_connect_timeout; + } + else + ret = -1; + } + else if ((session->out = session->message_out()) != NULL) + { + ret = this->send_message(entry); + if (ret == 0) + { + res->data.operation = PD_OP_READ; + res->data.create_message = Communicator::create_reply; + res->data.message = NULL; + timeout = session->first_timeout(); + if (timeout == 0) + timeout = Communicator::first_timeout_recv(session); + else + { + session->timeout = -1; + session->begin_time.tv_sec = -1; + session->begin_time.tv_nsec = 0; + } + } + else if (ret > 0) + break; + } + else + ret = -1; + + if (ret >= 0) + { + if (mpoller_add(&res->data, timeout, this->mpoller) >= 0) + { + if (this->stop_flag) + mpoller_del(res->data.fd, this->mpoller); + break; + } + } + + res->error = errno; + if (1) + case PR_ST_ERROR: + state = CS_STATE_ERROR; + else + case PR_ST_DELETED: + case PR_ST_STOPPED: + state = CS_STATE_STOPPED; + + target->release(); + session->handle(state, res->error); + __release_conn(entry); + break; + } +} + +void Communicator::handle_listen_result(struct poller_result *res) +{ + CommService *service = (CommService *)res->data.context; + struct CommConnEntry *entry; + CommServiceTarget *target; + int timeout; + + switch (res->state) + { + case PR_ST_SUCCESS: + target = (CommServiceTarget *)res->data.result; + entry = Communicator::accept_conn(target, service); + if (entry) + { + entry->mpoller = this->mpoller; + if (service->ssl_ctx) + { + if (__create_ssl(service->ssl_ctx, entry) >= 0 && + service->init_ssl(entry->ssl) >= 0) + { + res->data.operation = PD_OP_SSL_ACCEPT; + timeout = service->ssl_accept_timeout; + } + } + else + { + res->data.operation = PD_OP_READ; + res->data.create_message = Communicator::create_request; + res->data.message = NULL; + timeout = target->response_timeout; + } + + if (res->data.operation != PD_OP_LISTEN) + { + res->data.fd = entry->sockfd; + res->data.ssl = entry->ssl; + res->data.context = entry; + if (mpoller_add(&res->data, timeout, this->mpoller) >= 0) + { + if (this->stop_flag) + mpoller_del(res->data.fd, this->mpoller); + break; + } + } + + __release_conn(entry); + } + else + close(target->sockfd); + + target->decref(); + break; + + case PR_ST_DELETED: + this->shutdown_service(service); + break; + + case PR_ST_ERROR: + case PR_ST_STOPPED: + service->handle_stop(res->error); + break; + } +} + +void Communicator::handle_recvfrom_result(struct poller_result *res) +{ + CommService *service = (CommService *)res->data.context; + struct CommConnEntry *entry; + CommTarget *target; + int state, error; + + switch (res->state) + { + case PR_ST_SUCCESS: + entry = (struct CommConnEntry *)res->data.result; + target = entry->target; + if (entry->state == CONN_STATE_SUCCESS) + { + state = CS_STATE_TOREPLY; + error = 0; + entry->state = CONN_STATE_IDLE; + list_add(&entry->list, &target->idle_list); + } + else + { + state = CS_STATE_ERROR; + if (entry->state == CONN_STATE_ERROR) + error = entry->error; + else + error = EBADMSG; + } + + entry->session->handle(state, error); + if (state == CS_STATE_ERROR) + { + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + } + + break; + + case PR_ST_DELETED: + this->shutdown_service(service); + break; + + case PR_ST_ERROR: + case PR_ST_STOPPED: + service->handle_stop(res->error); + break; + } +} + +void Communicator::handle_ssl_accept_result(struct poller_result *res) +{ + struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; + CommTarget *target = entry->target; + int timeout; + + switch (res->state) + { + case PR_ST_FINISHED: + res->data.operation = PD_OP_READ; + res->data.create_message = Communicator::create_request; + res->data.message = NULL; + timeout = target->response_timeout; + if (mpoller_add(&res->data, timeout, this->mpoller) >= 0) + { + if (this->stop_flag) + mpoller_del(res->data.fd, this->mpoller); + break; + } + + case PR_ST_DELETED: + case PR_ST_ERROR: + case PR_ST_STOPPED: + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + break; + } +} + +void Communicator::handle_sleep_result(struct poller_result *res) +{ + SleepSession *session = (SleepSession *)res->data.context; + int state; + + switch (res->state) + { + case PR_ST_FINISHED: + state = SS_STATE_COMPLETE; + break; + case PR_ST_DELETED: + res->error = ECANCELED; + case PR_ST_ERROR: + state = SS_STATE_ERROR; + break; + case PR_ST_STOPPED: + state = SS_STATE_DISRUPTED; + break; + } + + session->handle(state, res->error); +} + +void Communicator::handle_aio_result(struct poller_result *res) +{ + IOService *service = (IOService *)res->data.context; + IOSession *session; + int state, error; + + switch (res->state) + { + case PR_ST_SUCCESS: + session = (IOSession *)res->data.result; + pthread_mutex_lock(&service->mutex); + list_del(&session->list); + pthread_mutex_unlock(&service->mutex); + if (session->res >= 0) + { + state = IOS_STATE_SUCCESS; + error = 0; + } + else + { + state = IOS_STATE_ERROR; + error = -session->res; + } + + session->handle(state, error); + service->decref(); + break; + + case PR_ST_DELETED: + this->shutdown_io_service(service); + break; + + case PR_ST_ERROR: + case PR_ST_STOPPED: + service->handle_stop(res->error); + break; + } +} + +void Communicator::handler_thread_routine(void *context) +{ + Communicator *comm = (Communicator *)context; + struct poller_result *res; + + while (1) + { + res = (struct poller_result *)msgqueue_get(comm->msgqueue); + if (!res) + break; + + switch (res->data.operation) + { + case PD_OP_TIMER: + comm->handle_sleep_result(res); + break; + case PD_OP_READ: + comm->handle_read_result(res); + break; + case PD_OP_WRITE: + comm->handle_write_result(res); + break; + case PD_OP_CONNECT: + case PD_OP_SSL_CONNECT: + comm->handle_connect_result(res); + break; + case PD_OP_LISTEN: + comm->handle_listen_result(res); + break; + case PD_OP_RECVFROM: + comm->handle_recvfrom_result(res); + break; + case PD_OP_SSL_ACCEPT: + comm->handle_ssl_accept_result(res); + break; + case PD_OP_EVENT: + case PD_OP_NOTIFY: + comm->handle_aio_result(res); + break; + default: + free(res); + if (comm->thrdpool) + thrdpool_exit(comm->thrdpool); + continue; + } + + free(res); + } + + if (!comm->thrdpool) + { + mpoller_destroy(comm->mpoller); + msgqueue_destroy(comm->msgqueue); + } +} + +int Communicator::append_message(const void *buf, size_t *size, + poller_message_t *msg) +{ + CommMessageIn *in = (CommMessageIn *)msg; + struct CommConnEntry *entry = in->entry; + CommSession *session = entry->session; + int timeout; + int ret; + + ret = in->append(buf, size); + if (ret > 0) + { + entry->state = CONN_STATE_SUCCESS; + if (!entry->service) + { + timeout = session->keep_alive_timeout(); + session->timeout = timeout; /* Reuse session's timeout field. */ + if (timeout == 0) + { + mpoller_del(entry->sockfd, entry->mpoller); + return ret; + } + } + else + timeout = -1; + } + else if (ret == 0 && session->timeout != 0) + { + if (session->begin_time.tv_sec < 0) + { + if (session->begin_time.tv_nsec < 0) + timeout = session->first_timeout(); + else + timeout = 0; + + if (timeout == 0) + timeout = Communicator::first_timeout_recv(session); + else + session->begin_time.tv_nsec = 0; + } + else + timeout = Communicator::next_timeout(session); + } + else + return ret; + + /* This set_timeout() never fails, which is very important. */ + mpoller_set_timeout(entry->sockfd, timeout, entry->mpoller); + return ret; +} + +poller_message_t *Communicator::create_request(void *context) +{ + struct CommConnEntry *entry = (struct CommConnEntry *)context; + CommService *service = entry->service; + CommTarget *target = entry->target; + CommSession *session; + CommMessageIn *in; + int timeout; + + if (entry->state == CONN_STATE_IDLE) + { + pthread_mutex_lock(&target->mutex); + /* do nothing */ + pthread_mutex_unlock(&target->mutex); + } + + pthread_mutex_lock(&service->mutex); + if (entry->state == CONN_STATE_KEEPALIVE) + list_del(&entry->list); + else if (entry->state != CONN_STATE_CONNECTED) + entry = NULL; + + pthread_mutex_unlock(&service->mutex); + if (!entry) + { + errno = EBADMSG; + return NULL; + } + + session = service->new_session(entry->seq, entry->conn); + if (!session) + return NULL; + + session->passive = 1; + entry->session = session; + session->target = target; + session->conn = entry->conn; + session->seq = entry->seq++; + session->out = NULL; + session->in = NULL; + + timeout = Communicator::first_timeout_recv(session); + mpoller_set_timeout(entry->sockfd, timeout, entry->mpoller); + entry->state = CONN_STATE_RECEIVING; + + ((CommServiceTarget *)target)->incref(); + + in = session->message_in(); + if (in) + { + in->poller_message_t::append = Communicator::append_message; + in->entry = entry; + session->in = in; + } + + return in; +} + +poller_message_t *Communicator::create_reply(void *context) +{ + struct CommConnEntry *entry = (struct CommConnEntry *)context; + CommSession *session; + CommMessageIn *in; + + if (entry->state == CONN_STATE_IDLE) + { + pthread_mutex_lock(&entry->mutex); + /* do nothing */ + pthread_mutex_unlock(&entry->mutex); + } + + if (entry->state != CONN_STATE_RECEIVING) + { + errno = EBADMSG; + return NULL; + } + + session = entry->session; + in = session->message_in(); + if (in) + { + in->poller_message_t::append = Communicator::append_message; + in->entry = entry; + session->in = in; + } + + return in; +} + +int Communicator::recv_request(const void *buf, size_t size, + struct CommConnEntry *entry) +{ + CommService *service = entry->service; + CommTarget *target = entry->target; + CommSession *session; + CommMessageIn *in; + size_t n; + int ret; + + session = service->new_session(entry->seq, entry->conn); + if (!session) + return -1; + + session->passive = 1; + entry->session = session; + session->target = target; + session->conn = entry->conn; + session->seq = entry->seq++; + session->out = NULL; + session->in = NULL; + + entry->state = CONN_STATE_RECEIVING; + + ((CommServiceTarget *)target)->incref(); + + in = session->message_in(); + if (in) + { + in->entry = entry; + session->in = in; + do + { + n = size; + ret = in->append(buf, &n); + if (ret == 0) + { + size -= n; + buf = (const char *)buf + n; + } + else if (ret < 0) + { + entry->error = errno; + entry->state = CONN_STATE_ERROR; + } + else + entry->state = CONN_STATE_SUCCESS; + + } while (ret == 0 && size > 0); + } + + return 0; +} + +int Communicator::partial_written(size_t n, void *context) +{ + struct CommConnEntry *entry = (struct CommConnEntry *)context; + CommSession *session = entry->session; + int timeout; + + timeout = Communicator::next_timeout(session); + mpoller_set_timeout(entry->sockfd, timeout, entry->mpoller); + return 0; +} + +void *Communicator::accept(const struct sockaddr *addr, socklen_t addrlen, + int sockfd, void *context) +{ + CommService *service = (CommService *)context; + CommServiceTarget *target = new CommServiceTarget; + + if (target) + { + if (target->init(addr, addrlen, 0, service->response_timeout) >= 0) + { + service->incref(); + target->service = service; + target->sockfd = sockfd; + target->ref = 1; + return target; + } + + delete target; + } + + close(sockfd); + return NULL; +} + +void *Communicator::recvfrom(const struct sockaddr *addr, socklen_t addrlen, + const void *buf, size_t size, void *context) +{ + CommService *service = (CommService *)context; + struct CommConnEntry *entry; + CommServiceTarget *target; + void *result; + int sockfd; + + sockfd = dup(service->listen_fd); + if (sockfd >= 0) + { + result = Communicator::accept(addr, addrlen, sockfd, context); + if (result) + { + target = (CommServiceTarget *)result; + entry = Communicator::accept_conn(target, service); + if (entry) + { + if (Communicator::recv_request(buf, size, entry) >= 0) + return entry; + + __release_conn(entry); + } + else + close(sockfd); + + target->decref(); + } + } + + return NULL; +} + +void Communicator::callback(struct poller_result *res, void *context) +{ + msgqueue_t *msgqueue = (msgqueue_t *)context; + msgqueue_put(res, msgqueue); +} + +int Communicator::create_handler_threads(size_t handler_threads) +{ + struct thrdpool_task task = { + .routine = Communicator::handler_thread_routine, + .context = this + }; + size_t i; + + this->thrdpool = thrdpool_create(handler_threads, 0); + if (this->thrdpool) + { + for (i = 0; i < handler_threads; i++) + { + if (thrdpool_schedule(&task, this->thrdpool) < 0) + break; + } + + if (i == handler_threads) + return 0; + + msgqueue_set_nonblock(this->msgqueue); + thrdpool_destroy(NULL, this->thrdpool); + } + + return -1; +} + +int Communicator::create_poller(size_t poller_threads) +{ + struct poller_params params = { + .max_open_files = (size_t)sysconf(_SC_OPEN_MAX), + .callback = Communicator::callback, + }; + + if ((ssize_t)params.max_open_files < 0) + return -1; + + this->msgqueue = msgqueue_create(16 * 1024, sizeof (struct poller_result)); + if (this->msgqueue) + { + params.context = this->msgqueue; + this->mpoller = mpoller_create(¶ms, poller_threads); + if (this->mpoller) + { + if (mpoller_start(this->mpoller) >= 0) + return 0; + + mpoller_destroy(this->mpoller); + } + + msgqueue_destroy(this->msgqueue); + } + + return -1; +} + +int Communicator::init(size_t poller_threads, size_t handler_threads) +{ + if (poller_threads == 0) + { + errno = EINVAL; + return -1; + } + + if (this->create_poller(poller_threads) >= 0) + { + if (this->create_handler_threads(handler_threads) >= 0) + { + this->stop_flag = 0; + return 0; + } + + mpoller_stop(this->mpoller); + mpoller_destroy(this->mpoller); + msgqueue_destroy(this->msgqueue); + } + + return -1; +} + +void Communicator::deinit() +{ + int in_handler = this->is_handler_thread(); + + this->stop_flag = 1; + mpoller_stop(this->mpoller); + msgqueue_set_nonblock(this->msgqueue); + thrdpool_destroy(NULL, this->thrdpool); + this->thrdpool = NULL; + if (!in_handler) + Communicator::handler_thread_routine(this); +} + +int Communicator::nonblock_connect(CommTarget *target) +{ + int sockfd = target->create_connect_fd(); + + if (sockfd >= 0) + { + if (__set_fd_nonblock(sockfd) >= 0) + { + if (connect(sockfd, target->addr, target->addrlen) >= 0 || + errno == EINPROGRESS) + { + return sockfd; + } + } + + close(sockfd); + } + + return -1; +} + +struct CommConnEntry *Communicator::launch_conn(CommSession *session, + CommTarget *target) +{ + struct CommConnEntry *entry; + int sockfd; + int ret; + + sockfd = Communicator::nonblock_connect(target); + if (sockfd >= 0) + { + entry = (struct CommConnEntry *)malloc(sizeof (struct CommConnEntry)); + if (entry) + { + ret = pthread_mutex_init(&entry->mutex, NULL); + if (ret == 0) + { + entry->conn = target->new_connection(sockfd); + if (entry->conn) + { + entry->seq = 0; + entry->mpoller = NULL; + entry->service = NULL; + entry->target = target; + entry->session = session; + entry->ssl = NULL; + entry->sockfd = sockfd; + entry->state = CONN_STATE_CONNECTING; + entry->ref = 1; + return entry; + } + + pthread_mutex_destroy(&entry->mutex); + } + else + errno = ret; + + free(entry); + } + + close(sockfd); + } + + return NULL; +} + +int Communicator::request_idle_conn(CommSession *session, CommTarget *target) +{ + struct CommConnEntry *entry; + struct list_head *pos; + int ret = -1; + + while (1) + { + pthread_mutex_lock(&target->mutex); + if (!list_empty(&target->idle_list)) + { + pos = target->idle_list.next; + entry = list_entry(pos, struct CommConnEntry, list); + list_del(pos); + pthread_mutex_lock(&entry->mutex); + } + else + entry = NULL; + + pthread_mutex_unlock(&target->mutex); + if (!entry) + { + errno = ENOENT; + return -1; + } + + if (mpoller_set_timeout(entry->sockfd, -1, this->mpoller) >= 0) + break; + + entry->state = CONN_STATE_CLOSING; + pthread_mutex_unlock(&entry->mutex); + } + + entry->session = session; + session->conn = entry->conn; + session->seq = entry->seq++; + session->out = session->message_out(); + if (session->out) + ret = this->send_message(entry); + + if (ret < 0) + { + entry->error = errno; + mpoller_del(entry->sockfd, this->mpoller); + entry->state = CONN_STATE_ERROR; + ret = 1; + } + + pthread_mutex_unlock(&entry->mutex); + return ret; +} + +int Communicator::request_new_conn(CommSession *session, CommTarget *target) +{ + struct CommConnEntry *entry; + struct poller_data data; + int timeout; + + entry = Communicator::launch_conn(session, target); + if (entry) + { + entry->mpoller = this->mpoller; + session->conn = entry->conn; + session->seq = entry->seq++; + data.operation = PD_OP_CONNECT; + data.fd = entry->sockfd; + data.ssl = NULL; + data.context = entry; + timeout = session->target->connect_timeout; + if (mpoller_add(&data, timeout, this->mpoller) >= 0) + return 0; + + __release_conn(entry); + } + + return -1; +} + +int Communicator::request(CommSession *session, CommTarget *target) +{ + int errno_bak; + + if (session->passive) + { + errno = EINVAL; + return -1; + } + + errno_bak = errno; + session->target = target; + session->out = NULL; + session->in = NULL; + if (this->request_idle_conn(session, target) < 0) + { + if (this->request_new_conn(session, target) < 0) + { + session->conn = NULL; + session->seq = 0; + return -1; + } + } + + errno = errno_bak; + return 0; +} + +int Communicator::nonblock_listen(CommService *service) +{ + int sockfd = service->create_listen_fd(); + int ret; + + if (sockfd >= 0) + { + if (__set_fd_nonblock(sockfd) >= 0) + { + if (__bind_sockaddr(sockfd, service->bind_addr, + service->addrlen) >= 0) + { + ret = listen(sockfd, SOMAXCONN); + if (ret >= 0 || errno == EOPNOTSUPP) + { + service->reliable = (ret >= 0); + return sockfd; + } + } + } + + close(sockfd); + } + + return -1; +} + +int Communicator::bind(CommService *service) +{ + struct poller_data data; + int errno_bak = errno; + int sockfd; + + sockfd = this->nonblock_listen(service); + if (sockfd >= 0) + { + service->listen_fd = sockfd; + service->ref = 1; + data.fd = sockfd; + data.context = service; + data.result = NULL; + if (service->reliable) + { + data.operation = PD_OP_LISTEN; + data.accept = Communicator::accept; + } + else + { + data.operation = PD_OP_RECVFROM; + data.recvfrom = Communicator::recvfrom; + } + + if (mpoller_add(&data, service->listen_timeout, this->mpoller) >= 0) + { + errno = errno_bak; + return 0; + } + + close(sockfd); + } + + return -1; +} + +void Communicator::unbind(CommService *service) +{ + int errno_bak = errno; + + if (mpoller_del(service->listen_fd, this->mpoller) < 0) + { + /* Error occurred on listen_fd or Communicator::deinit() called. */ + this->shutdown_service(service); + errno = errno_bak; + } +} + +int Communicator::reply_reliable(CommSession *session, CommTarget *target) +{ + struct CommConnEntry *entry; + struct list_head *pos; + int ret = -1; + + pthread_mutex_lock(&target->mutex); + if (!list_empty(&target->idle_list)) + { + pos = target->idle_list.next; + entry = list_entry(pos, struct CommConnEntry, list); + list_del(pos); + + session->out = session->message_out(); + if (session->out) + ret = this->send_message(entry); + + if (ret < 0) + { + entry->error = errno; + mpoller_del(entry->sockfd, this->mpoller); + entry->state = CONN_STATE_ERROR; + ret = 1; + } + } + else + errno = ENOENT; + + pthread_mutex_unlock(&target->mutex); + return ret; +} + +int Communicator::reply_message_unreliable(struct CommConnEntry *entry) +{ + struct iovec vectors[ENCODE_IOV_MAX]; + int cnt; + + cnt = entry->session->out->encode(vectors, ENCODE_IOV_MAX); + if ((unsigned int)cnt > ENCODE_IOV_MAX) + { + if (cnt > ENCODE_IOV_MAX) + errno = EOVERFLOW; + return -1; + } + + if (cnt > 0) + { + struct msghdr message = { + .msg_name = entry->target->addr, + .msg_namelen = entry->target->addrlen, + .msg_iov = vectors, +#ifdef __linux__ + .msg_iovlen = (size_t)cnt, +#else + .msg_iovlen = cnt, +#endif + }; + if (sendmsg(entry->sockfd, &message, 0) < 0) + return -1; + } + + return 0; +} + +int Communicator::reply_unreliable(CommSession *session, CommTarget *target) +{ + struct CommConnEntry *entry; + struct list_head *pos; + + if (!list_empty(&target->idle_list)) + { + pos = target->idle_list.next; + entry = list_entry(pos, struct CommConnEntry, list); + list_del(pos); + + session->out = session->message_out(); + if (session->out) + { + if (this->reply_message_unreliable(entry) >= 0) + return 0; + } + + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + } + else + errno = ENOENT; + + return -1; +} + +int Communicator::reply(CommSession *session) +{ + struct CommConnEntry *entry; + CommServiceTarget *target; + int errno_bak; + int ret; + + if (session->passive != 2) + { + errno = EINVAL; + return -1; + } + + errno_bak = errno; + session->passive = 3; + target = (CommServiceTarget *)session->target; + if (target->service->reliable) + ret = this->reply_reliable(session, target); + else + ret = this->reply_unreliable(session, target); + + if (ret == 0) + { + entry = session->in->entry; + session->handle(CS_STATE_SUCCESS, 0); + if (__sync_sub_and_fetch(&entry->ref, 1) == 0) + { + __release_conn(entry); + target->decref(); + } + } + else if (ret < 0) + return -1; + + errno = errno_bak; + return 0; +} + +int Communicator::push(const void *buf, size_t size, CommSession *session) +{ + CommMessageIn *in = session->in; + pthread_mutex_t *mutex; + int ret; + + if (!in) + { + errno = ENOENT; + return -1; + } + + if (session->passive) + mutex = &session->target->mutex; + else + mutex = &in->entry->mutex; + + pthread_mutex_lock(mutex); + if ((session->passive == 2 && !list_empty(&session->target->idle_list)) || + (!session->passive && in->entry->session == session) || + session->passive == 1) + { + ret = in->inner()->feedback(buf, size); + } + else + { + errno = ENOENT; + ret = -1; + } + + pthread_mutex_unlock(mutex); + return ret; +} + +int Communicator::shutdown(CommSession *session) +{ + CommServiceTarget *target; + + if (session->passive != 2) + { + errno = EINVAL; + return -1; + } + + session->passive = 3; + target = (CommServiceTarget *)session->target; + if (!target->shutdown()) + { + errno = ENOENT; + return -1; + } + + return 0; +} + +int Communicator::sleep(SleepSession *session) +{ + struct timespec value; + + if (session->duration(&value) >= 0) + { + if (mpoller_add_timer(&value, session, &session->timer, &session->index, + this->mpoller) >= 0) + return 0; + } + + return -1; +} + +int Communicator::unsleep(SleepSession *session) +{ + return mpoller_del_timer(session->timer, session->index, this->mpoller); +} + +int Communicator::is_handler_thread() const +{ + return thrdpool_in_pool(this->thrdpool); +} + +extern "C" void __thrdpool_schedule(const struct thrdpool_task *, void *, + thrdpool_t *); + +int Communicator::increase_handler_thread() +{ + void *buf = malloc(4 * sizeof (void *)); + + if (buf) + { + if (thrdpool_increase(this->thrdpool) >= 0) + { + struct thrdpool_task task = { + .routine = Communicator::handler_thread_routine, + .context = this + }; + __thrdpool_schedule(&task, buf, this->thrdpool); + return 0; + } + + free(buf); + } + + return -1; +} + +int Communicator::decrease_handler_thread() +{ + struct poller_result *res; + size_t size; + + size = sizeof (struct poller_result) + sizeof (void *); + res = (struct poller_result *)malloc(size); + if (res) + { + res->data.operation = -1; + msgqueue_put_head(res, this->msgqueue); + return 0; + } + + return -1; +} + +#ifdef __linux__ + +void Communicator::shutdown_io_service(IOService *service) +{ + pthread_mutex_lock(&service->mutex); + close(service->event_fd); + service->event_fd = -1; + pthread_mutex_unlock(&service->mutex); + service->decref(); +} + +int Communicator::io_bind(IOService *service) +{ + struct poller_data data; + int event_fd; + + event_fd = service->create_event_fd(); + if (event_fd >= 0) + { + if (__set_fd_nonblock(event_fd) >= 0) + { + service->ref = 1; + data.operation = PD_OP_EVENT; + data.fd = event_fd; + data.event = IOService::aio_finish; + data.context = service; + data.result = NULL; + if (mpoller_add(&data, -1, this->mpoller) >= 0) + { + service->event_fd = event_fd; + return 0; + } + } + + close(event_fd); + } + + return -1; +} + +void Communicator::io_unbind(IOService *service) +{ + int errno_bak = errno; + + if (mpoller_del(service->event_fd, this->mpoller) < 0) + { + /* Error occurred on event_fd or Communicator::deinit() called. */ + this->shutdown_io_service(service); + errno = errno_bak; + } +} + +#else + +void Communicator::shutdown_io_service(IOService *service) +{ + pthread_mutex_lock(&service->mutex); + close(service->pipe_fd[0]); + close(service->pipe_fd[1]); + service->pipe_fd[0] = -1; + service->pipe_fd[1] = -1; + pthread_mutex_unlock(&service->mutex); + service->decref(); +} + +int Communicator::io_bind(IOService *service) +{ + struct poller_data data; + int pipe_fd[2]; + + if (service->create_pipe_fd(pipe_fd) >= 0) + { + if (__set_fd_nonblock(pipe_fd[0]) >= 0) + { + service->ref = 1; + data.operation = PD_OP_NOTIFY; + data.fd = pipe_fd[0]; + data.notify = IOService::aio_finish; + data.context = service; + data.result = NULL; + if (mpoller_add(&data, -1, this->mpoller) >= 0) + { + service->pipe_fd[0] = pipe_fd[0]; + service->pipe_fd[1] = pipe_fd[1]; + return 0; + } + } + + close(pipe_fd[0]); + close(pipe_fd[1]); + } + + return -1; +} + +void Communicator::io_unbind(IOService *service) +{ + int errno_bak = errno; + + if (mpoller_del(service->pipe_fd[0], this->mpoller) < 0) + { + /* Error occurred on pipe_fd or Communicator::deinit() called. */ + this->shutdown_io_service(service); + errno = errno_bak; + } +} + +#endif + diff --git a/src/kernel/Communicator.h b/src/kernel/Communicator.h new file mode 100644 index 0000000..2c5f7e0 --- /dev/null +++ b/src/kernel/Communicator.h @@ -0,0 +1,385 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _COMMUNICATOR_H_ +#define _COMMUNICATOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "list.h" +#include "poller.h" + +class CommConnection +{ +public: + virtual ~CommConnection() { } +}; + +class CommTarget +{ +public: + int init(const struct sockaddr *addr, socklen_t addrlen, + int connect_timeout, int response_timeout); + void deinit(); + +public: + void get_addr(const struct sockaddr **addr, socklen_t *addrlen) const + { + *addr = this->addr; + *addrlen = this->addrlen; + } + + int has_idle_conn() const { return !list_empty(&this->idle_list); } + +protected: + void set_ssl(SSL_CTX *ssl_ctx, int ssl_connect_timeout) + { + this->ssl_ctx = ssl_ctx; + this->ssl_connect_timeout = ssl_connect_timeout; + } + + SSL_CTX *get_ssl_ctx() const { return this->ssl_ctx; } + +private: + virtual int create_connect_fd() + { + return socket(this->addr->sa_family, SOCK_STREAM, 0); + } + + virtual CommConnection *new_connection(int connect_fd) + { + return new CommConnection; + } + + virtual int init_ssl(SSL *ssl) { return 0; } + +public: + virtual void release() { } + +private: + struct sockaddr *addr; + socklen_t addrlen; + int connect_timeout; + int response_timeout; + int ssl_connect_timeout; + SSL_CTX *ssl_ctx; + +private: + struct list_head idle_list; + pthread_mutex_t mutex; + +public: + virtual ~CommTarget() { } + friend class CommServiceTarget; + friend class Communicator; +}; + +class CommMessageOut +{ +private: + virtual int encode(struct iovec vectors[], int max) = 0; + +public: + virtual ~CommMessageOut() { } + friend class Communicator; +}; + +class CommMessageIn : private poller_message_t +{ +private: + virtual int append(const void *buf, size_t *size) = 0; + +protected: + /* Send small packet while receiving. Call only in append(). */ + virtual int feedback(const void *buf, size_t size); + + /* In append(), reset the begin time of receiving to current time. */ + virtual void renew(); + + /* Return the deepest wrapped message. */ + virtual CommMessageIn *inner() { return this; } + +private: + struct CommConnEntry *entry; + +public: + virtual ~CommMessageIn() { } + friend class Communicator; +}; + +#define CS_STATE_SUCCESS 0 +#define CS_STATE_ERROR 1 +#define CS_STATE_STOPPED 2 +#define CS_STATE_TOREPLY 3 /* for service session only. */ + +class CommSession +{ +private: + virtual CommMessageOut *message_out() = 0; + virtual CommMessageIn *message_in() = 0; + virtual int send_timeout() { return -1; } + virtual int receive_timeout() { return -1; } + virtual int keep_alive_timeout() { return 0; } + virtual int first_timeout() { return 0; } + virtual void handle(int state, int error) = 0; + +protected: + CommTarget *get_target() const { return this->target; } + CommConnection *get_connection() const { return this->conn; } + CommMessageOut *get_message_out() const { return this->out; } + CommMessageIn *get_message_in() const { return this->in; } + long long get_seq() const { return this->seq; } + +private: + CommTarget *target; + CommConnection *conn; + CommMessageOut *out; + CommMessageIn *in; + long long seq; + +private: + struct timespec begin_time; + int timeout; + int passive; + +public: + CommSession() { this->passive = 0; } + virtual ~CommSession(); + friend class CommMessageIn; + friend class Communicator; +}; + +class CommService +{ +public: + int init(const struct sockaddr *bind_addr, socklen_t addrlen, + int listen_timeout, int response_timeout); + void deinit(); + + int drain(int max); + +public: + void get_addr(const struct sockaddr **addr, socklen_t *addrlen) const + { + *addr = this->bind_addr; + *addrlen = this->addrlen; + } + +protected: + void set_ssl(SSL_CTX *ssl_ctx, int ssl_accept_timeout) + { + this->ssl_ctx = ssl_ctx; + this->ssl_accept_timeout = ssl_accept_timeout; + } + + SSL_CTX *get_ssl_ctx() const { return this->ssl_ctx; } + +private: + virtual CommSession *new_session(long long seq, CommConnection *conn) = 0; + virtual void handle_stop(int error) { } + virtual void handle_unbound() = 0; + +private: + virtual int create_listen_fd() + { + return socket(this->bind_addr->sa_family, SOCK_STREAM, 0); + } + + virtual CommConnection *new_connection(int accept_fd) + { + return new CommConnection; + } + + virtual int init_ssl(SSL *ssl) { return 0; } + +private: + struct sockaddr *bind_addr; + socklen_t addrlen; + int listen_timeout; + int response_timeout; + int ssl_accept_timeout; + SSL_CTX *ssl_ctx; + +private: + void incref(); + void decref(); + +private: + int reliable; + int listen_fd; + int ref; + +private: + struct list_head alive_list; + pthread_mutex_t mutex; + +public: + virtual ~CommService() { } + friend class CommServiceTarget; + friend class Communicator; +}; + +#define SS_STATE_COMPLETE 0 +#define SS_STATE_ERROR 1 +#define SS_STATE_DISRUPTED 2 + +class SleepSession +{ +private: + virtual int duration(struct timespec *value) = 0; + virtual void handle(int state, int error) = 0; + +private: + void *timer; + int index; + +public: + virtual ~SleepSession() { } + friend class Communicator; +}; + +#ifdef __linux__ +# include "IOService_linux.h" +#else +# include "IOService_thread.h" +#endif + +class Communicator +{ +public: + int init(size_t poller_threads, size_t handler_threads); + void deinit(); + + int request(CommSession *session, CommTarget *target); + int reply(CommSession *session); + + int push(const void *buf, size_t size, CommSession *session); + + int shutdown(CommSession *session); + + int bind(CommService *service); + void unbind(CommService *service); + + int sleep(SleepSession *session); + int unsleep(SleepSession *session); + + int io_bind(IOService *service); + void io_unbind(IOService *service); + +public: + int is_handler_thread() const; + + int increase_handler_thread(); + int decrease_handler_thread(); + +private: + struct __mpoller *mpoller; + struct __msgqueue *msgqueue; + struct __thrdpool *thrdpool; + int stop_flag; + +private: + int create_poller(size_t poller_threads); + + int create_handler_threads(size_t handler_threads); + + void shutdown_service(CommService *service); + + void shutdown_io_service(IOService *service); + + int send_message_sync(struct iovec vectors[], int cnt, + struct CommConnEntry *entry); + int send_message_async(struct iovec vectors[], int cnt, + struct CommConnEntry *entry); + + int send_message(struct CommConnEntry *entry); + + int request_new_conn(CommSession *session, CommTarget *target); + int request_idle_conn(CommSession *session, CommTarget *target); + + int reply_message_unreliable(struct CommConnEntry *entry); + + int reply_reliable(CommSession *session, CommTarget *target); + int reply_unreliable(CommSession *session, CommTarget *target); + + void handle_incoming_request(struct poller_result *res); + void handle_incoming_reply(struct poller_result *res); + + void handle_request_result(struct poller_result *res); + void handle_reply_result(struct poller_result *res); + + void handle_write_result(struct poller_result *res); + void handle_read_result(struct poller_result *res); + + void handle_connect_result(struct poller_result *res); + void handle_listen_result(struct poller_result *res); + + void handle_recvfrom_result(struct poller_result *res); + + void handle_ssl_accept_result(struct poller_result *res); + + void handle_sleep_result(struct poller_result *res); + + void handle_aio_result(struct poller_result *res); + + static void handler_thread_routine(void *context); + + static int nonblock_connect(CommTarget *target); + static int nonblock_listen(CommService *service); + + static struct CommConnEntry *launch_conn(CommSession *session, + CommTarget *target); + static struct CommConnEntry *accept_conn(class CommServiceTarget *target, + CommService *service); + + static int first_timeout(CommSession *session); + static int next_timeout(CommSession *session); + + static int first_timeout_send(CommSession *session); + static int first_timeout_recv(CommSession *session); + + static int append_message(const void *buf, size_t *size, + poller_message_t *msg); + + static poller_message_t *create_request(void *context); + static poller_message_t *create_reply(void *context); + + static int recv_request(const void *buf, size_t size, + struct CommConnEntry *entry); + + static int partial_written(size_t n, void *context); + + static void *accept(const struct sockaddr *addr, socklen_t addrlen, + int sockfd, void *context); + + static void *recvfrom(const struct sockaddr *addr, socklen_t addrlen, + const void *buf, size_t size, void *context); + + static void callback(struct poller_result *res, void *context); + +public: + virtual ~Communicator() { } +}; + +#endif + diff --git a/src/kernel/ExecRequest.h b/src/kernel/ExecRequest.h new file mode 100644 index 0000000..fd2f45b --- /dev/null +++ b/src/kernel/ExecRequest.h @@ -0,0 +1,62 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _EXECREQUEST_H_ +#define _EXECREQUEST_H_ + +#include "SubTask.h" +#include "Executor.h" + +class ExecRequest : public SubTask, public ExecSession +{ +public: + ExecRequest(ExecQueue *queue, Executor *executor) + { + this->executor = executor; + this->queue = queue; + } + + ExecQueue *get_request_queue() const { return this->queue; } + void set_request_queue(ExecQueue *queue) { this->queue = queue; } + +public: + virtual void dispatch() + { + if (this->executor->request(this, this->queue) < 0) + this->handle(ES_STATE_ERROR, errno); + } + +protected: + int state; + int error; + +protected: + ExecQueue *queue; + Executor *executor; + +protected: + virtual void handle(int state, int error) + { + this->state = state; + this->error = error; + this->subtask_done(); + } +}; + +#endif + diff --git a/src/kernel/Executor.cc b/src/kernel/Executor.cc new file mode 100644 index 0000000..efb731d --- /dev/null +++ b/src/kernel/Executor.cc @@ -0,0 +1,158 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include "list.h" +#include "thrdpool.h" +#include "Executor.h" + +struct ExecSessionEntry +{ + struct list_head list; + ExecSession *session; + thrdpool_t *thrdpool; +}; + +int ExecQueue::init() +{ + int ret; + + ret = pthread_mutex_init(&this->mutex, NULL); + if (ret == 0) + { + INIT_LIST_HEAD(&this->session_list); + return 0; + } + + errno = ret; + return -1; +} + +void ExecQueue::deinit() +{ + pthread_mutex_destroy(&this->mutex); +} + +int Executor::init(size_t nthreads) +{ + this->thrdpool = thrdpool_create(nthreads, 0); + if (this->thrdpool) + return 0; + + return -1; +} + +void Executor::deinit() +{ + thrdpool_destroy(Executor::executor_cancel, this->thrdpool); +} + +extern "C" void __thrdpool_schedule(const struct thrdpool_task *, void *, + thrdpool_t *); + +void Executor::executor_thread_routine(void *context) +{ + ExecQueue *queue = (ExecQueue *)context; + struct ExecSessionEntry *entry; + ExecSession *session; + int empty; + + entry = list_entry(queue->session_list.next, struct ExecSessionEntry, list); + pthread_mutex_lock(&queue->mutex); + list_del(&entry->list); + empty = list_empty(&queue->session_list); + pthread_mutex_unlock(&queue->mutex); + + session = entry->session; + if (!empty) + { + struct thrdpool_task task = { + .routine = Executor::executor_thread_routine, + .context = queue + }; + __thrdpool_schedule(&task, entry, entry->thrdpool); + } + else + free(entry); + + session->execute(); + session->handle(ES_STATE_FINISHED, 0); +} + +void Executor::executor_cancel(const struct thrdpool_task *task) +{ + ExecQueue *queue = (ExecQueue *)task->context; + struct ExecSessionEntry *entry; + struct list_head *pos, *tmp; + ExecSession *session; + + list_for_each_safe(pos, tmp, &queue->session_list) + { + entry = list_entry(pos, struct ExecSessionEntry, list); + list_del(pos); + session = entry->session; + free(entry); + + session->handle(ES_STATE_CANCELED, 0); + } +} + +int Executor::request(ExecSession *session, ExecQueue *queue) +{ + struct ExecSessionEntry *entry; + + session->queue = queue; + entry = (struct ExecSessionEntry *)malloc(sizeof (struct ExecSessionEntry)); + if (entry) + { + entry->session = session; + entry->thrdpool = this->thrdpool; + pthread_mutex_lock(&queue->mutex); + list_add_tail(&entry->list, &queue->session_list); + if (queue->session_list.next == &entry->list) + { + struct thrdpool_task task = { + .routine = Executor::executor_thread_routine, + .context = queue + }; + if (thrdpool_schedule(&task, this->thrdpool) < 0) + { + list_del(&entry->list); + free(entry); + entry = NULL; + } + } + + pthread_mutex_unlock(&queue->mutex); + } + + return -!entry; +} + +int Executor::increase_thread() +{ + return thrdpool_increase(this->thrdpool); +} + +int Executor::decrease_thread() +{ + return thrdpool_decrease(this->thrdpool); +} + diff --git a/src/kernel/Executor.h b/src/kernel/Executor.h new file mode 100644 index 0000000..0b84c9b --- /dev/null +++ b/src/kernel/Executor.h @@ -0,0 +1,86 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _EXECUTOR_H_ +#define _EXECUTOR_H_ + +#include +#include +#include "list.h" + +class ExecQueue +{ +public: + int init(); + void deinit(); + +private: + struct list_head session_list; + pthread_mutex_t mutex; + +public: + virtual ~ExecQueue() { } + friend class Executor; +}; + +#define ES_STATE_FINISHED 0 +#define ES_STATE_ERROR 1 +#define ES_STATE_CANCELED 2 + +class ExecSession +{ +private: + virtual void execute() = 0; + virtual void handle(int state, int error) = 0; + +protected: + ExecQueue *get_queue() const { return this->queue; } + +private: + ExecQueue *queue; + +public: + virtual ~ExecSession() { } + friend class Executor; +}; + +class Executor +{ +public: + int init(size_t nthreads); + void deinit(); + + int request(ExecSession *session, ExecQueue *queue); + +public: + int increase_thread(); + int decrease_thread(); + +private: + struct __thrdpool *thrdpool; + +private: + static void executor_thread_routine(void *context); + static void executor_cancel(const struct thrdpool_task *task); + +public: + virtual ~Executor() { } +}; + +#endif + diff --git a/src/kernel/IORequest.h b/src/kernel/IORequest.h new file mode 100644 index 0000000..c17bf2e --- /dev/null +++ b/src/kernel/IORequest.h @@ -0,0 +1,58 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _IOREQUEST_H_ +#define _IOREQUEST_H_ + +#include +#include "SubTask.h" +#include "Communicator.h" + +class IORequest : public SubTask, public IOSession +{ +public: + IORequest(IOService *service) + { + this->service = service; + } + +public: + virtual void dispatch() + { + if (this->service->request(this) < 0) + this->handle(IOS_STATE_ERROR, errno); + } + +protected: + int state; + int error; + +protected: + IOService *service; + +protected: + virtual void handle(int state, int error) + { + this->state = state; + this->error = error; + this->subtask_done(); + } +}; + +#endif + diff --git a/src/kernel/IOService_linux.cc b/src/kernel/IOService_linux.cc new file mode 100644 index 0000000..81acb0f --- /dev/null +++ b/src/kernel/IOService_linux.cc @@ -0,0 +1,362 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include "list.h" +#include "IOService_linux.h" + +/* Linux async I/O interface from libaio.h */ + +typedef struct io_context *io_context_t; + +typedef enum io_iocb_cmd { + IO_CMD_PREAD = 0, + IO_CMD_PWRITE = 1, + + IO_CMD_FSYNC = 2, + IO_CMD_FDSYNC = 3, + + IO_CMD_POLL = 5, + IO_CMD_NOOP = 6, + IO_CMD_PREADV = 7, + IO_CMD_PWRITEV = 8, +} io_iocb_cmd_t; + +/* little endian, 32 bits */ +#if defined(__i386__) || (defined(__arm__) && !defined(__ARMEB__)) || \ + defined(__sh__) || defined(__bfin__) || defined(__MIPSEL__) || \ + defined(__cris__) || (defined(__riscv) && __riscv_xlen == 32) || \ + (defined(__GNUC__) && defined(__BYTE_ORDER__) && \ + __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ && __SIZEOF_LONG__ == 4) +#define PADDED(x, y) x; unsigned y +#define PADDEDptr(x, y) x; unsigned y +#define PADDEDul(x, y) unsigned long x; unsigned y + +/* little endian, 64 bits */ +#elif defined(__ia64__) || defined(__x86_64__) || defined(__alpha__) || \ + (defined(__aarch64__) && defined(__AARCH64EL__)) || \ + (defined(__riscv) && __riscv_xlen == 64) || \ + (defined(__GNUC__) && defined(__BYTE_ORDER__) && \ + __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ && __SIZEOF_LONG__ == 8) +#define PADDED(x, y) x, y +#define PADDEDptr(x, y) x +#define PADDEDul(x, y) unsigned long x + +/* big endian, 64 bits */ +#elif defined(__powerpc64__) || defined(__s390x__) || \ + (defined(__sparc__) && defined(__arch64__)) || \ + (defined(__aarch64__) && defined(__AARCH64EB__)) || \ + (defined(__GNUC__) && defined(__BYTE_ORDER__) && \ + __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ && __SIZEOF_LONG__ == 8) +#define PADDED(x, y) unsigned y; x +#define PADDEDptr(x,y) x +#define PADDEDul(x, y) unsigned long x + +/* big endian, 32 bits */ +#elif defined(__PPC__) || defined(__s390__) || \ + (defined(__arm__) && defined(__ARMEB__)) || \ + defined(__sparc__) || defined(__MIPSEB__) || defined(__m68k__) || \ + defined(__hppa__) || defined(__frv__) || defined(__avr32__) || \ + (defined(__GNUC__) && defined(__BYTE_ORDER__) && \ + __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ && __SIZEOF_LONG__ == 4) +#define PADDED(x, y) unsigned y; x +#define PADDEDptr(x, y) unsigned y; x +#define PADDEDul(x, y) unsigned y; unsigned long x + +#else +#error endian? +#endif + +struct io_iocb_poll { + PADDED(int events, __pad1); +}; /* result code is the set of result flags or -'ve errno */ + +struct io_iocb_sockaddr { + struct sockaddr *addr; + int len; +}; /* result code is the length of the sockaddr, or -'ve errno */ + +struct io_iocb_common { + PADDEDptr(void *buf, __pad1); + PADDEDul(nbytes, __pad2); + long long offset; + long long __pad3; + unsigned flags; + unsigned resfd; +}; /* result code is the amount read or -'ve errno */ + +struct io_iocb_vector { + const struct iovec *vec; + int nr; + long long offset; +}; /* result code is the amount read or -'ve errno */ + +struct iocb { + PADDEDptr(void *data, __pad1); /* Return in the io completion event */ + /* key: For use in identifying io requests */ + /* aio_rw_flags: RWF_* flags (such as RWF_NOWAIT) */ + PADDED(unsigned key, aio_rw_flags); + + short aio_lio_opcode; + short aio_reqprio; + int aio_fildes; + + union { + struct io_iocb_common c; + struct io_iocb_vector v; + struct io_iocb_poll poll; + struct io_iocb_sockaddr saddr; + } u; +}; + +struct io_event { + PADDEDptr(void *data, __pad1); + PADDEDptr(struct iocb *obj, __pad2); + PADDEDul(res, __pad3); + PADDEDul(res2, __pad4); +}; + +#undef PADDED +#undef PADDEDptr +#undef PADDEDul + +/* Actual syscalls */ +static inline int io_setup(int maxevents, io_context_t *ctxp) +{ + return syscall(__NR_io_setup, maxevents, ctxp); +} + +static inline int io_destroy(io_context_t ctx) +{ + return syscall(__NR_io_destroy, ctx); +} + +static inline int io_submit(io_context_t ctx, long nr, struct iocb *ios[]) +{ + return syscall(__NR_io_submit, ctx, nr, ios); +} + +static inline int io_cancel(io_context_t ctx, struct iocb *iocb, + struct io_event *evt) +{ + return syscall(__NR_io_cancel, ctx, iocb, evt); +} + +static inline int io_getevents(io_context_t ctx_id, long min_nr, long nr, + struct io_event *events, + struct timespec *timeout) +{ + return syscall(__NR_io_getevents, ctx_id, min_nr, nr, events, timeout); +} + +static inline void io_set_eventfd(struct iocb *iocb, int eventfd) +{ + iocb->u.c.flags |= (1 << 0) /* IOCB_FLAG_RESFD */; + iocb->u.c.resfd = eventfd; +} + +void IOSession::prep_pread(int fd, void *buf, size_t count, long long offset) +{ + struct iocb *iocb = (struct iocb *)this->iocb_buf; + + memset(iocb, 0, sizeof(*iocb)); + iocb->aio_fildes = fd; + iocb->aio_lio_opcode = IO_CMD_PREAD; + iocb->u.c.buf = buf; + iocb->u.c.nbytes = count; + iocb->u.c.offset = offset; +} + +void IOSession::prep_pwrite(int fd, void *buf, size_t count, long long offset) +{ + struct iocb *iocb = (struct iocb *)this->iocb_buf; + + memset(iocb, 0, sizeof(*iocb)); + iocb->aio_fildes = fd; + iocb->aio_lio_opcode = IO_CMD_PWRITE; + iocb->u.c.buf = buf; + iocb->u.c.nbytes = count; + iocb->u.c.offset = offset; +} + +void IOSession::prep_preadv(int fd, const struct iovec *iov, int iovcnt, + long long offset) +{ + struct iocb *iocb = (struct iocb *)this->iocb_buf; + + memset(iocb, 0, sizeof(*iocb)); + iocb->aio_fildes = fd; + iocb->aio_lio_opcode = IO_CMD_PREADV; + iocb->u.c.buf = (void *)iov; + iocb->u.c.nbytes = iovcnt; + iocb->u.c.offset = offset; +} + +void IOSession::prep_pwritev(int fd, const struct iovec *iov, int iovcnt, + long long offset) +{ + struct iocb *iocb = (struct iocb *)this->iocb_buf; + + memset(iocb, 0, sizeof(*iocb)); + iocb->aio_fildes = fd; + iocb->aio_lio_opcode = IO_CMD_PWRITEV; + iocb->u.c.buf = (void *)iov; + iocb->u.c.nbytes = iovcnt; + iocb->u.c.offset = offset; +} + +void IOSession::prep_fsync(int fd) +{ + struct iocb *iocb = (struct iocb *)this->iocb_buf; + + memset(iocb, 0, sizeof(*iocb)); + iocb->aio_fildes = fd; + iocb->aio_lio_opcode = IO_CMD_FSYNC; +} + +void IOSession::prep_fdsync(int fd) +{ + struct iocb *iocb = (struct iocb *)this->iocb_buf; + + memset(iocb, 0, sizeof(*iocb)); + iocb->aio_fildes = fd; + iocb->aio_lio_opcode = IO_CMD_FDSYNC; +} + +int IOService::init(int maxevents) +{ + int ret; + + if (maxevents < 0) + { + errno = EINVAL; + return -1; + } + + this->io_ctx = NULL; + if (io_setup(maxevents, &this->io_ctx) >= 0) + { + ret = pthread_mutex_init(&this->mutex, NULL); + if (ret == 0) + { + INIT_LIST_HEAD(&this->session_list); + this->event_fd = -1; + return 0; + } + + errno = ret; + io_destroy(this->io_ctx); + } + + return -1; +} + +void IOService::deinit() +{ + pthread_mutex_destroy(&this->mutex); + io_destroy(this->io_ctx); +} + +inline void IOService::incref() +{ + __sync_add_and_fetch(&this->ref, 1); +} + +void IOService::decref() +{ + IOSession *session; + struct io_event event; + int state, error; + + if (__sync_sub_and_fetch(&this->ref, 1) == 0) + { + while (!list_empty(&this->session_list)) + { + if (io_getevents(this->io_ctx, 1, 1, &event, NULL) > 0) + { + session = (IOSession *)event.data; + list_del(&session->list); + session->res = event.res; + if (session->res >= 0) + { + state = IOS_STATE_SUCCESS; + error = 0; + } + else + { + state = IOS_STATE_ERROR; + error = -session->res; + } + + session->handle(state, error); + } + } + + this->handle_unbound(); + } +} + +int IOService::request(IOSession *session) +{ + struct iocb *iocb = (struct iocb *)session->iocb_buf; + int ret = -1; + + pthread_mutex_lock(&this->mutex); + if (this->event_fd < 0) + errno = ENOENT; + else if (session->prepare() >= 0) + { + io_set_eventfd(iocb, this->event_fd); + iocb->data = session; + if (io_submit(this->io_ctx, 1, &iocb) > 0) + { + list_add_tail(&session->list, &this->session_list); + ret = 0; + } + } + + pthread_mutex_unlock(&this->mutex); + if (ret < 0) + session->res = -errno; + + return ret; +} + +void *IOService::aio_finish(void *context) +{ + IOService *service = (IOService *)context; + IOSession *session; + struct io_event event; + + if (io_getevents(service->io_ctx, 1, 1, &event, NULL) > 0) + { + service->incref(); + session = (IOSession *)event.data; + session->res = event.res; + return session; + } + + return NULL; +} + diff --git a/src/kernel/IOService_linux.h b/src/kernel/IOService_linux.h new file mode 100644 index 0000000..11103ea --- /dev/null +++ b/src/kernel/IOService_linux.h @@ -0,0 +1,106 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _IOSERVICE_LINUX_H_ +#define _IOSERVICE_LINUX_H_ + +#include +#include +#include +#include +#include "list.h" + +#define IOS_STATE_SUCCESS 0 +#define IOS_STATE_ERROR 1 + +class IOSession +{ +private: + virtual int prepare() = 0; + virtual void handle(int state, int error) = 0; + +protected: + /* prepare() has to call one of the the prep_ functions. */ + void prep_pread(int fd, void *buf, size_t count, long long offset); + void prep_pwrite(int fd, void *buf, size_t count, long long offset); + void prep_preadv(int fd, const struct iovec *iov, int iovcnt, + long long offset); + void prep_pwritev(int fd, const struct iovec *iov, int iovcnt, + long long offset); + void prep_fsync(int fd); + void prep_fdsync(int fd); + +protected: + long get_res() const { return this->res; } + +private: + char iocb_buf[64]; + long res; + +private: + struct list_head list; + +public: + virtual ~IOSession() { } + friend class IOService; + friend class Communicator; +}; + +class IOService +{ +public: + int init(int maxevents); + void deinit(); + + int request(IOSession *session); + +private: + virtual void handle_stop(int error) { } + virtual void handle_unbound() = 0; + +private: + virtual int create_event_fd() + { + return eventfd(0, 0); + } + +private: + struct io_context *io_ctx; + +private: + void incref(); + void decref(); + +private: + int event_fd; + int ref; + +private: + struct list_head session_list; + pthread_mutex_t mutex; + +private: + static void *aio_finish(void *context); + +public: + virtual ~IOService() { } + friend class Communicator; +}; + +#endif + diff --git a/src/kernel/IOService_thread.cc b/src/kernel/IOService_thread.cc new file mode 100644 index 0000000..5ab45b4 --- /dev/null +++ b/src/kernel/IOService_thread.cc @@ -0,0 +1,314 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include "list.h" +#include "IOService_thread.h" + +typedef enum io_iocb_cmd { + IO_CMD_PREAD = 0, + IO_CMD_PWRITE = 1, + + IO_CMD_FSYNC = 2, + IO_CMD_FDSYNC = 3, + + IO_CMD_NOOP = 6, + IO_CMD_PREADV = 7, + IO_CMD_PWRITEV = 8, +} io_iocb_cmd_t; + +void IOSession::prep_pread(int fd, void *buf, size_t count, long long offset) +{ + this->fd = fd; + this->op = IO_CMD_PREAD; + this->buf = buf; + this->count = count; + this->offset = offset; +} + +void IOSession::prep_pwrite(int fd, void *buf, size_t count, long long offset) +{ + this->fd = fd; + this->op = IO_CMD_PWRITE; + this->buf = buf; + this->count = count; + this->offset = offset; +} + +void IOSession::prep_preadv(int fd, const struct iovec *iov, int iovcnt, + long long offset) +{ + this->fd = fd; + this->op = IO_CMD_PREADV; + this->buf = (void *)iov; + this->count = iovcnt; + this->offset = offset; +} + +void IOSession::prep_pwritev(int fd, const struct iovec *iov, int iovcnt, + long long offset) +{ + this->fd = fd; + this->op = IO_CMD_PWRITEV; + this->buf = (void *)iov; + this->count = iovcnt; + this->offset = offset; +} + +void IOSession::prep_fsync(int fd) +{ + this->fd = fd; + this->op = IO_CMD_FSYNC; +} + +void IOSession::prep_fdsync(int fd) +{ + this->fd = fd; + this->op = IO_CMD_FDSYNC; +} + +int IOService::init(int maxevents) +{ + void *p; + int ret; + + if (maxevents <= 0) + { + errno = EINVAL; + return -1; + } + + ret = pthread_mutex_init(&this->mutex, NULL); + if (ret) + { + errno = ret; + return -1; + } + + p = dlsym(RTLD_DEFAULT, "preadv"); + if (p) + this->preadv = (ssize_t (*)(int, const struct iovec *, int, off_t))p; + else + this->preadv = IOService::preadv_emul; + + p = dlsym(RTLD_DEFAULT, "pwritev"); + if (p) + this->pwritev = (ssize_t (*)(int, const struct iovec *, int, off_t))p; + else + this->pwritev = IOService::pwritev_emul; + + this->maxevents = maxevents; + this->nevents = 0; + INIT_LIST_HEAD(&this->session_list); + this->pipe_fd[0] = -1; + this->pipe_fd[1] = -1; + return 0; +} + +void IOService::deinit() +{ + pthread_mutex_destroy(&this->mutex); +} + +inline void IOService::incref() +{ + __sync_add_and_fetch(&this->ref, 1); +} + +void IOService::decref() +{ + IOSession *session; + int state, error; + + if (__sync_sub_and_fetch(&this->ref, 1) == 0) + { + while (!list_empty(&this->session_list)) + { + session = list_entry(this->session_list.next, IOSession, list); + pthread_join(session->tid, NULL); + list_del(&session->list); + if (session->res >= 0) + { + state = IOS_STATE_SUCCESS; + error = 0; + } + else + { + state = IOS_STATE_ERROR; + error = -session->res; + } + + session->handle(state, error); + } + + pthread_mutex_lock(&this->mutex); + /* Wait for detached threads. */ + pthread_mutex_unlock(&this->mutex); + this->handle_unbound(); + } +} + +int IOService::request(IOSession *session) +{ + pthread_t tid; + int ret = -1; + + pthread_mutex_lock(&this->mutex); + if (this->pipe_fd[0] < 0) + errno = ENOENT; + else if (this->nevents >= this->maxevents) + errno = EAGAIN; + else if (session->prepare() >= 0) + { + session->service = this; + ret = pthread_create(&tid, NULL, IOService::io_routine, session); + if (ret == 0) + { + session->tid = tid; + list_add_tail(&session->list, &this->session_list); + this->nevents++; + } + else + { + errno = ret; + ret = -1; + } + } + + pthread_mutex_unlock(&this->mutex); + if (ret < 0) + session->res = -errno; + + return ret; +} + +#if _POSIX_SYNCHRONIZED_IO <= 0 +static inline int fdatasync(int fd) +{ + return fsync(fd); +} +#endif + +void *IOService::io_routine(void *arg) +{ + IOSession *session = (IOSession *)arg; + IOService *service = session->service; + int fd = session->fd; + ssize_t ret; + + switch (session->op) + { + case IO_CMD_PREAD: + ret = pread(fd, session->buf, session->count, session->offset); + break; + case IO_CMD_PWRITE: + ret = pwrite(fd, session->buf, session->count, session->offset); + break; + case IO_CMD_FSYNC: + ret = fsync(fd); + break; + case IO_CMD_FDSYNC: + ret = fdatasync(fd); + break; + case IO_CMD_PREADV: + ret = service->preadv(fd, (const struct iovec *)session->buf, + session->count, session->offset); + break; + case IO_CMD_PWRITEV: + ret = service->pwritev(fd, (const struct iovec *)session->buf, + session->count, session->offset); + break; + default: + errno = EINVAL; + ret = -1; + break; + } + + if (ret < 0) + ret = -errno; + + session->res = ret; + pthread_mutex_lock(&service->mutex); + if (service->pipe_fd[1] >= 0) + write(service->pipe_fd[1], &session, sizeof (void *)); + + service->nevents--; + pthread_mutex_unlock(&service->mutex); + return NULL; +} + +void *IOService::aio_finish(void *ptr, void *context) +{ + IOService *service = (IOService *)context; + IOSession *session = (IOSession *)ptr; + + service->incref(); + pthread_detach(session->tid); + return session; +} + +ssize_t IOService::preadv_emul(int fd, const struct iovec *iov, int iovcnt, + off_t offset) +{ + size_t total = 0; + ssize_t n; + int i; + + for (i = 0; i < iovcnt; i++) + { + n = pread(fd, iov[i].iov_base, iov[i].iov_len, offset); + if (n < 0) + return total == 0 ? -1 : total; + + total += n; + if ((size_t)n < iov[i].iov_len) + return total; + + offset += n; + } + + return total; +} + +ssize_t IOService::pwritev_emul(int fd, const struct iovec *iov, int iovcnt, + off_t offset) +{ + size_t total = 0; + ssize_t n; + int i; + + for (i = 0; i < iovcnt; i++) + { + n = pwrite(fd, iov[i].iov_base, iov[i].iov_len, offset); + if (n < 0) + return total == 0 ? -1 : total; + + total += n; + if ((size_t)n < iov[i].iov_len) + return total; + + offset += n; + } + + return total; +} + diff --git a/src/kernel/IOService_thread.h b/src/kernel/IOService_thread.h new file mode 100644 index 0000000..29bebad --- /dev/null +++ b/src/kernel/IOService_thread.h @@ -0,0 +1,122 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _IOSERVICE_THREAD_H_ +#define _IOSERVICE_THREAD_H_ + +#include +#include +#include +#include +#include "list.h" + +#define IOS_STATE_SUCCESS 0 +#define IOS_STATE_ERROR 1 + +class IOSession +{ +private: + virtual int prepare() = 0; + virtual void handle(int state, int error) = 0; + +protected: + /* prepare() has to call one of the the prep_ functions. */ + void prep_pread(int fd, void *buf, size_t count, long long offset); + void prep_pwrite(int fd, void *buf, size_t count, long long offset); + void prep_preadv(int fd, const struct iovec *iov, int iovcnt, + long long offset); + void prep_pwritev(int fd, const struct iovec *iov, int iovcnt, + long long offset); + void prep_fsync(int fd); + void prep_fdsync(int fd); + +protected: + long get_res() const { return this->res; } + +private: + int fd; + int op; + void *buf; + size_t count; + long long offset; + long res; + +private: + struct list_head list; + class IOService *service; + pthread_t tid; + +public: + virtual ~IOSession() { } + friend class IOService; + friend class Communicator; +}; + +class IOService +{ +public: + int init(int maxevents); + void deinit(); + + int request(IOSession *session); + +private: + virtual void handle_stop(int error) { } + virtual void handle_unbound() = 0; + +private: + virtual int create_pipe_fd(int pipe_fd[2]) + { + return pipe(pipe_fd); + } + +private: + int maxevents; + int nevents; + +private: + void incref(); + void decref(); + +private: + int pipe_fd[2]; + int ref; + +private: + struct list_head session_list; + pthread_mutex_t mutex; + +private: + static void *io_routine(void *arg); + static void *aio_finish(void *ptr, void *context); + +private: + static ssize_t preadv_emul(int fd, const struct iovec *iov, int iovcnt, + off_t offset); + static ssize_t pwritev_emul(int fd, const struct iovec *iov, int iovcnt, + off_t offset); + ssize_t (*preadv)(int, const struct iovec *, int, off_t); + ssize_t (*pwritev)(int, const struct iovec *, int, off_t); + +public: + virtual ~IOService() { } + friend class Communicator; +}; + +#endif + diff --git a/src/kernel/SleepRequest.h b/src/kernel/SleepRequest.h new file mode 100644 index 0000000..7f1f285 --- /dev/null +++ b/src/kernel/SleepRequest.h @@ -0,0 +1,65 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _SLEEPREQUEST_H_ +#define _SLEEPREQUEST_H_ + +#include +#include "SubTask.h" +#include "Communicator.h" +#include "CommScheduler.h" + +class SleepRequest : public SubTask, public SleepSession +{ +public: + SleepRequest(CommScheduler *scheduler) + { + this->scheduler = scheduler; + } + +public: + virtual void dispatch() + { + if (this->scheduler->sleep(this) < 0) + this->handle(SS_STATE_ERROR, errno); + } + +protected: + int cancel() + { + return this->scheduler->unsleep(this); + } + +protected: + int state; + int error; + +protected: + CommScheduler *scheduler; + +protected: + virtual void handle(int state, int error) + { + this->state = state; + this->error = error; + this->subtask_done(); + } +}; + +#endif + diff --git a/src/kernel/SubTask.cc b/src/kernel/SubTask.cc new file mode 100644 index 0000000..0afdf16 --- /dev/null +++ b/src/kernel/SubTask.cc @@ -0,0 +1,65 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include "SubTask.h" + +void SubTask::subtask_done() +{ + SubTask *cur = this; + ParallelTask *parent; + + while (1) + { + parent = cur->parent; + cur = cur->done(); + if (cur) + { + cur->parent = parent; + cur->dispatch(); + } + else if (parent) + { + if (__sync_sub_and_fetch(&parent->nleft, 1) == 0) + { + cur = parent; + continue; + } + } + + break; + } +} + +void ParallelTask::dispatch() +{ + SubTask **end = this->subtasks + this->subtasks_nr; + SubTask **p = this->subtasks; + + this->nleft = this->subtasks_nr; + if (this->nleft != 0) + { + do + { + (*p)->parent = this; + (*p)->dispatch(); + } while (++p != end); + } + else + this->subtask_done(); +} + diff --git a/src/kernel/SubTask.h b/src/kernel/SubTask.h new file mode 100644 index 0000000..b06a164 --- /dev/null +++ b/src/kernel/SubTask.h @@ -0,0 +1,80 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _SUBTASK_H_ +#define _SUBTASK_H_ + +#include + +class ParallelTask; + +class SubTask +{ +public: + virtual void dispatch() = 0; + +private: + virtual SubTask *done() = 0; + +protected: + void subtask_done(); + +public: + void *get_pointer() const { return this->pointer; } + void set_pointer(void *pointer) { this->pointer = pointer; } + +private: + ParallelTask *parent; + void *pointer; + +public: + SubTask() + { + this->parent = NULL; + this->pointer = NULL; + } + + virtual ~SubTask() { } + friend class ParallelTask; +}; + +class ParallelTask : public SubTask +{ +public: + virtual void dispatch(); + +protected: + SubTask **subtasks; + size_t subtasks_nr; + +private: + size_t nleft; + +public: + ParallelTask(SubTask **subtasks, size_t n) + { + this->subtasks = subtasks; + this->subtasks_nr = n; + } + + virtual ~ParallelTask() { } + friend class SubTask; +}; + +#endif + diff --git a/src/kernel/list.h b/src/kernel/list.h new file mode 100644 index 0000000..4054e5e --- /dev/null +++ b/src/kernel/list.h @@ -0,0 +1,331 @@ +#ifndef _LINUX_LIST_H +#define _LINUX_LIST_H + +/* + * Circular doubly linked list implementation. + * + * Some of the internal functions ("__xxx") are useful when + * manipulating whole lists rather than single entries, as + * sometimes we already know the next/prev entries and we can + * generate better code by using them directly rather than + * using the generic single-entry routines. + */ + +struct list_head { + struct list_head *next, *prev; +}; + +#define LIST_HEAD_INIT(name) { &(name), &(name) } + +#define LIST_HEAD(name) \ + struct list_head name = LIST_HEAD_INIT(name) + +/** + * INIT_LIST_HEAD - Initialize a list_head structure + * @list: list_head structure to be initialized. + * + * Initializes the list_head to point to itself. If it is a list header, + * the result is an empty list. + */ +static inline void INIT_LIST_HEAD(struct list_head *list) +{ + list->next = list; + list->prev = list; +} + +/* + * Insert a new entry between two known consecutive entries. + * + * This is only for internal list manipulation where we know + * the prev/next entries already! + */ +static inline void __list_add(struct list_head *entry, + struct list_head *prev, + struct list_head *next) +{ + next->prev = entry; + entry->next = next; + entry->prev = prev; + prev->next = entry; +} + +/** + * list_add - add a new entry + * @entry: new entry to be added + * @head: list head to add it after + * + * Insert a new entry after the specified head. + * This is good for implementing stacks. + */ +static inline void list_add(struct list_head *entry, struct list_head *head) +{ + __list_add(entry, head, head->next); +} + +/** + * list_add_tail - add a new entry + * @entry: new entry to be added + * @head: list head to add it before + * + * Insert a new entry before the specified head. + * This is useful for implementing queues. + */ +static inline void list_add_tail(struct list_head *entry, + struct list_head *head) +{ + __list_add(entry, head->prev, head); +} + +/* + * Delete a list entry by making the prev/next entries + * point to each other. + * + * This is only for internal list manipulation where we know + * the prev/next entries already! + */ +static inline void __list_del(struct list_head *prev, struct list_head *next) +{ + next->prev = prev; + prev->next = next; +} + +/** + * list_del - deletes entry from list. + * @entry: the element to delete from the list. + * Note: list_empty() on entry does not return true after this, the entry is + * in an undefined state. + */ +static inline void list_del(struct list_head *entry) +{ + __list_del(entry->prev, entry->next); +} + +/** + * list_move - delete from one list and add as another's head + * @entry: the entry to move + * @head: the head that will precede our entry + */ +static inline void list_move(struct list_head *entry, struct list_head *head) +{ + __list_del(entry->prev, entry->next); + list_add(entry, head); +} + +/** + * list_move_tail - delete from one list and add as another's tail + * @entry: the entry to move + * @head: the head that will follow our entry + */ +static inline void list_move_tail(struct list_head *entry, + struct list_head *head) +{ + __list_del(entry->prev, entry->next); + list_add_tail(entry, head); +} + +/** + * list_empty - tests whether a list is empty + * @head: the list to test. + */ +static inline int list_empty(const struct list_head *head) +{ + return head->next == head; +} + +static inline void __list_splice(const struct list_head *list, + struct list_head *prev, + struct list_head *next) +{ + struct list_head *first = list->next; + struct list_head *last = list->prev; + + first->prev = prev; + prev->next = first; + + last->next = next; + next->prev = last; +} + +/** + * list_splice - join two lists + * @list: the new list to add. + * @head: the place to add it in the first list. + */ +static inline void list_splice(const struct list_head *list, + struct list_head *head) +{ + if (!list_empty(list)) + __list_splice(list, head, head->next); +} + +/** + * list_splice_init - join two lists and reinitialise the emptied list. + * @list: the new list to add. + * @head: the place to add it in the first list. + * + * The list at @list is reinitialised + */ +static inline void list_splice_init(struct list_head *list, + struct list_head *head) +{ + if (!list_empty(list)) { + __list_splice(list, head, head->next); + INIT_LIST_HEAD(list); + } +} + +/** + * list_entry - get the struct for this entry + * @ptr: the &struct list_head pointer. + * @type: the type of the struct this is embedded in. + * @member: the name of the list_struct within the struct. + */ +#define list_entry(ptr, type, member) \ + ((type *)((char *)(ptr)-(unsigned long)(&((type *)0)->member))) + +/** + * list_for_each - iterate over a list + * @pos: the &struct list_head to use as a loop counter. + * @head: the head for your list. + */ +#define list_for_each(pos, head) \ + for (pos = (head)->next; pos != (head); pos = pos->next) + +/** + * list_for_each_prev - iterate over a list backwards + * @pos: the &struct list_head to use as a loop counter. + * @head: the head for your list. + */ +#define list_for_each_prev(pos, head) \ + for (pos = (head)->prev; pos != (head); pos = pos->prev) + +/** + * list_for_each_safe - iterate over a list safe against removal of list entry + * @pos: the &struct list_head to use as a loop counter. + * @n: another &struct list_head to use as temporary storage + * @head: the head for your list. + */ +#define list_for_each_safe(pos, n, head) \ + for (pos = (head)->next, n = pos->next; pos != (head); \ + pos = n, n = pos->next) + +/** + * list_for_each_entry - iterate over list of given type + * @pos: the type * to use as a loop counter. + * @head: the head for your list. + * @member: the name of the list_struct within the struct. + */ +#define list_for_each_entry(pos, head, member) \ + for (pos = list_entry((head)->next, typeof (*pos), member); \ + &pos->member != (head); \ + pos = list_entry(pos->member.next, typeof (*pos), member)) + +/* + * Singly linked list implementation. + */ + +struct slist_node { + struct slist_node *next; +}; + +struct slist_head { + struct slist_node first, *last; +}; + +#define SLIST_HEAD_INIT(name) { { (struct slist_node *)0 }, &(name).first } + +#define SLIST_HEAD(name) \ + struct slist_head name = SLIST_HEAD_INIT(name) + +static inline void INIT_SLIST_HEAD(struct slist_head *list) +{ + list->first.next = (struct slist_node *)0; + list->last = &list->first; +} + +static inline void slist_add_after(struct slist_node *entry, + struct slist_node *prev, + struct slist_head *list) +{ + entry->next = prev->next; + prev->next = entry; + if (!entry->next) + list->last = entry; +} + +static inline void slist_add_head(struct slist_node *entry, + struct slist_head *list) +{ + slist_add_after(entry, &list->first, list); +} + +static inline void slist_add_tail(struct slist_node *entry, + struct slist_head *list) +{ + entry->next = (struct slist_node *)0; + list->last->next = entry; + list->last = entry; +} + +static inline void slist_del_after(struct slist_node *prev, + struct slist_head *list) +{ + prev->next = prev->next->next; + if (!prev->next) + list->last = prev; +} + +static inline void slist_del_head(struct slist_head *list) +{ + slist_del_after(&list->first, list); +} + +static inline int slist_empty(const struct slist_head *list) +{ + return !list->first.next; +} + +static inline void __slist_splice(const struct slist_head *list, + struct slist_node *prev, + struct slist_head *head) +{ + list->last->next = prev->next; + prev->next = list->first.next; + if (!list->last->next) + head->last = list->last; +} + +static inline void slist_splice(const struct slist_head *list, + struct slist_node *prev, + struct slist_head *head) +{ + if (!slist_empty(list)) + __slist_splice(list, prev, head); +} + +static inline void slist_splice_init(struct slist_head *list, + struct slist_node *prev, + struct slist_head *head) +{ + if (!slist_empty(list)) { + __slist_splice(list, prev, head); + INIT_SLIST_HEAD(list); + } +} + +#define slist_entry(ptr, type, member) \ + ((type *)((char *)(ptr)-(unsigned long)(&((type *)0)->member))) + +#define slist_for_each(pos, head) \ + for (pos = (head)->first.next; pos; pos = pos->next) + +#define slist_for_each_safe(pos, prev, head) \ + for (prev = &(head)->first, pos = prev->next; pos; \ + prev = prev->next == pos ? pos : prev, pos = prev->next) + +#define slist_for_each_entry(pos, head, member) \ + for (pos = slist_entry((head)->first.next, typeof (*pos), member); \ + &pos->member != (struct slist_node *)0; \ + pos = slist_entry(pos->member.next, typeof (*pos), member)) + +#endif diff --git a/src/kernel/mpoller.c b/src/kernel/mpoller.c new file mode 100644 index 0000000..4c97190 --- /dev/null +++ b/src/kernel/mpoller.c @@ -0,0 +1,116 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include "poller.h" +#include "mpoller.h" + +extern poller_t *__poller_create(void **, const struct poller_params *); +extern void __poller_destroy(poller_t *); + +static int __mpoller_create(const struct poller_params *params, + mpoller_t *mpoller) +{ + void **nodes_buf = (void **)calloc(params->max_open_files, sizeof (void *)); + unsigned int i; + + if (nodes_buf) + { + for (i = 0; i < mpoller->nthreads; i++) + { + mpoller->poller[i] = __poller_create(nodes_buf, params); + if (!mpoller->poller[i]) + break; + } + + if (i == mpoller->nthreads) + { + mpoller->nodes_buf = nodes_buf; + return 0; + } + + while (i > 0) + __poller_destroy(mpoller->poller[--i]); + + free(nodes_buf); + } + + return -1; +} + +mpoller_t *mpoller_create(const struct poller_params *params, size_t nthreads) +{ + mpoller_t *mpoller; + size_t size; + + if (nthreads == 0) + nthreads = 1; + + size = offsetof(mpoller_t, poller) + nthreads * sizeof (void *); + mpoller = (mpoller_t *)malloc(size); + if (mpoller) + { + mpoller->nthreads = (unsigned int)nthreads; + if (__mpoller_create(params, mpoller) >= 0) + return mpoller; + + free(mpoller); + } + + return NULL; +} + +int mpoller_start(mpoller_t *mpoller) +{ + size_t i; + + for (i = 0; i < mpoller->nthreads; i++) + { + if (poller_start(mpoller->poller[i]) < 0) + break; + } + + if (i == mpoller->nthreads) + return 0; + + while (i > 0) + poller_stop(mpoller->poller[--i]); + + return -1; +} + +void mpoller_stop(mpoller_t *mpoller) +{ + size_t i; + + for (i = 0; i < mpoller->nthreads; i++) + poller_stop(mpoller->poller[i]); +} + +void mpoller_destroy(mpoller_t *mpoller) +{ + size_t i; + + for (i = 0; i < mpoller->nthreads; i++) + __poller_destroy(mpoller->poller[i]); + + free(mpoller->nodes_buf); + free(mpoller); +} + diff --git a/src/kernel/mpoller.h b/src/kernel/mpoller.h new file mode 100644 index 0000000..9013bc7 --- /dev/null +++ b/src/kernel/mpoller.h @@ -0,0 +1,89 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _MPOLLER_H_ +#define _MPOLLER_H_ + +#include +#include "poller.h" + +typedef struct __mpoller mpoller_t; + +#ifdef __cplusplus +extern "C" +{ +#endif + +mpoller_t *mpoller_create(const struct poller_params *params, size_t nthreads); +int mpoller_start(mpoller_t *mpoller); +void mpoller_stop(mpoller_t *mpoller); +void mpoller_destroy(mpoller_t *mpoller); + +#ifdef __cplusplus +} +#endif + +struct __mpoller +{ + void **nodes_buf; + unsigned int nthreads; + poller_t *poller[1]; +}; + +static inline int mpoller_add(const struct poller_data *data, int timeout, + mpoller_t *mpoller) +{ + int index = (unsigned int)data->fd % mpoller->nthreads; + return poller_add(data, timeout, mpoller->poller[index]); +} + +static inline int mpoller_del(int fd, mpoller_t *mpoller) +{ + int index = (unsigned int)fd % mpoller->nthreads; + return poller_del(fd, mpoller->poller[index]); +} + +static inline int mpoller_mod(const struct poller_data *data, int timeout, + mpoller_t *mpoller) +{ + int index = (unsigned int)data->fd % mpoller->nthreads; + return poller_mod(data, timeout, mpoller->poller[index]); +} + +static inline int mpoller_set_timeout(int fd, int timeout, mpoller_t *mpoller) +{ + int index = (unsigned int)fd % mpoller->nthreads; + return poller_set_timeout(fd, timeout, mpoller->poller[index]); +} + +static inline int mpoller_add_timer(const struct timespec *value, void *context, + void **timer, int *index, + mpoller_t *mpoller) +{ + static unsigned int n = 0; + *index = n++ % mpoller->nthreads; + return poller_add_timer(value, context, timer, mpoller->poller[*index]); +} + +static inline int mpoller_del_timer(void *timer, int index, mpoller_t *mpoller) +{ + return poller_del_timer(timer, mpoller->poller[index]); +} + +#endif + diff --git a/src/kernel/msgqueue.c b/src/kernel/msgqueue.c new file mode 100644 index 0000000..1d7e94b --- /dev/null +++ b/src/kernel/msgqueue.c @@ -0,0 +1,203 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +/* + * This message queue originates from the project of Sogou C++ Workflow: + * https://github.com/sogou/workflow + * + * The idea of this implementation is quite simple and obvious. When the + * get_list is not empty, the consumer takes a message. Otherwise the consumer + * waits till put_list is not empty, and swap two lists. This method performs + * well when the queue is very busy, and the number of consumers is big. + */ + +#include +#include +#include +#include "msgqueue.h" + +struct __msgqueue +{ + size_t msg_max; + size_t msg_cnt; + int linkoff; + int nonblock; + void *head1; + void *head2; + void **get_head; + void **put_head; + void **put_tail; + pthread_mutex_t get_mutex; + pthread_mutex_t put_mutex; + pthread_cond_t get_cond; + pthread_cond_t put_cond; +}; + +void msgqueue_set_nonblock(msgqueue_t *queue) +{ + queue->nonblock = 1; + pthread_mutex_lock(&queue->put_mutex); + pthread_cond_signal(&queue->get_cond); + pthread_cond_broadcast(&queue->put_cond); + pthread_mutex_unlock(&queue->put_mutex); +} + +void msgqueue_set_block(msgqueue_t *queue) +{ + queue->nonblock = 0; +} + +void msgqueue_put(void *msg, msgqueue_t *queue) +{ + void **link = (void **)((char *)msg + queue->linkoff); + + *link = NULL; + pthread_mutex_lock(&queue->put_mutex); + while (queue->msg_cnt > queue->msg_max - 1 && !queue->nonblock) + pthread_cond_wait(&queue->put_cond, &queue->put_mutex); + + *queue->put_tail = link; + queue->put_tail = link; + queue->msg_cnt++; + pthread_mutex_unlock(&queue->put_mutex); + pthread_cond_signal(&queue->get_cond); +} + +void msgqueue_put_head(void *msg, msgqueue_t *queue) +{ + void **link = (void **)((char *)msg + queue->linkoff); + + pthread_mutex_lock(&queue->put_mutex); + while (*queue->get_head) + { + if (pthread_mutex_trylock(&queue->get_mutex) == 0) + { + pthread_mutex_unlock(&queue->put_mutex); + *link = *queue->get_head; + *queue->get_head = link; + pthread_mutex_unlock(&queue->get_mutex); + return; + } + } + + while (queue->msg_cnt > queue->msg_max - 1 && !queue->nonblock) + pthread_cond_wait(&queue->put_cond, &queue->put_mutex); + + *link = *queue->put_head; + if (*link == NULL) + queue->put_tail = link; + + *queue->put_head = link; + queue->msg_cnt++; + pthread_mutex_unlock(&queue->put_mutex); + pthread_cond_signal(&queue->get_cond); +} + +static size_t __msgqueue_swap(msgqueue_t *queue) +{ + void **get_head = queue->get_head; + size_t cnt; + + pthread_mutex_lock(&queue->put_mutex); + while (queue->msg_cnt == 0 && !queue->nonblock) + pthread_cond_wait(&queue->get_cond, &queue->put_mutex); + + cnt = queue->msg_cnt; + if (cnt > queue->msg_max - 1) + pthread_cond_broadcast(&queue->put_cond); + + queue->get_head = queue->put_head; + queue->put_head = get_head; + queue->put_tail = get_head; + queue->msg_cnt = 0; + pthread_mutex_unlock(&queue->put_mutex); + return cnt; +} + +void *msgqueue_get(msgqueue_t *queue) +{ + void *msg; + + pthread_mutex_lock(&queue->get_mutex); + if (*queue->get_head || __msgqueue_swap(queue) > 0) + { + msg = (char *)*queue->get_head - queue->linkoff; + *queue->get_head = *(void **)*queue->get_head; + } + else + msg = NULL; + + pthread_mutex_unlock(&queue->get_mutex); + return msg; +} + +msgqueue_t *msgqueue_create(size_t maxlen, int linkoff) +{ + msgqueue_t *queue = (msgqueue_t *)malloc(sizeof (msgqueue_t)); + int ret; + + if (!queue) + return NULL; + + ret = pthread_mutex_init(&queue->get_mutex, NULL); + if (ret == 0) + { + ret = pthread_mutex_init(&queue->put_mutex, NULL); + if (ret == 0) + { + ret = pthread_cond_init(&queue->get_cond, NULL); + if (ret == 0) + { + ret = pthread_cond_init(&queue->put_cond, NULL); + if (ret == 0) + { + queue->msg_max = maxlen; + queue->linkoff = linkoff; + queue->head1 = NULL; + queue->head2 = NULL; + queue->get_head = &queue->head1; + queue->put_head = &queue->head2; + queue->put_tail = &queue->head2; + queue->msg_cnt = 0; + queue->nonblock = 0; + return queue; + } + + pthread_cond_destroy(&queue->get_cond); + } + + pthread_mutex_destroy(&queue->put_mutex); + } + + pthread_mutex_destroy(&queue->get_mutex); + } + + errno = ret; + free(queue); + return NULL; +} + +void msgqueue_destroy(msgqueue_t *queue) +{ + pthread_cond_destroy(&queue->put_cond); + pthread_cond_destroy(&queue->get_cond); + pthread_mutex_destroy(&queue->put_mutex); + pthread_mutex_destroy(&queue->get_mutex); + free(queue); +} + diff --git a/src/kernel/msgqueue.h b/src/kernel/msgqueue.h new file mode 100644 index 0000000..3e5c1b0 --- /dev/null +++ b/src/kernel/msgqueue.h @@ -0,0 +1,50 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _MSGQUEUE_H_ +#define _MSGQUEUE_H_ + +#include + +typedef struct __msgqueue msgqueue_t; + +#ifdef __cplusplus +extern "C" +{ +#endif + +/* A simple implementation of message queue. The max pending messages may + * reach two times 'maxlen' when the queue is in blocking mode, and infinite + * in nonblocking mode. 'linkoff' is the offset from the head of each message, + * where spaces of one pointer size should be available for internal usage. + * 'linkoff' can be positive or negative or zero. */ + +msgqueue_t *msgqueue_create(size_t maxlen, int linkoff); +void *msgqueue_get(msgqueue_t *queue); +void msgqueue_put(void *msg, msgqueue_t *queue); +void msgqueue_put_head(void *msg, msgqueue_t *queue); +void msgqueue_set_nonblock(msgqueue_t *queue); +void msgqueue_set_block(msgqueue_t *queue); +void msgqueue_destroy(msgqueue_t *queue); + +#ifdef __cplusplus +} +#endif + +#endif + diff --git a/src/kernel/poller.c b/src/kernel/poller.c new file mode 100644 index 0000000..342a5f1 --- /dev/null +++ b/src/kernel/poller.c @@ -0,0 +1,1653 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#ifdef __linux__ +# include +# include +#else +# include +# undef LIST_HEAD +# undef SLIST_HEAD +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include "list.h" +#include "rbtree.h" +#include "poller.h" + +#define POLLER_BUFSIZE (256 * 1024) +#define POLLER_EVENTS_MAX 256 + +struct __poller_node +{ + int state; + int error; + struct poller_data data; +#pragma pack(1) + union + { + struct list_head list; + struct rb_node rb; + }; +#pragma pack() + char in_rbtree; + char removed; + int event; + struct timespec timeout; + struct __poller_node *res; +}; + +struct __poller +{ + size_t max_open_files; + void (*callback)(struct poller_result *, void *); + void *context; + + pthread_t tid; + int pfd; + int timerfd; + int pipe_rd; + int pipe_wr; + int stopped; + struct rb_root timeo_tree; + struct rb_node *tree_first; + struct rb_node *tree_last; + struct list_head timeo_list; + struct list_head no_timeo_list; + struct __poller_node **nodes; + pthread_mutex_t mutex; + char buf[POLLER_BUFSIZE]; +}; + +#ifdef __linux__ + +static inline int __poller_create_pfd() +{ + return epoll_create(1); +} + +static inline int __poller_add_fd(int fd, int event, void *data, + poller_t *poller) +{ + struct epoll_event ev = { + .events = event, + .data = { + .ptr = data + } + }; + return epoll_ctl(poller->pfd, EPOLL_CTL_ADD, fd, &ev); +} + +static inline int __poller_del_fd(int fd, int event, poller_t *poller) +{ + return epoll_ctl(poller->pfd, EPOLL_CTL_DEL, fd, NULL); +} + +static inline int __poller_mod_fd(int fd, int old_event, + int new_event, void *data, + poller_t *poller) +{ + struct epoll_event ev = { + .events = new_event, + .data = { + .ptr = data + } + }; + return epoll_ctl(poller->pfd, EPOLL_CTL_MOD, fd, &ev); +} + +static inline int __poller_create_timerfd() +{ + return timerfd_create(CLOCK_MONOTONIC, 0); +} + +static inline int __poller_close_timerfd(int fd) +{ + return close(fd); +} + +static inline int __poller_add_timerfd(int fd, poller_t *poller) +{ + struct epoll_event ev = { + .events = EPOLLIN | EPOLLET, + .data = { + .ptr = NULL + } + }; + return epoll_ctl(poller->pfd, EPOLL_CTL_ADD, fd, &ev); +} + +static inline int __poller_set_timerfd(int fd, const struct timespec *abstime, + poller_t *poller) +{ + struct itimerspec timer = { + .it_interval = { }, + .it_value = *abstime + }; + return timerfd_settime(fd, TFD_TIMER_ABSTIME, &timer, NULL); +} + +typedef struct epoll_event __poller_event_t; + +static inline int __poller_wait(__poller_event_t *events, int maxevents, + poller_t *poller) +{ + return epoll_wait(poller->pfd, events, maxevents, -1); +} + +static inline void *__poller_event_data(const __poller_event_t *event) +{ + return event->data.ptr; +} + +#else /* BSD, macOS */ + +static inline int __poller_create_pfd() +{ + return kqueue(); +} + +static inline int __poller_add_fd(int fd, int event, void *data, + poller_t *poller) +{ + struct kevent ev; + EV_SET(&ev, fd, event, EV_ADD, 0, 0, data); + return kevent(poller->pfd, &ev, 1, NULL, 0, NULL); +} + +static inline int __poller_del_fd(int fd, int event, poller_t *poller) +{ + struct kevent ev; + EV_SET(&ev, fd, event, EV_DELETE, 0, 0, NULL); + return kevent(poller->pfd, &ev, 1, NULL, 0, NULL); +} + +static inline int __poller_mod_fd(int fd, int old_event, + int new_event, void *data, + poller_t *poller) +{ + struct kevent ev[2]; + EV_SET(&ev[0], fd, old_event, EV_DELETE, 0, 0, NULL); + EV_SET(&ev[1], fd, new_event, EV_ADD, 0, 0, data); + return kevent(poller->pfd, ev, 2, NULL, 0, NULL); +} + +static inline int __poller_create_timerfd() +{ + return 0; +} + +static inline int __poller_close_timerfd(int fd) +{ + return 0; +} + +static inline int __poller_add_timerfd(int fd, poller_t *poller) +{ + return 0; +} + +static int __poller_set_timerfd(int fd, const struct timespec *abstime, + poller_t *poller) +{ + struct timespec curtime; + long long nseconds; + struct kevent ev; + int flags; + + if (abstime->tv_sec || abstime->tv_nsec) + { + clock_gettime(CLOCK_MONOTONIC, &curtime); + nseconds = 1000000000LL * (abstime->tv_sec - curtime.tv_sec); + nseconds += abstime->tv_nsec - curtime.tv_nsec; + flags = EV_ADD; + } + else + { + nseconds = 0; + flags = EV_DELETE; + } + + EV_SET(&ev, fd, EVFILT_TIMER, flags, NOTE_NSECONDS, nseconds, NULL); + return kevent(poller->pfd, &ev, 1, NULL, 0, NULL); +} + +typedef struct kevent __poller_event_t; + +static inline int __poller_wait(__poller_event_t *events, int maxevents, + poller_t *poller) +{ + return kevent(poller->pfd, NULL, 0, events, maxevents, NULL); +} + +static inline void *__poller_event_data(const __poller_event_t *event) +{ + return event->udata; +} + +#define EPOLLIN EVFILT_READ +#define EPOLLOUT EVFILT_WRITE +#define EPOLLET 0 + +#endif + +static inline long __timeout_cmp(const struct __poller_node *node1, + const struct __poller_node *node2) +{ + long ret = node1->timeout.tv_sec - node2->timeout.tv_sec; + + if (ret == 0) + ret = node1->timeout.tv_nsec - node2->timeout.tv_nsec; + + return ret; +} + +static void __poller_tree_insert(struct __poller_node *node, poller_t *poller) +{ + struct rb_node **p = &poller->timeo_tree.rb_node; + struct rb_node *parent = NULL; + struct __poller_node *entry; + + entry = rb_entry(poller->tree_last, struct __poller_node, rb); + if (!*p) + { + poller->tree_first = &node->rb; + poller->tree_last = &node->rb; + } + else if (__timeout_cmp(node, entry) >= 0) + { + parent = poller->tree_last; + p = &parent->rb_right; + poller->tree_last = &node->rb; + } + else + { + do + { + parent = *p; + entry = rb_entry(*p, struct __poller_node, rb); + if (__timeout_cmp(node, entry) < 0) + p = &(*p)->rb_left; + else + p = &(*p)->rb_right; + } while (*p); + + if (p == &poller->tree_first->rb_left) + poller->tree_first = &node->rb; + } + + node->in_rbtree = 1; + rb_link_node(&node->rb, parent, p); + rb_insert_color(&node->rb, &poller->timeo_tree); +} + +static inline void __poller_tree_erase(struct __poller_node *node, + poller_t *poller) +{ + if (&node->rb == poller->tree_first) + poller->tree_first = rb_next(&node->rb); + + if (&node->rb == poller->tree_last) + poller->tree_last = rb_prev(&node->rb); + + rb_erase(&node->rb, &poller->timeo_tree); + node->in_rbtree = 0; +} + +static int __poller_remove_node(struct __poller_node *node, poller_t *poller) +{ + int removed; + + pthread_mutex_lock(&poller->mutex); + removed = node->removed; + if (!removed) + { + poller->nodes[node->data.fd] = NULL; + + if (node->in_rbtree) + __poller_tree_erase(node, poller); + else + list_del(&node->list); + + __poller_del_fd(node->data.fd, node->event, poller); + } + + pthread_mutex_unlock(&poller->mutex); + return removed; +} + +static int __poller_append_message(const void *buf, size_t *n, + struct __poller_node *node, + poller_t *poller) +{ + poller_message_t *msg = node->data.message; + struct __poller_node *res; + int ret; + + if (!msg) + { + res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); + if (!res) + return -1; + + msg = node->data.create_message(node->data.context); + if (!msg) + { + free(res); + return -1; + } + + node->data.message = msg; + node->res = res; + } + else + res = node->res; + + ret = msg->append(buf, n, msg); + if (ret > 0) + { + res->data = node->data; + res->error = 0; + res->state = PR_ST_SUCCESS; + poller->callback((struct poller_result *)res, poller->context); + + node->data.message = NULL; + node->res = NULL; + } + + return ret; +} + +static int __poller_handle_ssl_error(struct __poller_node *node, int ret, + poller_t *poller) +{ + int error = SSL_get_error(node->data.ssl, ret); + int event; + + switch (error) + { + case SSL_ERROR_WANT_READ: + event = EPOLLIN | EPOLLET; + break; + case SSL_ERROR_WANT_WRITE: + event = EPOLLOUT | EPOLLET; + break; + default: + errno = -error; + case SSL_ERROR_SYSCALL: + return -1; + } + + if (event == node->event) + return 0; + + pthread_mutex_lock(&poller->mutex); + if (!node->removed) + { + ret = __poller_mod_fd(node->data.fd, node->event, event, node, poller); + if (ret >= 0) + node->event = event; + } + else + ret = 0; + + pthread_mutex_unlock(&poller->mutex); + return ret; +} + +static void __poller_handle_read(struct __poller_node *node, + poller_t *poller) +{ + ssize_t nleft; + size_t n; + char *p; + + while (1) + { + p = poller->buf; + if (!node->data.ssl) + { + nleft = read(node->data.fd, p, POLLER_BUFSIZE); + if (nleft < 0) + { + if (errno == EAGAIN) + return; + } + } + else + { + nleft = SSL_read(node->data.ssl, p, POLLER_BUFSIZE); + if (nleft < 0) + { + if (__poller_handle_ssl_error(node, nleft, poller) >= 0) + return; + } + } + + if (nleft <= 0) + break; + + do + { + n = nleft; + if (__poller_append_message(p, &n, node, poller) >= 0) + { + nleft -= n; + p += n; + } + else + nleft = -1; + } while (nleft > 0); + + if (nleft < 0) + break; + } + + if (__poller_remove_node(node, poller)) + return; + + if (nleft == 0) + { + node->error = 0; + node->state = PR_ST_FINISHED; + } + else + { + node->error = errno; + node->state = PR_ST_ERROR; + } + + free(node->res); + poller->callback((struct poller_result *)node, poller->context); +} + +#ifndef IOV_MAX +# ifdef UIO_MAXIOV +# define IOV_MAX UIO_MAXIOV +# else +# define IOV_MAX 1024 +# endif +#endif + +static void __poller_handle_write(struct __poller_node *node, + poller_t *poller) +{ + struct iovec *iov = node->data.write_iov; + size_t count = 0; + ssize_t nleft; + int iovcnt; + int ret; + + while (node->data.iovcnt > 0) + { + if (!node->data.ssl) + { + iovcnt = node->data.iovcnt; + if (iovcnt > IOV_MAX) + iovcnt = IOV_MAX; + + nleft = writev(node->data.fd, iov, iovcnt); + if (nleft < 0) + { + ret = errno == EAGAIN ? 0 : -1; + break; + } + } + else if (iov->iov_len > 0) + { + nleft = SSL_write(node->data.ssl, iov->iov_base, iov->iov_len); + if (nleft <= 0) + { + ret = __poller_handle_ssl_error(node, nleft, poller); + break; + } + } + else + nleft = 0; + + count += nleft; + do + { + if (nleft >= iov->iov_len) + { + nleft -= iov->iov_len; + iov->iov_base = (char *)iov->iov_base + iov->iov_len; + iov->iov_len = 0; + iov++; + node->data.iovcnt--; + } + else + { + iov->iov_base = (char *)iov->iov_base + nleft; + iov->iov_len -= nleft; + break; + } + } while (node->data.iovcnt > 0); + } + + node->data.write_iov = iov; + if (node->data.iovcnt > 0 && ret >= 0) + { + if (count == 0) + return; + + if (node->data.partial_written(count, node->data.context) >= 0) + return; + } + + if (__poller_remove_node(node, poller)) + return; + + if (node->data.iovcnt == 0) + { + node->error = 0; + node->state = PR_ST_FINISHED; + } + else + { + node->error = errno; + node->state = PR_ST_ERROR; + } + + poller->callback((struct poller_result *)node, poller->context); +} + +static void __poller_handle_listen(struct __poller_node *node, + poller_t *poller) +{ + struct __poller_node *res = node->res; + struct sockaddr_storage ss; + struct sockaddr *addr = (struct sockaddr *)&ss; + socklen_t addrlen; + void *result; + int sockfd; + + while (1) + { + addrlen = sizeof (struct sockaddr_storage); + sockfd = accept(node->data.fd, addr, &addrlen); + if (sockfd < 0) + { + if (errno == EAGAIN || errno == EMFILE || errno == ENFILE) + return; + else if (errno == ECONNABORTED) + continue; + else + break; + } + + result = node->data.accept(addr, addrlen, sockfd, node->data.context); + if (!result) + break; + + res->data = node->data; + res->data.result = result; + res->error = 0; + res->state = PR_ST_SUCCESS; + poller->callback((struct poller_result *)res, poller->context); + + res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); + node->res = res; + if (!res) + break; + } + + if (__poller_remove_node(node, poller)) + return; + + node->error = errno; + node->state = PR_ST_ERROR; + free(node->res); + poller->callback((struct poller_result *)node, poller->context); +} + +static void __poller_handle_connect(struct __poller_node *node, + poller_t *poller) +{ + socklen_t len = sizeof (int); + int error; + + if (getsockopt(node->data.fd, SOL_SOCKET, SO_ERROR, &error, &len) < 0) + error = errno; + + if (__poller_remove_node(node, poller)) + return; + + if (error == 0) + { + node->error = 0; + node->state = PR_ST_FINISHED; + } + else + { + node->error = error; + node->state = PR_ST_ERROR; + } + + poller->callback((struct poller_result *)node, poller->context); +} + +static void __poller_handle_recvfrom(struct __poller_node *node, + poller_t *poller) +{ + struct __poller_node *res = node->res; + struct sockaddr_storage ss; + struct sockaddr *addr = (struct sockaddr *)&ss; + socklen_t addrlen; + void *result; + ssize_t n; + + while (1) + { + addrlen = sizeof (struct sockaddr_storage); + n = recvfrom(node->data.fd, poller->buf, POLLER_BUFSIZE, 0, + addr, &addrlen); + if (n < 0) + { + if (errno == EAGAIN) + return; + else + break; + } + + result = node->data.recvfrom(addr, addrlen, poller->buf, n, + node->data.context); + if (!result) + break; + + res->data = node->data; + res->data.result = result; + res->error = 0; + res->state = PR_ST_SUCCESS; + poller->callback((struct poller_result *)res, poller->context); + + res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); + node->res = res; + if (!res) + break; + } + + if (__poller_remove_node(node, poller)) + return; + + node->error = errno; + node->state = PR_ST_ERROR; + free(node->res); + poller->callback((struct poller_result *)node, poller->context); +} + +static void __poller_handle_ssl_accept(struct __poller_node *node, + poller_t *poller) +{ + int ret = SSL_accept(node->data.ssl); + + if (ret <= 0) + { + if (__poller_handle_ssl_error(node, ret, poller) >= 0) + return; + } + + if (__poller_remove_node(node, poller)) + return; + + if (ret > 0) + { + node->error = 0; + node->state = PR_ST_FINISHED; + } + else + { + node->error = errno; + node->state = PR_ST_ERROR; + } + + poller->callback((struct poller_result *)node, poller->context); +} + +static void __poller_handle_ssl_connect(struct __poller_node *node, + poller_t *poller) +{ + int ret = SSL_connect(node->data.ssl); + + if (ret <= 0) + { + if (__poller_handle_ssl_error(node, ret, poller) >= 0) + return; + } + + if (__poller_remove_node(node, poller)) + return; + + if (ret > 0) + { + node->error = 0; + node->state = PR_ST_FINISHED; + } + else + { + node->error = errno; + node->state = PR_ST_ERROR; + } + + poller->callback((struct poller_result *)node, poller->context); +} + +static void __poller_handle_ssl_shutdown(struct __poller_node *node, + poller_t *poller) +{ + int ret = SSL_shutdown(node->data.ssl); + + if (ret <= 0) + { + if (__poller_handle_ssl_error(node, ret, poller) >= 0) + return; + } + + if (__poller_remove_node(node, poller)) + return; + + if (ret > 0) + { + node->error = 0; + node->state = PR_ST_FINISHED; + } + else + { + node->error = errno; + node->state = PR_ST_ERROR; + } + + poller->callback((struct poller_result *)node, poller->context); +} + +static void __poller_handle_event(struct __poller_node *node, + poller_t *poller) +{ + struct __poller_node *res = node->res; + unsigned long long cnt = 0; + unsigned long long value; + void *result; + ssize_t n; + + while (1) + { + n = read(node->data.fd, &value, sizeof (unsigned long long)); + if (n == sizeof (unsigned long long)) + cnt += value; + else + { + if (n >= 0) + errno = EINVAL; + break; + } + } + + if (errno == EAGAIN) + { + while (1) + { + if (cnt == 0) + return; + + cnt--; + result = node->data.event(node->data.context); + if (!result) + break; + + res->data = node->data; + res->data.result = result; + res->error = 0; + res->state = PR_ST_SUCCESS; + poller->callback((struct poller_result *)res, poller->context); + + res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); + node->res = res; + if (!res) + break; + } + } + + if (cnt != 0) + write(node->data.fd, &cnt, sizeof (unsigned long long)); + + if (__poller_remove_node(node, poller)) + return; + + node->error = errno; + node->state = PR_ST_ERROR; + free(node->res); + poller->callback((struct poller_result *)node, poller->context); +} + +static void __poller_handle_notify(struct __poller_node *node, + poller_t *poller) +{ + struct __poller_node *res = node->res; + void *result; + ssize_t n; + + while (1) + { + n = read(node->data.fd, &result, sizeof (void *)); + if (n == sizeof (void *)) + { + result = node->data.notify(result, node->data.context); + if (!result) + break; + + res->data = node->data; + res->data.result = result; + res->error = 0; + res->state = PR_ST_SUCCESS; + poller->callback((struct poller_result *)res, poller->context); + + res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); + node->res = res; + if (!res) + break; + } + else if (n < 0 && errno == EAGAIN) + return; + else + { + if (n > 0) + errno = EINVAL; + break; + } + } + + if (__poller_remove_node(node, poller)) + return; + + if (n == 0) + { + node->error = 0; + node->state = PR_ST_FINISHED; + } + else + { + node->error = errno; + node->state = PR_ST_ERROR; + } + + free(node->res); + poller->callback((struct poller_result *)node, poller->context); +} + +static int __poller_handle_pipe(poller_t *poller) +{ + struct __poller_node **node = (struct __poller_node **)poller->buf; + int stop = 0; + int n; + int i; + + n = read(poller->pipe_rd, node, POLLER_BUFSIZE) / sizeof (void *); + for (i = 0; i < n; i++) + { + if (node[i]) + { + free(node[i]->res); + poller->callback((struct poller_result *)node[i], poller->context); + } + else + stop = 1; + } + + return stop; +} + +static void __poller_handle_timeout(const struct __poller_node *time_node, + poller_t *poller) +{ + struct __poller_node *node; + struct list_head *pos, *tmp; + LIST_HEAD(timeo_list); + + pthread_mutex_lock(&poller->mutex); + list_for_each_safe(pos, tmp, &poller->timeo_list) + { + node = list_entry(pos, struct __poller_node, list); + if (__timeout_cmp(node, time_node) > 0) + break; + + if (node->data.fd >= 0) + { + poller->nodes[node->data.fd] = NULL; + __poller_del_fd(node->data.fd, node->event, poller); + } + else + node->removed = 1; + + list_move_tail(pos, &timeo_list); + } + + while (poller->tree_first) + { + node = rb_entry(poller->tree_first, struct __poller_node, rb); + if (__timeout_cmp(node, time_node) > 0) + break; + + if (node->data.fd >= 0) + { + poller->nodes[node->data.fd] = NULL; + __poller_del_fd(node->data.fd, node->event, poller); + } + else + node->removed = 1; + + poller->tree_first = rb_next(poller->tree_first); + rb_erase(&node->rb, &poller->timeo_tree); + list_add_tail(&node->list, &timeo_list); + if (!poller->tree_first) + poller->tree_last = NULL; + } + + pthread_mutex_unlock(&poller->mutex); + list_for_each_safe(pos, tmp, &timeo_list) + { + node = list_entry(pos, struct __poller_node, list); + if (node->data.fd >= 0) + { + node->error = ETIMEDOUT; + node->state = PR_ST_ERROR; + } + else + { + node->error = 0; + node->state = PR_ST_FINISHED; + } + + free(node->res); + poller->callback((struct poller_result *)node, poller->context); + } +} + +static void __poller_set_timer(poller_t *poller) +{ + struct __poller_node *node = NULL; + struct __poller_node *first; + struct timespec abstime; + + pthread_mutex_lock(&poller->mutex); + if (!list_empty(&poller->timeo_list)) + node = list_entry(poller->timeo_list.next, struct __poller_node, list); + + if (poller->tree_first) + { + first = rb_entry(poller->tree_first, struct __poller_node, rb); + if (!node || __timeout_cmp(first, node) < 0) + node = first; + } + + if (node) + abstime = node->timeout; + else + { + abstime.tv_sec = 0; + abstime.tv_nsec = 0; + } + + __poller_set_timerfd(poller->timerfd, &abstime, poller); + pthread_mutex_unlock(&poller->mutex); +} + +static void *__poller_thread_routine(void *arg) +{ + poller_t *poller = (poller_t *)arg; + __poller_event_t events[POLLER_EVENTS_MAX]; + struct __poller_node time_node; + struct __poller_node *node; + int has_pipe_event; + int nevents; + int i; + + while (1) + { + __poller_set_timer(poller); + nevents = __poller_wait(events, POLLER_EVENTS_MAX, poller); + clock_gettime(CLOCK_MONOTONIC, &time_node.timeout); + has_pipe_event = 0; + for (i = 0; i < nevents; i++) + { + node = (struct __poller_node *)__poller_event_data(&events[i]); + if (node <= (struct __poller_node *)1) + { + if (node == (struct __poller_node *)1) + has_pipe_event = 1; + continue; + } + + switch (node->data.operation) + { + case PD_OP_READ: + __poller_handle_read(node, poller); + break; + case PD_OP_WRITE: + __poller_handle_write(node, poller); + break; + case PD_OP_LISTEN: + __poller_handle_listen(node, poller); + break; + case PD_OP_CONNECT: + __poller_handle_connect(node, poller); + break; + case PD_OP_RECVFROM: + __poller_handle_recvfrom(node, poller); + break; + case PD_OP_SSL_ACCEPT: + __poller_handle_ssl_accept(node, poller); + break; + case PD_OP_SSL_CONNECT: + __poller_handle_ssl_connect(node, poller); + break; + case PD_OP_SSL_SHUTDOWN: + __poller_handle_ssl_shutdown(node, poller); + break; + case PD_OP_EVENT: + __poller_handle_event(node, poller); + break; + case PD_OP_NOTIFY: + __poller_handle_notify(node, poller); + break; + } + } + + if (has_pipe_event) + { + if (__poller_handle_pipe(poller)) + break; + } + + __poller_handle_timeout(&time_node, poller); + } + + return NULL; +} + +static int __poller_open_pipe(poller_t *poller) +{ + int pipefd[2]; + + if (pipe(pipefd) >= 0) + { + if (__poller_add_fd(pipefd[0], EPOLLIN, (void *)1, poller) >= 0) + { + poller->pipe_rd = pipefd[0]; + poller->pipe_wr = pipefd[1]; + return 0; + } + + close(pipefd[0]); + close(pipefd[1]); + } + + return -1; +} + +static int __poller_create_timer(poller_t *poller) +{ + int timerfd = __poller_create_timerfd(); + + if (timerfd >= 0) + { + if (__poller_add_timerfd(timerfd, poller) >= 0) + { + poller->timerfd = timerfd; + return 0; + } + + __poller_close_timerfd(timerfd); + } + + return -1; +} + +poller_t *__poller_create(void **nodes_buf, const struct poller_params *params) +{ + poller_t *poller = (poller_t *)malloc(sizeof (poller_t)); + int ret; + + if (!poller) + return NULL; + + poller->pfd = __poller_create_pfd(); + if (poller->pfd >= 0) + { + if (__poller_create_timer(poller) >= 0) + { + ret = pthread_mutex_init(&poller->mutex, NULL); + if (ret == 0) + { + poller->nodes = (struct __poller_node **)nodes_buf; + poller->max_open_files = params->max_open_files; + poller->callback = params->callback; + poller->context = params->context; + + poller->timeo_tree.rb_node = NULL; + poller->tree_first = NULL; + poller->tree_last = NULL; + INIT_LIST_HEAD(&poller->timeo_list); + INIT_LIST_HEAD(&poller->no_timeo_list); + + poller->stopped = 1; + return poller; + } + + errno = ret; + close(poller->timerfd); + } + + close(poller->pfd); + } + + free(poller); + return NULL; +} + +poller_t *poller_create(const struct poller_params *params) +{ + void **nodes_buf = (void **)calloc(params->max_open_files, sizeof (void *)); + poller_t *poller; + + if (nodes_buf) + { + poller = __poller_create(nodes_buf, params); + if (poller) + return poller; + + free(nodes_buf); + } + + return NULL; +} + +void __poller_destroy(poller_t *poller) +{ + pthread_mutex_destroy(&poller->mutex); + __poller_close_timerfd(poller->timerfd); + close(poller->pfd); + free(poller); +} + +void poller_destroy(poller_t *poller) +{ + free(poller->nodes); + __poller_destroy(poller); +} + +int poller_start(poller_t *poller) +{ + pthread_t tid; + int ret; + + pthread_mutex_lock(&poller->mutex); + if (__poller_open_pipe(poller) >= 0) + { + ret = pthread_create(&tid, NULL, __poller_thread_routine, poller); + if (ret == 0) + { + poller->tid = tid; + poller->stopped = 0; + } + else + { + errno = ret; + close(poller->pipe_wr); + close(poller->pipe_rd); + } + } + + pthread_mutex_unlock(&poller->mutex); + return -poller->stopped; +} + +static void __poller_insert_node(struct __poller_node *node, + poller_t *poller) +{ + struct __poller_node *end; + + end = list_entry(poller->timeo_list.prev, struct __poller_node, list); + if (list_empty(&poller->timeo_list)) + { + list_add(&node->list, &poller->timeo_list); + end = rb_entry(poller->tree_first, struct __poller_node, rb); + } + else if (__timeout_cmp(node, end) >= 0) + { + list_add_tail(&node->list, &poller->timeo_list); + return; + } + else + { + __poller_tree_insert(node, poller); + if (&node->rb != poller->tree_first) + return; + + end = list_entry(poller->timeo_list.next, struct __poller_node, list); + } + + if (!poller->tree_first || __timeout_cmp(node, end) < 0) + __poller_set_timerfd(poller->timerfd, &node->timeout, poller); +} + +static void __poller_node_set_timeout(int timeout, struct __poller_node *node) +{ + clock_gettime(CLOCK_MONOTONIC, &node->timeout); + node->timeout.tv_sec += timeout / 1000; + node->timeout.tv_nsec += timeout % 1000 * 1000000; + if (node->timeout.tv_nsec >= 1000000000) + { + node->timeout.tv_nsec -= 1000000000; + node->timeout.tv_sec++; + } +} + +static int __poller_data_get_event(int *event, const struct poller_data *data) +{ + switch (data->operation) + { + case PD_OP_READ: + *event = EPOLLIN | EPOLLET; + return !!data->message; + case PD_OP_WRITE: + *event = EPOLLOUT | EPOLLET; + return 0; + case PD_OP_LISTEN: + *event = EPOLLIN; + return 1; + case PD_OP_CONNECT: + *event = EPOLLOUT | EPOLLET; + return 0; + case PD_OP_RECVFROM: + *event = EPOLLIN | EPOLLET; + return 1; + case PD_OP_SSL_ACCEPT: + *event = EPOLLIN | EPOLLET; + return 0; + case PD_OP_SSL_CONNECT: + *event = EPOLLOUT | EPOLLET; + return 0; + case PD_OP_SSL_SHUTDOWN: + *event = EPOLLOUT | EPOLLET; + return 0; + case PD_OP_EVENT: + *event = EPOLLIN | EPOLLET; + return 1; + case PD_OP_NOTIFY: + *event = EPOLLIN | EPOLLET; + return 1; + default: + errno = EINVAL; + return -1; + } +} + +static struct __poller_node *__poller_new_node(const struct poller_data *data, + int timeout, poller_t *poller) +{ + struct __poller_node *res = NULL; + struct __poller_node *node; + int need_res; + int event; + + if ((size_t)data->fd >= poller->max_open_files) + { + errno = data->fd < 0 ? EBADF : EMFILE; + return NULL; + } + + need_res = __poller_data_get_event(&event, data); + if (need_res < 0) + return NULL; + + if (need_res) + { + res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); + if (!res) + return NULL; + } + + node = (struct __poller_node *)malloc(sizeof (struct __poller_node)); + if (node) + { + node->data = *data; + node->event = event; + node->in_rbtree = 0; + node->removed = 0; + node->res = res; + if (timeout >= 0) + __poller_node_set_timeout(timeout, node); + } + + return node; +} + +int poller_add(const struct poller_data *data, int timeout, poller_t *poller) +{ + struct __poller_node *node; + + node = __poller_new_node(data, timeout, poller); + if (!node) + return -1; + + pthread_mutex_lock(&poller->mutex); + if (!poller->nodes[data->fd]) + { + if (__poller_add_fd(data->fd, node->event, node, poller) >= 0) + { + if (timeout >= 0) + __poller_insert_node(node, poller); + else + list_add_tail(&node->list, &poller->no_timeo_list); + + poller->nodes[data->fd] = node; + node = NULL; + } + } + else + errno = EEXIST; + + pthread_mutex_unlock(&poller->mutex); + if (node == NULL) + return 0; + + free(node->res); + free(node); + return -1; +} + +int poller_del(int fd, poller_t *poller) +{ + struct __poller_node *node; + int stopped = 0; + + if ((size_t)fd >= poller->max_open_files) + { + errno = fd < 0 ? EBADF : EMFILE; + return -1; + } + + pthread_mutex_lock(&poller->mutex); + node = poller->nodes[fd]; + if (node) + { + poller->nodes[fd] = NULL; + + if (node->in_rbtree) + __poller_tree_erase(node, poller); + else + list_del(&node->list); + + __poller_del_fd(fd, node->event, poller); + + node->error = 0; + node->state = PR_ST_DELETED; + stopped = poller->stopped; + if (!stopped) + { + node->removed = 1; + write(poller->pipe_wr, &node, sizeof (void *)); + } + } + else + errno = ENOENT; + + pthread_mutex_unlock(&poller->mutex); + if (stopped) + { + free(node->res); + poller->callback((struct poller_result *)node, poller->context); + } + + return -!node; +} + +int poller_mod(const struct poller_data *data, int timeout, poller_t *poller) +{ + struct __poller_node *node; + struct __poller_node *orig; + int stopped = 0; + + node = __poller_new_node(data, timeout, poller); + if (!node) + return -1; + + pthread_mutex_lock(&poller->mutex); + orig = poller->nodes[data->fd]; + if (orig) + { + if (__poller_mod_fd(data->fd, orig->event, node->event, node, poller) >= 0) + { + if (orig->in_rbtree) + __poller_tree_erase(orig, poller); + else + list_del(&orig->list); + + orig->error = 0; + orig->state = PR_ST_MODIFIED; + stopped = poller->stopped; + if (!stopped) + { + orig->removed = 1; + write(poller->pipe_wr, &orig, sizeof (void *)); + } + + if (timeout >= 0) + __poller_insert_node(node, poller); + else + list_add_tail(&node->list, &poller->no_timeo_list); + + poller->nodes[data->fd] = node; + node = NULL; + } + } + else + errno = ENOENT; + + pthread_mutex_unlock(&poller->mutex); + if (stopped) + { + free(orig->res); + poller->callback((struct poller_result *)orig, poller->context); + } + + if (node == NULL) + return 0; + + free(node->res); + free(node); + return -1; +} + +int poller_set_timeout(int fd, int timeout, poller_t *poller) +{ + struct __poller_node time_node; + struct __poller_node *node; + + if ((size_t)fd >= poller->max_open_files) + { + errno = fd < 0 ? EBADF : EMFILE; + return -1; + } + + if (timeout >= 0) + __poller_node_set_timeout(timeout, &time_node); + + pthread_mutex_lock(&poller->mutex); + node = poller->nodes[fd]; + if (node) + { + if (node->in_rbtree) + __poller_tree_erase(node, poller); + else + list_del(&node->list); + + if (timeout >= 0) + { + node->timeout = time_node.timeout; + __poller_insert_node(node, poller); + } + else + list_add_tail(&node->list, &poller->no_timeo_list); + } + else + errno = ENOENT; + + pthread_mutex_unlock(&poller->mutex); + return -!node; +} + +int poller_add_timer(const struct timespec *value, void *context, void **timer, + poller_t *poller) +{ + struct __poller_node *node; + + if (value->tv_nsec < 0 || value->tv_nsec >= 1000000000) + { + errno = EINVAL; + return -1; + } + + node = (struct __poller_node *)malloc(sizeof (struct __poller_node)); + if (node) + { + memset(&node->data, 0, sizeof (struct poller_data)); + node->data.operation = PD_OP_TIMER; + node->data.fd = -1; + node->data.context = context; + node->in_rbtree = 0; + node->removed = 0; + node->res = NULL; + + if (value->tv_sec >= 0) + { + clock_gettime(CLOCK_MONOTONIC, &node->timeout); + node->timeout.tv_sec += value->tv_sec; + node->timeout.tv_nsec += value->tv_nsec; + if (node->timeout.tv_nsec >= 1000000000) + { + node->timeout.tv_nsec -= 1000000000; + node->timeout.tv_sec++; + } + } + + *timer = node; + pthread_mutex_lock(&poller->mutex); + if (value->tv_sec >= 0) + __poller_insert_node(node, poller); + else + list_add_tail(&node->list, &poller->no_timeo_list); + + pthread_mutex_unlock(&poller->mutex); + return 0; + } + + return -1; +} + +int poller_del_timer(void *timer, poller_t *poller) +{ + struct __poller_node *node = (struct __poller_node *)timer; + + pthread_mutex_lock(&poller->mutex); + if (!node->removed) + { + node->removed = 1; + + if (node->in_rbtree) + __poller_tree_erase(node, poller); + else + list_del(&node->list); + } + else + { + errno = ENOENT; + node = NULL; + } + + pthread_mutex_unlock(&poller->mutex); + if (node) + { + node->error = 0; + node->state = PR_ST_DELETED; + poller->callback((struct poller_result *)node, poller->context); + return 0; + } + + return -1; +} + +void poller_stop(poller_t *poller) +{ + struct __poller_node *node; + struct list_head *pos, *tmp; + LIST_HEAD(node_list); + void *p = NULL; + + write(poller->pipe_wr, &p, sizeof (void *)); + pthread_join(poller->tid, NULL); + poller->stopped = 1; + + pthread_mutex_lock(&poller->mutex); + close(poller->pipe_wr); + __poller_handle_pipe(poller); + close(poller->pipe_rd); + + poller->tree_first = NULL; + poller->tree_last = NULL; + while (poller->timeo_tree.rb_node) + { + node = rb_entry(poller->timeo_tree.rb_node, struct __poller_node, rb); + rb_erase(&node->rb, &poller->timeo_tree); + list_add(&node->list, &node_list); + } + + list_splice_init(&poller->timeo_list, &node_list); + list_splice_init(&poller->no_timeo_list, &node_list); + list_for_each(pos, &node_list) + { + node = list_entry(pos, struct __poller_node, list); + if (node->data.fd >= 0) + { + poller->nodes[node->data.fd] = NULL; + __poller_del_fd(node->data.fd, node->event, poller); + } + else + node->removed = 1; + } + + pthread_mutex_unlock(&poller->mutex); + list_for_each_safe(pos, tmp, &node_list) + { + node = list_entry(pos, struct __poller_node, list); + node->error = 0; + node->state = PR_ST_STOPPED; + free(node->res); + poller->callback((struct poller_result *)node, poller->context); + } +} + diff --git a/src/kernel/poller.h b/src/kernel/poller.h new file mode 100644 index 0000000..71ff70c --- /dev/null +++ b/src/kernel/poller.h @@ -0,0 +1,117 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _POLLER_H_ +#define _POLLER_H_ + +#include +#include +#include +#include + +typedef struct __poller poller_t; +typedef struct __poller_message poller_message_t; + +struct __poller_message +{ + int (*append)(const void *, size_t *, poller_message_t *); + char data[0]; +}; + +struct poller_data +{ +#define PD_OP_TIMER 0 +#define PD_OP_READ 1 +#define PD_OP_WRITE 2 +#define PD_OP_LISTEN 3 +#define PD_OP_CONNECT 4 +#define PD_OP_RECVFROM 5 +#define PD_OP_SSL_READ PD_OP_READ +#define PD_OP_SSL_WRITE PD_OP_WRITE +#define PD_OP_SSL_ACCEPT 6 +#define PD_OP_SSL_CONNECT 7 +#define PD_OP_SSL_SHUTDOWN 8 +#define PD_OP_EVENT 9 +#define PD_OP_NOTIFY 10 + short operation; + unsigned short iovcnt; + int fd; + SSL *ssl; + union + { + poller_message_t *(*create_message)(void *); + int (*partial_written)(size_t, void *); + void *(*accept)(const struct sockaddr *, socklen_t, int, void *); + void *(*recvfrom)(const struct sockaddr *, socklen_t, + const void *, size_t, void *); + void *(*event)(void *); + void *(*notify)(void *, void *); + }; + void *context; + union + { + poller_message_t *message; + struct iovec *write_iov; + void *result; + }; +}; + +struct poller_result +{ +#define PR_ST_SUCCESS 0 +#define PR_ST_FINISHED 1 +#define PR_ST_ERROR 2 +#define PR_ST_DELETED 3 +#define PR_ST_MODIFIED 4 +#define PR_ST_STOPPED 5 + int state; + int error; + struct poller_data data; + /* In callback, spaces of six pointers are available from here. */ +}; + +struct poller_params +{ + size_t max_open_files; + void (*callback)(struct poller_result *, void *); + void *context; +}; + +#ifdef __cplusplus +extern "C" +{ +#endif + +poller_t *poller_create(const struct poller_params *params); +int poller_start(poller_t *poller); +int poller_add(const struct poller_data *data, int timeout, poller_t *poller); +int poller_del(int fd, poller_t *poller); +int poller_mod(const struct poller_data *data, int timeout, poller_t *poller); +int poller_set_timeout(int fd, int timeout, poller_t *poller); +int poller_add_timer(const struct timespec *value, void *context, void **timer, + poller_t *poller); +int poller_del_timer(void *timer, poller_t *poller); +void poller_stop(poller_t *poller); +void poller_destroy(poller_t *poller); + +#ifdef __cplusplus +} +#endif + +#endif + diff --git a/src/kernel/rbtree.c b/src/kernel/rbtree.c new file mode 100644 index 0000000..5b6ccb8 --- /dev/null +++ b/src/kernel/rbtree.c @@ -0,0 +1,387 @@ +/* + Red Black Trees + (C) 1999 Andrea Arcangeli + (C) 2002 David Woodhouse + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program; if not, write to the Free Software + Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + + linux/lib/rbtree.c +*/ + +#include "rbtree.h" + +static void __rb_rotate_left(struct rb_node *node, struct rb_root *root) +{ + struct rb_node *right = node->rb_right; + + if ((node->rb_right = right->rb_left)) + right->rb_left->rb_parent = node; + right->rb_left = node; + + if ((right->rb_parent = node->rb_parent)) + { + if (node == node->rb_parent->rb_left) + node->rb_parent->rb_left = right; + else + node->rb_parent->rb_right = right; + } + else + root->rb_node = right; + node->rb_parent = right; +} + +static void __rb_rotate_right(struct rb_node *node, struct rb_root *root) +{ + struct rb_node *left = node->rb_left; + + if ((node->rb_left = left->rb_right)) + left->rb_right->rb_parent = node; + left->rb_right = node; + + if ((left->rb_parent = node->rb_parent)) + { + if (node == node->rb_parent->rb_right) + node->rb_parent->rb_right = left; + else + node->rb_parent->rb_left = left; + } + else + root->rb_node = left; + node->rb_parent = left; +} + +void rb_insert_color(struct rb_node *node, struct rb_root *root) +{ + struct rb_node *parent, *gparent; + + while ((parent = node->rb_parent) && parent->rb_color == RB_RED) + { + gparent = parent->rb_parent; + + if (parent == gparent->rb_left) + { + { + register struct rb_node *uncle = gparent->rb_right; + if (uncle && uncle->rb_color == RB_RED) + { + uncle->rb_color = RB_BLACK; + parent->rb_color = RB_BLACK; + gparent->rb_color = RB_RED; + node = gparent; + continue; + } + } + + if (parent->rb_right == node) + { + register struct rb_node *tmp; + __rb_rotate_left(parent, root); + tmp = parent; + parent = node; + node = tmp; + } + + parent->rb_color = RB_BLACK; + gparent->rb_color = RB_RED; + __rb_rotate_right(gparent, root); + } else { + { + register struct rb_node *uncle = gparent->rb_left; + if (uncle && uncle->rb_color == RB_RED) + { + uncle->rb_color = RB_BLACK; + parent->rb_color = RB_BLACK; + gparent->rb_color = RB_RED; + node = gparent; + continue; + } + } + + if (parent->rb_left == node) + { + register struct rb_node *tmp; + __rb_rotate_right(parent, root); + tmp = parent; + parent = node; + node = tmp; + } + + parent->rb_color = RB_BLACK; + gparent->rb_color = RB_RED; + __rb_rotate_left(gparent, root); + } + } + + root->rb_node->rb_color = RB_BLACK; +} + +static void __rb_erase_color(struct rb_node *node, struct rb_node *parent, + struct rb_root *root) +{ + struct rb_node *other; + + while ((!node || node->rb_color == RB_BLACK) && node != root->rb_node) + { + if (parent->rb_left == node) + { + other = parent->rb_right; + if (other->rb_color == RB_RED) + { + other->rb_color = RB_BLACK; + parent->rb_color = RB_RED; + __rb_rotate_left(parent, root); + other = parent->rb_right; + } + if ((!other->rb_left || + other->rb_left->rb_color == RB_BLACK) + && (!other->rb_right || + other->rb_right->rb_color == RB_BLACK)) + { + other->rb_color = RB_RED; + node = parent; + parent = node->rb_parent; + } + else + { + if (!other->rb_right || + other->rb_right->rb_color == RB_BLACK) + { + register struct rb_node *o_left; + if ((o_left = other->rb_left)) + o_left->rb_color = RB_BLACK; + other->rb_color = RB_RED; + __rb_rotate_right(other, root); + other = parent->rb_right; + } + other->rb_color = parent->rb_color; + parent->rb_color = RB_BLACK; + if (other->rb_right) + other->rb_right->rb_color = RB_BLACK; + __rb_rotate_left(parent, root); + node = root->rb_node; + break; + } + } + else + { + other = parent->rb_left; + if (other->rb_color == RB_RED) + { + other->rb_color = RB_BLACK; + parent->rb_color = RB_RED; + __rb_rotate_right(parent, root); + other = parent->rb_left; + } + if ((!other->rb_left || + other->rb_left->rb_color == RB_BLACK) + && (!other->rb_right || + other->rb_right->rb_color == RB_BLACK)) + { + other->rb_color = RB_RED; + node = parent; + parent = node->rb_parent; + } + else + { + if (!other->rb_left || + other->rb_left->rb_color == RB_BLACK) + { + register struct rb_node *o_right; + if ((o_right = other->rb_right)) + o_right->rb_color = RB_BLACK; + other->rb_color = RB_RED; + __rb_rotate_left(other, root); + other = parent->rb_left; + } + other->rb_color = parent->rb_color; + parent->rb_color = RB_BLACK; + if (other->rb_left) + other->rb_left->rb_color = RB_BLACK; + __rb_rotate_right(parent, root); + node = root->rb_node; + break; + } + } + } + if (node) + node->rb_color = RB_BLACK; +} + +void rb_erase(struct rb_node *node, struct rb_root *root) +{ + struct rb_node *child, *parent; + int color; + + if (!node->rb_left) + child = node->rb_right; + else if (!node->rb_right) + child = node->rb_left; + else + { + struct rb_node *old = node, *left; + + node = node->rb_right; + while ((left = node->rb_left)) + node = left; + child = node->rb_right; + parent = node->rb_parent; + color = node->rb_color; + + if (child) + child->rb_parent = parent; + if (parent) + { + if (parent->rb_left == node) + parent->rb_left = child; + else + parent->rb_right = child; + } + else + root->rb_node = child; + + if (node->rb_parent == old) + parent = node; + node->rb_parent = old->rb_parent; + node->rb_color = old->rb_color; + node->rb_right = old->rb_right; + node->rb_left = old->rb_left; + + if (old->rb_parent) + { + if (old->rb_parent->rb_left == old) + old->rb_parent->rb_left = node; + else + old->rb_parent->rb_right = node; + } else + root->rb_node = node; + + old->rb_left->rb_parent = node; + if (old->rb_right) + old->rb_right->rb_parent = node; + goto color; + } + + parent = node->rb_parent; + color = node->rb_color; + + if (child) + child->rb_parent = parent; + if (parent) + { + if (parent->rb_left == node) + parent->rb_left = child; + else + parent->rb_right = child; + } + else + root->rb_node = child; + + color: + if (color == RB_BLACK) + __rb_erase_color(child, parent, root); +} + +/* + * This function returns the first node (in sort order) of the tree. + */ +struct rb_node *rb_first(struct rb_root *root) +{ + struct rb_node *n; + + n = root->rb_node; + if (!n) + return (struct rb_node *)0; + while (n->rb_left) + n = n->rb_left; + return n; +} + +struct rb_node *rb_last(struct rb_root *root) +{ + struct rb_node *n; + + n = root->rb_node; + if (!n) + return (struct rb_node *)0; + while (n->rb_right) + n = n->rb_right; + return n; +} + +struct rb_node *rb_next(struct rb_node *node) +{ + /* If we have a right-hand child, go down and then left as far + as we can. */ + if (node->rb_right) { + node = node->rb_right; + while (node->rb_left) + node = node->rb_left; + return node; + } + + /* No right-hand children. Everything down and left is + smaller than us, so any 'next' node must be in the general + direction of our parent. Go up the tree; any time the + ancestor is a right-hand child of its parent, keep going + up. First time it's a left-hand child of its parent, said + parent is our 'next' node. */ + while (node->rb_parent && node == node->rb_parent->rb_right) + node = node->rb_parent; + + return node->rb_parent; +} + +struct rb_node *rb_prev(struct rb_node *node) +{ + /* If we have a left-hand child, go down and then right as far + as we can. */ + if (node->rb_left) { + node = node->rb_left; + while (node->rb_right) + node = node->rb_right; + return node; + } + + /* No left-hand children. Go up till we find an ancestor which + is a right-hand child of its parent */ + while (node->rb_parent && node == node->rb_parent->rb_left) + node = node->rb_parent; + + return node->rb_parent; +} + +void rb_replace_node(struct rb_node *victim, struct rb_node *newnode, + struct rb_root *root) +{ + struct rb_node *parent = victim->rb_parent; + + /* Set the surrounding nodes to point to the replacement */ + if (parent) { + if (victim == parent->rb_left) + parent->rb_left = newnode; + else + parent->rb_right = newnode; + } else { + root->rb_node = newnode; + } + if (victim->rb_left) + victim->rb_left->rb_parent = newnode; + if (victim->rb_right) + victim->rb_right->rb_parent = newnode; + + /* Copy the pointers/colour from the victim to the replacement */ + *newnode = *victim; +} + diff --git a/src/kernel/rbtree.h b/src/kernel/rbtree.h new file mode 100644 index 0000000..1a5fff5 --- /dev/null +++ b/src/kernel/rbtree.h @@ -0,0 +1,150 @@ +/* + Red Black Trees + (C) 1999 Andrea Arcangeli + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program; if not, write to the Free Software + Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + + linux/include/linux/rbtree.h + + To use rbtrees you'll have to implement your own insert and search cores. + This will avoid us to use callbacks and to drop drammatically performances. + I know it's not the cleaner way, but in C (not in C++) to get + performances and genericity... + + Some example of insert and search follows here. The search is a plain + normal search over an ordered tree. The insert instead must be implemented + int two steps: as first thing the code must insert the element in + order as a red leaf in the tree, then the support library function + rb_insert_color() must be called. Such function will do the + not trivial work to rebalance the rbtree if necessary. + +----------------------------------------------------------------------- +static inline struct page * rb_search_page_cache(struct inode * inode, + unsigned long offset) +{ + rb_node_t * n = inode->i_rb_page_cache.rb_node; + struct page * page; + + while (n) + { + page = rb_entry(n, struct page, rb_page_cache); + + if (offset < page->offset) + n = n->rb_left; + else if (offset > page->offset) + n = n->rb_right; + else + return page; + } + return NULL; +} + +static inline struct page * __rb_insert_page_cache(struct inode * inode, + unsigned long offset, + rb_node_t * node) +{ + rb_node_t ** p = &inode->i_rb_page_cache.rb_node; + rb_node_t * parent = NULL; + struct page * page; + + while (*p) + { + parent = *p; + page = rb_entry(parent, struct page, rb_page_cache); + + if (offset < page->offset) + p = &(*p)->rb_left; + else if (offset > page->offset) + p = &(*p)->rb_right; + else + return page; + } + + rb_link_node(node, parent, p); + + return NULL; +} + +static inline struct page * rb_insert_page_cache(struct inode * inode, + unsigned long offset, + rb_node_t * node) +{ + struct page * ret; + if ((ret = __rb_insert_page_cache(inode, offset, node))) + goto out; + rb_insert_color(node, &inode->i_rb_page_cache); + out: + return ret; +} +----------------------------------------------------------------------- +*/ + +#ifndef _LINUX_RBTREE_H +#define _LINUX_RBTREE_H + +#pragma pack(1) +struct rb_node +{ + struct rb_node *rb_parent; + struct rb_node *rb_right; + struct rb_node *rb_left; + char rb_color; +#define RB_RED 0 +#define RB_BLACK 1 +}; +#pragma pack() + +struct rb_root +{ + struct rb_node *rb_node; +}; + +#define RB_ROOT (struct rb_root){ (struct rb_node *)0, } +#define rb_entry(ptr, type, member) \ + ((type *)((char *)(ptr)-(unsigned long)(&((type *)0)->member))) + +#ifdef __cplusplus +extern "C" +{ +#endif + +extern void rb_insert_color(struct rb_node *node, struct rb_root *root); +extern void rb_erase(struct rb_node *node, struct rb_root *root); + +/* Find logical next and previous nodes in a tree */ +extern struct rb_node *rb_next(struct rb_node *); +extern struct rb_node *rb_prev(struct rb_node *); +extern struct rb_node *rb_first(struct rb_root *); +extern struct rb_node *rb_last(struct rb_root *); + +/* Fast replacement of a single node without remove/rebalance/add/rebalance */ +extern void rb_replace_node(struct rb_node *victim, struct rb_node *newnode, + struct rb_root *root); + +#ifdef __cplusplus +} +#endif + +static inline void rb_link_node(struct rb_node *node, struct rb_node *parent, + struct rb_node **link) +{ + node->rb_parent = parent; + node->rb_color = RB_RED; + node->rb_left = node->rb_right = (struct rb_node *)0; + + *link = node; +} + +#endif diff --git a/src/kernel/thrdpool.c b/src/kernel/thrdpool.c new file mode 100644 index 0000000..c1be399 --- /dev/null +++ b/src/kernel/thrdpool.c @@ -0,0 +1,293 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include "msgqueue.h" +#include "thrdpool.h" + +struct __thrdpool +{ + msgqueue_t *msgqueue; + size_t nthreads; + size_t stacksize; + pthread_t tid; + pthread_mutex_t mutex; + pthread_key_t key; + pthread_cond_t *terminate; +}; + +struct __thrdpool_task_entry +{ + void *link; + struct thrdpool_task task; +}; + +static pthread_t __zero_tid; + +static void __thrdpool_exit_routine(void *context) +{ + thrdpool_t *pool = (thrdpool_t *)context; + pthread_t tid; + + /* One thread joins another. Don't need to keep all thread IDs. */ + pthread_mutex_lock(&pool->mutex); + tid = pool->tid; + pool->tid = pthread_self(); + if (--pool->nthreads == 0 && pool->terminate) + pthread_cond_signal(pool->terminate); + + pthread_mutex_unlock(&pool->mutex); + if (!pthread_equal(tid, __zero_tid)) + pthread_join(tid, NULL); + + pthread_exit(NULL); +} + +static void *__thrdpool_routine(void *arg) +{ + thrdpool_t *pool = (thrdpool_t *)arg; + struct __thrdpool_task_entry *entry; + void (*task_routine)(void *); + void *task_context; + + pthread_setspecific(pool->key, pool); + while (!pool->terminate) + { + entry = (struct __thrdpool_task_entry *)msgqueue_get(pool->msgqueue); + if (!entry) + break; + + task_routine = entry->task.routine; + task_context = entry->task.context; + free(entry); + task_routine(task_context); + + if (pool->nthreads == 0) + { + /* Thread pool was destroyed by the task. */ + free(pool); + return NULL; + } + } + + __thrdpool_exit_routine(pool); + return NULL; +} + +static void __thrdpool_terminate(int in_pool, thrdpool_t *pool) +{ + pthread_cond_t term = PTHREAD_COND_INITIALIZER; + + pthread_mutex_lock(&pool->mutex); + msgqueue_set_nonblock(pool->msgqueue); + pool->terminate = &term; + + if (in_pool) + { + /* Thread pool destroyed in a pool thread is legal. */ + pthread_detach(pthread_self()); + pool->nthreads--; + } + + while (pool->nthreads > 0) + pthread_cond_wait(&term, &pool->mutex); + + pthread_mutex_unlock(&pool->mutex); + if (!pthread_equal(pool->tid, __zero_tid)) + pthread_join(pool->tid, NULL); +} + +static int __thrdpool_create_threads(size_t nthreads, thrdpool_t *pool) +{ + pthread_attr_t attr; + pthread_t tid; + int ret; + + ret = pthread_attr_init(&attr); + if (ret == 0) + { + if (pool->stacksize) + pthread_attr_setstacksize(&attr, pool->stacksize); + + while (pool->nthreads < nthreads) + { + ret = pthread_create(&tid, &attr, __thrdpool_routine, pool); + if (ret == 0) + pool->nthreads++; + else + break; + } + + pthread_attr_destroy(&attr); + if (pool->nthreads == nthreads) + return 0; + + __thrdpool_terminate(0, pool); + } + + errno = ret; + return -1; +} + +thrdpool_t *thrdpool_create(size_t nthreads, size_t stacksize) +{ + thrdpool_t *pool; + int ret; + + pool = (thrdpool_t *)malloc(sizeof (thrdpool_t)); + if (!pool) + return NULL; + + pool->msgqueue = msgqueue_create(0, 0); + if (pool->msgqueue) + { + ret = pthread_mutex_init(&pool->mutex, NULL); + if (ret == 0) + { + ret = pthread_key_create(&pool->key, NULL); + if (ret == 0) + { + pool->stacksize = stacksize; + pool->nthreads = 0; + pool->tid = __zero_tid; + pool->terminate = NULL; + if (__thrdpool_create_threads(nthreads, pool) >= 0) + return pool; + + pthread_key_delete(pool->key); + } + + pthread_mutex_destroy(&pool->mutex); + } + + errno = ret; + msgqueue_destroy(pool->msgqueue); + } + + free(pool); + return NULL; +} + +inline void __thrdpool_schedule(const struct thrdpool_task *task, void *buf, + thrdpool_t *pool); + +void __thrdpool_schedule(const struct thrdpool_task *task, void *buf, + thrdpool_t *pool) +{ + ((struct __thrdpool_task_entry *)buf)->task = *task; + msgqueue_put(buf, pool->msgqueue); +} + +int thrdpool_schedule(const struct thrdpool_task *task, thrdpool_t *pool) +{ + void *buf = malloc(sizeof (struct __thrdpool_task_entry)); + + if (buf) + { + __thrdpool_schedule(task, buf, pool); + return 0; + } + + return -1; +} + +inline int thrdpool_in_pool(thrdpool_t *pool); + +int thrdpool_in_pool(thrdpool_t *pool) +{ + return pthread_getspecific(pool->key) == pool; +} + +int thrdpool_increase(thrdpool_t *pool) +{ + pthread_attr_t attr; + pthread_t tid; + int ret; + + ret = pthread_attr_init(&attr); + if (ret == 0) + { + if (pool->stacksize) + pthread_attr_setstacksize(&attr, pool->stacksize); + + pthread_mutex_lock(&pool->mutex); + ret = pthread_create(&tid, &attr, __thrdpool_routine, pool); + if (ret == 0) + pool->nthreads++; + + pthread_mutex_unlock(&pool->mutex); + pthread_attr_destroy(&attr); + if (ret == 0) + return 0; + } + + errno = ret; + return -1; +} + +int thrdpool_decrease(thrdpool_t *pool) +{ + void *buf = malloc(sizeof (struct __thrdpool_task_entry)); + struct __thrdpool_task_entry *entry; + + if (buf) + { + entry = (struct __thrdpool_task_entry *)buf; + entry->task.routine = __thrdpool_exit_routine; + entry->task.context = pool; + msgqueue_put_head(entry, pool->msgqueue); + return 0; + } + + return -1; +} + +void thrdpool_exit(thrdpool_t *pool) +{ + if (thrdpool_in_pool(pool)) + __thrdpool_exit_routine(pool); +} + +void thrdpool_destroy(void (*pending)(const struct thrdpool_task *), + thrdpool_t *pool) +{ + int in_pool = thrdpool_in_pool(pool); + struct __thrdpool_task_entry *entry; + + __thrdpool_terminate(in_pool, pool); + while (1) + { + entry = (struct __thrdpool_task_entry *)msgqueue_get(pool->msgqueue); + if (!entry) + break; + + if (pending && entry->task.routine != __thrdpool_exit_routine) + pending(&entry->task); + + free(entry); + } + + pthread_key_delete(pool->key); + pthread_mutex_destroy(&pool->mutex); + msgqueue_destroy(pool->msgqueue); + if (!in_pool) + free(pool); +} + diff --git a/src/kernel/thrdpool.h b/src/kernel/thrdpool.h new file mode 100644 index 0000000..96c1896 --- /dev/null +++ b/src/kernel/thrdpool.h @@ -0,0 +1,63 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _THRDPOOL_H_ +#define _THRDPOOL_H_ + +#include + +typedef struct __thrdpool thrdpool_t; + +struct thrdpool_task +{ + void (*routine)(void *); + void *context; +}; + +#ifdef __cplusplus +extern "C" +{ +#endif + +/* + * Thread pool originates from project Sogou C++ Workflow + * https://github.com/sogou/workflow + * + * A thread task can be scheduled by another task, which is very important, + * even if the pool is being destroyed. Because thread task is hard to know + * what's happening to the pool. + * The thread pool can also be destroyed by a thread task. This may sound + * strange, but it's very logical. Destroying thread pool in thread task + * does not end the task thread. It'll run till the end of task. + */ + +thrdpool_t *thrdpool_create(size_t nthreads, size_t stacksize); +int thrdpool_schedule(const struct thrdpool_task *task, thrdpool_t *pool); +int thrdpool_in_pool(thrdpool_t *pool); +int thrdpool_increase(thrdpool_t *pool); +int thrdpool_decrease(thrdpool_t *pool); +void thrdpool_exit(thrdpool_t *pool); +void thrdpool_destroy(void (*pending)(const struct thrdpool_task *), + thrdpool_t *pool); + +#ifdef __cplusplus +} +#endif + +#endif + diff --git a/src/kernel/xmake.lua b/src/kernel/xmake.lua new file mode 100644 index 0000000..7a926b7 --- /dev/null +++ b/src/kernel/xmake.lua @@ -0,0 +1,9 @@ +target("kernel") + set_kind("object") + add_files("*.cc") + add_files("*.c") + if is_plat("linux", "android") then + remove_files("IOService_thread.cc") + else + remove_files("IOService_linux.cc") + end diff --git a/src/manager/CMakeLists.txt b/src/manager/CMakeLists.txt new file mode 100644 index 0000000..d171fb5 --- /dev/null +++ b/src/manager/CMakeLists.txt @@ -0,0 +1,17 @@ +cmake_minimum_required(VERSION 3.6) +project(manager) + +set(SRC + DnsCache.cc + RouteManager.cc + WFGlobal.cc +) + +if (NOT UPSTREAM STREQUAL "n") + set(SRC + ${SRC} + UpstreamManager.cc + ) +endif () + +add_library(${PROJECT_NAME} OBJECT ${SRC}) diff --git a/src/manager/DnsCache.cc b/src/manager/DnsCache.cc new file mode 100644 index 0000000..3c0cee9 --- /dev/null +++ b/src/manager/DnsCache.cc @@ -0,0 +1,107 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include "DnsCache.h" + +#define GET_CURRENT_SECOND std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count() + +#define TTL_INC 5 + +const DnsCache::DnsHandle *DnsCache::get_inner(const HostPort& host_port, + int type) +{ + int64_t cur = GET_CURRENT_SECOND; + std::lock_guard lock(mutex_); + const DnsHandle *handle = cache_pool_.get(host_port); + + if (handle && ((type == GET_TYPE_TTL && cur > handle->value.expire_time) || + (type == GET_TYPE_CONFIDENT && cur > handle->value.confident_time))) + { + if (!handle->value.delayed()) + { + DnsHandle *h = const_cast(handle); + if (type == GET_TYPE_TTL) + h->value.expire_time += TTL_INC; + else + h->value.confident_time += TTL_INC; + + h->value.addrinfo->ai_flags |= 2; + } + + cache_pool_.release(handle); + return NULL; + } + + return handle; +} + +const DnsCache::DnsHandle *DnsCache::put(const HostPort& host_port, + struct addrinfo *addrinfo, + unsigned int dns_ttl_default, + unsigned int dns_ttl_min) +{ + int64_t expire_time; + int64_t confident_time; + int64_t cur_time = GET_CURRENT_SECOND; + + if (dns_ttl_min > dns_ttl_default) + dns_ttl_min = dns_ttl_default; + + if (dns_ttl_min == (unsigned int)-1) + confident_time = INT64_MAX; + else + confident_time = cur_time + dns_ttl_min; + + if (dns_ttl_default == (unsigned int)-1) + expire_time = INT64_MAX; + else + expire_time = cur_time + dns_ttl_default; + + std::lock_guard lock(mutex_); + return cache_pool_.put(host_port, {addrinfo, confident_time, expire_time}); +} + +const DnsCache::DnsHandle *DnsCache::get(const DnsCache::HostPort& host_port) +{ + std::lock_guard lock(mutex_); + return cache_pool_.get(host_port); +} + +void DnsCache::release(const DnsCache::DnsHandle *handle) +{ + std::lock_guard lock(mutex_); + cache_pool_.release(handle); +} + +void DnsCache::del(const DnsCache::HostPort& key) +{ + std::lock_guard lock(mutex_); + cache_pool_.del(key); +} + +DnsCache::DnsCache() +{ +} + +DnsCache::~DnsCache() +{ +} + diff --git a/src/manager/DnsCache.h b/src/manager/DnsCache.h new file mode 100644 index 0000000..bd30146 --- /dev/null +++ b/src/manager/DnsCache.h @@ -0,0 +1,170 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _DNSCACHE_H_ +#define _DNSCACHE_H_ + +#include +#include +#include +#include +#include +#include "LRUCache.h" +#include "DnsUtil.h" + +#define GET_TYPE_TTL 0 +#define GET_TYPE_CONFIDENT 1 + +struct DnsCacheValue +{ + struct addrinfo *addrinfo; + int64_t confident_time; + int64_t expire_time; + + bool delayed() const + { + return addrinfo->ai_flags & 2; + } +}; + +// RAII: NO. Release handle by user +// Thread safety: YES +// MUST call release when handle no longer used +class DnsCache +{ +public: + using HostPort = std::pair; + using DnsHandle = LRUHandle; + +public: + // get handler + // Need call release when handle no longer needed + //Handle *get(const KEY &key); + const DnsHandle *get(const HostPort& host_port); + + const DnsHandle *get(const std::string& host, unsigned short port) + { + return get(HostPort(host, port)); + } + + const DnsHandle *get(const char *host, unsigned short port) + { + return get(std::string(host), port); + } + + const DnsHandle *get_ttl(const HostPort& host_port) + { + return get_inner(host_port, GET_TYPE_TTL); + } + + const DnsHandle *get_ttl(const std::string& host, unsigned short port) + { + return get_ttl(HostPort(host, port)); + } + + const DnsHandle *get_ttl(const char *host, unsigned short port) + { + return get_ttl(std::string(host), port); + } + + const DnsHandle *get_confident(const HostPort& host_port) + { + return get_inner(host_port, GET_TYPE_CONFIDENT); + } + + const DnsHandle *get_confident(const std::string& host, unsigned short port) + { + return get_confident(HostPort(host, port)); + } + + const DnsHandle *get_confident(const char *host, unsigned short port) + { + return get_confident(std::string(host), port); + } + + const DnsHandle *put(const HostPort& host_port, + struct addrinfo *addrinfo, + unsigned int dns_ttl_default, + unsigned int dns_ttl_min); + + const DnsHandle *put(const std::string& host, + unsigned short port, + struct addrinfo *addrinfo, + unsigned int dns_ttl_default, + unsigned int dns_ttl_min) + { + return put(HostPort(host, port), addrinfo, dns_ttl_default, dns_ttl_min); + } + + const DnsHandle *put(const char *host, + unsigned short port, + struct addrinfo *addrinfo, + unsigned int dns_ttl_default, + unsigned int dns_ttl_min) + { + return put(std::string(host), port, addrinfo, dns_ttl_default, dns_ttl_min); + } + + // release handle by get/put + void release(const DnsHandle *handle); + + // delete from cache, deleter delay called when all inuse-handle release. + void del(const HostPort& key); + + void del(const std::string& host, unsigned short port) + { + del(HostPort(host, port)); + } + + void del(const char *host, unsigned short port) + { + del(std::string(host), port); + } + +private: + const DnsHandle *get_inner(const HostPort& host_port, int type); + + std::mutex mutex_; + + class ValueDeleter + { + public: + void operator() (const DnsCacheValue& value) const + { + struct addrinfo *ai = value.addrinfo; + + if (ai) + { + if (ai->ai_flags) + freeaddrinfo(ai); + else + protocol::DnsUtil::freeaddrinfo(ai); + } + } + }; + + LRUCache cache_pool_; + +public: + // To prevent inline calling LRUCache's constructor and deconstructor. + DnsCache(); + ~DnsCache(); +}; + +#endif + diff --git a/src/manager/EndpointParams.h b/src/manager/EndpointParams.h new file mode 100644 index 0000000..57e85c3 --- /dev/null +++ b/src/manager/EndpointParams.h @@ -0,0 +1,60 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _ENDPOINTPARAMS_H_ +#define _ENDPOINTPARAMS_H_ + +#include +#include + +/** + * @file EndpointParams.h + * @brief Network config for client task + */ + +enum TransportType +{ + TT_TCP, + TT_UDP, + TT_SCTP, + TT_TCP_SSL, + TT_SCTP_SSL, +}; + +struct EndpointParams +{ + int address_family; + size_t max_connections; + int connect_timeout; + int response_timeout; + int ssl_connect_timeout; + bool use_tls_sni; +}; + +static constexpr struct EndpointParams ENDPOINT_PARAMS_DEFAULT = +{ + .address_family = AF_UNSPEC, + .max_connections = 200, + .connect_timeout = 10 * 1000, + .response_timeout = 10 * 1000, + .ssl_connect_timeout = 10 * 1000, + .use_tls_sni = false, +}; + +#endif + diff --git a/src/manager/RouteManager.cc b/src/manager/RouteManager.cc new file mode 100644 index 0000000..0789f67 --- /dev/null +++ b/src/manager/RouteManager.cc @@ -0,0 +1,579 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "list.h" +#include "rbtree.h" +#include "WFGlobal.h" +#include "CommScheduler.h" +#include "EndpointParams.h" +#include "RouteManager.h" +#include "StringUtil.h" + +#define GET_CURRENT_SECOND std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count() +#define MTTR_SECOND 30 + +using RouteTargetTCP = RouteManager::RouteTarget; + +class RouteTargetUDP : public RouteManager::RouteTarget +{ +private: + virtual int create_connect_fd() + { + const struct sockaddr *addr; + socklen_t addrlen; + + this->get_addr(&addr, &addrlen); + return socket(addr->sa_family, SOCK_DGRAM, 0); + } +}; + +class RouteTargetSCTP : public RouteManager::RouteTarget +{ +private: +#ifdef IPPROTO_SCTP + virtual int create_connect_fd() + { + const struct sockaddr *addr; + socklen_t addrlen; + + this->get_addr(&addr, &addrlen); + return socket(addr->sa_family, SOCK_STREAM, IPPROTO_SCTP); + } +#else + virtual int create_connect_fd() + { + errno = EPROTONOSUPPORT; + return -1; + } +#endif +}; + +/* To support TLS SNI. */ +class RouteTargetTCPSNI : public RouteTargetTCP +{ +private: + virtual int init_ssl(SSL *ssl) + { + if (SSL_set_tlsext_host_name(ssl, this->hostname.c_str()) > 0) + return 0; + else + return -1; + } + +private: + std::string hostname; + +public: + RouteTargetTCPSNI(const std::string& name) : hostname(name) + { + } +}; + +class RouteTargetSCTPSNI : public RouteTargetSCTP +{ +private: + virtual int init_ssl(SSL *ssl) + { + if (SSL_set_tlsext_host_name(ssl, this->hostname.c_str()) > 0) + return 0; + else + return -1; + } + +private: + std::string hostname; + +public: + RouteTargetSCTPSNI(const std::string& name) : hostname(name) + { + } +}; + +// protocol_name\n user\n pass\n dbname\n ai_addr ai_addrlen \n.... +// + +struct RouteParams +{ + enum TransportType transport_type; + const struct addrinfo *addrinfo; + uint64_t key; + SSL_CTX *ssl_ctx; + size_t max_connections; + int connect_timeout; + int response_timeout; + int ssl_connect_timeout; + bool use_tls_sni; + const std::string& hostname; +}; + +class RouteResultEntry +{ +public: + struct rb_node rb; + CommSchedObject *request_object; + CommSchedGroup *group; + std::mutex mutex; + std::vector targets; + struct list_head breaker_list; + uint64_t key; + int nleft; + int nbreak; + + RouteResultEntry(): + request_object(NULL), + group(NULL) + { + INIT_LIST_HEAD(&this->breaker_list); + this->nleft = 0; + this->nbreak = 0; + } + +public: + int init(const struct RouteParams *params); + void deinit(); + + void notify_unavailable(RouteManager::RouteTarget *target); + void notify_available(RouteManager::RouteTarget *target); + void check_breaker(); + +private: + void free_list(); + RouteManager::RouteTarget *create_target(const struct RouteParams *params, + const struct addrinfo *addrinfo); + int add_group_targets(const struct RouteParams *params); +}; + +struct __breaker_node +{ + RouteManager::RouteTarget *target; + int64_t timeout; + struct list_head breaker_list; +}; + +RouteManager::RouteTarget * +RouteResultEntry::create_target(const struct RouteParams *params, + const struct addrinfo *addr) +{ + RouteManager::RouteTarget *target; + + switch (params->transport_type) + { + case TT_TCP_SSL: + if (params->use_tls_sni) + target = new RouteTargetTCPSNI(params->hostname); + else + case TT_TCP: + target = new RouteTargetTCP(); + break; + case TT_UDP: + target = new RouteTargetUDP(); + break; + case TT_SCTP_SSL: + if (params->use_tls_sni) + target = new RouteTargetSCTPSNI(params->hostname); + else + case TT_SCTP: + target = new RouteTargetSCTP(); + break; + default: + errno = EINVAL; + return NULL; + } + + if (target->init(addr->ai_addr, addr->ai_addrlen, params->ssl_ctx, + params->connect_timeout, params->ssl_connect_timeout, + params->response_timeout, params->max_connections) < 0) + { + delete target; + target = NULL; + } + + return target; +} + +int RouteResultEntry::init(const struct RouteParams *params) +{ + const struct addrinfo *addr = params->addrinfo; + RouteManager::RouteTarget *target; + + if (addr == NULL)//0 + { + errno = EINVAL; + return -1; + } + + if (addr->ai_next == NULL)//1 + { + target = this->create_target(params, addr); + if (target) + { + this->targets.push_back(target); + this->request_object = target; + this->key = params->key; + return 0; + } + + return -1; + } + + this->group = new CommSchedGroup(); + if (this->group->init() >= 0) + { + if (this->add_group_targets(params) >= 0) + { + this->request_object = this->group; + this->key = params->key; + return 0; + } + + this->group->deinit(); + } + + delete this->group; + return -1; +} + +int RouteResultEntry::add_group_targets(const struct RouteParams *params) +{ + RouteManager::RouteTarget *target; + const struct addrinfo *addr; + + for (addr = params->addrinfo; addr; addr = addr->ai_next) + { + target = this->create_target(params, addr); + if (target) + { + if (this->group->add(target) >= 0) + { + this->targets.push_back(target); + this->nleft++; + continue; + } + + target->deinit(); + delete target; + } + + for (auto *target : this->targets) + { + this->group->remove(target); + target->deinit(); + delete target; + } + + return -1; + } + + return 0; +} + +void RouteResultEntry::deinit() +{ + for (auto *target : this->targets) + { + if (this->group) + this->group->remove(target); + + target->deinit(); + delete target; + } + + if (this->group) + { + this->group->deinit(); + delete this->group; + } + + struct list_head *pos, *tmp; + __breaker_node *node; + + list_for_each_safe(pos, tmp, &this->breaker_list) + { + node = list_entry(pos, __breaker_node, breaker_list); + list_del(pos); + delete node; + } +} + +void RouteResultEntry::notify_unavailable(RouteManager::RouteTarget *target) +{ + if (this->targets.size() <= 1) + return; + + int errno_bak = errno; + std::lock_guard lock(this->mutex); + + if (this->nleft <= 1) + return; + + if (this->group->remove(target) < 0) + { + errno = errno_bak; + return; + } + + auto *node = new __breaker_node; + + node->target = target; + node->timeout = GET_CURRENT_SECOND + MTTR_SECOND; + list_add_tail(&node->breaker_list, &this->breaker_list); + this->nbreak++; + this->nleft--; +} + +void RouteResultEntry::notify_available(RouteManager::RouteTarget *target) +{ + if (this->targets.size() <= 1 || this->nbreak == 0) + return; + + int errno_bak = errno; + std::lock_guard lock(this->mutex); + + if (this->group->add(target) == 0) + this->nleft++; + else + errno = errno_bak; +} + +void RouteResultEntry::check_breaker() +{ + if (this->targets.size() <= 1 || this->nbreak == 0) + return; + + struct list_head *pos, *tmp; + __breaker_node *node; + int errno_bak = errno; + int64_t cur_time = GET_CURRENT_SECOND; + std::lock_guard lock(this->mutex); + + list_for_each_safe(pos, tmp, &this->breaker_list) + { + node = list_entry(pos, __breaker_node, breaker_list); + if (cur_time >= node->timeout) + { + if (this->group->add(node->target) == 0) + this->nleft++; + else + errno = errno_bak; + + list_del(pos); + delete node; + this->nbreak--; + } + } +} + +static inline int __addr_cmp(const struct addrinfo *x, const struct addrinfo *y) +{ + //todo ai_protocol + if (x->ai_addrlen == y->ai_addrlen) + return memcmp(x->ai_addr, y->ai_addr, x->ai_addrlen); + else if (x->ai_addrlen < y->ai_addrlen) + return -1; + else + return 1; +} + +static inline bool __addr_less(const struct addrinfo *x, const struct addrinfo *y) +{ + return __addr_cmp(x, y) < 0; +} + +static uint64_t __fnv_hash(const unsigned char *data, size_t size) +{ + uint64_t hash = 14695981039346656037ULL; + + while (size) + { + hash ^= (const uint64_t)*data++; + hash *= 1099511628211ULL; + size--; + } + + return hash; +} + +static uint64_t __generate_key(enum TransportType type, + const struct addrinfo *addrinfo, + const std::string& other_info, + const struct EndpointParams *ep_params, + const std::string& hostname, + SSL_CTX *ssl_ctx) +{ + const int params[] = { + ep_params->address_family, (int)ep_params->max_connections, + ep_params->connect_timeout, ep_params->response_timeout + }; + std::string buf((const char *)&type, sizeof (enum TransportType)); + + if (!other_info.empty()) + buf += other_info; + + buf.append((const char *)params, sizeof params); + if (type == TT_TCP_SSL || type == TT_SCTP_SSL) + { + buf.append((const char *)&ssl_ctx, sizeof (void *)); + buf.append((const char *)&ep_params->ssl_connect_timeout, sizeof (int)); + if (ep_params->use_tls_sni) + { + buf += hostname; + buf += '\n'; + } + } + + if (addrinfo->ai_next) + { + std::vector sorted_addr; + + sorted_addr.push_back(addrinfo); + addrinfo = addrinfo->ai_next; + do + { + sorted_addr.push_back(addrinfo); + addrinfo = addrinfo->ai_next; + } while (addrinfo); + + std::sort(sorted_addr.begin(), sorted_addr.end(), __addr_less); + for (const struct addrinfo *p : sorted_addr) + { + buf.append((const char *)&p->ai_addrlen, sizeof (socklen_t)); + buf.append((const char *)p->ai_addr, p->ai_addrlen); + } + } + else + { + buf.append((const char *)&addrinfo->ai_addrlen, sizeof (socklen_t)); + buf.append((const char *)addrinfo->ai_addr, addrinfo->ai_addrlen); + } + + return __fnv_hash((const unsigned char *)buf.c_str(), buf.size()); +} + +RouteManager::~RouteManager() +{ + RouteResultEntry *entry; + + while (cache_.rb_node) + { + entry = rb_entry(cache_.rb_node, RouteResultEntry, rb); + rb_erase(cache_.rb_node, &cache_); + entry->deinit(); + delete entry; + } +} + +int RouteManager::get(enum TransportType type, + const struct addrinfo *addrinfo, + const std::string& other_info, + const struct EndpointParams *ep_params, + const std::string& hostname, SSL_CTX *ssl_ctx, + RouteResult& result) +{ + if (type == TT_TCP_SSL || type == TT_SCTP_SSL) + { + static SSL_CTX *global_client_ctx = WFGlobal::get_ssl_client_ctx(); + if (ssl_ctx == NULL) + ssl_ctx = global_client_ctx; + } + else + ssl_ctx = NULL; + + uint64_t key = __generate_key(type, addrinfo, other_info, ep_params, + hostname, ssl_ctx); + struct rb_node **p = &cache_.rb_node; + struct rb_node *parent = NULL; + RouteResultEntry *bound = NULL; + RouteResultEntry *entry; + std::lock_guard lock(mutex_); + + while (*p) + { + parent = *p; + entry = rb_entry(*p, RouteResultEntry, rb); + if (key <= entry->key) + { + bound = entry; + p = &(*p)->rb_left; + } + else + p = &(*p)->rb_right; + } + + if (bound && bound->key == key) + { + entry = bound; + entry->check_breaker(); + } + else + { + struct RouteParams params = { + .transport_type = type, + .addrinfo = addrinfo, + .key = key, + .ssl_ctx = ssl_ctx, + .max_connections = ep_params->max_connections, + .connect_timeout = ep_params->connect_timeout, + .response_timeout = ep_params->response_timeout, + .ssl_connect_timeout = ep_params->ssl_connect_timeout, + .use_tls_sni = ep_params->use_tls_sni, + .hostname = hostname, + }; + + entry = new RouteResultEntry; + if (entry->init(¶ms) >= 0) + { + rb_link_node(&entry->rb, parent, p); + rb_insert_color(&entry->rb, &cache_); + } + else + { + delete entry; + return -1; + } + } + + result.cookie = entry; + result.request_object = entry->request_object; + return 0; +} + +void RouteManager::notify_unavailable(void *cookie, CommTarget *target) +{ + if (cookie && target) + ((RouteResultEntry *)cookie)->notify_unavailable((RouteTarget *)target); +} + +void RouteManager::notify_available(void *cookie, CommTarget *target) +{ + if (cookie && target) + ((RouteResultEntry *)cookie)->notify_available((RouteTarget *)target); +} + diff --git a/src/manager/RouteManager.h b/src/manager/RouteManager.h new file mode 100644 index 0000000..cd660d4 --- /dev/null +++ b/src/manager/RouteManager.h @@ -0,0 +1,113 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _ROUTEMANAGER_H_ +#define _ROUTEMANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include "rbtree.h" +#include "WFConnection.h" +#include "EndpointParams.h" +#include "CommScheduler.h" + +class RouteManager +{ +public: + class RouteResult + { + public: + void *cookie; + CommSchedObject *request_object; + + public: + RouteResult(): cookie(NULL), request_object(NULL) { } + void clear() { cookie = NULL; request_object = NULL; } + }; + + class RouteTarget : public CommSchedTarget + { +#if OPENSSL_VERSION_NUMBER >= 0x10100000L + public: + int init(const struct sockaddr *addr, socklen_t addrlen, SSL_CTX *ssl_ctx, + int connect_timeout, int ssl_connect_timeout, int response_timeout, + size_t max_connections) + { + int ret = this->CommSchedTarget::init(addr, addrlen, ssl_ctx, + connect_timeout, ssl_connect_timeout, + response_timeout, max_connections); + + if (ret >= 0 && ssl_ctx) + SSL_CTX_up_ref(ssl_ctx); + + return ret; + } + + void deinit() + { + SSL_CTX *ssl_ctx = this->get_ssl_ctx(); + + this->CommSchedTarget::deinit(); + if (ssl_ctx) + SSL_CTX_free(ssl_ctx); + } +#endif + + public: + int state; + + private: + virtual WFConnection *new_connection(int connect_fd) + { + return new WFConnection; + } + + public: + RouteTarget() : state(0) { } + }; + +public: + int get(enum TransportType type, + const struct addrinfo *addrinfo, + const std::string& other_info, + const struct EndpointParams *ep_params, + const std::string& hostname, SSL_CTX *ssl_ctx, + RouteResult& result); + + RouteManager() + { + cache_.rb_node = NULL; + } + + ~RouteManager(); + +private: + std::mutex mutex_; + struct rb_root cache_; + +public: + static void notify_unavailable(void *cookie, CommTarget *target); + static void notify_available(void *cookie, CommTarget *target); +}; + +#endif + diff --git a/src/manager/UpstreamManager.cc b/src/manager/UpstreamManager.cc new file mode 100644 index 0000000..8bfce20 --- /dev/null +++ b/src/manager/UpstreamManager.cc @@ -0,0 +1,267 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#include +#include +#include "UpstreamManager.h" +#include "WFNameService.h" +#include "WFGlobal.h" +#include "UpstreamPolicies.h" + +class __UpstreamManager +{ +public: + static __UpstreamManager *get_instance() + { + static __UpstreamManager kInstance; + return &kInstance; + } + + void add_upstream_policy(UPSGroupPolicy *policy) + { + pthread_mutex_lock(&this->mutex); + this->upstream_policies.push_back(policy); + pthread_mutex_unlock(&this->mutex); + } + +private: + __UpstreamManager() : + mutex(PTHREAD_MUTEX_INITIALIZER) + { + } + + ~__UpstreamManager() + { + for (UPSGroupPolicy *policy : this->upstream_policies) + delete policy; + } + + pthread_mutex_t mutex; + std::vector upstream_policies; +}; + +int UpstreamManager::upstream_create_round_robin(const std::string& name, + bool try_another) +{ + WFNameService *ns = WFGlobal::get_name_service(); + auto *policy = new UPSRoundRobinPolicy(try_another); + + if (ns->add_policy(name.c_str(), policy) >= 0) + { + __UpstreamManager::get_instance()->add_upstream_policy(policy); + return 0; + } + + delete policy; + return -1; +} + +static unsigned int __default_consistent_hash(const char *path, + const char *query, + const char *fragment) +{ + static std::hash std_hash; + std::string str(path); + + str += query; + str += fragment; + return std_hash(str); +} + +int UpstreamManager::upstream_create_consistent_hash(const std::string& name, + upstream_route_t consistent_hash) +{ + WFNameService *ns = WFGlobal::get_name_service(); + UPSConsistentHashPolicy *policy; + + policy = new UPSConsistentHashPolicy( + consistent_hash ? std::move(consistent_hash) : + __default_consistent_hash); + if (ns->add_policy(name.c_str(), policy) >= 0) + { + __UpstreamManager::get_instance()->add_upstream_policy(policy); + return 0; + } + + delete policy; + return -1; +} + +int UpstreamManager::upstream_create_weighted_random(const std::string& name, + bool try_another) +{ + WFNameService *ns = WFGlobal::get_name_service(); + auto *policy = new UPSWeightedRandomPolicy(try_another); + + if (ns->add_policy(name.c_str(), policy) >= 0) + { + __UpstreamManager::get_instance()->add_upstream_policy(policy); + return 0; + } + + delete policy; + return -1; +} + +int UpstreamManager::upstream_create_vnswrr(const std::string& name) +{ + WFNameService *ns = WFGlobal::get_name_service(); + auto *policy = new UPSVNSWRRPolicy(); + + if (ns->add_policy(name.c_str(), policy) >= 0) + { + __UpstreamManager::get_instance()->add_upstream_policy(policy); + return 0; + } + + delete policy; + return -1; +} + +int UpstreamManager::upstream_create_manual(const std::string& name, + upstream_route_t select, + bool try_another, + upstream_route_t consistent_hash) +{ + WFNameService *ns = WFGlobal::get_name_service(); + UPSManualPolicy *policy; + + policy = new UPSManualPolicy(try_another, std::move(select), + consistent_hash ? std::move(consistent_hash) : + __default_consistent_hash); + if (ns->add_policy(name.c_str(), policy) >= 0) + { + __UpstreamManager::get_instance()->add_upstream_policy(policy); + return 0; + } + + delete policy; + return -1; +} + +int UpstreamManager::upstream_add_server(const std::string& name, + const std::string& address) +{ + return UpstreamManager::upstream_add_server(name, address, + &ADDRESS_PARAMS_DEFAULT); +} + +int UpstreamManager::upstream_add_server(const std::string& name, + const std::string& address, + const AddressParams *address_params) +{ + WFNameService *ns = WFGlobal::get_name_service(); + UPSGroupPolicy *policy = dynamic_cast(ns->get_policy(name.c_str())); + + if (policy) + { + policy->add_server(address, address_params); + return 0; + } + + errno = ENOENT; + return -1; +} + +int UpstreamManager::upstream_remove_server(const std::string& name, + const std::string& address) +{ + WFNameService *ns = WFGlobal::get_name_service(); + UPSGroupPolicy *policy = dynamic_cast(ns->get_policy(name.c_str())); + + if (policy) + return policy->remove_server(address); + + errno = ENOENT; + return -1; +} + +int UpstreamManager::upstream_delete(const std::string& name) +{ + WFNameService *ns = WFGlobal::get_name_service(); + UPSGroupPolicy *policy = dynamic_cast(ns->del_policy(name.c_str())); + + if (policy) + return 0; + + errno = ENOENT; + return -1; +} + +std::vector +UpstreamManager::upstream_main_address_list(const std::string& name) +{ + std::vector address; + WFNameService *ns = WFGlobal::get_name_service(); + UPSGroupPolicy *policy = dynamic_cast(ns->get_policy(name.c_str())); + + if (policy) + policy->get_main_address(address); + + return address; +} + +int UpstreamManager::upstream_disable_server(const std::string& name, + const std::string& address) +{ + WFNameService *ns = WFGlobal::get_name_service(); + UPSGroupPolicy *policy = dynamic_cast(ns->get_policy(name.c_str())); + + if (policy) + { + policy->disable_server(address); + return 0; + } + + errno = ENOENT; + return -1; +} + +int UpstreamManager::upstream_enable_server(const std::string& name, + const std::string& address) +{ + WFNameService *ns = WFGlobal::get_name_service(); + UPSGroupPolicy *policy = dynamic_cast(ns->get_policy(name.c_str())); + + if (policy) + { + policy->enable_server(address); + return 0; + } + + errno = ENOENT; + return -1; +} + +int UpstreamManager::upstream_replace_server(const std::string& name, + const std::string& address, + const struct AddressParams *address_params) +{ + WFNameService *ns = WFGlobal::get_name_service(); + UPSGroupPolicy *policy = dynamic_cast(ns->get_policy(name.c_str())); + + if (policy) + { + policy->replace_server(address, address_params); + return 0; + } + + errno = ENOENT; + return -1; +} + diff --git a/src/manager/UpstreamManager.h b/src/manager/UpstreamManager.h new file mode 100644 index 0000000..7188a1a --- /dev/null +++ b/src/manager/UpstreamManager.h @@ -0,0 +1,222 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _UPSTREAM_MANAGER_H_ +#define _UPSTREAM_MANAGER_H_ + +#include +#include +#include "WFServiceGovernance.h" +#include "UpstreamPolicies.h" +#include "WFGlobal.h" + +/** + * @file UpstreamManager.h + * @brief Local Reverse Proxy & Load Balance & Service Discovery + * @details + * - This is very similar with Nginx-Upstream. + * - Do not cost any other network resource, We just simulate in local to choose one target properly. + * - This is working only for the current process. + */ + +/** + * @brief Upstream Management Class + * @details + * - We support three modes: + * 1. Weighted-Random + * 2. Consistent-Hash + * 3. Manual-Select + * - Additional, we support Main-backup & Group for server and working well in any mode. + * + * @code{.cc} + upstream_create_weighted_random("abc.sogou", true); //UPSTREAM_WEIGHTED_RANDOM + upstream_add_server("abc.sogou", "192.168.2.100:8081"); //weight=1, max_fails=200 + upstream_add_server("abc.sogou", "192.168.2.100:9090"); //weight=1, max_fails=200 + AddressParams params = ADDRESS_PARAMS_DEFAULT; + params.weight = 2; + params.max_fails = 6; + upstream_add_server("abc.sogou", "www.sogou.com", ¶ms); //weight=2, max_fails=6 + + //send request with url like http://abc.sogou/somepath/somerequest + + upstream_create_consistent_hash("def.sogou", + [](const char *path, + const char *query, + const char *fragment) -> int { + return somehash(...)); + }); //UPSTREAM_CONSISTENT_HASH + upstream_create_manual("xyz.sogou", + [](const char *path, + const char *query, + const char *fragment) -> int { + return select_xxx(...)); + }, + true, + [](const char *path, + const char *query, + const char *fragment) -> int { + return rehash(...)); + },); //UPSTREAM_MANUAL + * @endcode + */ +class UpstreamManager +{ +public: + /** + * @brief MODE 0: round-robin select + * @param[in] name upstream name + * @param[in] try_another when first choice is failed, try another one or not + * @return success/fail + * @retval 0 success + * @retval -1 fail, more info see errno + * @note + * when first choose server is already down: + * - if try_another==false, request will be failed + * - if try_another==true, upstream will choose the next + */ + static int upstream_create_round_robin(const std::string& name, + bool try_another); + + /** + * @brief MODE 1: consistent-hashing select + * @param[in] name upstream name + * @param[in] consitent_hash consistent-hash functional + * @return success/fail + * @retval 0 success + * @retval -1 fail, more info see errno + * @note consitent_hash need to return value in 0~(2^31-1) Balance/Monotonicity/Spread/Smoothness + * @note if consitent_hash==nullptr, upstream will use std::hash with request uri + */ + static int upstream_create_consistent_hash(const std::string& name, + upstream_route_t consitent_hash); + + /** + * @brief MODE 2: weighted-random select + * @param[in] name upstream name + * @param[in] try_another when first choice is failed, try another one or not + * @return success/fail + * @retval 0 success + * @retval -1 fail, more info see errno + * @note + * when first choose server is already down: + * - if try_another==false, request will be failed + * - if try_another==true, upstream will choose from alive-servers by weight-random strategy + */ + static int upstream_create_weighted_random(const std::string& name, + bool try_another); + + /** + * @brief MODE 3: manual select + * @param[in] name upstream name + * @param[in] select manual select functional, just tell us main-index. + * @param[in] try_another when first choice is failed, try another one or not + * @param[in] consitent_hash consistent-hash functional + * @return success/fail + * @retval 0 success + * @retval -1 fail, more info see errno + * @note + * when first choose server is already down: + * - if try_another==false, request will be failed, consistent_hash value will be ignored + * - if try_another==true, upstream will work with consistent hash mode,if consitent_hash==NULL, upstream will use std::hash with request uri + * @warning select functional cannot be nullptr! + */ + static int upstream_create_manual(const std::string& name, + upstream_route_t select, + bool try_another, + upstream_route_t consitent_hash); + + /** + * @brief MODE 4: VNSWRR select + * @param[in] name upstream name + * @return success/fail + * @retval 0 success + * @retval -1 fail, more info see errno + * @note + */ + static int upstream_create_vnswrr(const std::string& name); + + /** + * @brief Delete one upstream + * @param[in] name upstream name + * @return success/fail + * @retval 0 success + * @retval -1 fail, not found + */ + static int upstream_delete(const std::string& name); + +public: + /** + * @brief Add server into one upstream, with default config + * @param[in] name upstream name + * @param[in] address ip OR host OR ip:port OR host:port OR /unix-domain-socket + * @return success/fail + * @retval 0 success + * @retval -1 fail, more info see errno + * @warning Same address add twice, means two different server + */ + static int upstream_add_server(const std::string& name, + const std::string& address); + + /** + * @brief Add server into one upstream, with custom config + * @param[in] name upstream name + * @param[in] address ip OR host OR ip:port OR host:port OR /unix-domain-socket + * @param[in] address_params custom config for this target server + * @return success/fail + * @retval 0 success + * @retval -1 fail, more info see errno + * @warning Same address with different params, means two different server + * @warning Same address with exactly same params, still means two different server + */ + static int upstream_add_server(const std::string& name, + const std::string& address, + const struct AddressParams *address_params); + + /** + * @brief Remove server from one upstream + * @param[in] name upstream name + * @param[in] address same as address when add by upstream_add_server + * @return success/fail + * @retval >=0 success, the amount of being removed server + * @retval -1 fail, upstream name not found + * @warning If server servers has the same address in this upstream, we will remove them all + */ + static int upstream_remove_server(const std::string& name, + const std::string& address); + + /** + * @brief get all main servers address list from one upstream + * @param[in] name upstream name + * @return all main servers' address list + * @warning If server servers has the same address in this upstream, then will appear in the vector multiply times + */ + static std::vector upstream_main_address_list(const std::string& name); + +public: + /// @breif for plugin + static int upstream_disable_server(const std::string& name, const std::string& address); + static int upstream_enable_server(const std::string& name, const std::string& address); + + static int upstream_replace_server(const std::string& name, + const std::string& address, + const struct AddressParams *address_params); + +}; + +#endif + diff --git a/src/manager/WFFacilities.h b/src/manager/WFFacilities.h new file mode 100644 index 0000000..c51a455 --- /dev/null +++ b/src/manager/WFFacilities.h @@ -0,0 +1,91 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFFACILITIES_H_ +#define _WFFACILITIES_H_ + +#include +#include "WFFuture.h" +#include "WFTaskFactory.h" + +class WFFacilities +{ +public: + static void usleep(unsigned int microseconds); + static WFFuture async_usleep(unsigned int microseconds); + +public: + template + static void go(const std::string& queue_name, FUNC&& func, ARGS&&... args); + +public: + template + struct WFNetworkResult + { + RESP resp; + long long seqid; + int task_state; + int task_error; + }; + + template + static WFNetworkResult request(enum TransportType type, const std::string& url, REQ&& req, int retry_max); + + template + static WFFuture> async_request(enum TransportType type, const std::string& url, REQ&& req, int retry_max); + +public:// async fileIO + static WFFuture async_pread(int fd, void *buf, size_t count, off_t offset); + static WFFuture async_pwrite(int fd, const void *buf, size_t count, off_t offset); + static WFFuture async_preadv(int fd, const struct iovec *iov, int iovcnt, off_t offset); + static WFFuture async_pwritev(int fd, const struct iovec *iov, int iovcnt, off_t offset); + static WFFuture async_fsync(int fd); + static WFFuture async_fdatasync(int fd); + +public: + class WaitGroup + { + public: + WaitGroup(int n); + ~WaitGroup(); + + void done(); + void wait() const; + std::future_status wait(int timeout) const; + + private: + static void __wait_group_callback(WFCounterTask *task); + + std::atomic nleft; + WFCounterTask *task; + WFFuture future; + }; + +private: + static void __timer_future_callback(WFTimerTask *task); + static void __fio_future_callback(WFFileIOTask *task); + static void __fvio_future_callback(WFFileVIOTask *task); + static void __fsync_future_callback(WFFileSyncTask *task); +}; + +#include "WFFacilities.inl" + +#endif + diff --git a/src/manager/WFFacilities.inl b/src/manager/WFFacilities.inl new file mode 100644 index 0000000..ea3ab40 --- /dev/null +++ b/src/manager/WFFacilities.inl @@ -0,0 +1,232 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +inline void WFFacilities::usleep(unsigned int microseconds) +{ + async_usleep(microseconds).get(); +} + +inline WFFuture WFFacilities::async_usleep(unsigned int microseconds) +{ + auto *pr = new WFPromise(); + auto fr = pr->get_future(); + auto *task = WFTaskFactory::create_timer_task(microseconds, __timer_future_callback); + + task->user_data = pr; + task->start(); + return fr; +} + +template +void WFFacilities::go(const std::string& queue_name, FUNC&& func, ARGS&&... args) +{ + WFTaskFactory::create_go_task(queue_name, std::forward(func), std::forward(args)...)->start(); +} + +template +WFFacilities::WFNetworkResult WFFacilities::request(enum TransportType type, const std::string& url, REQ&& req, int retry_max) +{ + return async_request(type, url, std::forward(req), retry_max).get(); +} + +template +WFFuture> WFFacilities::async_request(enum TransportType type, const std::string& url, REQ&& req, int retry_max) +{ + ParsedURI uri; + auto *pr = new WFPromise>(); + auto fr = pr->get_future(); + auto *task = new WFComplexClientTask(retry_max, [pr](WFNetworkTask *task) { + WFNetworkResult res; + + res.seqid = task->get_task_seq(); + res.task_state = task->get_state(); + res.task_error = task->get_error(); + if (res.task_state == WFT_STATE_SUCCESS) + res.resp = std::move(*task->get_resp()); + + pr->set_value(std::move(res)); + delete pr; + }); + + URIParser::parse(url, uri); + task->init(std::move(uri)); + task->set_transport_type(type); + *task->get_req() = std::forward(req); + task->start(); + return fr; +} + +inline WFFuture WFFacilities::async_pread(int fd, void *buf, size_t count, off_t offset) +{ + auto *pr = new WFPromise(); + auto fr = pr->get_future(); + auto *task = WFTaskFactory::create_pread_task(fd, buf, count, offset, __fio_future_callback); + + task->user_data = pr; + task->start(); + return fr; +} + +inline WFFuture WFFacilities::async_pwrite(int fd, const void *buf, size_t count, off_t offset) +{ + auto *pr = new WFPromise(); + auto fr = pr->get_future(); + auto *task = WFTaskFactory::create_pwrite_task(fd, buf, count, offset, __fio_future_callback); + + task->user_data = pr; + task->start(); + return fr; +} + +inline WFFuture WFFacilities::async_preadv(int fd, const struct iovec *iov, int iovcnt, off_t offset) +{ + auto *pr = new WFPromise(); + auto fr = pr->get_future(); + auto *task = WFTaskFactory::create_preadv_task(fd, iov, iovcnt, offset, __fvio_future_callback); + + task->user_data = pr; + task->start(); + return fr; +} + +inline WFFuture WFFacilities::async_pwritev(int fd, const struct iovec *iov, int iovcnt, off_t offset) +{ + auto *pr = new WFPromise(); + auto fr = pr->get_future(); + auto *task = WFTaskFactory::create_pwritev_task(fd, iov, iovcnt, offset, __fvio_future_callback); + + task->user_data = pr; + task->start(); + return fr; +} + +inline WFFuture WFFacilities::async_fsync(int fd) +{ + auto *pr = new WFPromise(); + auto fr = pr->get_future(); + auto *task = WFTaskFactory::create_fsync_task(fd, __fsync_future_callback); + + task->user_data = pr; + task->start(); + return fr; +} + +inline WFFuture WFFacilities::async_fdatasync(int fd) +{ + auto *pr = new WFPromise(); + auto fr = pr->get_future(); + auto *task = WFTaskFactory::create_fdsync_task(fd, __fsync_future_callback); + + task->user_data = pr; + task->start(); + return fr; +} + +inline void WFFacilities::__timer_future_callback(WFTimerTask *task) +{ + auto *pr = static_cast *>(task->user_data); + + pr->set_value(); + delete pr; +} + +inline void WFFacilities::__fio_future_callback(WFFileIOTask *task) +{ + auto *pr = static_cast *>(task->user_data); + + pr->set_value(task->get_retval()); + delete pr; +} + +inline void WFFacilities::__fvio_future_callback(WFFileVIOTask *task) +{ + auto *pr = static_cast *>(task->user_data); + + pr->set_value(task->get_retval()); + delete pr; +} + +inline void WFFacilities::__fsync_future_callback(WFFileSyncTask *task) +{ + auto *pr = static_cast *>(task->user_data); + + pr->set_value(task->get_retval()); + delete pr; +} + +inline WFFacilities::WaitGroup::WaitGroup(int n) : nleft(n) +{ + if (n <= 0) + { + this->nleft = -1; + return; + } + + auto *pr = new WFPromise(); + + this->task = WFTaskFactory::create_counter_task(1, __wait_group_callback); + this->future = pr->get_future(); + this->task->user_data = pr; + this->task->start(); +} + +inline WFFacilities::WaitGroup::~WaitGroup() +{ + if (this->nleft > 0) + this->task->count(); +} + +inline void WFFacilities::WaitGroup::done() +{ + if (--this->nleft == 0) + { + this->task->count(); + } +} + +inline void WFFacilities::WaitGroup::wait() const +{ + if (this->nleft < 0) + return; + + this->future.wait(); +} + +inline std::future_status WFFacilities::WaitGroup::wait(int timeout) const +{ + if (this->nleft < 0) + return std::future_status::ready; + + if (timeout < 0) + { + this->future.wait(); + return std::future_status::ready; + } + + return this->future.wait_for(std::chrono::milliseconds(timeout)); +} + +inline void WFFacilities::WaitGroup::__wait_group_callback(WFCounterTask *task) +{ + auto *pr = static_cast *>(task->user_data); + + pr->set_value(); + delete pr; +} + diff --git a/src/manager/WFFuture.h b/src/manager/WFFuture.h new file mode 100644 index 0000000..23f0029 --- /dev/null +++ b/src/manager/WFFuture.h @@ -0,0 +1,161 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) +*/ + +#ifndef _WFFUTURE_H_ +#define _WFFUTURE_H_ + +#include +#include +#include +#include "CommScheduler.h" +#include "WFGlobal.h" + +template +class WFFuture +{ +public: + WFFuture(std::future&& fr) : + future(std::move(fr)) + { + } + + WFFuture() = default; + WFFuture(const WFFuture&) = delete; + WFFuture(WFFuture&& move) = default; + + WFFuture& operator=(const WFFuture&) = delete; + WFFuture& operator=(WFFuture&& move) = default; + + void wait() const; + + template + std::future_status wait_for(const std::chrono::duration& time_duration) const; + + template + std::future_status wait_until(const std::chrono::time_point& timeout_time) const; + + RES get() + { + this->wait(); + return this->future.get(); + } + + bool valid() const { return this->future.valid(); } + +private: + std::future future; +}; + +template +class WFPromise +{ +public: + WFPromise() = default; + WFPromise(const WFPromise& promise) = delete; + WFPromise(WFPromise&& move) = default; + WFPromise& operator=(const WFPromise& promise) = delete; + WFPromise& operator=(WFPromise&& move) = default; + + WFFuture get_future() + { + return WFFuture(this->promise.get_future()); + } + + void set_value(const RES& value) { this->promise.set_value(value); } + void set_value(RES&& value) { this->promise.set_value(std::move(value)); } + +private: + std::promise promise; +}; + +template +void WFFuture::wait() const +{ + if (this->future.wait_for(std::chrono::seconds(0)) != std::future_status::ready) + { + int cookie = WFGlobal::sync_operation_begin(); + this->future.wait(); + WFGlobal::sync_operation_end(cookie); + } +} + +template +template +std::future_status WFFuture::wait_for(const std::chrono::duration& time_duration) const +{ + std::future_status status = std::future_status::ready; + + if (this->future.wait_for(std::chrono::seconds(0)) != std::future_status::ready) + { + int cookie = WFGlobal::sync_operation_begin(); + status = this->future.wait_for(time_duration); + WFGlobal::sync_operation_end(cookie); + } + + return status; +} + +template +template +std::future_status WFFuture::wait_until(const std::chrono::time_point& timeout_time) const +{ + std::future_status status = std::future_status::ready; + + if (this->future.wait_for(std::chrono::seconds(0)) != std::future_status::ready) + { + int cookie = WFGlobal::sync_operation_begin(); + status = this->future.wait_until(timeout_time); + WFGlobal::sync_operation_end(cookie); + } + + return status; +} + +///// WFFuture template specialization +template<> +inline void WFFuture::get() +{ + this->wait(); + this->future.get(); +} + +template<> +class WFPromise +{ +public: + WFPromise() = default; + WFPromise(const WFPromise& promise) = delete; + WFPromise(WFPromise&& move) = default; + WFPromise& operator=(const WFPromise& promise) = delete; + WFPromise& operator=(WFPromise&& move) = default; + + WFFuture get_future() + { + return WFFuture(this->promise.get_future()); + } + + void set_value() { this->promise.set_value(); } +// void set_value(const RES& value) { this->promise.set_value(value); } +// void set_value(RES&& value) { this->promise.set_value(std::move(value)); } + +private: + std::promise promise; +}; + +#endif + diff --git a/src/manager/WFGlobal.cc b/src/manager/WFGlobal.cc new file mode 100644 index 0000000..dc006ed --- /dev/null +++ b/src/manager/WFGlobal.cc @@ -0,0 +1,973 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) + Liu Kai (liukaidx@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if OPENSSL_VERSION_NUMBER < 0x10100000L +# include +# include +# include +# include +#endif +#include "CommScheduler.h" +#include "Executor.h" +#include "WFResourcePool.h" +#include "WFTaskError.h" +#include "WFDnsClient.h" +#include "WFGlobal.h" +#include "URIParser.h" + +class __WFGlobal +{ +public: + static __WFGlobal *get_instance() + { + static __WFGlobal kInstance; + return &kInstance; + } + + const char *get_default_port(const std::string& scheme) + { + const auto it = static_scheme_port_.find(scheme); + + if (it != static_scheme_port_.end()) + return it->second; + + const char *port = NULL; + user_scheme_port_mutex_.lock(); + const auto it2 = user_scheme_port_.find(scheme); + + if (it2 != user_scheme_port_.end()) + port = it2->second.c_str(); + + user_scheme_port_mutex_.unlock(); + return port; + } + + void register_scheme_port(const std::string& scheme, unsigned short port) + { + user_scheme_port_mutex_.lock(); + user_scheme_port_[scheme] = std::to_string(port); + user_scheme_port_mutex_.unlock(); + } + + void sync_operation_begin() + { + bool inc; + + sync_mutex_.lock(); + inc = ++sync_count_ > sync_max_; + if (inc) + sync_max_ = sync_count_; + + sync_mutex_.unlock(); + if (inc) + WFGlobal::increase_handler_thread(); + } + + void sync_operation_end() + { + int dec = 0; + + sync_mutex_.lock(); + if (--sync_count_ < (sync_max_ + 1) / 2) + { + dec = sync_max_ - 2 * sync_count_; + sync_max_ -= dec; + } + + sync_mutex_.unlock(); + while (dec > 0) + { + WFGlobal::decrease_handler_thread(); + dec--; + } + } + +private: + __WFGlobal(); + +private: + std::unordered_map static_scheme_port_; + std::unordered_map user_scheme_port_; + std::mutex user_scheme_port_mutex_; + std::mutex sync_mutex_; + int sync_count_; + int sync_max_; +}; + +__WFGlobal::__WFGlobal() +{ + static_scheme_port_["dns"] = "53"; + static_scheme_port_["Dns"] = "53"; + static_scheme_port_["DNS"] = "53"; + + static_scheme_port_["dnss"] = "853"; + static_scheme_port_["Dnss"] = "853"; + static_scheme_port_["DNSs"] = "853"; + static_scheme_port_["DNSS"] = "853"; + + static_scheme_port_["http"] = "80"; + static_scheme_port_["Http"] = "80"; + static_scheme_port_["HTTP"] = "80"; + + static_scheme_port_["https"] = "443"; + static_scheme_port_["Https"] = "443"; + static_scheme_port_["HTTPs"] = "443"; + static_scheme_port_["HTTPS"] = "443"; + + static_scheme_port_["redis"] = "6379"; + static_scheme_port_["Redis"] = "6379"; + static_scheme_port_["REDIS"] = "6379"; + + static_scheme_port_["rediss"] = "6379"; + static_scheme_port_["Rediss"] = "6379"; + static_scheme_port_["REDISs"] = "6379"; + static_scheme_port_["REDISS"] = "6379"; + + static_scheme_port_["mysql"] = "3306"; + static_scheme_port_["Mysql"] = "3306"; + static_scheme_port_["MySql"] = "3306"; + static_scheme_port_["MySQL"] = "3306"; + static_scheme_port_["MYSQL"] = "3306"; + + static_scheme_port_["mysqls"] = "3306"; + static_scheme_port_["Mysqls"] = "3306"; + static_scheme_port_["MySqls"] = "3306"; + static_scheme_port_["MySQLs"] = "3306"; + static_scheme_port_["MYSQLs"] = "3306"; + static_scheme_port_["MYSQLS"] = "3306"; + + static_scheme_port_["kafka"] = "9092"; + static_scheme_port_["Kafka"] = "9092"; + static_scheme_port_["KAFKA"] = "9092"; + + static_scheme_port_["kafkas"] = "9093"; + static_scheme_port_["Kafkas"] = "9093"; + static_scheme_port_["KAFKAs"] = "9093"; + static_scheme_port_["KAFKAS"] = "9093"; + + sync_count_ = 0; + sync_max_ = 0; +} + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +static std::mutex *__ssl_mutex; + +static void ssl_locking_callback(int mode, int type, const char* file, int line) +{ + if (mode & CRYPTO_LOCK) + __ssl_mutex[type].lock(); + else if (mode & CRYPTO_UNLOCK) + __ssl_mutex[type].unlock(); +} +#endif + +class __SSLManager +{ +public: + static __SSLManager *get_instance() + { + static __SSLManager kInstance; + return &kInstance; + } + + SSL_CTX *get_ssl_client_ctx() { return ssl_client_ctx_; } + SSL_CTX *new_ssl_server_ctx() { return SSL_CTX_new(SSLv23_server_method()); } + +private: + __SSLManager() + { +#if OPENSSL_VERSION_NUMBER < 0x10100000L + __ssl_mutex = new std::mutex[CRYPTO_num_locks()]; + CRYPTO_set_locking_callback(ssl_locking_callback); + SSL_library_init(); + SSL_load_error_strings(); + //ERR_load_crypto_strings(); + //OpenSSL_add_all_algorithms(); +#endif + + ssl_client_ctx_ = SSL_CTX_new(SSLv23_client_method()); + if (ssl_client_ctx_ == NULL) + abort(); + } + + ~__SSLManager() + { + SSL_CTX_free(ssl_client_ctx_); + +#if OPENSSL_VERSION_NUMBER < 0x10100000L + //free ssl to avoid memory leak + FIPS_mode_set(0); + CRYPTO_set_locking_callback(NULL); +# ifdef CRYPTO_LOCK_ECDH + CRYPTO_THREADID_set_callback(NULL); +# else + CRYPTO_set_id_callback(NULL); +# endif + ENGINE_cleanup(); + CONF_modules_unload(1); + ERR_free_strings(); + EVP_cleanup(); +# ifdef CRYPTO_LOCK_ECDH + ERR_remove_thread_state(NULL); +# else + ERR_remove_state(0); +# endif + CRYPTO_cleanup_all_ex_data(); + sk_SSL_COMP_free(SSL_COMP_get_compression_methods()); + delete []__ssl_mutex; +#endif + } + +private: + SSL_CTX *ssl_client_ctx_; +}; + +class __FileIOService : public IOService +{ +public: + __FileIOService(CommScheduler *scheduler): + scheduler_(scheduler), + flag_(true) + {} + + int bind() + { + mutex_.lock(); + flag_ = false; + + int ret = scheduler_->io_bind(this); + + if (ret < 0) + flag_ = true; + + mutex_.unlock(); + return ret; + } + + void deinit() + { + std::unique_lock lock(mutex_); + while (!flag_) + cond_.wait(lock); + + lock.unlock(); + IOService::deinit(); + } + +private: + virtual void handle_unbound() + { + mutex_.lock(); + flag_ = true; + cond_.notify_one(); + mutex_.unlock(); + } + + virtual void handle_stop(int error) + { + scheduler_->io_unbind(this); + } + + CommScheduler *scheduler_; + std::mutex mutex_; + std::condition_variable cond_; + bool flag_; +}; + +class __ThreadDnsManager +{ +public: + static __ThreadDnsManager *get_instance() + { + static __ThreadDnsManager kInstance; + return &kInstance; + } + + ExecQueue *get_dns_queue() { return &dns_queue_; } + Executor *get_dns_executor() { return &dns_executor_; } + + __ThreadDnsManager() + { + int ret; + + ret = dns_queue_.init(); + if (ret < 0) + abort(); + + ret = dns_executor_.init(WFGlobal::get_global_settings()->dns_threads); + if (ret < 0) + abort(); + } + + ~__ThreadDnsManager() + { + dns_executor_.deinit(); + dns_queue_.deinit(); + } + +private: + ExecQueue dns_queue_; + Executor dns_executor_; +}; + +class __CommManager +{ +public: + static __CommManager *get_instance() + { + static __CommManager kInstance; + __CommManager::created_ = true; + return &kInstance; + } + + CommScheduler *get_scheduler() { return &scheduler_; } + IOService *get_io_service(); + static bool is_created() { return created_; } + +private: + __CommManager(): + fio_service_(NULL), + fio_flag_(false) + { + const auto *settings = WFGlobal::get_global_settings(); + if (scheduler_.init(settings->poller_threads, + settings->handler_threads) < 0) + abort(); + + signal(SIGPIPE, SIG_IGN); + } + + ~__CommManager() + { + // scheduler_.deinit() will triger fio_service to stop + scheduler_.deinit(); + if (fio_service_) + { + fio_service_->deinit(); + delete fio_service_; + } + } + +private: + CommScheduler scheduler_; + __FileIOService *fio_service_; + volatile bool fio_flag_; + std::mutex fio_mutex_; + +private: + static bool created_; +}; + +bool __CommManager::created_ = false; + +inline IOService *__CommManager::get_io_service() +{ + if (!fio_flag_) + { + fio_mutex_.lock(); + if (!fio_flag_) + { + int maxevents = WFGlobal::get_global_settings()->fio_max_events; + + fio_service_ = new __FileIOService(&scheduler_); + if (fio_service_->init(maxevents) < 0) + abort(); + + if (fio_service_->bind() < 0) + abort(); + + fio_flag_ = true; + } + + fio_mutex_.unlock(); + } + + return fio_service_; +} + +class __ExecManager +{ +protected: + using ExecQueueMap = std::unordered_map; + +public: + static __ExecManager *get_instance() + { + static __ExecManager kInstance; + return &kInstance; + } + + ExecQueue *get_exec_queue(const std::string& queue_name); + Executor *get_compute_executor() { return &compute_executor_; } + +private: + __ExecManager(): + rwlock_(PTHREAD_RWLOCK_INITIALIZER) + { + int compute_threads = WFGlobal::get_global_settings()->compute_threads; + + if (compute_threads < 0) + compute_threads = sysconf(_SC_NPROCESSORS_ONLN); + + if (compute_executor_.init(compute_threads) < 0) + abort(); + } + + ~__ExecManager() + { + compute_executor_.deinit(); + + for (auto& kv : queue_map_) + { + kv.second->deinit(); + delete kv.second; + } + } + +private: + pthread_rwlock_t rwlock_; + ExecQueueMap queue_map_; + Executor compute_executor_; +}; + +inline ExecQueue *__ExecManager::get_exec_queue(const std::string& queue_name) +{ + ExecQueue *queue = NULL; + ExecQueueMap::const_iterator iter; + + pthread_rwlock_rdlock(&rwlock_); + iter = queue_map_.find(queue_name); + if (iter != queue_map_.cend()) + queue = iter->second; + + pthread_rwlock_unlock(&rwlock_); + if (queue) + return queue; + + pthread_rwlock_wrlock(&rwlock_); + iter = queue_map_.find(queue_name); + if (iter == queue_map_.cend()) + { + queue = new ExecQueue(); + if (queue->init() >= 0) + queue_map_.emplace(queue_name, queue); + else + { + delete queue; + queue = NULL; + } + } + else + queue = iter->second; + + pthread_rwlock_unlock(&rwlock_); + return queue; +} + +static std::string __dns_server_url(const std::string& url, + const struct addrinfo *hints) +{ + std::string host; + ParsedURI uri; + struct addrinfo *res; + struct in6_addr buf; + + if (strncasecmp(url.c_str(), "dns://", 6) == 0 || + strncasecmp(url.c_str(), "dnss://", 7) == 0) + { + host = url; + } + else if (inet_pton(AF_INET6, url.c_str(), &buf) > 0) + host = "dns://[" + url + "]"; + else + host = "dns://" + url; + + if (URIParser::parse(host, uri) == 0 && uri.host && uri.host[0]) + { + if (getaddrinfo(uri.host, "53", hints, &res) == 0) + { + freeaddrinfo(res); + return host; + } + } + + return ""; +} + +static void __split_merge_str(const char *p, bool is_nameserver, + const struct addrinfo *hints, + std::string& result) +{ + const char *start; + + if (!isspace(*p)) + return; + + while (1) + { + while (isspace(*p)) + p++; + + start = p; + while (*p && *p != '#' && *p != ';' && !isspace(*p)) + p++; + + if (start == p) + break; + + std::string str(start, p); + if (is_nameserver) + str = __dns_server_url(str, hints); + + if (!str.empty()) + { + if (!result.empty()) + result.push_back(','); + + result.append(str); + } + } +} + +static inline const char *__try_options(const char *p, const char *q, + const char *r) +{ + size_t len = strlen(r); + if ((size_t)(q - p) >= len && strncmp(p, r, len) == 0) + return p + len; + return NULL; +} + +static void __set_options(const char *p, + int *ndots, int *attempts, bool *rotate) +{ + const char *start; + const char *opt; + + if (!isspace(*p)) + return; + + while (1) + { + while (isspace(*p)) + p++; + + start = p; + while (*p && *p != '#' && *p != ';' && !isspace(*p)) + p++; + + if (start == p) + break; + + if ((opt = __try_options(start, p, "ndots:")) != NULL) + *ndots = atoi(opt); + else if ((opt = __try_options(start, p, "attempts:")) != NULL) + *attempts = atoi(opt); + else if ((opt = __try_options(start, p, "rotate")) != NULL) + *rotate = true; + } +} + +static int __parse_resolv_conf(const char *path, + std::string& url, std::string& search_list, + int *ndots, int *attempts, bool *rotate) +{ + size_t bufsize = 0; + char *line = NULL; + FILE *fp; + int ret; + + fp = fopen(path, "r"); + if (!fp) + return -1; + + const struct WFGlobalSettings *settings = WFGlobal::get_global_settings(); + struct addrinfo hints = { + .ai_flags = AI_ADDRCONFIG | AI_NUMERICHOST | AI_NUMERICSERV, + .ai_family = settings->dns_server_params.address_family, + .ai_socktype = SOCK_STREAM, + }; + + while ((ret = getline(&line, &bufsize, fp)) > 0) + { + if (strncmp(line, "nameserver", 10) == 0) + __split_merge_str(line + 10, true, &hints, url); + else if (strncmp(line, "search", 6) == 0) + __split_merge_str(line + 6, false, &hints, search_list); + else if (strncmp(line, "options", 7) == 0) + __set_options(line + 7, ndots, attempts, rotate); + } + + ret = ferror(fp) ? -1 : 0; + free(line); + fclose(fp); + return ret; +} + +class __DnsClientManager +{ +public: + static __DnsClientManager *get_instance() + { + static __DnsClientManager kInstance; + return &kInstance; + } + +public: + WFDnsClient *get_dns_client() { return client_; } + WFResourcePool *get_dns_respool() { return &respool_; }; + +private: + __DnsClientManager() : respool_(WFGlobal::get_global_settings()-> + dns_server_params.max_connections) + { + const char *path = WFGlobal::get_global_settings()->resolv_conf_path; + + client_ = NULL; + if (path && path[0]) + { + int ndots = 1; + int attempts = 2; + bool rotate = false; + std::string url; + std::string search; + + __parse_resolv_conf(path, url, search, &ndots, &attempts, &rotate); + if (url.size() == 0) + url = "8.8.8.8"; + + client_ = new WFDnsClient; + if (client_->init(url, search, ndots, attempts, rotate) >= 0) + return; + + delete client_; + client_ = NULL; + } + } + + ~__DnsClientManager() + { + if (client_) + { + client_->deinit(); + delete client_; + } + } + + WFDnsClient *client_; + WFResourcePool respool_; +}; + +struct WFGlobalSettings WFGlobal::settings_ = GLOBAL_SETTINGS_DEFAULT; +RouteManager WFGlobal::route_manager_; +DnsCache WFGlobal::dns_cache_; +WFDnsResolver WFGlobal::dns_resolver_; +WFNameService WFGlobal::name_service_(&WFGlobal::dns_resolver_); + +bool WFGlobal::is_scheduler_created() +{ + return __CommManager::is_created(); +} + +CommScheduler *WFGlobal::get_scheduler() +{ + return __CommManager::get_instance()->get_scheduler(); +} + +SSL_CTX *WFGlobal::get_ssl_client_ctx() +{ + return __SSLManager::get_instance()->get_ssl_client_ctx(); +} + +SSL_CTX *WFGlobal::new_ssl_server_ctx() +{ + return __SSLManager::get_instance()->new_ssl_server_ctx(); +} + +ExecQueue *WFGlobal::get_exec_queue(const std::string& queue_name) +{ + return __ExecManager::get_instance()->get_exec_queue(queue_name); +} + +Executor *WFGlobal::get_compute_executor() +{ + return __ExecManager::get_instance()->get_compute_executor(); +} + +IOService *WFGlobal::get_io_service() +{ + return __CommManager::get_instance()->get_io_service(); +} + +ExecQueue *WFGlobal::get_dns_queue() +{ + return __ThreadDnsManager::get_instance()->get_dns_queue(); +} + +Executor *WFGlobal::get_dns_executor() +{ + return __ThreadDnsManager::get_instance()->get_dns_executor(); +} + +WFDnsClient *WFGlobal::get_dns_client() +{ + return __DnsClientManager::get_instance()->get_dns_client(); +} + +WFResourcePool *WFGlobal::get_dns_respool() +{ + return __DnsClientManager::get_instance()->get_dns_respool(); +} + +const char *WFGlobal::get_default_port(const std::string& scheme) +{ + return __WFGlobal::get_instance()->get_default_port(scheme); +} + +void WFGlobal::register_scheme_port(const std::string& scheme, + unsigned short port) +{ + __WFGlobal::get_instance()->register_scheme_port(scheme, port); +} + +int WFGlobal::sync_operation_begin() +{ + if (WFGlobal::is_scheduler_created() && + WFGlobal::get_scheduler()->is_handler_thread()) + { + __WFGlobal::get_instance()->sync_operation_begin(); + return 1; + } + + return 0; +} + +void WFGlobal::sync_operation_end(int cookie) +{ + if (cookie) + __WFGlobal::get_instance()->sync_operation_end(); +} + +static inline const char *__get_ssl_error_string(int error) +{ + switch (error) + { + case SSL_ERROR_NONE: + return "SSL Error None"; + + case SSL_ERROR_ZERO_RETURN: + return "SSL Error Zero Return"; + + case SSL_ERROR_WANT_READ: + return "SSL Error Want Read"; + + case SSL_ERROR_WANT_WRITE: + return "SSL Error Want Write"; + + case SSL_ERROR_WANT_CONNECT: + return "SSL Error Want Connect"; + + case SSL_ERROR_WANT_ACCEPT: + return "SSL Error Want Accept"; + + case SSL_ERROR_WANT_X509_LOOKUP: + return "SSL Error Want X509 Lookup"; + +#ifdef SSL_ERROR_WANT_ASYNC + case SSL_ERROR_WANT_ASYNC: + return "SSL Error Want Async"; +#endif + +#ifdef SSL_ERROR_WANT_ASYNC_JOB + case SSL_ERROR_WANT_ASYNC_JOB: + return "SSL Error Want Async Job"; +#endif + +#ifdef SSL_ERROR_WANT_CLIENT_HELLO_CB + case SSL_ERROR_WANT_CLIENT_HELLO_CB: + return "SSL Error Want Client Hello CB"; +#endif + + case SSL_ERROR_SYSCALL: + return "SSL System Error"; + + case SSL_ERROR_SSL: + return "SSL Error SSL"; + + default: + break; + } + + return "Unknown"; +} + +static inline const char *__get_task_error_string(int error) +{ + switch (error) + { + case WFT_ERR_URI_PARSE_FAILED: + return "URI Parse Failed"; + + case WFT_ERR_URI_SCHEME_INVALID: + return "URI Scheme Invalid"; + + case WFT_ERR_URI_PORT_INVALID: + return "URI Port Invalid"; + + case WFT_ERR_UPSTREAM_UNAVAILABLE: + return "Upstream Unavailable"; + + case WFT_ERR_HTTP_BAD_REDIRECT_HEADER: + return "Http Bad Redirect Header"; + + case WFT_ERR_HTTP_PROXY_CONNECT_FAILED: + return "Http Proxy Connect Failed"; + + case WFT_ERR_REDIS_ACCESS_DENIED: + return "Redis Access Denied"; + + case WFT_ERR_REDIS_COMMAND_DISALLOWED: + return "Redis Command Disallowed"; + + case WFT_ERR_MYSQL_HOST_NOT_ALLOWED: + return "MySQL Host Not Allowed"; + + case WFT_ERR_MYSQL_ACCESS_DENIED: + return "MySQL Access Denied"; + + case WFT_ERR_MYSQL_INVALID_CHARACTER_SET: + return "MySQL Invalid Character Set"; + + case WFT_ERR_MYSQL_COMMAND_DISALLOWED: + return "MySQL Command Disallowed"; + + case WFT_ERR_MYSQL_QUERY_NOT_SET: + return "MySQL Query Not Set"; + + case WFT_ERR_MYSQL_SSL_NOT_SUPPORTED: + return "MySQL SSL Not Supported"; + + case WFT_ERR_KAFKA_PARSE_RESPONSE_FAILED: + return "Kafka parse response failed"; + + case WFT_ERR_KAFKA_PRODUCE_FAILED: + return "Kafka produce api failed"; + + case WFT_ERR_KAFKA_FETCH_FAILED: + return "Kafka fetch api failed"; + + case WFT_ERR_KAFKA_CGROUP_FAILED: + return "Kafka cgroup failed"; + + case WFT_ERR_KAFKA_COMMIT_FAILED: + return "Kafka commit api failed"; + + case WFT_ERR_KAFKA_META_FAILED: + return "Kafka meta api failed"; + + case WFT_ERR_KAFKA_LEAVEGROUP_FAILED: + return "Kafka leavegroup failed"; + + case WFT_ERR_KAFKA_API_UNKNOWN: + return "Kafka api type unknown"; + + case WFT_ERR_KAFKA_VERSION_DISALLOWED: + return "Kafka broker version not supported"; + + case WFT_ERR_KAFKA_SASL_DISALLOWED: + return "Kafka sasl disallowed"; + + case WFT_ERR_KAFKA_ARRANGE_FAILED: + return "Kafka arrange failed"; + + case WFT_ERR_KAFKA_LIST_OFFSETS_FAILED: + return "Kafka list offsets failed"; + + case WFT_ERR_KAFKA_CGROUP_ASSIGN_FAILED: + return "Kafka cgroup assign failed"; + + case WFT_ERR_CONSUL_API_UNKNOWN: + return "Consul api type unknown"; + + case WFT_ERR_CONSUL_CHECK_RESPONSE_FAILED: + return "Consul check response failed"; + + default: + break; + } + + return "Unknown"; +} + +const char *WFGlobal::get_error_string(int state, int error) +{ + switch (state) + { + case WFT_STATE_SUCCESS: + return "Success"; + + case WFT_STATE_TOREPLY: + return "To Reply"; + + case WFT_STATE_NOREPLY: + return "No Reply"; + + case WFT_STATE_SYS_ERROR: + return strerror(error); + + case WFT_STATE_SSL_ERROR: + return __get_ssl_error_string(error); + + case WFT_STATE_DNS_ERROR: + return gai_strerror(error); + + case WFT_STATE_TASK_ERROR: + return __get_task_error_string(error); + + case WFT_STATE_ABORTED: + return "Aborted"; + + case WFT_STATE_UNDEFINED: + return "Undefined"; + + default: + break; + } + + return "Unknown"; +} + +void WORKFLOW_library_init(const struct WFGlobalSettings *settings) +{ + WFGlobal::set_global_settings(settings); +} + diff --git a/src/manager/WFGlobal.h b/src/manager/WFGlobal.h new file mode 100644 index 0000000..8e8b4f5 --- /dev/null +++ b/src/manager/WFGlobal.h @@ -0,0 +1,194 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _WFGLOBAL_H_ +#define _WFGLOBAL_H_ + +#if __cplusplus < 201100 +#error CPLUSPLUS VERSION required at least C++11. Please use "-std=c++11". +#include +#endif + +#include +#include +#include "CommScheduler.h" +#include "DnsCache.h" +#include "RouteManager.h" +#include "Executor.h" +#include "EndpointParams.h" +#include "WFResourcePool.h" +#include "WFNameService.h" +#include "WFDnsResolver.h" + +/** + * @file WFGlobal.h + * @brief Workflow Global Settings & Workflow Global APIs + */ + +/** + * @brief Workflow Library Global Setting + * @details + * If you want set different settings with default, please call WORKFLOW_library_init at the beginning of the process +*/ +struct WFGlobalSettings +{ + struct EndpointParams endpoint_params; + struct EndpointParams dns_server_params; + unsigned int dns_ttl_default; ///< in seconds, DNS TTL when network request success + unsigned int dns_ttl_min; ///< in seconds, DNS TTL when network request fail + int dns_threads; + int poller_threads; + int handler_threads; + int compute_threads; ///< auto-set by system CPU number if value<0 + int fio_max_events; + const char *resolv_conf_path; + const char *hosts_path; +}; + +/** + * @brief Default Workflow Library Global Settings + */ +static constexpr struct WFGlobalSettings GLOBAL_SETTINGS_DEFAULT = +{ + .endpoint_params = ENDPOINT_PARAMS_DEFAULT, + .dns_server_params = ENDPOINT_PARAMS_DEFAULT, + .dns_ttl_default = 3600, + .dns_ttl_min = 60, + .dns_threads = 4, + .poller_threads = 4, + .handler_threads = 20, + .compute_threads = -1, + .fio_max_events = 4096, + .resolv_conf_path = "/etc/resolv.conf", + .hosts_path = "/etc/hosts", +}; + +/** + * @brief Reset Workflow Library Global Setting + * @param[in] settings custom settings pointer +*/ +extern void WORKFLOW_library_init(const struct WFGlobalSettings *settings); + +/** + * @brief Workflow Global Management Class + * @details Workflow Global APIs + */ +class WFGlobal +{ +public: + /** + * @brief register default port for one scheme string + * @param[in] scheme scheme string + * @param[in] port default port value + * @warning No effect when scheme is "http"/"https"/"redis"/"rediss"/"mysql"/"kafka" + */ + static void register_scheme_port(const std::string& scheme, + unsigned short port); + /** + * @brief get default port string for one scheme string + * @param[in] scheme scheme string + * @return port string const pointer + * @retval NULL fail, scheme not found + * @retval not NULL success + */ + static const char *get_default_port(const std::string& scheme); + /** + * @brief get current global settings + * @return current global settings const pointer + * @note returnval never NULL + */ + static const struct WFGlobalSettings *get_global_settings() + { + return &settings_; + } + + static void set_global_settings(const struct WFGlobalSettings *settings) + { + settings_ = *settings; + } + + static const char *get_error_string(int state, int error); + + static bool increase_handler_thread() + { + return WFGlobal::get_scheduler()->increase_handler_thread() == 0; + } + + static bool decrease_handler_thread() + { + return WFGlobal::get_scheduler()->decrease_handler_thread() == 0; + } + + static bool increase_compute_thread() + { + return WFGlobal::get_compute_executor()->increase_thread() == 0; + } + + static bool decrease_compute_thread() + { + return WFGlobal::get_compute_executor()->decrease_thread() == 0; + } + + // Internal usage only +public: + static bool is_scheduler_created(); + static class CommScheduler *get_scheduler(); + static SSL_CTX *get_ssl_client_ctx(); + static SSL_CTX *new_ssl_server_ctx(); + static class ExecQueue *get_exec_queue(const std::string& queue_name); + static class Executor *get_compute_executor(); + static class IOService *get_io_service(); + static class ExecQueue *get_dns_queue(); + static class Executor *get_dns_executor(); + static class WFDnsClient *get_dns_client(); + static class WFResourcePool *get_dns_respool(); + + static class RouteManager *get_route_manager() + { + return &route_manager_; + } + + static class DnsCache *get_dns_cache() + { + return &dns_cache_; + } + + static class WFDnsResolver *get_dns_resolver() + { + return &dns_resolver_; + } + + static class WFNameService *get_name_service() + { + return &name_service_; + } + +public: + static int sync_operation_begin(); + static void sync_operation_end(int cookie); + +private: + static struct WFGlobalSettings settings_; + static RouteManager route_manager_; + static DnsCache dns_cache_; + static WFDnsResolver dns_resolver_; + static WFNameService name_service_; +}; + +#endif + diff --git a/src/manager/xmake.lua b/src/manager/xmake.lua new file mode 100644 index 0000000..a305562 --- /dev/null +++ b/src/manager/xmake.lua @@ -0,0 +1,6 @@ +target("manager") + add_files("*.cc") + set_kind("object") + if not has_config("upstream") then + remove_files("UpstreamManager.cc") + end diff --git a/src/nameservice/CMakeLists.txt b/src/nameservice/CMakeLists.txt new file mode 100644 index 0000000..e915c89 --- /dev/null +++ b/src/nameservice/CMakeLists.txt @@ -0,0 +1,17 @@ +cmake_minimum_required(VERSION 3.6) +project(nameservice) + +set(SRC + WFNameService.cc + WFDnsResolver.cc +) + +if (NOT UPSTREAM STREQUAL "n") + set(SRC + ${SRC} + WFServiceGovernance.cc + UpstreamPolicies.cc + ) +endif () + +add_library(${PROJECT_NAME} OBJECT ${SRC}) diff --git a/src/nameservice/UpstreamPolicies.cc b/src/nameservice/UpstreamPolicies.cc new file mode 100644 index 0000000..6d41632 --- /dev/null +++ b/src/nameservice/UpstreamPolicies.cc @@ -0,0 +1,729 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) + Wang Zhulei (wangzhulei@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include "rbtree.h" +#include "URIParser.h" +#include "UpstreamPolicies.h" + +class EndpointGroup +{ +public: + EndpointGroup(int group_id, UPSGroupPolicy *policy) : + mutex(PTHREAD_MUTEX_INITIALIZER), + gen(random()) + { + this->id = group_id; + this->policy = policy; + this->nalives = 0; + this->weight = 0; + } + + EndpointAddress *get_one(WFNSTracing *tracing); + EndpointAddress *get_one_backup(WFNSTracing *tracing); + +public: + int id; + UPSGroupPolicy *policy; + struct rb_node rb; + pthread_mutex_t mutex; + std::mt19937 gen; + std::vector mains; + std::vector backups; + std::atomic nalives; + int weight; +}; + +UPSAddrParams::UPSAddrParams(const struct AddressParams *params, + const std::string& address) : + PolicyAddrParams(params) +{ + this->weight = params->weight; + this->server_type = params->server_type; + this->group_id = params->group_id; + + if (this->group_id < 0) + this->group_id = -1; + + if (this->weight == 0) + this->weight = 1; +} + +void UPSGroupPolicy::get_main_address(std::vector& addr_list) +{ + UPSAddrParams *params; + pthread_rwlock_rdlock(&this->rwlock); + + for (const EndpointAddress *server : this->servers) + { + params = static_cast(server->params); + if (params->server_type == 0) + addr_list.push_back(server->address); + } + + pthread_rwlock_unlock(&this->rwlock); +} + +UPSGroupPolicy::UPSGroupPolicy() +{ + this->group_map.rb_node = NULL; + this->default_group = new EndpointGroup(-1, this); + rb_link_node(&this->default_group->rb, NULL, &this->group_map.rb_node); + rb_insert_color(&this->default_group->rb, &this->group_map); +} + +UPSGroupPolicy::~UPSGroupPolicy() +{ + EndpointGroup *group; + + while (this->group_map.rb_node) + { + group = rb_entry(this->group_map.rb_node, EndpointGroup, rb); + rb_erase(this->group_map.rb_node, &this->group_map); + delete group; + } +} + +inline bool UPSGroupPolicy::is_alive(const EndpointAddress *addr) const +{ + UPSAddrParams *params = static_cast(addr->params); + return ((params->group_id < 0 && + addr->fail_count < addr->params->max_fails) || + (params->group_id >= 0 && + params->group->nalives > 0)); +} + +void UPSGroupPolicy::recover_one_server(const EndpointAddress *addr) +{ + this->nalives++; + UPSAddrParams *params = static_cast(addr->params); + params->group->nalives++; +} + +void UPSGroupPolicy::fuse_one_server(const EndpointAddress *addr) +{ + this->nalives--; + UPSAddrParams *params = static_cast(addr->params); + params->group->nalives--; +} + +void UPSGroupPolicy::add_server(const std::string& address, + const AddressParams *params) +{ + EndpointAddress *addr = new EndpointAddress(address, + new UPSAddrParams(params, address)); + + pthread_rwlock_wrlock(&this->rwlock); + this->add_server_locked(addr); + pthread_rwlock_unlock(&this->rwlock); +} + +int UPSGroupPolicy::replace_server(const std::string& address, + const AddressParams *params) +{ + int ret; + EndpointAddress *addr = new EndpointAddress(address, + new UPSAddrParams(params, address)); + + pthread_rwlock_wrlock(&this->rwlock); + ret = this->remove_server_locked(address); + this->add_server_locked(addr); + pthread_rwlock_unlock(&this->rwlock); + return ret; +} + +bool UPSGroupPolicy::select(const ParsedURI& uri, WFNSTracing *tracing, + EndpointAddress **addr) +{ + pthread_rwlock_rdlock(&this->rwlock); + unsigned int n = (unsigned int)this->servers.size(); + + if (n == 0) + { + pthread_rwlock_unlock(&this->rwlock); + return false; + } + + this->check_breaker(); + + // select_addr == NULL will happen only in consistent_hash + EndpointAddress *select_addr = this->first_strategy(uri, tracing); + + if (!select_addr || select_addr->fail_count >= select_addr->params->max_fails) + { + if (select_addr) + select_addr = this->check_and_get(select_addr, true, tracing); + + if (!select_addr && this->try_another) + select_addr = this->another_strategy(uri, tracing); + } + + if (!select_addr) + select_addr = this->default_group->get_one_backup(tracing); + + if (select_addr) + { + *addr = select_addr; + ++select_addr->ref; + } + + pthread_rwlock_unlock(&this->rwlock); + return !!select_addr; +} + +/* + * addr_failed true: return an available one. If not exists, return NULL. + * false: means addr maybe group-alive. + * If addr is not available, get one from addr->group. + */ +EndpointAddress *UPSGroupPolicy::check_and_get(EndpointAddress *addr, + bool addr_failed, + WFNSTracing *tracing) +{ + UPSAddrParams *params = static_cast(addr->params); + + if (addr_failed) // means fail_count >= max_fails + { + if (params->group_id == -1) + return NULL; + + return params->group->get_one(tracing); + } + + if (addr && addr->fail_count >= addr->params->max_fails && + params->group_id >= 0) + { + EndpointAddress *tmp = params->group->get_one(tracing); + if (tmp) + addr = tmp; + } + + return addr; +} + +EndpointAddress *EndpointGroup::get_one(WFNSTracing *tracing) +{ + if (this->nalives == 0) + return NULL; + + EndpointAddress *server; + EndpointAddress *addr = NULL; + pthread_mutex_lock(&this->mutex); + + std::shuffle(this->mains.begin(), this->mains.end(), this->gen); + for (size_t i = 0; i < this->mains.size(); i++) + { + server = this->mains[i]; + if (server->fail_count < server->params->max_fails && + WFServiceGovernance::in_select_history(tracing, server) == false) + { + addr = server; + break; + } + } + + if (!addr) + { + std::shuffle(this->backups.begin(), this->backups.end(), this->gen); + for (size_t i = 0; i < this->backups.size(); i++) + { + server = this->backups[i]; + if (server->fail_count < server->params->max_fails && + WFServiceGovernance::in_select_history(tracing, server) == false) + { + addr = server; + break; + } + } + } + + pthread_mutex_unlock(&this->mutex); + return addr; +} + +EndpointAddress *EndpointGroup::get_one_backup(WFNSTracing *tracing) +{ + if (this->nalives == 0) + return NULL; + + EndpointAddress *server; + EndpointAddress *addr = NULL; + + pthread_mutex_lock(&this->mutex); + + std::shuffle(this->backups.begin(), this->backups.end(), this->gen); + for (size_t i = 0; i < this->backups.size(); i++) + { + server = this->backups[i]; + if (server->fail_count < server->params->max_fails && + WFServiceGovernance::in_select_history(tracing, server) == false) + { + addr = server; + break; + } + } + + pthread_mutex_unlock(&this->mutex); + return addr; +} + +void UPSGroupPolicy::add_server_locked(EndpointAddress *addr) +{ + UPSAddrParams *params = static_cast(addr->params); + int group_id = params->group_id; + rb_node **p = &this->group_map.rb_node; + rb_node *parent = NULL; + EndpointGroup *group; + + this->server_map[addr->address].push_back(addr); + + if (params->server_type == 0) + this->servers.push_back(addr); + + while (*p) + { + parent = *p; + group = rb_entry(*p, EndpointGroup, rb); + + if (group_id < group->id) + p = &(*p)->rb_left; + else if (group_id > group->id) + p = &(*p)->rb_right; + else + break; + } + + if (*p == NULL) + { + group = new EndpointGroup(group_id, this); + rb_link_node(&group->rb, parent, p); + rb_insert_color(&group->rb, &this->group_map); + } + + pthread_mutex_lock(&group->mutex); + params->group = group; + this->recover_one_server(addr); + if (params->server_type == 0) + { + group->mains.push_back(addr); + group->weight += params->weight; + } + else + group->backups.push_back(addr); + pthread_mutex_unlock(&group->mutex); +} + +int UPSGroupPolicy::remove_server_locked(const std::string& address) +{ + const auto map_it = this->server_map.find(address); + size_t n = this->servers.size(); + size_t new_n = 0; + int ret = 0; + + for (size_t i = 0; i < n; i++) + { + if (this->servers[i]->address != address) + this->servers[new_n++] = this->servers[i]; + } + + this->servers.resize(new_n); + + if (map_it != this->server_map.cend()) + { + for (EndpointAddress *addr : map_it->second) + { + UPSAddrParams *params = static_cast(addr->params); + EndpointGroup *group = params->group; + std::vector *vec; + + if (params->server_type == 0) + vec = &group->mains; + else + vec = &group->backups; + + pthread_mutex_lock(&group->mutex); + if (params->server_type == 0) + group->weight -= params->weight; + + for (auto it = vec->begin(); it != vec->end(); ++it) + { + if (*it == addr) + { + vec->erase(it); + break; + } + } + + if (--addr->ref == 0) + { + this->pre_delete_server(addr); + delete addr; + } + + pthread_mutex_unlock(&group->mutex); + ret++; + } + + this->server_map.erase(map_it); + } + + return ret; +} + +EndpointAddress *UPSGroupPolicy::consistent_hash_with_group(unsigned int hash, + WFNSTracing *tracing) +{ + if (this->nalives == 0) + return NULL; + + std::map::iterator it; + it = this->addr_hash.lower_bound(hash); + + if (it == this->addr_hash.end()) + it = this->addr_hash.begin(); + + while (!this->is_alive(it->second)) + { + it++; + if (it == this->addr_hash.end()) + it = this->addr_hash.begin(); + } + + return this->check_and_get(it->second, false, tracing); +} + +void UPSGroupPolicy::hash_map_add_addr(EndpointAddress *addr) +{ + UPSAddrParams *params = static_cast(addr->params); + + if (params->server_type == 0) + { + static std::hash std_hash; + unsigned int hash_value; + size_t ip_count = this->server_map[addr->address].size(); + + for (int i = 0; i < VIRTUAL_GROUP_SIZE * params->weight; i++) + { + hash_value = std_hash(addr->address + "|v" + std::to_string(i) + + "|n" + std::to_string(ip_count)); + this->addr_hash.insert(std::make_pair(hash_value, addr)); + } + } +} + +void UPSGroupPolicy::hash_map_remove_addr(const std::string& address) +{ + std::map::iterator it; + + for (it = this->addr_hash.begin(); it != this->addr_hash.end();) + { + if (it->second->address == address) + this->addr_hash.erase(it++); + else + it++; + } +} + +int UPSRoundRobinPolicy::remove_server_locked(const std::string& address) +{ + if (servers.size() != 0) + { + size_t cur_idx = this->cur_idx % servers.size(); + + for (size_t i = 0; i < cur_idx; i++) + { + if (this->servers[i]->address == address) + this->cur_idx--; + } + } + + return UPSGroupPolicy::remove_server_locked(address); +} + +EndpointAddress *UPSRoundRobinPolicy::first_strategy(const ParsedURI& uri, + WFNSTracing *tracing) +{ + return this->servers[this->cur_idx++ % this->servers.size()]; +} + +EndpointAddress *UPSRoundRobinPolicy::another_strategy(const ParsedURI& uri, + WFNSTracing *tracing) +{ + EndpointAddress *addr = this->servers[this->cur_idx++ % this->servers.size()]; + return this->check_and_get(addr, false, tracing); +} + +void UPSWeightedRandomPolicy::add_server_locked(EndpointAddress *addr) +{ + UPSAddrParams *params = static_cast(addr->params); + + UPSGroupPolicy::add_server_locked(addr); + if (params->server_type == 0) + this->total_weight += params->weight; +} + +int UPSWeightedRandomPolicy::remove_server_locked(const std::string& address) +{ + UPSAddrParams *params; + const auto map_it = this->server_map.find(address); + + if (map_it != this->server_map.cend()) + { + for (EndpointAddress *addr : map_it->second) + { + params = static_cast(addr->params); + if (params->server_type == 0) + this->total_weight -= params->weight; + } + } + + return UPSGroupPolicy::remove_server_locked(address); +} + +int UPSWeightedRandomPolicy::select_history_weight(WFNSTracing *tracing) +{ + struct TracingData *tracing_data = (struct TracingData *)tracing->data; + + if (!tracing_data) + return 0; + + int ret = 0; + + for (EndpointAddress *server : tracing_data->history) + ret += ((UPSAddrParams *)server->params)->weight; + + return ret; +} + +EndpointAddress *UPSWeightedRandomPolicy::first_strategy(const ParsedURI& uri, + WFNSTracing *tracing) +{ + int x = 0; + int s = 0; + size_t idx; + UPSAddrParams *params; + int temp_weight = this->total_weight; + temp_weight -= UPSWeightedRandomPolicy::select_history_weight(tracing); + + if (temp_weight > 0) + x = rand() % temp_weight; + + for (idx = 0; idx < this->servers.size(); idx++) + { + if (WFServiceGovernance::in_select_history(tracing, this->servers[idx])) + continue; + + params = static_cast(this->servers[idx]->params); + s += params->weight; + if (s > x) + break; + } + if (idx == this->servers.size()) + idx--; + + return this->servers[idx]; +} + +EndpointAddress *UPSWeightedRandomPolicy::another_strategy(const ParsedURI& uri, + WFNSTracing *tracing) +{ + /* When all servers are down, recover all servers if any server + * reaches fusing timeout. */ + if (this->available_weight == 0) + this->try_clear_breaker(); + + int temp_weight = this->available_weight; + if (temp_weight == 0) + return NULL; + + UPSAddrParams *params; + EndpointAddress *addr = NULL; + int x = rand() % temp_weight; + int s = 0; + + for (EndpointAddress *server : this->servers) + { + if (this->is_alive(server)) + { + addr = server; + params = static_cast(server->params); + s += params->weight; + if (s > x) + break; + } + } + + if (!addr) + return NULL; + + return this->check_and_get(addr, false, tracing); +} + +void UPSWeightedRandomPolicy::recover_one_server(const EndpointAddress *addr) +{ + UPSAddrParams *params = static_cast(addr->params); + + this->nalives++; + if (params->group->nalives++ == 0 && params->group->id > 0) + this->available_weight += params->group->weight; + + if (params->group_id < 0 && params->server_type == 0) + this->available_weight += params->weight; +} + +void UPSWeightedRandomPolicy::fuse_one_server(const EndpointAddress *addr) +{ + UPSAddrParams *params = static_cast(addr->params); + + this->nalives--; + if (--params->group->nalives == 0 && params->group->id > 0) + this->available_weight -= params->group->weight; + + if (params->group_id < 0 && params->server_type == 0) + this->available_weight -= params->weight; +} + +EndpointAddress *UPSVNSWRRPolicy::first_strategy(const ParsedURI& uri, + WFNSTracing *tracing) +{ + int idx = this->cur_idx.fetch_add(1); + int pos = 0; + for (int i = 0; i < this->total_weight; i++, idx++) + { + pos = this->pre_generated_vec[idx % this->pre_generated_vec.size()]; + if (WFServiceGovernance::in_select_history(tracing, this->servers[pos])) + continue; + + break; + } + return this->servers[pos]; +} + +void UPSVNSWRRPolicy::init_virtual_nodes() +{ + UPSAddrParams *params; + size_t start_pos = this->pre_generated_vec.size(); + size_t end_pos = this->total_weight; + this->pre_generated_vec.resize(end_pos); + + for (size_t i = start_pos; i < end_pos; i++) + { + for (size_t j = 0; j < this->servers.size(); j++) + { + const EndpointAddress *server = this->servers[j]; + params = static_cast(server->params); + this->current_weight_vec[j] += params->weight; + } + std::vector::iterator biggest = std::max_element(this->current_weight_vec.begin(), + this->current_weight_vec.end()); + this->pre_generated_vec[i] = std::distance(this->current_weight_vec.begin(), biggest); + this->current_weight_vec[this->pre_generated_vec[i]] -= this->total_weight; + } +} + +void UPSVNSWRRPolicy::init() +{ + if (this->total_weight <= 0) + return; + + this->pre_generated_vec.clear(); + this->cur_idx = rand() % this->total_weight; + std::vector t(this->servers.size(), 0); + this->current_weight_vec.swap(t); + this->init_virtual_nodes(); +} + +void UPSVNSWRRPolicy::add_server_locked(EndpointAddress *addr) +{ + UPSWeightedRandomPolicy::add_server_locked(addr); + init(); +} + +int UPSVNSWRRPolicy::remove_server_locked(const std::string& address) +{ + int ret = UPSWeightedRandomPolicy::remove_server_locked(address); + init(); + return ret; +} + +EndpointAddress *UPSConsistentHashPolicy::first_strategy(const ParsedURI& uri, + WFNSTracing *tracing) +{ + unsigned int hash_value = this->consistent_hash( + uri.path ? uri.path : "", + uri.query ? uri.query : "", + uri.fragment ? uri.fragment : ""); + return this->consistent_hash_with_group(hash_value, tracing); +} + +void UPSConsistentHashPolicy::add_server_locked(EndpointAddress *addr) +{ + UPSGroupPolicy::add_server_locked(addr); + this->hash_map_add_addr(addr); +} + +int UPSConsistentHashPolicy::remove_server_locked(const std::string& address) +{ + this->hash_map_remove_addr(address); + + return UPSGroupPolicy::remove_server_locked(address); +} + +EndpointAddress *UPSManualPolicy::first_strategy(const ParsedURI& uri, + WFNSTracing *tracing) +{ + unsigned int idx = this->manual_select(uri.path ? uri.path : "", + uri.query ? uri.query : "", + uri.fragment ? uri.fragment : ""); + + if (idx >= this->servers.size()) + idx %= this->servers.size(); + + return this->servers[idx]; +} + +EndpointAddress *UPSManualPolicy::another_strategy(const ParsedURI& uri, + WFNSTracing *tracing) +{ + unsigned int hash_value = this->another_select( + uri.path ? uri.path : "", + uri.query ? uri.query : "", + uri.fragment ? uri.fragment : ""); + return this->consistent_hash_with_group(hash_value, tracing); +} + +void UPSManualPolicy::add_server_locked(EndpointAddress *addr) +{ + UPSGroupPolicy::add_server_locked(addr); + + if (this->try_another) + this->hash_map_add_addr(addr); +} + +int UPSManualPolicy::remove_server_locked(const std::string& address) +{ + if (this->try_another) + this->hash_map_remove_addr(address); + + return UPSGroupPolicy::remove_server_locked(address); +} + diff --git a/src/nameservice/UpstreamPolicies.h b/src/nameservice/UpstreamPolicies.h new file mode 100644 index 0000000..07e2909 --- /dev/null +++ b/src/nameservice/UpstreamPolicies.h @@ -0,0 +1,207 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) + Wang Zhulei (wangzhulei@sogou-inc.com) +*/ + +#ifndef _UPSTREAMPOLICIES_H_ +#define _UPSTREAMPOLICIES_H_ + +#include +#include +#include +#include +#include +#include "URIParser.h" +#include "EndpointParams.h" +#include "WFNameService.h" +#include "WFServiceGovernance.h" + +using upstream_route_t = std::function; + +class EndpointGroup; +class UPSGroupPolicy; + +class UPSAddrParams : public PolicyAddrParams +{ +public: + unsigned short weight; + short server_type; + int group_id; + EndpointGroup *group; + + UPSAddrParams(const struct AddressParams *params, + const std::string& address); +}; + +class UPSGroupPolicy : public WFServiceGovernance +{ +public: + UPSGroupPolicy(); + virtual ~UPSGroupPolicy(); + +public: + virtual bool select(const ParsedURI& uri, WFNSTracing *tracing, + EndpointAddress **addr); + virtual void add_server(const std::string& address, + const struct AddressParams *params); + virtual int replace_server(const std::string& address, + const struct AddressParams *params); + void get_main_address(std::vector& addr_list); + +protected: + struct rb_root group_map; + EndpointGroup *default_group; + +private: + virtual void recover_one_server(const EndpointAddress *addr); + virtual void fuse_one_server(const EndpointAddress *addr); + +protected: + virtual void add_server_locked(EndpointAddress *addr); + virtual int remove_server_locked(const std::string& address); + + EndpointAddress *check_and_get(EndpointAddress *addr, bool addr_failed, + WFNSTracing *tracing); + + bool is_alive(const EndpointAddress *addr) const; + +protected: + EndpointAddress *consistent_hash_with_group(unsigned int hash, + WFNSTracing *tracing); + void hash_map_add_addr(EndpointAddress *addr); + void hash_map_remove_addr(const std::string& address); + + std::map addr_hash; +}; + +class UPSRoundRobinPolicy : public UPSGroupPolicy +{ +public: + UPSRoundRobinPolicy(bool try_another) : + cur_idx(0) + { + this->try_another = try_another; + } + +protected: + virtual EndpointAddress *first_strategy(const ParsedURI& uri, + WFNSTracing *tracing); + virtual EndpointAddress *another_strategy(const ParsedURI& uri, + WFNSTracing *tracing); + +protected: + virtual int remove_server_locked(const std::string& address); + +protected: + std::atomic cur_idx; +}; + +class UPSWeightedRandomPolicy : public UPSGroupPolicy +{ +public: + UPSWeightedRandomPolicy(bool try_another) + { + this->total_weight = 0; + this->available_weight = 0; + this->try_another = try_another; + } + +protected: + virtual EndpointAddress *first_strategy(const ParsedURI& uri, + WFNSTracing *tracing); + virtual EndpointAddress *another_strategy(const ParsedURI& uri, + WFNSTracing *tracing); + +protected: + virtual void add_server_locked(EndpointAddress *addr); + virtual int remove_server_locked(const std::string& address); + int total_weight; + int available_weight; + +private: + virtual void recover_one_server(const EndpointAddress *addr); + virtual void fuse_one_server(const EndpointAddress *addr); + static int select_history_weight(WFNSTracing *tracing); +}; + +class UPSVNSWRRPolicy : public UPSWeightedRandomPolicy +{ +public: + UPSVNSWRRPolicy() : UPSWeightedRandomPolicy(false) + { + this->cur_idx = 0; + this->try_another = false; + }; + +protected: + virtual EndpointAddress *first_strategy(const ParsedURI& uri, + WFNSTracing *tracing); + +private: + virtual void add_server_locked(EndpointAddress *addr); + virtual int remove_server_locked(const std::string& address); + void init(); + void init_virtual_nodes(); + std::vector pre_generated_vec; + std::vector current_weight_vec; + std::atomic cur_idx; +}; + +class UPSConsistentHashPolicy : public UPSGroupPolicy +{ +public: + UPSConsistentHashPolicy(upstream_route_t consistent_hash) : + consistent_hash(std::move(consistent_hash)) + { + } + +protected: + virtual EndpointAddress *first_strategy(const ParsedURI& uri, + WFNSTracing *tracing); + +private: + virtual void add_server_locked(EndpointAddress *addr); + virtual int remove_server_locked(const std::string& address); + upstream_route_t consistent_hash; +}; + +class UPSManualPolicy : public UPSGroupPolicy +{ +public: + UPSManualPolicy(bool try_another, upstream_route_t select, + upstream_route_t try_another_select) : + manual_select(std::move(select)), + another_select(std::move(try_another_select)) + { + this->try_another = try_another; + } + + EndpointAddress *first_strategy(const ParsedURI& uri, + WFNSTracing *tracing); + EndpointAddress *another_strategy(const ParsedURI& uri, + WFNSTracing *tracing); + +private: + virtual void add_server_locked(EndpointAddress *addr); + virtual int remove_server_locked(const std::string& address); + upstream_route_t manual_select; + upstream_route_t another_select; +}; + +#endif diff --git a/src/nameservice/WFDnsResolver.cc b/src/nameservice/WFDnsResolver.cc new file mode 100644 index 0000000..9a955e3 --- /dev/null +++ b/src/nameservice/WFDnsResolver.cc @@ -0,0 +1,767 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) + Liu Kai (liukaidx@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "EndpointParams.h" +#include "RouteManager.h" +#include "WFGlobal.h" +#include "WFTaskFactory.h" +#include "WFResourcePool.h" +#include "WFNameService.h" +#include "DnsCache.h" +#include "DnsUtil.h" +#include "WFDnsClient.h" +#include "WFDnsResolver.h" + +#define HOSTS_LINEBUF_INIT_SIZE 128 +#define PORT_STR_MAX 5 + +class DnsInput +{ +public: + DnsInput() : + port_(0), + numeric_host_(false), + family_(AF_UNSPEC) + {} + + DnsInput(const std::string& host, unsigned short port, + bool numeric_host, int family) : + host_(host), + port_(port), + numeric_host_(numeric_host), + family_(family) + {} + + void reset(const std::string& host, unsigned short port) + { + host_.assign(host); + port_ = port; + numeric_host_ = false; + family_ = AF_UNSPEC; + } + + void reset(const std::string& host, unsigned short port, + bool numeric_host, int family) + { + host_.assign(host); + port_ = port; + numeric_host_ = numeric_host; + family_ = family; + } + + const std::string& get_host() const { return host_; } + unsigned short get_port() const { return port_; } + bool is_numeric_host() const { return numeric_host_; } + +protected: + std::string host_; + unsigned short port_; + bool numeric_host_; + int family_; + + friend class DnsRoutine; +}; + +class DnsOutput +{ +public: + DnsOutput(): + error_(0), + addrinfo_(NULL) + {} + + ~DnsOutput() + { + if (addrinfo_) + { + if (addrinfo_->ai_flags) + freeaddrinfo(addrinfo_); + else + free(addrinfo_); + } + } + + int get_error() const { return error_; } + const struct addrinfo *get_addrinfo() const { return addrinfo_; } + + //if DONOT want DnsOutput release addrinfo, use move_addrinfo in callback + struct addrinfo *move_addrinfo() + { + struct addrinfo *p = addrinfo_; + addrinfo_ = NULL; + return p; + } + +protected: + int error_; + struct addrinfo *addrinfo_; + + friend class DnsRoutine; +}; + +class DnsRoutine +{ +public: + static void run(const DnsInput *in, DnsOutput *out); + static void create(DnsOutput *out, int error, struct addrinfo *ai) + { + if (out->addrinfo_) + { + if (out->addrinfo_->ai_flags) + freeaddrinfo(out->addrinfo_); + else + free(out->addrinfo_); + } + + out->error_ = error; + out->addrinfo_ = ai; + } + +private: + static void run_local_path(const std::string& path, DnsOutput *out); +}; + +void DnsRoutine::run_local_path(const std::string& path, DnsOutput *out) +{ + struct sockaddr_un *sun = NULL; + + if (path.size() + 1 <= sizeof sun->sun_path) + { + size_t size = sizeof (struct addrinfo) + sizeof (struct sockaddr_un); + + out->addrinfo_ = (struct addrinfo *)calloc(size, 1); + if (out->addrinfo_) + { + sun = (struct sockaddr_un *)(out->addrinfo_ + 1); + sun->sun_family = AF_UNIX; + memcpy(sun->sun_path, path.c_str(), path.size()); + + out->addrinfo_->ai_family = AF_UNIX; + out->addrinfo_->ai_socktype = SOCK_STREAM; + out->addrinfo_->ai_addr = (struct sockaddr *)sun; + size = offsetof(struct sockaddr_un, sun_path) + path.size() + 1; + out->addrinfo_->ai_addrlen = size; + out->error_ = 0; + return; + } + } + else + errno = EINVAL; + + out->error_ = EAI_SYSTEM; +} + +void DnsRoutine::run(const DnsInput *in, DnsOutput *out) +{ + if (in->host_[0] == '/') + { + run_local_path(in->host_, out); + return; + } + + struct addrinfo hints = { + .ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV, + .ai_family = in->family_, + .ai_socktype = SOCK_STREAM, + }; + char port_str[PORT_STR_MAX + 1]; + + if (in->is_numeric_host()) + hints.ai_flags |= AI_NUMERICHOST; + + snprintf(port_str, PORT_STR_MAX + 1, "%u", in->port_); + out->error_ = getaddrinfo(in->host_.c_str(), port_str, + &hints, &out->addrinfo_); + if (out->error_ == 0) + out->addrinfo_->ai_flags = 1; +} + +// Dns Thread task. For internal usage only. +using ThreadDnsTask = WFThreadTask; +using thread_dns_callback_t = std::function; + +struct DnsContext +{ + int state; + int error; + int eai_error; + unsigned short port; + struct addrinfo *ai; +}; + +static int __default_family() +{ + struct addrinfo hints = { + .ai_flags = AI_ADDRCONFIG, + .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM, + }; + struct addrinfo *res; + struct addrinfo *cur; + int family = AF_UNSPEC; + bool v4 = false; + bool v6 = false; + + if (getaddrinfo(NULL, "1", &hints, &res) == 0) + { + for (cur = res; cur; cur = cur->ai_next) + { + if (cur->ai_family == AF_INET) + v4 = true; + else if (cur->ai_family == AF_INET6) + v6 = true; + } + + freeaddrinfo(res); + if (v4 ^ v6) + family = v4 ? AF_INET : AF_INET6; + } + + return family; +} + +// hosts line format: IP canonical_name [aliases...] [# Comment] +static int __readaddrinfo_line(char *p, const char *name, const char *port, + const struct addrinfo *hints, + struct addrinfo **res) +{ + const char *ip = NULL; + char *start; + + start = p; + while (*start != '\0' && *start != '#') + start++; + *start = '\0'; + + while (1) + { + while (isspace(*p)) + p++; + + start = p; + while (*p != '\0' && !isspace(*p)) + p++; + + if (start == p) + break; + + if (*p != '\0') + *p++ = '\0'; + + if (ip == NULL) + { + ip = start; + continue; + } + + if (strcasecmp(name, start) == 0) + { + if (getaddrinfo(ip, port, hints, res) == 0) + return 0; + } + } + + return 1; +} + +static int __readaddrinfo(const char *path, + const char *name, unsigned short port, + const struct addrinfo *hints, + struct addrinfo **res) +{ + char port_str[PORT_STR_MAX + 1]; + size_t bufsize = 0; + char *line = NULL; + int count = 0; + int errno_bak; + FILE *fp; + int ret; + + fp = fopen(path, "r"); + if (!fp) + return EAI_SYSTEM; + + snprintf(port_str, PORT_STR_MAX + 1, "%u", port); + + errno_bak = errno; + while ((ret = getline(&line, &bufsize, fp)) > 0) + { + if (__readaddrinfo_line(line, name, port_str, hints, res) == 0) + { + count++; + res = &(*res)->ai_next; + } + } + + ret = ferror(fp) ? EAI_SYSTEM : EAI_NONAME; + free(line); + fclose(fp); + if (count != 0) + { + errno = errno_bak; + return 0; + } + + return ret; +} + +static ThreadDnsTask *__create_thread_dns_task(const std::string& host, + unsigned short port, + int family, + thread_dns_callback_t callback) +{ + auto *task = WFThreadTaskFactory:: + create_thread_task(WFGlobal::get_dns_queue(), + WFGlobal::get_dns_executor(), + DnsRoutine::run, + std::move(callback)); + + task->get_input()->reset(host, port, false, family); + return task; +} + +static std::string __get_cache_host(const std::string& hostname, + int family) +{ + char c; + + if (family == AF_UNSPEC) + c = '*'; + else if (family == AF_INET) + c = '4'; + else if (family == AF_INET6) + c = '6'; + else + c = '?'; + + return hostname + c; +} + +static std::string __get_guard_name(const std::string& cache_host, + unsigned short port) +{ + std::string guard_name("INTERNAL-dns:"); + guard_name.append(cache_host).append(":"); + guard_name.append(std::to_string(port)); + return guard_name; +} + +void WFResolverTask::dispatch() +{ + if (this->msg_) + { + this->state = WFT_STATE_DNS_ERROR; + this->error = (intptr_t)msg_; + this->subtask_done(); + return; + } + + const ParsedURI& uri = ns_params_.uri; + host_ = uri.host ? uri.host : ""; + port_ = uri.port ? atoi(uri.port) : 0; + + DnsCache *dns_cache = WFGlobal::get_dns_cache(); + const DnsCache::DnsHandle *addr_handle; + std::string hostname = host_; + int family = ep_params_.address_family; + std::string cache_host = __get_cache_host(hostname, family); + + if (ns_params_.retry_times == 0) + addr_handle = dns_cache->get_ttl(cache_host, port_); + else + addr_handle = dns_cache->get_confident(cache_host, port_); + + if (in_guard_ && (addr_handle == NULL || addr_handle->value.delayed())) + { + if (addr_handle) + dns_cache->release(addr_handle); + + this->request_dns(); + return; + } + + if (addr_handle) + { + RouteManager *route_manager = WFGlobal::get_route_manager(); + struct addrinfo *addrinfo = addr_handle->value.addrinfo; + struct addrinfo first; + + if (ns_params_.fixed_addr && addrinfo->ai_next) + { + first = *addrinfo; + first.ai_next = NULL; + addrinfo = &first; + } + + if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info, + &ep_params_, hostname, ns_params_.ssl_ctx, + this->result) < 0) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + } + else + this->state = WFT_STATE_SUCCESS; + + dns_cache->release(addr_handle); + this->subtask_done(); + return; + } + + if (*host_) + { + char front = host_[0]; + char back = host_[hostname.size() - 1]; + struct in6_addr addr; + int ret; + + if (strchr(host_, ':')) + ret = inet_pton(AF_INET6, host_, &addr); + else if (isdigit(back) && isdigit(front)) + ret = inet_pton(AF_INET, host_, &addr); + else if (front == '/') + ret = 1; + else + ret = 0; + + if (ret == 1) + { + // 'true' means numeric host + DnsInput dns_in(hostname, port_, true, AF_UNSPEC); + DnsOutput dns_out; + + DnsRoutine::run(&dns_in, &dns_out); + dns_callback_internal(&dns_out, (unsigned int)-1, (unsigned int)-1); + this->subtask_done(); + return; + } + } + + const char *hosts = WFGlobal::get_global_settings()->hosts_path; + if (hosts) + { + struct addrinfo hints = { + .ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV | AI_NUMERICHOST, + .ai_family = ep_params_.address_family, + .ai_socktype = SOCK_STREAM, + }; + struct addrinfo *ai; + int ret; + + ret = __readaddrinfo(hosts, host_, port_, &hints, &ai); + if (ret == 0) + { + DnsOutput out; + DnsRoutine::create(&out, ret, ai); + dns_callback_internal(&out, dns_ttl_default_, dns_ttl_min_); + this->subtask_done(); + return; + } + } + + std::string guard_name = __get_guard_name(cache_host, port_); + WFConditional *guard = WFTaskFactory::create_guard(guard_name, this, &msg_); + + in_guard_ = true; + has_next_ = true; + + series_of(this)->push_front(guard); + this->subtask_done(); +} + +void WFResolverTask::request_dns() +{ + WFDnsClient *client = WFGlobal::get_dns_client(); + if (client) + { + static int default_family = __default_family(); + WFResourcePool *respool = WFGlobal::get_dns_respool(); + + int family = ep_params_.address_family; + if (family == AF_UNSPEC) + family = default_family; + + if (family == AF_INET || family == AF_INET6) + { + auto&& cb = std::bind(&WFResolverTask::dns_single_callback, + this, + std::placeholders::_1); + WFDnsTask *dns_task = client->create_dns_task(host_, std::move(cb)); + + if (family == AF_INET6) + dns_task->get_req()->set_question_type(DNS_TYPE_AAAA); + + WFConditional *cond = respool->get(dns_task); + series_of(this)->push_front(cond); + } + else + { + struct DnsContext *dctx = new struct DnsContext[2]; + WFDnsTask *task_v4; + WFDnsTask *task_v6; + ParallelWork *pwork; + + dctx[0].ai = NULL; + dctx[1].ai = NULL; + dctx[0].port = port_; + dctx[1].port = port_; + + task_v4 = client->create_dns_task(host_, dns_partial_callback); + task_v4->user_data = dctx; + + task_v6 = client->create_dns_task(host_, dns_partial_callback); + task_v6->get_req()->set_question_type(DNS_TYPE_AAAA); + task_v6->user_data = dctx + 1; + + auto&& cb = std::bind(&WFResolverTask::dns_parallel_callback, + this, + std::placeholders::_1); + + pwork = Workflow::create_parallel_work(std::move(cb)); + pwork->set_context(dctx); + + WFConditional *cond_v4 = respool->get(task_v4); + WFConditional *cond_v6 = respool->get(task_v6); + pwork->add_series(Workflow::create_series_work(cond_v4, nullptr)); + pwork->add_series(Workflow::create_series_work(cond_v6, nullptr)); + + series_of(this)->push_front(pwork); + } + } + else + { + ThreadDnsTask *dns_task; + auto&& cb = std::bind(&WFResolverTask::thread_dns_callback, + this, + std::placeholders::_1); + dns_task = __create_thread_dns_task(host_, port_, + ep_params_.address_family, + std::move(cb)); + series_of(this)->push_front(dns_task); + } + + has_next_ = true; + this->subtask_done(); +} + +SubTask *WFResolverTask::done() +{ + SeriesWork *series = series_of(this); + + if (!has_next_) + task_callback(); + else + has_next_ = false; + + return series->pop(); +} + +void WFResolverTask::dns_callback_internal(void *thrd_dns_output, + unsigned int ttl_default, + unsigned int ttl_min) +{ + DnsOutput *dns_out = (DnsOutput *)thrd_dns_output; + int dns_error = dns_out->get_error(); + + if (dns_error) + { + if (dns_error == EAI_SYSTEM) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + } + else + { + this->state = WFT_STATE_DNS_ERROR; + this->error = dns_error; + } + } + else + { + RouteManager *route_manager = WFGlobal::get_route_manager(); + DnsCache *dns_cache = WFGlobal::get_dns_cache(); + struct addrinfo *addrinfo = dns_out->move_addrinfo(); + const DnsCache::DnsHandle *addr_handle; + std::string hostname = host_; + int family = ep_params_.address_family; + std::string cache_host = __get_cache_host(hostname, family); + + addr_handle = dns_cache->put(cache_host, port_, addrinfo, + (unsigned int)ttl_default, + (unsigned int)ttl_min); + if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info, + &ep_params_, hostname, ns_params_.ssl_ctx, + this->result) < 0) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + } + else + this->state = WFT_STATE_SUCCESS; + + dns_cache->release(addr_handle); + } +} + +void WFResolverTask::dns_single_callback(void *net_dns_task) +{ + WFDnsTask *dns_task = (WFDnsTask *)net_dns_task; + WFGlobal::get_dns_respool()->post(NULL); + + if (dns_task->get_state() == WFT_STATE_SUCCESS) + { + struct addrinfo *ai = NULL; + int ret; + + ret = protocol::DnsUtil::getaddrinfo(dns_task->get_resp(), port_, &ai); + DnsOutput out; + DnsRoutine::create(&out, ret, ai); + dns_callback_internal(&out, dns_ttl_default_, dns_ttl_min_); + } + else + { + this->state = dns_task->get_state(); + this->error = dns_task->get_error(); + } + + task_callback(); +} + +void WFResolverTask::dns_partial_callback(void *net_dns_task) +{ + WFDnsTask *dns_task = (WFDnsTask *)net_dns_task; + WFGlobal::get_dns_respool()->post(NULL); + + struct DnsContext *ctx = (struct DnsContext *)dns_task->user_data; + ctx->ai = NULL; + ctx->state = dns_task->get_state(); + ctx->error = dns_task->get_error(); + if (ctx->state == WFT_STATE_SUCCESS) + { + protocol::DnsResponse *resp = dns_task->get_resp(); + ctx->eai_error = protocol::DnsUtil::getaddrinfo(resp, ctx->port, + &ctx->ai); + } + else + ctx->eai_error = EAI_NONAME; +} + +void WFResolverTask::dns_parallel_callback(const void *parallel) +{ + const ParallelWork *pwork = (const ParallelWork *)parallel; + struct DnsContext *c4 = (struct DnsContext *)(pwork->get_context()); + struct DnsContext *c6 = c4 + 1; + DnsOutput out; + + if (c4->state != WFT_STATE_SUCCESS && c6->state != WFT_STATE_SUCCESS) + { + this->state = c4->state; + this->error = c4->error; + } + else if (c4->eai_error != 0 && c6->eai_error != 0) + { + DnsRoutine::create(&out, c4->eai_error, NULL); + dns_callback_internal(&out, dns_ttl_default_, dns_ttl_min_); + } + else + { + struct addrinfo *ai = NULL; + struct addrinfo **pai = &ai; + + if (c4->ai != NULL) + { + *pai = c4->ai; + while (*pai) + pai = &(*pai)->ai_next; + } + + if (c6->ai != NULL) + *pai = c6->ai; + + DnsRoutine::create(&out, 0, ai); + dns_callback_internal(&out, dns_ttl_default_, dns_ttl_min_); + } + + delete[] c4; + + task_callback(); +} + +void WFResolverTask::thread_dns_callback(void *thrd_dns_task) +{ + ThreadDnsTask *dns_task = (ThreadDnsTask *)thrd_dns_task; + + if (dns_task->get_state() == WFT_STATE_SUCCESS) + { + DnsOutput *out = dns_task->get_output(); + dns_callback_internal(out, dns_ttl_default_, dns_ttl_min_); + } + else + { + this->state = dns_task->get_state(); + this->error = dns_task->get_error(); + } + + task_callback(); +} + +void WFResolverTask::task_callback() +{ + if (in_guard_) + { + int family = ep_params_.address_family; + std::string cache_host = __get_cache_host(host_, family); + std::string guard_name = __get_guard_name(cache_host, port_); + + if (this->state == WFT_STATE_DNS_ERROR) + msg_ = (void *)(intptr_t)this->error; + + WFTaskFactory::release_guard_safe(guard_name, msg_); + } + + if (this->callback) + this->callback(this); + + delete this; +} + +WFRouterTask *WFDnsResolver::create_router_task(const struct WFNSParams *params, + router_callback_t callback) +{ + const struct WFGlobalSettings *settings = WFGlobal::get_global_settings(); + unsigned int dns_ttl_default = settings->dns_ttl_default; + unsigned int dns_ttl_min = settings->dns_ttl_min; + const struct EndpointParams *ep_params = &settings->endpoint_params; + return new WFResolverTask(params, dns_ttl_default, dns_ttl_min, ep_params, + std::move(callback)); +} + diff --git a/src/nameservice/WFDnsResolver.h b/src/nameservice/WFDnsResolver.h new file mode 100644 index 0000000..4584d96 --- /dev/null +++ b/src/nameservice/WFDnsResolver.h @@ -0,0 +1,100 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFDNSRESOLVER_H_ +#define _WFDNSRESOLVER_H_ + +#include +#include +#include "EndpointParams.h" +#include "WFNameService.h" + +class WFResolverTask : public WFRouterTask +{ +public: + WFResolverTask(const struct WFNSParams *ns_params, + unsigned int dns_ttl_default, unsigned int dns_ttl_min, + const struct EndpointParams *ep_params, + router_callback_t&& cb) : + WFRouterTask(std::move(cb)), + ns_params_(*ns_params), + ep_params_(*ep_params) + { + if (ns_params_.fixed_conn) + ep_params_.max_connections = 1; + + dns_ttl_default_ = dns_ttl_default; + dns_ttl_min_ = dns_ttl_min; + has_next_ = false; + in_guard_ = false; + msg_ = NULL; + } + + WFResolverTask(const struct WFNSParams *ns_params, + router_callback_t&& cb) : + WFRouterTask(std::move(cb)), + ns_params_(*ns_params) + { + if (ns_params_.fixed_conn) + ep_params_.max_connections = 1; + + has_next_ = false; + in_guard_ = false; + msg_ = NULL; + } + +protected: + virtual void dispatch(); + virtual SubTask *done(); + void set_has_next() { has_next_ = true; } + +private: + void thread_dns_callback(void *thrd_dns_task); + void dns_single_callback(void *net_dns_task); + static void dns_partial_callback(void *net_dns_task); + void dns_parallel_callback(const void *parallel); + void dns_callback_internal(void *thrd_dns_output, + unsigned int ttl_default, + unsigned int ttl_min); + + void request_dns(); + void task_callback(); + +protected: + struct WFNSParams ns_params_; + unsigned int dns_ttl_default_; + unsigned int dns_ttl_min_; + struct EndpointParams ep_params_; + +private: + const char *host_; + unsigned short port_; + bool has_next_; + bool in_guard_; + void *msg_; +}; + +class WFDnsResolver : public WFNSPolicy +{ +public: + virtual WFRouterTask *create_router_task(const struct WFNSParams *params, + router_callback_t callback); +}; + +#endif + diff --git a/src/nameservice/WFNameService.cc b/src/nameservice/WFNameService.cc new file mode 100644 index 0000000..fb2191f --- /dev/null +++ b/src/nameservice/WFNameService.cc @@ -0,0 +1,145 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include "rbtree.h" +#include "WFNameService.h" + +struct WFNSPolicyEntry +{ + struct rb_node rb; + WFNSPolicy *policy; + char name[1]; +}; + +int WFNameService::add_policy(const char *name, WFNSPolicy *policy) +{ + struct rb_node **p = &this->root.rb_node; + struct rb_node *parent = NULL; + struct WFNSPolicyEntry *entry; + int n, ret = -1; + + pthread_rwlock_wrlock(&this->rwlock); + while (*p) + { + parent = *p; + entry = rb_entry(*p, struct WFNSPolicyEntry, rb); + n = strcasecmp(name, entry->name); + if (n < 0) + p = &(*p)->rb_left; + else if (n > 0) + p = &(*p)->rb_right; + else + break; + } + + if (!*p) + { + size_t len = strlen(name); + size_t size = offsetof(struct WFNSPolicyEntry, name) + len + 1; + + entry = (struct WFNSPolicyEntry *)malloc(size); + if (entry) + { + memcpy(entry->name, name, len + 1); + entry->policy = policy; + rb_link_node(&entry->rb, parent, p); + rb_insert_color(&entry->rb, &this->root); + ret = 0; + } + } + else + errno = EEXIST; + + pthread_rwlock_unlock(&this->rwlock); + return ret; +} + +inline struct WFNSPolicyEntry *WFNameService::get_policy_entry(const char *name) +{ + struct rb_node *p = this->root.rb_node; + struct WFNSPolicyEntry *entry; + int n; + + while (p) + { + entry = rb_entry(p, struct WFNSPolicyEntry, rb); + n = strcasecmp(name, entry->name); + if (n < 0) + p = p->rb_left; + else if (n > 0) + p = p->rb_right; + else + return entry; + } + + return NULL; +} + +WFNSPolicy *WFNameService::get_policy(const char *name) +{ + WFNSPolicy *policy = this->default_policy; + struct WFNSPolicyEntry *entry; + + if (this->root.rb_node) + { + pthread_rwlock_rdlock(&this->rwlock); + entry = this->get_policy_entry(name); + if (entry) + policy = entry->policy; + + pthread_rwlock_unlock(&this->rwlock); + } + + return policy; +} + +WFNSPolicy *WFNameService::del_policy(const char *name) +{ + WFNSPolicy *policy = NULL; + struct WFNSPolicyEntry *entry; + + pthread_rwlock_wrlock(&this->rwlock); + entry = this->get_policy_entry(name); + if (entry) + { + policy = entry->policy; + rb_erase(&entry->rb, &this->root); + } + + pthread_rwlock_unlock(&this->rwlock); + free(entry); + return policy; +} + +WFNameService::~WFNameService() +{ + struct WFNSPolicyEntry *entry; + + while (this->root.rb_node) + { + entry = rb_entry(this->root.rb_node, struct WFNSPolicyEntry, rb); + rb_erase(&entry->rb, &this->root); + free(entry); + } +} + diff --git a/src/nameservice/WFNameService.h b/src/nameservice/WFNameService.h new file mode 100644 index 0000000..d11b5d2 --- /dev/null +++ b/src/nameservice/WFNameService.h @@ -0,0 +1,156 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFNAMESERVICE_H_ +#define _WFNAMESERVICE_H_ + +#include +#include +#include +#include "rbtree.h" +#include "Communicator.h" +#include "Workflow.h" +#include "WFTask.h" +#include "RouteManager.h" +#include "URIParser.h" +#include "EndpointParams.h" + +class WFRouterTask : public WFGenericTask +{ +public: + RouteManager::RouteResult *get_result() { return &this->result; } + +public: + void set_state(int state) { this->state = state; } + void set_error(int error) { this->error = error; } + +protected: + RouteManager::RouteResult result; + std::function callback; + +protected: + virtual SubTask *done() + { + SeriesWork *series = series_of(this); + + if (this->callback) + this->callback(this); + + delete this; + return series->pop(); + } + +public: + WFRouterTask(std::function&& cb) : + callback(std::move(cb)) + { + } +}; + +class WFNSTracing +{ +public: + void *data; + void (*deleter)(void *); + +public: + WFNSTracing() + { + this->data = NULL; + this->deleter = NULL; + } +}; + +struct WFNSParams +{ + enum TransportType type; + ParsedURI& uri; + const char *info; + SSL_CTX *ssl_ctx; + bool fixed_addr; + bool fixed_conn; + int retry_times; + WFNSTracing *tracing; +}; + +using router_callback_t = std::function; + +class WFNSPolicy +{ +public: + virtual WFRouterTask *create_router_task(const struct WFNSParams *params, + router_callback_t callback) = 0; + + virtual void success(RouteManager::RouteResult *result, + WFNSTracing *tracing, + CommTarget *target) + { + RouteManager::notify_available(result->cookie, target); + } + + virtual void failed(RouteManager::RouteResult *result, + WFNSTracing *tracing, + CommTarget *target) + { + if (target) + RouteManager::notify_unavailable(result->cookie, target); + } + +public: + virtual ~WFNSPolicy() { } +}; + +class WFNameService +{ +public: + int add_policy(const char *name, WFNSPolicy *policy); + WFNSPolicy *get_policy(const char *name); + WFNSPolicy *del_policy(const char *name); + +public: + WFNSPolicy *get_default_policy() const + { + return this->default_policy; + } + + void set_default_policy(WFNSPolicy *policy) + { + this->default_policy = policy; + } + +private: + WFNSPolicy *default_policy; + struct rb_root root; + pthread_rwlock_t rwlock; + +private: + struct WFNSPolicyEntry *get_policy_entry(const char *name); + +public: + WFNameService(WFNSPolicy *default_policy) : + rwlock(PTHREAD_RWLOCK_INITIALIZER) + { + this->root.rb_node = NULL; + this->default_policy = default_policy; + } + + virtual ~WFNameService(); +}; + +#endif + diff --git a/src/nameservice/WFServiceGovernance.cc b/src/nameservice/WFServiceGovernance.cc new file mode 100644 index 0000000..5689d80 --- /dev/null +++ b/src/nameservice/WFServiceGovernance.cc @@ -0,0 +1,487 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include "URIParser.h" +#include "WFTaskError.h" +#include "StringUtil.h" +#include "WFGlobal.h" +#include "WFNameService.h" +#include "WFDnsResolver.h" +#include "WFServiceGovernance.h" + +#define GET_CURRENT_SECOND std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count() + +#define DNS_CACHE_LEVEL_1 1 +#define DNS_CACHE_LEVEL_2 2 + +PolicyAddrParams::PolicyAddrParams() +{ + const struct AddressParams *params = &ADDRESS_PARAMS_DEFAULT; + this->endpoint_params = params->endpoint_params; + this->dns_ttl_default = params->dns_ttl_default; + this->dns_ttl_min = params->dns_ttl_min; + this->max_fails = params->max_fails; +} + +PolicyAddrParams::PolicyAddrParams(const struct AddressParams *params) : + endpoint_params(params->endpoint_params) +{ + this->dns_ttl_default = params->dns_ttl_default; + this->dns_ttl_min = params->dns_ttl_min; + this->max_fails = params->max_fails; +} + +EndpointAddress::EndpointAddress(const std::string& address, + PolicyAddrParams *address_params) +{ + std::vector arr = StringUtil::split(address, ':'); + + this->params = address_params; + if (this->params->max_fails == 0) + this->params->max_fails = 1; + + this->address = address; + this->fail_count = 0; + this->ref = 1; + this->entry.list.next = NULL; + this->entry.ptr = this; + + if (arr.size() == 0) + this->host = ""; + else + this->host = arr[0]; + + if (arr.size() <= 1) + this->port = ""; + else + this->port = arr[1]; +} + +class WFSGResolverTask : public WFResolverTask +{ +public: + WFSGResolverTask(const struct WFNSParams *params, + WFServiceGovernance *sg, + router_callback_t&& cb) : + WFResolverTask(params, std::move(cb)) + { + sg_ = sg; + } + +protected: + virtual void dispatch(); + +protected: + WFServiceGovernance *sg_; +}; + +static void copy_host_port(ParsedURI& uri, const EndpointAddress *addr) +{ + if (!addr->host.empty()) + { + free(uri.host); + uri.host = strdup(addr->host.c_str()); + } + + if (!addr->port.empty()) + { + free(uri.port); + uri.port = strdup(addr->port.c_str()); + } +} + +void WFSGResolverTask::dispatch() +{ + WFNSTracing *tracing = ns_params_.tracing; + EndpointAddress *addr; + + if (!sg_) + { + this->WFResolverTask::dispatch(); + return; + } + + if (sg_->pre_select_) + { + WFConditional *cond = sg_->pre_select_(this); + if (cond) + { + series_of(this)->push_front(cond); + this->set_has_next(); + this->subtask_done(); + return; + } + else if (this->state != WFT_STATE_UNDEFINED) + { + this->subtask_done(); + return; + } + } + + if (sg_->select(ns_params_.uri, tracing, &addr)) + { + auto *tracing_data = (WFServiceGovernance::TracingData *)tracing->data; + if (!tracing_data) + { + tracing_data = new WFServiceGovernance::TracingData; + tracing_data->sg = sg_; + tracing->data = tracing_data; + tracing->deleter = WFServiceGovernance::tracing_deleter; + } + + tracing_data->history.push_back(addr); + sg_ = NULL; + + copy_host_port(ns_params_.uri, addr); + dns_ttl_default_ = addr->params->dns_ttl_default; + dns_ttl_min_ = addr->params->dns_ttl_min; + ep_params_ = addr->params->endpoint_params; + this->WFResolverTask::dispatch(); + } + else + { + this->state = WFT_STATE_TASK_ERROR; + this->error = WFT_ERR_UPSTREAM_UNAVAILABLE; + this->subtask_done(); + } +} + +WFRouterTask *WFServiceGovernance::create_router_task(const struct WFNSParams *params, + router_callback_t callback) +{ + return new WFSGResolverTask(params, this, std::move(callback)); +} + +void WFServiceGovernance::tracing_deleter(void *data) +{ + struct TracingData *tracing_data = (struct TracingData *)data; + + for (EndpointAddress *addr : tracing_data->history) + { + if (--addr->ref == 0) + { + pthread_rwlock_wrlock(&tracing_data->sg->rwlock); + tracing_data->sg->pre_delete_server(addr); + pthread_rwlock_unlock(&tracing_data->sg->rwlock); + delete addr; + } + } + + delete tracing_data; +} + +bool WFServiceGovernance::in_select_history(WFNSTracing *tracing, + EndpointAddress *addr) +{ + struct TracingData *tracing_data = (struct TracingData *)tracing->data; + + if (!tracing_data) + return false; + + for (EndpointAddress *server : tracing_data->history) + { + if (server == addr) + return true; + } + + return false; +} + +void WFServiceGovernance::recover_server_from_breaker(EndpointAddress *addr) +{ + addr->fail_count = 0; + pthread_mutex_lock(&this->breaker_lock); + if (addr->entry.list.next) + { + list_del(&addr->entry.list); + addr->entry.list.next = NULL; + this->recover_one_server(addr); + } + pthread_mutex_unlock(&this->breaker_lock); +} + +void WFServiceGovernance::fuse_server_to_breaker(EndpointAddress *addr) +{ + pthread_mutex_lock(&this->breaker_lock); + if (!addr->entry.list.next) + { + addr->broken_timeout = GET_CURRENT_SECOND + this->mttr_second; + list_add_tail(&addr->entry.list, &this->breaker_list); + this->fuse_one_server(addr); + } + pthread_mutex_unlock(&this->breaker_lock); +} + +void WFServiceGovernance::pre_delete_server(EndpointAddress *addr) +{ + pthread_mutex_lock(&this->breaker_lock); + if (addr->entry.list.next) + { + list_del(&addr->entry.list); + addr->entry.list.next = NULL; + } + else + this->fuse_one_server(addr); + pthread_mutex_unlock(&this->breaker_lock); +} + +void WFServiceGovernance::success(RouteManager::RouteResult *result, + WFNSTracing *tracing, + CommTarget *target) +{ + struct TracingData *tracing_data = (struct TracingData *)tracing->data; + auto *v = &tracing_data->history; + EndpointAddress *server = (*v)[v->size() - 1]; + + pthread_rwlock_wrlock(&this->rwlock); + this->recover_server_from_breaker(server); + pthread_rwlock_unlock(&this->rwlock); + + this->WFNSPolicy::success(result, tracing, target); +} + +void WFServiceGovernance::failed(RouteManager::RouteResult *result, + WFNSTracing *tracing, + CommTarget *target) +{ + struct TracingData *tracing_data = (struct TracingData *)tracing->data; + auto *v = &tracing_data->history; + EndpointAddress *server = (*v)[v->size() - 1]; + + pthread_rwlock_wrlock(&this->rwlock); + if (++server->fail_count == server->params->max_fails) + this->fuse_server_to_breaker(server); + pthread_rwlock_unlock(&this->rwlock); + + this->WFNSPolicy::failed(result, tracing, target); +} + +void WFServiceGovernance::check_breaker_locked(int64_t cur_time) +{ + struct list_head *pos, *tmp; + struct EndpointAddress::address_entry *entry; + EndpointAddress *addr; + + list_for_each_safe(pos, tmp, &this->breaker_list) + { + entry = list_entry(pos, struct EndpointAddress::address_entry, list); + addr = entry->ptr; + + if (cur_time >= addr->broken_timeout) + { + addr->fail_count = addr->params->max_fails - 1; + this->recover_one_server(addr); + list_del(pos); + pos->next = NULL; + } + else + break; + } +} + +void WFServiceGovernance::check_breaker() +{ + pthread_mutex_lock(&this->breaker_lock); + if (!list_empty(&this->breaker_list)) + this->check_breaker_locked(GET_CURRENT_SECOND); + pthread_mutex_unlock(&this->breaker_lock); +} + +void WFServiceGovernance::try_clear_breaker() +{ + pthread_mutex_lock(&this->breaker_lock); + if (!list_empty(&this->breaker_list)) + { + struct list_head *pos = this->breaker_list.next; + struct EndpointAddress::address_entry *entry; + entry = list_entry(pos, struct EndpointAddress::address_entry, list); + if (GET_CURRENT_SECOND >= entry->ptr->broken_timeout) + this->check_breaker_locked(INT64_MAX); + } + pthread_mutex_unlock(&this->breaker_lock); +} + +EndpointAddress *WFServiceGovernance::first_strategy(const ParsedURI& uri, + WFNSTracing *tracing) +{ + unsigned int idx = rand() % this->servers.size(); + return this->servers[idx]; +} + +EndpointAddress *WFServiceGovernance::another_strategy(const ParsedURI& uri, + WFNSTracing *tracing) +{ + return this->first_strategy(uri, tracing); +} + +bool WFServiceGovernance::select(const ParsedURI& uri, WFNSTracing *tracing, + EndpointAddress **addr) +{ + pthread_rwlock_rdlock(&this->rwlock); + unsigned int n = (unsigned int)this->servers.size(); + + if (n == 0) + { + pthread_rwlock_unlock(&this->rwlock); + return false; + } + + this->check_breaker(); + if (this->nalives == 0) + { + pthread_rwlock_unlock(&this->rwlock); + return false; + } + + EndpointAddress *select_addr = this->first_strategy(uri, tracing); + + if (!select_addr || + select_addr->fail_count >= select_addr->params->max_fails) + { + if (this->try_another) + select_addr = this->another_strategy(uri, tracing); + } + + if (select_addr) + { + *addr = select_addr; + ++select_addr->ref; + } + + pthread_rwlock_unlock(&this->rwlock); + return !!select_addr; +} + +void WFServiceGovernance::add_server_locked(EndpointAddress *addr) +{ + this->server_map[addr->address].push_back(addr); + this->servers.push_back(addr); + this->recover_one_server(addr); +} + +int WFServiceGovernance::remove_server_locked(const std::string& address) +{ + const auto map_it = this->server_map.find(address); + size_t n = this->servers.size(); + size_t new_n = 0; + int ret = 0; + + for (size_t i = 0; i < n; i++) + { + if (this->servers[i]->address != address) + this->servers[new_n++] = this->servers[i]; + } + + this->servers.resize(new_n); + + if (map_it != this->server_map.cend()) + { + for (EndpointAddress *addr : map_it->second) + { + if (--addr->ref == 0) + { + this->pre_delete_server(addr); + delete addr; + } + + ret++; + } + + this->server_map.erase(map_it); + } + + return ret; +} + +void WFServiceGovernance::add_server(const std::string& address, + const AddressParams *params) +{ + EndpointAddress *addr = new EndpointAddress(address, + new PolicyAddrParams(params)); + + pthread_rwlock_wrlock(&this->rwlock); + this->add_server_locked(addr); + pthread_rwlock_unlock(&this->rwlock); +} + +int WFServiceGovernance::remove_server(const std::string& address) +{ + int ret; + pthread_rwlock_wrlock(&this->rwlock); + ret = this->remove_server_locked(address); + pthread_rwlock_unlock(&this->rwlock); + return ret; +} + +int WFServiceGovernance::replace_server(const std::string& address, + const AddressParams *params) +{ + int ret; + EndpointAddress *addr = new EndpointAddress(address, + new PolicyAddrParams(params)); + + pthread_rwlock_wrlock(&this->rwlock); + ret = this->remove_server_locked(address); + this->add_server_locked(addr); + pthread_rwlock_unlock(&this->rwlock); + return ret; +} + +void WFServiceGovernance::enable_server(const std::string& address) +{ + pthread_rwlock_wrlock(&this->rwlock); + const auto map_it = this->server_map.find(address); + if (map_it != this->server_map.cend()) + { + for (EndpointAddress *addr : map_it->second) + this->recover_server_from_breaker(addr); + } + pthread_rwlock_unlock(&this->rwlock); +} + +void WFServiceGovernance::disable_server(const std::string& address) +{ + pthread_rwlock_wrlock(&this->rwlock); + const auto map_it = this->server_map.find(address); + if (map_it != this->server_map.cend()) + { + for (EndpointAddress *addr : map_it->second) + { + addr->fail_count = addr->params->max_fails; + this->fuse_server_to_breaker(addr); + } + } + pthread_rwlock_unlock(&this->rwlock); +} + +void WFServiceGovernance::get_current_address(std::vector& addr_list) +{ + pthread_rwlock_rdlock(&this->rwlock); + + for (const EndpointAddress *server : this->servers) + addr_list.push_back(server->address); + + pthread_rwlock_unlock(&this->rwlock); +} + diff --git a/src/nameservice/WFServiceGovernance.h b/src/nameservice/WFServiceGovernance.h new file mode 100644 index 0000000..b3acd84 --- /dev/null +++ b/src/nameservice/WFServiceGovernance.h @@ -0,0 +1,204 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFSERVICEGOVERNANCE_H_ +#define _WFSERVICEGOVERNANCE_H_ + +#include +#include +#include +#include +#include +#include +#include "URIParser.h" +#include "EndpointParams.h" +#include "WFNameService.h" + +#define MTTR_SECOND_DEFAULT 30 +#define VIRTUAL_GROUP_SIZE 16 + +struct AddressParams +{ + struct EndpointParams endpoint_params; ///< Connection config + unsigned int dns_ttl_default; ///< in seconds, DNS TTL when network request success + unsigned int dns_ttl_min; ///< in seconds, DNS TTL when network request fail +/** + * - The max_fails directive sets the number of consecutive unsuccessful attempts to communicate with the server. + * - After 30s following the server failure, upstream probe the server with some live client`s requests. + * - If the probes have been successful, the server is marked as a live one. + * - If max_fails is set to 1, it means server would out of upstream selection in 30 seconds when failed only once + */ + unsigned int max_fails; ///< [1, INT32_MAX] max_fails = 0 means max_fails = 1 + unsigned short weight; ///< [1, 65535] weight = 0 means weight = 1. only for main server + int server_type; ///< 0 for main and 1 for backup + int group_id; ///< -1 means no group. Backup without group will be backup for any main +}; + +static constexpr struct AddressParams ADDRESS_PARAMS_DEFAULT = +{ + .endpoint_params = ENDPOINT_PARAMS_DEFAULT, + .dns_ttl_default = 12 * 3600, + .dns_ttl_min = 180, + .max_fails = 200, + .weight = 1, + .server_type = 0, /* 0 for main and 1 for backup. */ + .group_id = -1, +}; + +class PolicyAddrParams +{ +public: + struct EndpointParams endpoint_params; + unsigned int dns_ttl_default; + unsigned int dns_ttl_min; + unsigned int max_fails; + +public: + PolicyAddrParams(); + PolicyAddrParams(const struct AddressParams *params); + virtual ~PolicyAddrParams() { } +}; + +class EndpointAddress +{ +public: + std::string address; + std::string host; + std::string port; + unsigned int fail_count; + std::atomic ref; + long long broken_timeout; + PolicyAddrParams *params; + + struct address_entry + { + struct list_head list; + EndpointAddress *ptr; + } entry; + +public: + EndpointAddress(const std::string& address, PolicyAddrParams *params); + virtual ~EndpointAddress() { delete this->params; } +}; + +class WFServiceGovernance : public WFNSPolicy +{ +public: + virtual WFRouterTask *create_router_task(const struct WFNSParams *params, + router_callback_t callback); + virtual void success(RouteManager::RouteResult *result, + WFNSTracing *tracing, + CommTarget *target); + virtual void failed(RouteManager::RouteResult *result, + WFNSTracing *tracing, + CommTarget *target); + + virtual void add_server(const std::string& address, + const struct AddressParams *params); + int remove_server(const std::string& address); + virtual int replace_server(const std::string& address, + const struct AddressParams *params); + + void enable_server(const std::string& address); + void disable_server(const std::string& address); + virtual void get_current_address(std::vector& addr_list); + + void set_mttr_second(unsigned int second) { this->mttr_second = second; } + static bool in_select_history(WFNSTracing *tracing, EndpointAddress *addr); + +public: + using pre_select_t = std::function; + + void set_pre_select(pre_select_t pre_select) + { + pre_select_ = std::move(pre_select); + } + +public: + WFServiceGovernance() : + breaker_lock(PTHREAD_MUTEX_INITIALIZER), + rwlock(PTHREAD_RWLOCK_INITIALIZER) + { + this->nalives = 0; + this->try_another = false; + this->mttr_second = MTTR_SECOND_DEFAULT; + INIT_LIST_HEAD(&this->breaker_list); + } + + virtual ~WFServiceGovernance() + { + for (EndpointAddress *addr : this->servers) + delete addr; + } + +private: + virtual bool select(const ParsedURI& uri, WFNSTracing *tracing, + EndpointAddress **addr); + + virtual void recover_one_server(const EndpointAddress *addr) + { + this->nalives++; + } + + virtual void fuse_one_server(const EndpointAddress *addr) + { + this->nalives--; + } + + virtual void add_server_locked(EndpointAddress *addr); + virtual int remove_server_locked(const std::string& address); + + void recover_server_from_breaker(EndpointAddress *addr); + void fuse_server_to_breaker(EndpointAddress *addr); + void check_breaker_locked(int64_t cur_time); + +private: + struct list_head breaker_list; + pthread_mutex_t breaker_lock; + unsigned int mttr_second; + pre_select_t pre_select_; + +protected: + virtual EndpointAddress *first_strategy(const ParsedURI& uri, + WFNSTracing *tracing); + virtual EndpointAddress *another_strategy(const ParsedURI& uri, + WFNSTracing *tracing); + void check_breaker(); + void try_clear_breaker(); + void pre_delete_server(EndpointAddress *addr); + + struct TracingData + { + std::vector history; + WFServiceGovernance *sg; + }; + + static void tracing_deleter(void *data); + + std::vector servers; + std::unordered_map> server_map; + pthread_rwlock_t rwlock; + std::atomic nalives; + bool try_another; + friend class WFSGResolverTask; +}; + +#endif + diff --git a/src/nameservice/xmake.lua b/src/nameservice/xmake.lua new file mode 100644 index 0000000..f419a1b --- /dev/null +++ b/src/nameservice/xmake.lua @@ -0,0 +1,7 @@ +target("nameservice") + add_files("*.cc") + set_kind("object") + if not has_config("upstream") then + remove_files("WFServiceGovernance.cc", "UpstreamPolicies.cc") + end + diff --git a/src/protocol/CMakeLists.txt b/src/protocol/CMakeLists.txt new file mode 100644 index 0000000..6e53a11 --- /dev/null +++ b/src/protocol/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.6) +project(protocol) + +set(SRC + PackageWrapper.cc + SSLWrapper.cc + dns_parser.c + DnsMessage.cc + DnsUtil.cc + http_parser.c + HttpMessage.cc + HttpUtil.cc + TLVMessage.cc +) + +if (NOT MYSQL STREQUAL "n") + set(SRC + ${SRC} + mysql_stream.c + mysql_parser.c + mysql_byteorder.c + MySQLMessage.cc + MySQLResult.cc + MySQLUtil.cc + ) +endif () + +if (NOT REDIS STREQUAL "n") + set(SRC + ${SRC} + redis_parser.c + RedisMessage.cc + ) +endif () + +add_library(${PROJECT_NAME} OBJECT ${SRC}) + +if (KAFKA STREQUAL "y") + set(SRC + kafka_parser.c + KafkaMessage.cc + KafkaDataTypes.cc + KafkaResult.cc + ) + add_library("protocol_kafka" OBJECT ${SRC}) + set_property(SOURCE KafkaMessage.cc APPEND PROPERTY COMPILE_OPTIONS "-fno-rtti") +endif () diff --git a/src/protocol/ConsulDataTypes.h b/src/protocol/ConsulDataTypes.h new file mode 100644 index 0000000..bcbc957 --- /dev/null +++ b/src/protocol/ConsulDataTypes.h @@ -0,0 +1,353 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhenpeng (wangzhenpeng@sogou-inc.com) +*/ + +#ifndef _CONSULDATATYPES_H_ +#define _CONSULDATATYPES_H_ + +#include +#include +#include +#include +#include + +namespace protocol +{ + +class ConsulConfig +{ +public: + // common config + void set_token(const std::string& token) { this->ptr->token = token; } + std::string get_token() const { return this->ptr->token; } + + // discover config + + void set_datacenter(const std::string& data_center) + { + this->ptr->dc = data_center; + } + std::string get_datacenter() const { return this->ptr->dc; } + + void set_near_node(const std::string& near_node) + { + this->ptr->near = near_node; + } + std::string get_near_node() const { return this->ptr->near; } + + void set_filter_expr(const std::string& filter_expr) + { + this->ptr->filter = filter_expr; + } + std::string get_filter_expr() const { return this->ptr->filter; } + + // blocking query wait, limited to 10 minutes, default:5m, unit:ms + void set_wait_ttl(int wait_ttl) { this->ptr->wait_ttl = wait_ttl; } + int get_wait_ttl() const { return this->ptr->wait_ttl; } + + // enable blocking query + void set_blocking_query(bool enable_flag) + { + this->ptr->blocking_query = enable_flag; + } + bool blocking_query() const { return this->ptr->blocking_query; } + + // only get health passing status service instance + void set_passing(bool passing) { this->ptr->passing = passing; } + bool get_passing() const { return this->ptr->passing; } + + + // register config + + void set_replace_checks(bool replace_checks) + { + this->ptr->replace_checks = replace_checks; + } + bool get_replace_checks() const + { + return this->ptr->replace_checks; + } + + void set_check_name(const std::string& check_name) + { + this->ptr->check_cfg.check_name = check_name; + } + std::string get_check_name() const { return this->ptr->check_cfg.check_name; } + + void set_check_http_url(const std::string& http_url) + { + this->ptr->check_cfg.http_url = http_url; + } + std::string get_check_http_url() const + { + return this->ptr->check_cfg.http_url; + } + + void set_check_http_method(const std::string& method) + { + this->ptr->check_cfg.http_method = method; + } + std::string get_check_http_method() const + { + return this->ptr->check_cfg.http_method; + } + + void add_http_header(const std::string& key, + const std::vector& values) + { + this->ptr->check_cfg.headers.emplace(key, values); + } + const std::map> *get_http_headers() const + { + return &this->ptr->check_cfg.headers; + } + + void set_http_body(const std::string& body) + { + this->ptr->check_cfg.http_body = body; + } + std::string get_http_body() const { return this->ptr->check_cfg.http_body; } + + void set_check_interval(int interval) + { + this->ptr->check_cfg.interval = interval; + } + int get_check_interval() const { return this->ptr->check_cfg.interval; } + + void set_check_timeout(int timeout) + { + this->ptr->check_cfg.timeout = timeout; + } + int get_check_timeout() const { return this->ptr->check_cfg.timeout; } + + void set_check_notes(const std::string& notes) + { + this->ptr->check_cfg.notes = notes; + } + std::string get_check_notes() const { return this->ptr->check_cfg.notes; } + + void set_check_tcp(const std::string& tcp_address) + { + this->ptr->check_cfg.tcp_address = tcp_address; + } + std::string get_check_tcp() const { return this->ptr->check_cfg.tcp_address; } + + void set_initial_status(const std::string& initial_status) + { + this->ptr->check_cfg.initial_status = initial_status; + } + std::string get_initial_status() const + { + return this->ptr->check_cfg.initial_status; + } + + void set_auto_deregister_time(int milliseconds) + { + this->ptr->check_cfg.auto_deregister_time = milliseconds; + } + int get_auto_deregister_time() const + { + return this->ptr->check_cfg.auto_deregister_time; + } + + // set success times before passing, refer to success_before_passing, default:0 + void set_success_times(int times) + { + this->ptr->check_cfg.success_times = times; + } + int get_success_times() const { return this->ptr->check_cfg.success_times; } + + // set failure times before critical, refer to failures_before_critical, default:0 + void set_failure_times(int times) { this->ptr->check_cfg.failure_times = times; } + int get_failure_times() const { return this->ptr->check_cfg.failure_times; } + + void set_health_check(bool enable_flag) + { + this->ptr->check_cfg.health_check = enable_flag; + } + bool get_health_check() const + { + return this->ptr->check_cfg.health_check; + } + +public: + ConsulConfig() + { + this->ptr = new Config; + this->ptr->blocking_query = false; + this->ptr->passing = false; + this->ptr->replace_checks = false; + this->ptr->wait_ttl = 300 * 1000; + this->ptr->check_cfg.interval = 5000; + this->ptr->check_cfg.timeout = 10000; + this->ptr->check_cfg.http_method = "GET"; + this->ptr->check_cfg.initial_status = "critical"; + this->ptr->check_cfg.auto_deregister_time = 10 * 60 * 1000; + this->ptr->check_cfg.success_times = 0; + this->ptr->check_cfg.failure_times = 0; + this->ptr->check_cfg.health_check = false; + + this->ref = new std::atomic(1); + } + + virtual ~ConsulConfig() + { + if (--*this->ref == 0) + { + delete this->ptr; + delete this->ref; + } + } + + ConsulConfig(ConsulConfig&& move) + { + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new Config; + move.ref = new std::atomic(1); + } + + ConsulConfig(const ConsulConfig& copy) + { + this->ptr = copy.ptr; + this->ref = copy.ref; + ++(*this->ref); + } + + ConsulConfig& operator= (ConsulConfig&& move) + { + if (this != &move) + { + this->~ConsulConfig(); + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new Config; + move.ref = new std::atomic(1); + } + + return *this; + } + + ConsulConfig& operator= (const ConsulConfig& copy) + { + if (this != ©) + { + this->~ConsulConfig(); + this->ptr = copy.ptr; + this->ref = copy.ref; + ++(*this->ref); + } + + return *this; + } + +private: + // register health check config + struct HealthCheckConfig + { + std::string check_name; + std::string notes; + std::string http_url; + std::string http_method; + std::string http_body; + std::string tcp_address; + std::string initial_status; // passing or critical, default:critical + std::map> headers; + int auto_deregister_time; // refer to deregister_critical_service_after + int interval; + int timeout; // default 10000 + int success_times; // default:0 success times before passing + int failure_times; // default:0 failure_before_critical + bool health_check; + }; + + struct Config + { + // common config + std::string token; + + // discover config + std::string dc; + std::string near; + std::string filter; + int wait_ttl; + bool blocking_query; + bool passing; + + // register config + bool replace_checks; //refer to replace_existing_checks + HealthCheckConfig check_cfg; + }; + +private: + struct Config *ptr; + std::atomic *ref; +}; + +// k:address, v:port +using ConsulAddress = std::pair; + +struct ConsulService +{ + std::string service_name; + std::string service_namespace; + std::string service_id; + std::vector tags; + ConsulAddress service_address; + ConsulAddress lan; + ConsulAddress lan_ipv4; + ConsulAddress lan_ipv6; + ConsulAddress virtual_address; + ConsulAddress wan; + ConsulAddress wan_ipv4; + ConsulAddress wan_ipv6; + std::map meta; + bool tag_override; +}; + +struct ConsulServiceInstance +{ + // node info + std::string node_id; + std::string node_name; + std::string node_address; + std::string dc; + std::map node_meta; + long long create_index; + long long modify_index; + + // service info + struct ConsulService service; + + // service health check + std::string check_name; + std::string check_id; + std::string check_notes; + std::string check_output; + std::string check_status; + std::string check_type; +}; + +struct ConsulServiceTags +{ + std::string service_name; + std::vector tags; +}; + +} + +#endif diff --git a/src/protocol/DnsMessage.cc b/src/protocol/DnsMessage.cc new file mode 100644 index 0000000..5937217 --- /dev/null +++ b/src/protocol/DnsMessage.cc @@ -0,0 +1,357 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#include +#include +#include "DnsMessage.h" + +#define DNS_LABELS_MAX 63 +#define DNS_MESSAGE_MAX_UDP_SIZE 512 + +namespace protocol +{ + +static inline void __append_uint8(std::string& s, uint8_t tmp) +{ + s.append((const char *)&tmp, sizeof (uint8_t)); +} + +static inline void __append_uint16(std::string& s, uint16_t tmp) +{ + tmp = htons(tmp); + s.append((const char *)&tmp, sizeof (uint16_t)); +} + +static inline void __append_uint32(std::string& s, uint32_t tmp) +{ + tmp = htonl(tmp); + s.append((const char *)&tmp, sizeof (uint32_t)); +} + +static inline int __append_name(std::string& s, const char *p) +{ + const char *name; + size_t len; + + while (*p) + { + name = p; + while (*p && *p != '.') + p++; + + len = p - name; + if (len > DNS_LABELS_MAX || (len == 0 && *p && *(p + 1))) + { + errno = EINVAL; + return -1; + } + + if (len > 0) + { + __append_uint8(s, len); + s.append(name, len); + } + + if (*p == '.') + p++; + } + + len = 0; + __append_uint8(s, len); + + return 0; +} + +static inline int __append_record_list(std::string& s, int *count, + dns_record_cursor_t *cursor) +{ + int cnt = 0; + struct dns_record *record; + std::string record_buf; + std::string rdata_buf; + int ret; + + while (dns_record_cursor_next(&record, cursor) == 0) + { + record_buf.clear(); + ret = __append_name(record_buf, record->name); + if (ret < 0) + return ret; + + __append_uint16(record_buf, record->type); + __append_uint16(record_buf, record->rclass); + __append_uint32(record_buf, record->ttl); + + switch (record->type) + { + case DNS_TYPE_A: + case DNS_TYPE_AAAA: + __append_uint16(record_buf, record->rdlength); + record_buf.append((const char *)record->rdata, record->rdlength); + break; + + case DNS_TYPE_NS: + case DNS_TYPE_CNAME: + case DNS_TYPE_PTR: + rdata_buf.clear(); + ret = __append_name(rdata_buf, (const char *)record->rdata); + if (ret < 0) + return ret; + + __append_uint16(record_buf, rdata_buf.size()); + record_buf.append(rdata_buf); + + break; + + case DNS_TYPE_SOA: + { + auto *soa = (struct dns_record_soa *)record->rdata; + + rdata_buf.clear(); + ret = __append_name(rdata_buf, soa->mname); + if (ret < 0) + return ret; + ret = __append_name(rdata_buf, soa->rname); + if (ret < 0) + return ret; + + __append_uint32(rdata_buf, soa->serial); + __append_uint32(rdata_buf, soa->refresh); + __append_uint32(rdata_buf, soa->retry); + __append_uint32(rdata_buf, soa->expire); + __append_uint32(rdata_buf, soa->minimum); + + __append_uint16(record_buf, rdata_buf.size()); + record_buf.append(rdata_buf); + + break; + } + + case DNS_TYPE_SRV: + { + auto *srv = (struct dns_record_srv *)record->rdata; + + rdata_buf.clear(); + __append_uint16(rdata_buf, srv->priority); + __append_uint16(rdata_buf, srv->weight); + __append_uint16(rdata_buf, srv->port); + ret = __append_name(rdata_buf, srv->target); + if (ret < 0) + return ret; + + __append_uint16(record_buf, rdata_buf.size()); + record_buf.append(rdata_buf); + + break; + } + case DNS_TYPE_MX: + { + auto *mx = (struct dns_record_mx *)record->rdata; + rdata_buf.clear(); + __append_uint16(rdata_buf, mx->preference); + ret = __append_name(rdata_buf, mx->exchange); + if (ret < 0) + return ret; + + __append_uint16(record_buf, rdata_buf.size()); + record_buf.append(rdata_buf); + + break; + } + default: + // TODO not implement + continue; + } + + cnt++; + s.append(record_buf); + } + + if (count) + *count = cnt; + + return 0; +} + +DnsMessage::DnsMessage(DnsMessage&& msg) : + ProtocolMessage(std::move(msg)) +{ + this->parser = msg.parser; + msg.parser = NULL; + + this->cur_size = msg.cur_size; + msg.cur_size = 0; +} + +DnsMessage& DnsMessage::operator = (DnsMessage&& msg) +{ + if (&msg != this) + { + *(ProtocolMessage *)this = std::move(msg); + + if (this->parser) + { + dns_parser_deinit(this->parser); + delete this->parser; + } + + this->parser = msg.parser; + msg.parser = NULL; + + this->cur_size = msg.cur_size; + msg.cur_size = 0; + } + return *this; +} + +int DnsMessage::encode_reply() +{ + dns_record_cursor_t cursor; + struct dns_header h; + std::string tmpbuf; + const char *p; + int ancount; + int nscount; + int arcount; + int ret; + + msgbuf.clear(); + msgsize = 0; + + // TODO + // this is an incomplete and inefficient way, compress not used, + // pointers can only be used for occurances of a domain name where + // the format is not class specific + dns_answer_cursor_init(&cursor, this->parser); + ret = __append_record_list(tmpbuf, &ancount, &cursor); + dns_record_cursor_deinit(&cursor); + if (ret < 0) + return ret; + + dns_authority_cursor_init(&cursor, this->parser); + ret = __append_record_list(tmpbuf, &nscount, &cursor); + dns_record_cursor_deinit(&cursor); + if (ret < 0) + return ret; + + dns_additional_cursor_init(&cursor, this->parser); + ret = __append_record_list(tmpbuf, &arcount, &cursor); + dns_record_cursor_deinit(&cursor); + if (ret < 0) + return ret; + + h = this->parser->header; + h.id = htons(h.id); + h.qdcount = htons(1); + h.ancount = htons(ancount); + h.nscount = htons(nscount); + h.arcount = htons(arcount); + + msgbuf.append((const char *)&h, sizeof (struct dns_header)); + p = parser->question.qname ? parser->question.qname : "."; + ret = __append_name(msgbuf, p); + if (ret < 0) + return ret; + + __append_uint16(msgbuf, parser->question.qtype); + __append_uint16(msgbuf, parser->question.qclass); + + msgbuf.append(tmpbuf); + + if (msgbuf.size() >= (1 << 16)) + { + errno = EOVERFLOW; + return -1; + } + + msgsize = htons(msgbuf.size()); + + return 0; +} + +int DnsMessage::encode(struct iovec vectors[], int) +{ + struct iovec *p = vectors; + + if (this->encode_reply() < 0) + return -1; + + // TODO + // if this is a request, it won't exceed the 512 bytes UDP limit + // if this is a response and exceed 512 bytes, we need a TrunCation reply + + if (!this->is_single_packet()) + { + p->iov_base = &this->msgsize; + p->iov_len = sizeof (uint16_t); + p++; + } + + p->iov_base = (void *)this->msgbuf.data(); + p->iov_len = msgbuf.size(); + return p - vectors + 1; +} + +int DnsMessage::append(const void *buf, size_t *size) +{ + int ret = dns_parser_append_message(buf, size, this->parser); + + if (ret >= 0) + { + this->cur_size += *size; + if (this->cur_size > this->size_limit) + { + errno = EMSGSIZE; + ret = -1; + } + } + else if (ret == -2) + { + errno = EBADMSG; + ret = -1; + } + + return ret; +} + +int DnsResponse::append(const void *buf, size_t *size) +{ + int ret = this->DnsMessage::append(buf, size); + const char *qname = this->parser->question.qname; + + if (ret >= 1 && (this->request_id != this->get_id() || + strcasecmp(this->request_name.c_str(), qname) != 0)) + { + if (!this->is_single_packet()) + { + errno = EBADMSG; + ret = -1; + } + else + { + dns_parser_deinit(this->parser); + dns_parser_init(this->parser); + ret = 0; + } + } + + return ret; +} + +} + diff --git a/src/protocol/DnsMessage.h b/src/protocol/DnsMessage.h new file mode 100644 index 0000000..c9e191a --- /dev/null +++ b/src/protocol/DnsMessage.h @@ -0,0 +1,371 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#ifndef _DNSMESSAGE_H_ +#define _DNSMESSAGE_H_ + +/** + * @file DnsMessage.h + * @brief Dns Protocol Interface + */ + +#include +#include +#include "ProtocolMessage.h" +#include "dns_parser.h" + +namespace protocol +{ + +class DnsMessage : public ProtocolMessage +{ +protected: + virtual int encode(struct iovec vectors[], int max); + virtual int append(const void *buf, size_t *size); + +public: + int get_id() const + { + return parser->header.id; + } + int get_qr() const + { + return parser->header.qr; + } + int get_opcode() const + { + return parser->header.opcode; + } + int get_aa() const + { + return parser->header.aa; + } + int get_tc() const + { + return parser->header.tc; + } + int get_rd() const + { + return parser->header.rd; + } + int get_ra() const + { + return parser->header.ra; + } + int get_rcode() const + { + return parser->header.rcode; + } + int get_qdcount() const + { + return parser->header.qdcount; + } + int get_ancount() const + { + return parser->header.ancount; + } + int get_nscount() const + { + return parser->header.nscount; + } + int get_arcount() const + { + return parser->header.arcount; + } + + void set_id(int id) + { + parser->header.id = id; + } + void set_qr(int qr) + { + parser->header.qr = qr; + } + void set_opcode(int opcode) + { + parser->header.opcode = opcode; + } + void set_aa(int aa) + { + parser->header.aa = aa; + } + void set_tc(int tc) + { + parser->header.tc = tc; + } + void set_rd(int rd) + { + parser->header.rd = rd; + } + void set_ra(int ra) + { + parser->header.ra = ra; + } + void set_rcode(int rcode) + { + parser->header.rcode = rcode; + } + + int get_question_type() const + { + return parser->question.qtype; + } + int get_question_class() const + { + return parser->question.qclass; + } + std::string get_question_name() const + { + const char *name = parser->question.qname; + if (name == NULL) + return ""; + return name; + } + void set_question_type(int qtype) + { + parser->question.qtype = qtype; + } + void set_question_class(int qclass) + { + parser->question.qclass = qclass; + } + void set_question_name(const std::string& name) + { + char *pname = parser->question.qname; + if (pname != NULL) + free(pname); + parser->question.qname = strdup(name.c_str()); + } + + int add_a_record(int section, const char *name, + uint16_t rclass, uint32_t ttl, + const void *data) + { + struct list_head *list = get_section(section); + + if (!list) + return -1; + + return dns_add_raw_record(name, DNS_TYPE_A, rclass, ttl, 4, data, list); + } + + int add_aaaa_record(int section, const char *name, + uint16_t rclass, uint32_t ttl, + const void *data) + { + struct list_head *list = get_section(section); + + if (!list) + return -1; + + return dns_add_raw_record(name, DNS_TYPE_AAAA, rclass, ttl, + 16, data, list); + } + + int add_ns_record(int section, const char *name, + uint16_t rclass, uint32_t ttl, + const char *data) + { + struct list_head *list = get_section(section); + + if (!list) + return -1; + + return dns_add_str_record(name, DNS_TYPE_NS, rclass, ttl, data, list); + } + + int add_cname_record(int section, const char *name, + uint16_t rclass, uint32_t ttl, + const char *data) + { + struct list_head *list = get_section(section); + + if (!list) + return -1; + + return dns_add_str_record(name, DNS_TYPE_CNAME, rclass, ttl, + data, list); + } + + int add_ptr_record(int section, const char *name, + uint16_t rclass, uint32_t ttl, + const char *data) + { + struct list_head *list = get_section(section); + + if (!list) + return -1; + + return dns_add_str_record(name, DNS_TYPE_PTR, rclass, ttl, data, list); + } + + int add_soa_record(int section, const char *name, + uint16_t rclass, uint32_t ttl, + const char *mname, const char *rname, + uint32_t serial, int32_t refresh, + int32_t retry, int32_t expire, uint32_t minimum) + { + struct list_head *list = get_section(section); + + if (!list) + return -1; + + return dns_add_soa_record(name, rclass, ttl, mname, rname, serial, + refresh, retry, expire, minimum, list); + } + + int add_srv_record(int section, const char *name, + uint16_t rclass, uint32_t ttl, + uint16_t priority, uint16_t weight, uint16_t port, + const char *target) + { + struct list_head *list = get_section(section); + + if (!list) + return -1; + + return dns_add_srv_record(name, rclass, ttl, + priority, weight, port, target, list); + } + + int add_mx_record(int section, const char *name, + uint16_t rclass, uint32_t ttl, + int16_t preference, const char *exchange) + { + struct list_head *list = get_section(section); + + if (!list) + return -1; + + return dns_add_mx_record(name, rclass, ttl, preference, exchange, list); + } + + // Inner use only + bool is_single_packet() const + { + return parser->single_packet; + } + void set_single_packet(bool single) + { + parser->single_packet = single; + } + +public: + DnsMessage() : + parser(new dns_parser_t) + { + dns_parser_init(parser); + this->cur_size = 0; + } + + virtual ~DnsMessage() + { + if (this->parser) + { + dns_parser_deinit(parser); + delete this->parser; + } + } + + DnsMessage(DnsMessage&& msg); + DnsMessage& operator = (DnsMessage&& msg); + +protected: + dns_parser_t *parser; + std::string msgbuf; + size_t cur_size; + +private: + int encode_reply(); + int encode_truncation_reply(); + + struct list_head *get_section(int section) + { + switch (section) + { + case DNS_ANSWER_SECTION: + return &(parser->answer_list); + case DNS_AUTHORITY_SECTION: + return &(parser->authority_list); + case DNS_ADDITIONAL_SECTION: + return &(parser->additional_list); + default: + errno = EINVAL; + return NULL; + } + } + + // size of msgbuf, but in network byte order + uint16_t msgsize; +}; + +class DnsRequest : public DnsMessage +{ +public: + DnsRequest() = default; + DnsRequest(DnsRequest&& req) = default; + DnsRequest& operator = (DnsRequest&& req) = default; + + void set_question(const char *host, uint16_t qtype, uint16_t qclass) + { + dns_parser_set_question(host, qtype, qclass, this->parser); + } +}; + +class DnsResponse : public DnsMessage +{ +public: + DnsResponse() + { + this->request_id = 0; + } + + DnsResponse(DnsResponse&& req) = default; + DnsResponse& operator = (DnsResponse&& req) = default; + + const dns_parser_t *get_parser() const + { + return this->parser; + } + + void set_request_id(uint16_t id) + { + this->request_id = id; + } + + void set_request_name(const std::string& name) + { + std::string& req_name = this->request_name; + + req_name = name; + while (req_name.size() > 1 && req_name.back() == '.') + req_name.pop_back(); + } + +protected: + virtual int append(const void *buf, size_t *size); + +private: + uint16_t request_id; + std::string request_name; +}; + +} + +#endif + diff --git a/src/protocol/DnsUtil.cc b/src/protocol/DnsUtil.cc new file mode 100644 index 0000000..8909a52 --- /dev/null +++ b/src/protocol/DnsUtil.cc @@ -0,0 +1,147 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include "DnsUtil.h" + +namespace protocol +{ + +int DnsUtil::getaddrinfo(const DnsResponse *resp, + unsigned short port, + struct addrinfo **addrinfo) +{ + int ancount = resp->get_ancount(); + int rcode = resp->get_rcode(); + int status = 0; + struct addrinfo *res = NULL; + struct addrinfo **pres = &res; + struct dns_record *record; + struct addrinfo *ai; + std::string qname; + const char *cname; + int family; + int addrlen; + + switch (rcode) + { + case DNS_RCODE_NAME_ERROR: + status = EAI_NONAME; + break; + case DNS_RCODE_SERVER_FAILURE: + status = EAI_AGAIN; + break; + case DNS_RCODE_FORMAT_ERROR: + case DNS_RCODE_NOT_IMPLEMENTED: + case DNS_RCODE_REFUSED: + status = EAI_FAIL; + break; + } + + qname = resp->get_question_name(); + cname = qname.c_str(); + + DnsResultCursor cursor(resp); + cursor.reset_answer_cursor(); + /* Forbid loop in cname chain */ + while (cursor.find_cname(cname, &cname) && ancount-- > 0) { } + + if (rcode == DNS_RCODE_NO_ERROR && ancount <= 0) + status = EAI_NODATA; + if (status != 0) + return status; + + cursor.reset_answer_cursor(); + while (cursor.next(&record)) + { + if (!(record->rclass == DNS_CLASS_IN && + (record->type == DNS_TYPE_A || record->type == DNS_TYPE_AAAA) && + strcasecmp(record->name, cname) == 0)) + continue; + + if (record->type == DNS_TYPE_A) + { + family = AF_INET; + addrlen = sizeof (struct sockaddr_in); + } + else + { + family = AF_INET6; + addrlen = sizeof (struct sockaddr_in6); + } + + ai = (struct addrinfo *)calloc(sizeof (struct addrinfo) + addrlen, 1); + if (ai == NULL) + { + if (res) + DnsUtil::freeaddrinfo(res); + return EAI_SYSTEM; + } + + ai->ai_family = family; + ai->ai_addrlen = addrlen; + ai->ai_addr = (struct sockaddr *)(ai + 1); + ai->ai_addr->sa_family = family; + + if (family == AF_INET) + { + struct sockaddr_in *in = (struct sockaddr_in *)(ai->ai_addr); + in->sin_port = htons(port); + memcpy(&in->sin_addr, record->rdata, sizeof (struct in_addr)); + } + else + { + struct sockaddr_in6 *in = (struct sockaddr_in6 *)(ai->ai_addr); + in->sin6_port = htons(port); + memcpy(&in->sin6_addr, record->rdata, sizeof (struct in6_addr)); + } + + *pres = ai; + pres = &ai->ai_next; + } + + if (res == NULL) + return EAI_NODATA; + + if (cname) + res->ai_canonname = strdup(cname); + + *addrinfo = res; + + return 0; +} + +void DnsUtil::freeaddrinfo(struct addrinfo *ai) +{ + struct addrinfo *p; + + while (ai != NULL) + { + p = ai; + ai = ai->ai_next; + free(p->ai_canonname); + free(p); + } +} + +} + diff --git a/src/protocol/DnsUtil.h b/src/protocol/DnsUtil.h new file mode 100644 index 0000000..4f1e41a --- /dev/null +++ b/src/protocol/DnsUtil.h @@ -0,0 +1,97 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#ifndef _DNSUTIL_H_ +#define _DNSUTIL_H_ + +#include +#include "DnsMessage.h" + +/** + * @file DnsUtil.h + * @brief Dns toolbox + */ + +namespace protocol +{ + +class DnsUtil +{ +public: + static int getaddrinfo(const DnsResponse *resp, + unsigned short port, + struct addrinfo **res); + static void freeaddrinfo(struct addrinfo *ai); +}; + +class DnsResultCursor +{ +public: + DnsResultCursor(const DnsResponse *resp) : + parser(resp->get_parser()) + { + dns_answer_cursor_init(&cursor, parser); + record = NULL; + } + + DnsResultCursor(DnsResultCursor&& move) = delete; + DnsResultCursor& operator=(DnsResultCursor&& move) = delete; + + virtual ~DnsResultCursor() { } + + void reset_answer_cursor() + { + dns_answer_cursor_init(&cursor, parser); + } + + void reset_authority_cursor() + { + dns_authority_cursor_init(&cursor, parser); + } + + void reset_additional_cursor() + { + dns_additional_cursor_init(&cursor, parser); + } + + bool next(struct dns_record **next_record) + { + int ret = dns_record_cursor_next(&record, &cursor); + if (ret != 0) + record = NULL; + else + *next_record = record; + + return ret == 0; + } + + bool find_cname(const char *name, const char **cname) + { + return dns_record_cursor_find_cname(name, cname, &cursor) == 0; + } + +private: + const dns_parser_t *parser; + dns_record_cursor_t cursor; + struct dns_record *record; +}; + +} + +#endif + diff --git a/src/protocol/HttpMessage.cc b/src/protocol/HttpMessage.cc new file mode 100644 index 0000000..1030380 --- /dev/null +++ b/src/protocol/HttpMessage.cc @@ -0,0 +1,402 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include "HttpMessage.h" + +namespace protocol +{ + +struct HttpMessageBlock +{ + struct list_head list; + const void *ptr; + size_t size; +}; + +bool HttpMessage::append_output_body(const void *buf, size_t size) +{ + size_t n = sizeof (struct HttpMessageBlock) + size; + struct HttpMessageBlock *block = (struct HttpMessageBlock *)malloc(n); + + if (block) + { + memcpy(block + 1, buf, size); + block->ptr = block + 1; + block->size = size; + list_add_tail(&block->list, &this->output_body); + this->output_body_size += size; + return true; + } + + return false; +} + +bool HttpMessage::append_output_body_nocopy(const void *buf, size_t size) +{ + size_t n = sizeof (struct HttpMessageBlock); + struct HttpMessageBlock *block = (struct HttpMessageBlock *)malloc(n); + + if (block) + { + block->ptr = buf; + block->size = size; + list_add_tail(&block->list, &this->output_body); + this->output_body_size += size; + return true; + } + + return false; +} + +size_t HttpMessage::get_output_body_blocks(const void *buf[], size_t size[], + size_t max) const +{ + struct HttpMessageBlock *block; + struct list_head *pos; + size_t n = 0; + + list_for_each(pos, &this->output_body) + { + if (n == max) + break; + + block = list_entry(pos, struct HttpMessageBlock, list); + buf[n] = block->ptr; + size[n] = block->size; + n++; + } + + return n; +} + +bool HttpMessage::get_output_body_merged(void *buf, size_t *size) const +{ + struct HttpMessageBlock *block; + struct list_head *pos; + + if (*size < this->output_body_size) + { + errno = ENOSPC; + return false; + } + + list_for_each(pos, &this->output_body) + { + block = list_entry(pos, struct HttpMessageBlock, list); + memcpy(buf, block->ptr, block->size); + buf = (char *)buf + block->size; + } + + *size = this->output_body_size; + return true; +} + +void HttpMessage::clear_output_body() +{ + struct HttpMessageBlock *block; + struct list_head *pos, *tmp; + + list_for_each_safe(pos, tmp, &this->output_body) + { + block = list_entry(pos, struct HttpMessageBlock, list); + list_del(pos); + free(block); + } + + this->output_body_size = 0; +} + +struct list_head *HttpMessage::combine_from(struct list_head *pos, size_t size) +{ + size_t n = sizeof (struct HttpMessageBlock) + size; + struct HttpMessageBlock *block = (struct HttpMessageBlock *)malloc(n); + struct HttpMessageBlock *entry; + char *ptr; + + if (block) + { + block->ptr = block + 1; + block->size = size; + ptr = (char *)block->ptr; + + do + { + entry = list_entry(pos, struct HttpMessageBlock, list); + pos = pos->next; + list_del(&entry->list); + memcpy(ptr, entry->ptr, entry->size); + ptr += entry->size; + free(entry); + } while (pos != &this->output_body); + + list_add_tail(&block->list, &this->output_body); + return &block->list; + } + + return NULL; +} + +int HttpMessage::encode(struct iovec vectors[], int max) +{ + const char *start_line[3]; + http_header_cursor_t cursor; + struct HttpMessageHeader header; + struct HttpMessageBlock *block; + struct list_head *pos; + size_t size; + int i; + + start_line[0] = http_parser_get_method(this->parser); + if (start_line[0]) + { + start_line[1] = http_parser_get_uri(this->parser); + start_line[2] = http_parser_get_version(this->parser); + } + else + { + start_line[0] = http_parser_get_version(this->parser); + start_line[1] = http_parser_get_code(this->parser); + start_line[2] = http_parser_get_phrase(this->parser); + } + + if (!start_line[0] || !start_line[1] || !start_line[2]) + { + errno = EBADMSG; + return -1; + } + + vectors[0].iov_base = (void *)start_line[0]; + vectors[0].iov_len = strlen(start_line[0]); + vectors[1].iov_base = (void *)" "; + vectors[1].iov_len = 1; + + vectors[2].iov_base = (void *)start_line[1]; + vectors[2].iov_len = strlen(start_line[1]); + vectors[3].iov_base = (void *)" "; + vectors[3].iov_len = 1; + + vectors[4].iov_base = (void *)start_line[2]; + vectors[4].iov_len = strlen(start_line[2]); + vectors[5].iov_base = (void *)"\r\n"; + vectors[5].iov_len = 2; + + i = 6; + http_header_cursor_init(&cursor, this->parser); + while (http_header_cursor_next(&header.name, &header.name_len, + &header.value, &header.value_len, + &cursor) == 0) + { + if (i == max) + break; + + vectors[i].iov_base = (void *)header.name; + vectors[i].iov_len = header.name_len + 2 + header.value_len + 2; + i++; + } + + http_header_cursor_deinit(&cursor); + if (i + 1 >= max) + { + errno = EOVERFLOW; + return -1; + } + + vectors[i].iov_base = (void *)"\r\n"; + vectors[i].iov_len = 2; + i++; + + size = this->output_body_size; + list_for_each(pos, &this->output_body) + { + if (i + 1 == max && pos != this->output_body.prev) + { + pos = this->combine_from(pos, size); + if (!pos) + return -1; + } + + block = list_entry(pos, struct HttpMessageBlock, list); + vectors[i].iov_base = (void *)block->ptr; + vectors[i].iov_len = block->size; + size -= block->size; + i++; + } + + return i; +} + +inline int HttpMessage::append(const void *buf, size_t *size) +{ + int ret = http_parser_append_message(buf, size, this->parser); + + if (ret >= 0) + { + this->cur_size += *size; + if (this->cur_size > this->size_limit) + { + errno = EMSGSIZE; + ret = -1; + } + } + else if (ret == -2) + { + errno = EBADMSG; + ret = -1; + } + + return ret; +} + +HttpMessage::HttpMessage(HttpMessage&& msg) : + ProtocolMessage(std::move(msg)) +{ + this->parser = msg.parser; + msg.parser = NULL; + + INIT_LIST_HEAD(&this->output_body); + list_splice_init(&msg.output_body, &this->output_body); + this->output_body_size = msg.output_body_size; + msg.output_body_size = 0; + + this->cur_size = msg.cur_size; + msg.cur_size = 0; +} + +HttpMessage& HttpMessage::operator = (HttpMessage&& msg) +{ + if (&msg != this) + { + *(ProtocolMessage *)this = std::move(msg); + + if (this->parser) + { + http_parser_deinit(this->parser); + delete this->parser; + } + + this->parser = msg.parser; + msg.parser = NULL; + + this->clear_output_body(); + list_splice_init(&msg.output_body, &this->output_body); + this->output_body_size = msg.output_body_size; + msg.output_body_size = 0; + + this->cur_size = msg.cur_size; + msg.cur_size = 0; + } + + return *this; +} + +#define HTTP_100_STATUS_LINE "HTTP/1.1 100 Continue" +#define HTTP_400_STATUS_LINE "HTTP/1.1 400 Bad Request" +#define HTTP_413_STATUS_LINE "HTTP/1.1 413 Request Entity Too Large" +#define HTTP_417_STATUS_LINE "HTTP/1.1 417 Expectation Failed" +#define CONTENT_LENGTH_ZERO "Content-Length: 0" +#define CONNECTION_CLOSE "Connection: close" +#define CRLF "\r\n" + +#define HTTP_100_RESP HTTP_100_STATUS_LINE CRLF \ + CRLF +#define HTTP_400_RESP HTTP_400_STATUS_LINE CRLF \ + CONTENT_LENGTH_ZERO CRLF \ + CONNECTION_CLOSE CRLF \ + CRLF +#define HTTP_413_RESP HTTP_413_STATUS_LINE CRLF \ + CONTENT_LENGTH_ZERO CRLF \ + CONNECTION_CLOSE CRLF \ + CRLF +#define HTTP_417_RESP HTTP_417_STATUS_LINE CRLF \ + CONTENT_LENGTH_ZERO CRLF \ + CONNECTION_CLOSE CRLF \ + CRLF + +int HttpRequest::handle_expect_continue() +{ + size_t trans_len = this->parser->transfer_length; + int ret; + + if (trans_len != (size_t)-1) + { + if (this->parser->header_offset + trans_len > this->size_limit) + { + this->feedback(HTTP_417_RESP, strlen(HTTP_417_RESP)); + errno = EMSGSIZE; + return -1; + } + } + + ret = this->feedback(HTTP_100_RESP, strlen(HTTP_100_RESP)); + if (ret != strlen(HTTP_100_RESP)) + { + if (ret >= 0) + errno = ENOBUFS; + return -1; + } + + return 0; +} + +int HttpRequest::append(const void *buf, size_t *size) +{ + int ret = HttpMessage::append(buf, size); + + if (ret == 0) + { + if (this->parser->expect_continue && + http_parser_header_complete(this->parser)) + { + this->parser->expect_continue = 0; + ret = this->handle_expect_continue(); + } + } + else if (ret < 0) + { + if (errno == EBADMSG) + this->feedback(HTTP_400_RESP, strlen(HTTP_400_RESP)); + else if (errno == EMSGSIZE) + this->feedback(HTTP_413_RESP, strlen(HTTP_413_RESP)); + } + + return ret; +} + +int HttpResponse::append(const void *buf, size_t *size) +{ + int ret = HttpMessage::append(buf, size); + + if (ret > 0) + { + if (strcmp(http_parser_get_code(this->parser), "100") == 0) + { + http_parser_deinit(this->parser); + http_parser_init(1, this->parser); + ret = 0; + } + } + + return ret; +} + +} + diff --git a/src/protocol/HttpMessage.h b/src/protocol/HttpMessage.h new file mode 100644 index 0000000..a2589f8 --- /dev/null +++ b/src/protocol/HttpMessage.h @@ -0,0 +1,410 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _HTTPMESSAGE_H_ +#define _HTTPMESSAGE_H_ + +#include +#include +#include +#include "list.h" +#include "ProtocolMessage.h" +#include "http_parser.h" + +/** + * @file HttpMessage.h + * @brief Http Protocol Interface + */ + +namespace protocol +{ + +struct HttpMessageHeader +{ + const void *name; + size_t name_len; + const void *value; + size_t value_len; +}; + +class HttpMessage : public ProtocolMessage +{ +public: + const char *get_http_version() const + { + return http_parser_get_version(this->parser); + } + + bool set_http_version(const char *version) + { + return http_parser_set_version(version, this->parser) == 0; + } + + bool is_chunked() const + { + return http_parser_chunked(this->parser); + } + + bool is_keep_alive() const + { + return http_parser_keep_alive(this->parser); + } + + bool add_header(const struct HttpMessageHeader *header) + { + return http_parser_add_header(header->name, header->name_len, + header->value, header->value_len, + this->parser) == 0; + } + + bool add_header_pair(const char *name, const char *value) + { + return http_parser_add_header(name, strlen(name), + value, strlen(value), + this->parser) == 0; + } + + bool set_header(const struct HttpMessageHeader *header) + { + return http_parser_set_header(header->name, header->name_len, + header->value, header->value_len, + this->parser) == 0; + } + + bool set_header_pair(const char *name, const char *value) + { + return http_parser_set_header(name, strlen(name), + value, strlen(value), + this->parser) == 0; + } + + bool get_parsed_body(const void **body, size_t *size) const + { + return http_parser_get_body(body, size, this->parser) == 0; + } + + /* Output body is for sending. Want to transfer a message received, maybe: + * msg->get_parsed_body(&body, &size); + * msg->append_output_body_nocopy(body, size); */ + bool append_output_body(const void *buf, size_t size); + + bool append_output_body(const char *buf) + { + return this->append_output_body(buf, strlen(buf)); + } + + bool append_output_body_nocopy(const void *buf, size_t size); + + bool append_output_body_nocopy(const char *buf) + { + return this->append_output_body_nocopy(buf, strlen(buf)); + } + + size_t get_output_body_size() const + { + return this->output_body_size; + } + + size_t get_output_body_blocks(const void *buf[], size_t size[], + size_t max) const; + + bool get_output_body_merged(void *buf, size_t *size) const; + + void clear_output_body(); + + /* std::string interfaces */ +public: + bool get_http_version(std::string& version) const + { + const char *str = this->get_http_version(); + + if (str) + { + version.assign(str); + return true; + } + + return false; + } + + bool set_http_version(const std::string& version) + { + return this->set_http_version(version.c_str()); + } + + bool add_header_pair(const std::string& name, const std::string& value) + { + return http_parser_add_header(name.c_str(), name.size(), + value.c_str(), value.size(), + this->parser) == 0; + } + + bool set_header_pair(const std::string& name, const std::string& value) + { + return http_parser_set_header(name.c_str(), name.size(), + value.c_str(), value.size(), + this->parser) == 0; + } + + bool append_output_body(const std::string& buf) + { + return this->append_output_body(buf.c_str(), buf.size()); + } + + bool append_output_body_nocopy(const std::string& buf) + { + return this->append_output_body_nocopy(buf.c_str(), buf.size()); + } + + bool get_output_body_merged(std::string& body) const + { + size_t size = this->output_body_size; + body.resize(size); + return this->get_output_body_merged((void *)body.data(), &size); + } + + /* for http task implementations. */ +public: + bool is_header_complete() const + { + return http_parser_header_complete(this->parser); + } + + bool has_connection_header() const + { + return http_parser_has_connection(this->parser); + } + + bool has_content_length_header() const + { + return http_parser_has_content_length(this->parser); + } + + bool has_keep_alive_header() const + { + return http_parser_has_keep_alive(this->parser); + } + + void end_parsing() + { + http_parser_close_message(this->parser); + } + + /* for header cursor implementations. */ + const http_parser_t *get_parser() const + { + return this->parser; + } + +protected: + virtual int encode(struct iovec vectors[], int max); + virtual int append(const void *buf, size_t *size); + +protected: + http_parser_t *parser; + size_t cur_size; + +private: + struct list_head *combine_from(struct list_head *pos, size_t size); + +private: + struct list_head output_body; + size_t output_body_size; + +public: + HttpMessage(bool is_resp) : parser(new http_parser_t) + { + http_parser_init(is_resp, this->parser); + INIT_LIST_HEAD(&this->output_body); + this->output_body_size = 0; + this->cur_size = 0; + } + + virtual ~HttpMessage() + { + this->clear_output_body(); + if (this->parser) + { + http_parser_deinit(this->parser); + delete this->parser; + } + } + +public: + HttpMessage(HttpMessage&& msg); + HttpMessage& operator = (HttpMessage&& msg); +}; + +class HttpRequest : public HttpMessage +{ +public: + const char *get_method() const + { + return http_parser_get_method(this->parser); + } + + const char *get_request_uri() const + { + return http_parser_get_uri(this->parser); + } + + bool set_method(const char *method) + { + return http_parser_set_method(method, this->parser) == 0; + } + + bool set_request_uri(const char *uri) + { + return http_parser_set_uri(uri, this->parser) == 0; + } + + /* std::string interfaces */ +public: + bool get_method(std::string& method) const + { + const char *str = this->get_method(); + + if (str) + { + method.assign(str); + return true; + } + + return false; + } + + bool get_request_uri(std::string& uri) const + { + const char *str = this->get_request_uri(); + + if (str) + { + uri.assign(str); + return true; + } + + return false; + } + + bool set_method(const std::string& method) + { + return this->set_method(method.c_str()); + } + + bool set_request_uri(const std::string& uri) + { + return this->set_request_uri(uri.c_str()); + } + +protected: + virtual int append(const void *buf, size_t *size); + +private: + int handle_expect_continue(); + +public: + HttpRequest() : HttpMessage(false) { } + +public: + HttpRequest(HttpRequest&& req) = default; + HttpRequest& operator = (HttpRequest&& req) = default; +}; + +class HttpResponse : public HttpMessage +{ +public: + const char *get_status_code() const + { + return http_parser_get_code(this->parser); + } + + const char *get_reason_phrase() const + { + return http_parser_get_phrase(this->parser); + } + + bool set_status_code(const char *code) + { + return http_parser_set_code(code, this->parser) == 0; + } + + bool set_reason_phrase(const char *phrase) + { + return http_parser_set_phrase(phrase, this->parser) == 0; + } + + /* std::string interfaces */ +public: + bool get_status_code(std::string& code) const + { + const char *str = this->get_status_code(); + + if (str) + { + code.assign(str); + return true; + } + + return false; + } + + bool get_reason_phrase(std::string& phrase) const + { + const char *str = this->get_reason_phrase(); + + if (str) + { + phrase.assign(str); + return true; + } + + return false; + } + + bool set_status_code(const std::string& code) + { + return this->set_status_code(code.c_str()); + } + + bool set_reason_phrase(const std::string& phrase) + { + return this->set_reason_phrase(phrase.c_str()); + } + +public: + /* Tell the parser, it is a HEAD response. For implementations. */ + void parse_zero_body() + { + this->parser->transfer_length = 0; + } + +protected: + virtual int append(const void *buf, size_t *size); + +public: + HttpResponse() : HttpMessage(true) { } + +public: + HttpResponse(HttpResponse&& resp) = default; + HttpResponse& operator = (HttpResponse&& resp) = default; +}; + +} + +#endif + diff --git a/src/protocol/HttpUtil.cc b/src/protocol/HttpUtil.cc new file mode 100644 index 0000000..58244ca --- /dev/null +++ b/src/protocol/HttpUtil.cc @@ -0,0 +1,487 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#include +#include +#include +#include "http_parser.h" +#include "HttpMessage.h" +#include "HttpUtil.h" + +namespace protocol +{ + +HttpHeaderMap::HttpHeaderMap(const HttpMessage *message) +{ + http_header_cursor_t cursor; + struct HttpMessageHeader header; + + http_header_cursor_init(&cursor, message->get_parser()); + while (http_header_cursor_next(&header.name, &header.name_len, + &header.value, &header.value_len, + &cursor) == 0) + { + std::string key((const char *)header.name, header.name_len); + + std::transform(key.begin(), key.end(), key.begin(), ::tolower); + header_map_[key].emplace_back((const char *)header.value, header.value_len); + } + + http_header_cursor_deinit(&cursor); +} + +bool HttpHeaderMap::key_exists(std::string key) +{ + std::transform(key.begin(), key.end(), key.begin(), ::tolower); + return header_map_.count(key) > 0; +} + +std::string HttpHeaderMap::get(std::string key) +{ + std::transform(key.begin(), key.end(), key.begin(), ::tolower); + const auto it = header_map_.find(key); + + if (it == header_map_.end() || it->second.empty()) + return std::string(); + + return it->second[0]; +} + +bool HttpHeaderMap::get(std::string key, std::string& value) +{ + std::transform(key.begin(), key.end(), key.begin(), ::tolower); + const auto it = header_map_.find(key); + + if (it == header_map_.end() || it->second.empty()) + return false; + + value = it->second[0]; + return true; +} + +std::vector HttpHeaderMap::get_strict(std::string key) +{ + std::transform(key.begin(), key.end(), key.begin(), ::tolower); + return header_map_[key]; +} + +bool HttpHeaderMap::get_strict(std::string key, std::vector& values) +{ + std::transform(key.begin(), key.end(), key.begin(), ::tolower); + const auto it = header_map_.find(key); + + if (it == header_map_.end() || it->second.empty()) + return false; + + values = it->second; + return true; +} + +std::string HttpUtil::decode_chunked_body(const HttpMessage *msg) +{ + const void *body; + size_t body_len; + const void *chunk; + size_t chunk_size; + std::string decode_result; + HttpChunkCursor cursor(msg); + + if (msg->get_parsed_body(&body, &body_len)) + { + decode_result.reserve(body_len); + while (cursor.next(&chunk, &chunk_size)) + decode_result.append((const char *)chunk, chunk_size); + } + + return decode_result; +} + +void HttpUtil::set_response_status(HttpResponse *resp, int status_code) +{ + char buf[32]; + sprintf(buf, "%d", status_code); + resp->set_status_code(buf); + + switch (status_code) + { + case 100: + resp->set_reason_phrase("Continue"); + break; + + case 101: + resp->set_reason_phrase("Switching Protocols"); + break; + + case 102: + resp->set_reason_phrase("Processing"); + break; + + case 200: + resp->set_reason_phrase("OK"); + break; + + case 201: + resp->set_reason_phrase("Created"); + break; + + case 202: + resp->set_reason_phrase("Accepted"); + break; + + case 203: + resp->set_reason_phrase("Non-Authoritative Information"); + break; + + case 204: + resp->set_reason_phrase("No Content"); + break; + + case 205: + resp->set_reason_phrase("Reset Content"); + break; + + case 206: + resp->set_reason_phrase("Partial Content"); + break; + + case 207: + resp->set_reason_phrase("Multi-Status"); + break; + + case 208: + resp->set_reason_phrase("Already Reported"); + break; + + case 226: + resp->set_reason_phrase("IM Used"); + break; + + case 300: + resp->set_reason_phrase("Multiple Choices"); + break; + + case 301: + resp->set_reason_phrase("Moved Permanently"); + break; + + case 302: + resp->set_reason_phrase("Found"); + break; + + case 303: + resp->set_reason_phrase("See Other"); + break; + + case 304: + resp->set_reason_phrase("Not Modified"); + break; + + case 305: + resp->set_reason_phrase("Use Proxy"); + break; + + case 306: + resp->set_reason_phrase("Switch Proxy"); + break; + + case 307: + resp->set_reason_phrase("Temporary Redirect"); + break; + + case 308: + resp->set_reason_phrase("Permanent Redirect"); + break; + + case 400: + resp->set_reason_phrase("Bad Request"); + break; + + case 401: + resp->set_reason_phrase("Unauthorized"); + break; + + case 402: + resp->set_reason_phrase("Payment Required"); + break; + + case 403: + resp->set_reason_phrase("Forbidden"); + break; + + case 404: + resp->set_reason_phrase("Not Found"); + break; + + case 405: + resp->set_reason_phrase("Method Not Allowed"); + break; + + case 406: + resp->set_reason_phrase("Not Acceptable"); + break; + + case 407: + resp->set_reason_phrase("Proxy Authentication Required"); + break; + + case 408: + resp->set_reason_phrase("Request Timeout"); + break; + + case 409: + resp->set_reason_phrase("Conflict"); + break; + + case 410: + resp->set_reason_phrase("Gone"); + break; + + case 411: + resp->set_reason_phrase("Length Required"); + break; + + case 412: + resp->set_reason_phrase("Precondition Failed"); + break; + + case 413: + resp->set_reason_phrase("Request Entity Too Large"); + break; + + case 414: + resp->set_reason_phrase("Request-URI Too Long"); + break; + + case 415: + resp->set_reason_phrase("Unsupported Media Type"); + break; + + case 416: + resp->set_reason_phrase("Requested Range Not Satisfiable"); + break; + + case 417: + resp->set_reason_phrase("Expectation Failed"); + break; + + case 418: + resp->set_reason_phrase("I'm a teapot"); + break; + + case 420: + resp->set_reason_phrase("Enhance Your Caim"); + break; + + case 421: + resp->set_reason_phrase("Misdirected Request"); + break; + + case 422: + resp->set_reason_phrase("Unprocessable Entity"); + break; + + case 423: + resp->set_reason_phrase("Locked"); + break; + + case 424: + resp->set_reason_phrase("Failed Dependency"); + break; + + case 425: + resp->set_reason_phrase("Too Early"); + break; + + case 426: + resp->set_reason_phrase("Upgrade Required"); + break; + + case 428: + resp->set_reason_phrase("Precondition Required"); + break; + + case 429: + resp->set_reason_phrase("Too Many Requests"); + break; + + case 431: + resp->set_reason_phrase("Request Header Fields Too Large"); + break; + + case 444: + resp->set_reason_phrase("No Response"); + break; + + case 450: + resp->set_reason_phrase("Blocked by Windows Parental Controls"); + break; + + case 451: + resp->set_reason_phrase("Unavailable For Legal Reasons"); + break; + + case 494: + resp->set_reason_phrase("Request Header Too Large"); + break; + + case 500: + resp->set_reason_phrase("Internal Server Error"); + break; + + case 501: + resp->set_reason_phrase("Not Implemented"); + break; + + case 502: + resp->set_reason_phrase("Bad Gateway"); + break; + + case 503: + resp->set_reason_phrase("Service Unavailable"); + break; + + case 504: + resp->set_reason_phrase("Gateway Timeout"); + break; + + case 505: + resp->set_reason_phrase("HTTP Version Not Supported"); + break; + + case 506: + resp->set_reason_phrase("Variant Also Negotiates"); + break; + + case 507: + resp->set_reason_phrase("Insufficient Storage"); + break; + + case 508: + resp->set_reason_phrase("Loop Detected"); + break; + + case 510: + resp->set_reason_phrase("Not Extended"); + break; + + case 511: + resp->set_reason_phrase("Network Authentication Required"); + break; + + default: + resp->set_reason_phrase("Unknown"); + break; + } +} + +bool HttpHeaderCursor::next(std::string& name, std::string& value) +{ + struct HttpMessageHeader header; + + if (this->next(&header)) + { + name.assign((const char *)header.name, header.name_len); + value.assign((const char *)header.value, header.value_len); + return true; + } + + return false; +} + +bool HttpHeaderCursor::find(const std::string& name, std::string& value) +{ + struct HttpMessageHeader header = { + .name = name.c_str(), + .name_len = name.size(), + }; + + if (this->find(&header)) + { + value.assign((const char *)header.value, header.value_len); + return true; + } + + return false; +} + +bool HttpHeaderCursor::find_and_erase(const std::string& name) +{ + struct HttpMessageHeader header = { + .name = name.c_str(), + .name_len = name.size(), + }; + return this->find_and_erase(&header); +} + +HttpChunkCursor::HttpChunkCursor(const HttpMessage *msg) +{ + if (msg->get_parsed_body(&this->body, &this->body_len)) + { + this->pos = this->body; + this->chunked = msg->is_chunked(); + this->end = false; + } + else + { + this->body = NULL; + this->end = true; + } +} + +bool HttpChunkCursor::next(const void **chunk, size_t *size) +{ + if (this->end) + return false; + + if (!this->chunked) + { + *chunk = this->body; + *size = this->body_len; + this->end = true; + return true; + } + + const char *cur = (const char *)this->pos; + char *end; + + *size = strtol(cur, &end, 16); + if (*size == 0) + { + this->end = true; + return false; + } + + cur = strchr(end, '\r'); + *chunk = cur + 2; + cur += *size + 4; + this->pos = cur; + return true; +} + +void HttpChunkCursor::rewind() +{ + if (this->body) + { + this->pos = this->body; + this->end = false; + } +} + +} + diff --git a/src/protocol/HttpUtil.h b/src/protocol/HttpUtil.h new file mode 100644 index 0000000..6306a9a --- /dev/null +++ b/src/protocol/HttpUtil.h @@ -0,0 +1,232 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _HTTPUTIL_H_ +#define _HTTPUTIL_H_ + +#include +#include +#include +#include "http_parser.h" +#include "HttpMessage.h" + +/** + * @file HttpUtil.h + * @brief Http toolbox + */ + +#define HttpMethodGet "GET" +#define HttpMethodHead "HEAD" +#define HttpMethodPost "POST" +#define HttpMethodPut "PUT" +#define HttpMethodPatch "PATCH" +#define HttpMethodDelete "DELETE" +#define HttpMethodConnect "CONNECT" +#define HttpMethodOptions "OPTIONS" +#define HttpMethodTrace "TRACE" + +enum +{ + HttpStatusContinue = 100, // RFC 7231, 6.2.1 + HttpStatusSwitchingProtocols = 101, // RFC 7231, 6.2.2 + HttpStatusProcessing = 102, // RFC 2518, 10.1 + + HttpStatusOK = 200, // RFC 7231, 6.3.1 + HttpStatusCreated = 201, // RFC 7231, 6.3.2 + HttpStatusAccepted = 202, // RFC 7231, 6.3.3 + HttpStatusNonAuthoritativeInfo = 203, // RFC 7231, 6.3.4 + HttpStatusNoContent = 204, // RFC 7231, 6.3.5 + HttpStatusResetContent = 205, // RFC 7231, 6.3.6 + HttpStatusPartialContent = 206, // RFC 7233, 4.1 + HttpStatusMultiStatus = 207, // RFC 4918, 11.1 + HttpStatusAlreadyReported = 208, // RFC 5842, 7.1 + HttpStatusIMUsed = 226, // RFC 3229, 10.4.1 + + HttpStatusMultipleChoices = 300, // RFC 7231, 6.4.1 + HttpStatusMovedPermanently = 301, // RFC 7231, 6.4.2 + HttpStatusFound = 302, // RFC 7231, 6.4.3 + HttpStatusSeeOther = 303, // RFC 7231, 6.4.4 + HttpStatusNotModified = 304, // RFC 7232, 4.1 + HttpStatusUseProxy = 305, // RFC 7231, 6.4.5 + + HttpStatusTemporaryRedirect = 307, // RFC 7231, 6.4.7 + HttpStatusPermanentRedirect = 308, // RFC 7538, 3 + + HttpStatusBadRequest = 400, // RFC 7231, 6.5.1 + HttpStatusUnauthorized = 401, // RFC 7235, 3.1 + HttpStatusPaymentRequired = 402, // RFC 7231, 6.5.2 + HttpStatusForbidden = 403, // RFC 7231, 6.5.3 + HttpStatusNotFound = 404, // RFC 7231, 6.5.4 + HttpStatusMethodNotAllowed = 405, // RFC 7231, 6.5.5 + HttpStatusNotAcceptable = 406, // RFC 7231, 6.5.6 + HttpStatusProxyAuthRequired = 407, // RFC 7235, 3.2 + HttpStatusRequestTimeout = 408, // RFC 7231, 6.5.7 + HttpStatusConflict = 409, // RFC 7231, 6.5.8 + HttpStatusGone = 410, // RFC 7231, 6.5.9 + HttpStatusLengthRequired = 411, // RFC 7231, 6.5.10 + HttpStatusPreconditionFailed = 412, // RFC 7232, 4.2 + HttpStatusRequestEntityTooLarge = 413, // RFC 7231, 6.5.11 + HttpStatusRequestURITooLong = 414, // RFC 7231, 6.5.12 + HttpStatusUnsupportedMediaType = 415, // RFC 7231, 6.5.13 + HttpStatusRequestedRangeNotSatisfiable = 416, // RFC 7233, 4.4 + HttpStatusExpectationFailed = 417, // RFC 7231, 6.5.14 + HttpStatusTeapot = 418, // RFC 7168, 2.3.3 + HttpStatusEnhanceYourCaim = 420, // Twitter Search + HttpStatusMisdirectedRequest = 421, // RFC 7540, 9.1.2 + HttpStatusUnprocessableEntity = 422, // RFC 4918, 11.2 + HttpStatusLocked = 423, // RFC 4918, 11.3 + HttpStatusFailedDependency = 424, // RFC 4918, 11.4 + HttpStatusTooEarly = 425, // RFC 8470, 5.2. + HttpStatusUpgradeRequired = 426, // RFC 7231, 6.5.15 + HttpStatusPreconditionRequired = 428, // RFC 6585, 3 + HttpStatusTooManyRequests = 429, // RFC 6585, 4 + HttpStatusRequestHeaderFieldsTooLarge = 431, // RFC 6585, 5 + HttpStatusNoResponse = 444, // Nginx + HttpStatusBlocked = 450, // Windows + HttpStatusUnavailableForLegalReasons = 451, // RFC 7725, 3 + HttpStatusTooLargeForNginx = 494, // Nginx + + HttpStatusInternalServerError = 500, // RFC 7231, 6.6.1 + HttpStatusNotImplemented = 501, // RFC 7231, 6.6.2 + HttpStatusBadGateway = 502, // RFC 7231, 6.6.3 + HttpStatusServiceUnavailable = 503, // RFC 7231, 6.6.4 + HttpStatusGatewayTimeout = 504, // RFC 7231, 6.6.5 + HttpStatusHTTPVersionNotSupported = 505, // RFC 7231, 6.6.6 + HttpStatusVariantAlsoNegotiates = 506, // RFC 2295, 8.1 + HttpStatusInsufficientStorage = 507, // RFC 4918, 11.5 + HttpStatusLoopDetected = 508, // RFC 5842, 7.2 + HttpStatusNotExtended = 510, // RFC 2774, 7 + HttpStatusNetworkAuthenticationRequired = 511, // RFC 6585, 6 +}; + +namespace protocol +{ + +// static class +class HttpUtil +{ +public: + static void set_response_status(HttpResponse *resp, int status_code); + static std::string decode_chunked_body(const HttpMessage *msg); +}; + +class HttpHeaderMap +{ +public: + HttpHeaderMap(const HttpMessage *message); + + bool key_exists(std::string key); + std::string get(std::string key); + bool get(std::string key, std::string& value); + std::vector get_strict(std::string key); + bool get_strict(std::string key, std::vector& values); + +private: + std::unordered_map> header_map_; +}; + +class HttpHeaderCursor +{ +public: + HttpHeaderCursor(const HttpMessage *message); + virtual ~HttpHeaderCursor(); + +public: + bool next(struct HttpMessageHeader *header); + bool find(struct HttpMessageHeader *header); + bool erase(); + bool find_and_erase(struct HttpMessageHeader *header); + void rewind(); + + /* std::string interface */ +public: + bool next(std::string& name, std::string& value); + bool find(const std::string& name, std::string& value); + bool find_and_erase(const std::string& name); + +protected: + http_header_cursor_t cursor; +}; + +class HttpChunkCursor +{ +public: + HttpChunkCursor(const HttpMessage *message); + virtual ~HttpChunkCursor() { } + +public: + bool next(const void **chunk, size_t *size); + void rewind(); + +protected: + const void *body; + size_t body_len; + const void *pos; + bool chunked; + bool end; +}; + +//////////////////// + +inline HttpHeaderCursor::HttpHeaderCursor(const HttpMessage *message) +{ + http_header_cursor_init(&this->cursor, message->get_parser()); +} + +inline HttpHeaderCursor::~HttpHeaderCursor() +{ + http_header_cursor_deinit(&this->cursor); +} + +inline bool HttpHeaderCursor::next(struct HttpMessageHeader *header) +{ + return http_header_cursor_next(&header->name, &header->name_len, + &header->value, &header->value_len, + &this->cursor) == 0; +} + +inline bool HttpHeaderCursor::find(struct HttpMessageHeader *header) +{ + return http_header_cursor_find(header->name, header->name_len, + &header->value, &header->value_len, + &this->cursor) == 0; +} + +inline bool HttpHeaderCursor::erase() +{ + return http_header_cursor_erase(&this->cursor) == 0; +} + +inline bool HttpHeaderCursor::find_and_erase(struct HttpMessageHeader *header) +{ + if (this->find(header)) + return this->erase(); + + return false; +} + +inline void HttpHeaderCursor::rewind() +{ + http_header_cursor_rewind(&this->cursor); +} + +} + +#endif + diff --git a/src/protocol/KafkaDataTypes.cc b/src/protocol/KafkaDataTypes.cc new file mode 100644 index 0000000..870366b --- /dev/null +++ b/src/protocol/KafkaDataTypes.cc @@ -0,0 +1,667 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhulei (wangzhulei@sogou-inc.com) +*/ + +#include +#include +#include +#include "KafkaDataTypes.h" + +#define MIN(x, y) ((x) <= (y) ? (x) : (y)) + +namespace protocol +{ + +std::string KafkaConfig::get_sasl_info() const +{ + std::string info; + + if (strcasecmp(this->ptr->mechanisms, "plain") == 0) + { + info += this->ptr->mechanisms; + info += "|"; + info += this->ptr->username; + info += "|"; + info += this->ptr->password; + info += "|"; + } + else if (strncasecmp(this->ptr->mechanisms, "SCRAM", 5) == 0) + { + info += this->ptr->mechanisms; + info += "|"; + info += this->ptr->username; + info += "|"; + info += this->ptr->password; + info += "|"; + } + + return info; +} + +static bool compare_member(const kafka_member_t *m1, const kafka_member_t *m2) +{ + return strcmp(m1->member_id, m2->member_id) < 0; +} + +inline void KafkaMetaSubscriber::sort_by_member() +{ + std::sort(this->member_vec.begin(), this->member_vec.end(), compare_member); +} + +static bool operator<(const KafkaMetaSubscriber& s1, const KafkaMetaSubscriber& s2) +{ + return strcmp(s1.get_meta()->get_topic(), s2.get_meta()->get_topic()) < 0; +} + +/* + * For example, suppose there are two consumers C0 and C1, two topics t0 and t1, and each topic has 3 partitions, + * resulting in partitions t0p0, t0p1, t0p2, t1p0, t1p1, and t1p2. + * + * The assignment will be: + * C0: [t0p0, t0p1, t1p0, t1p1] + * C1: [t0p2, t1p2] + */ +int KafkaCgroup::kafka_range_assignor(kafka_member_t **members, + int member_elements, + void *meta_topic) +{ + std::vector *subscribers = + static_cast *>(meta_topic); + + /* The range assignor works on a per-topic basis. */ + for (auto& subscriber : *subscribers) + { + subscriber.sort_by_member(); + + int num_partitions_per_consumer = + subscriber.get_meta()->get_partition_elements() / + subscriber.get_member()->size(); + + /* If it does not evenly divide, then the first few consumers + * will have one extra partition. */ + int consumers_with_extra_partition = + subscriber.get_meta()->get_partition_elements() % + subscriber.get_member()->size(); + + for (int i = 0 ; i < (int)subscriber.get_member()->size(); i++) + { + int start = num_partitions_per_consumer * i + + MIN(i, consumers_with_extra_partition); + int length = num_partitions_per_consumer + + (i + 1 > consumers_with_extra_partition ? 0 : 1); + + if (length == 0) + continue; + + for (int j = start; j < length + start; ++j) + { + KafkaToppar *toppar = new KafkaToppar; + if (!toppar->set_topic_partition(subscriber.get_meta()->get_topic(), j)) + { + delete toppar; + return -1; + } + + list_add_tail(&toppar->list, &subscriber.get_member()->at(i)->assigned_toppar_list); + } + } + } + + return 0; +} + +/* + * For example, suppose there are two consumers C0 and C1, two topics t0 and + * t1, and each topic has 3 partitions, resulting in partitions t0p0, t0p1, + * t0p2, t1p0, t1p1, and t1p2. + * + * The assignment will be: + * C0: [t0p0, t0p2, t1p1] + * C1: [t0p1, t1p0, t1p2] + */ +int KafkaCgroup::kafka_roundrobin_assignor(kafka_member_t **members, + int member_elements, + void *meta_topic) +{ + std::vector *subscribers = + static_cast *>(meta_topic); + + int next = -1; + + std::sort(subscribers->begin(), subscribers->end()); + std::sort(members, members + member_elements, compare_member); + + for (const auto& subscriber : *subscribers) + { + int partition_elements = subscriber.get_meta()->get_partition_elements(); + + for (int partition = 0; partition < partition_elements; ++partition) + { + next = (next + 1) % subscriber.get_member()->size(); + struct list_head *pos; + KafkaToppar *toppar; + + int i = 0; + for (; i < member_elements; i++) + { + bool flag = false; + list_for_each(pos, &members[next + i]->toppar_list) + { + toppar = list_entry(pos, KafkaToppar, list); + if (strcmp(subscriber.get_meta()->get_topic(), toppar->get_topic()) == 0) + { + flag = true; + break; + } + } + + if (flag) + break; + } + + if (i >= member_elements) + return -1; + + toppar = new KafkaToppar; + if (!toppar->set_topic_partition(subscriber.get_meta()->get_topic(), + partition)) + { + delete toppar; + return -1; + } + + list_add_tail(toppar->get_list(), &members[next]->assigned_toppar_list); + } + } + + return 0; +} + +bool KafkaMeta::create_partitions(int partition_cnt) +{ + if (partition_cnt <= 0) + return true; + + kafka_partition_t **partitions; + partitions = (kafka_partition_t **)malloc(sizeof(void *) * partition_cnt); + if (!partitions) + return false; + + int i; + + for (i = 0; i < partition_cnt; ++i) + { + partitions[i] = (kafka_partition_t *)malloc(sizeof(kafka_partition_t)); + if (!partitions[i]) + break; + + kafka_partition_init(partitions[i]); + } + + if (i != partition_cnt) + { + while (--i >= 0) + { + kafka_partition_deinit(partitions[i]); + free(partitions[i]); + } + + free(partitions); + return false; + } + + for (i = 0; i < this->ptr->partition_elements; ++i) + { + kafka_partition_deinit(this->ptr->partitions[i]); + free(this->ptr->partitions[i]); + } + + free(this->ptr->partitions); + + this->ptr->partitions = partitions; + this->ptr->partition_elements = partition_cnt; + return true; +} + +void KafkaCgroup::add_subscriber(KafkaMetaList *meta_list, + std::vector *subscribers) +{ + meta_list->rewind(); + KafkaMeta *meta; + + while ((meta = meta_list->get_next()) != NULL) + { + KafkaMetaSubscriber subscriber; + + subscriber.set_meta(meta); + for (int i = 0; i < this->get_member_elements(); ++i) + { + struct list_head *pos; + KafkaToppar *toppar; + bool flag = false; + + list_for_each(pos, &this->get_members()[i]->toppar_list) + { + toppar = list_entry(pos, KafkaToppar, list); + if (strcmp(meta->get_topic(), toppar->get_topic()) == 0) + { + flag = true; + break; + } + } + + if (flag) + subscriber.add_member(this->get_members()[i]); + } + + if (!subscriber.get_member()->empty()) + subscribers->emplace_back(subscriber); + } +} + +int KafkaCgroup::run_assignor(KafkaMetaList *meta_list, + const char *protocol_name) +{ + std::vector subscribers; + this->add_subscriber(meta_list, &subscribers); + + struct list_head *pos; + kafka_group_protocol_t *protocol; + bool flag = false; + list_for_each(pos, this->get_group_protocol()) + { + protocol = list_entry(pos, kafka_group_protocol_t, list); + if (strcmp(protocol_name, protocol->protocol_name) == 0) + { + flag = true; + break; + } + } + + if (!flag) + { + errno = EBADMSG; + return -1; + } + + return protocol->assignor(this->get_members(), this->get_member_elements(), + &subscribers); +} + +KafkaCgroup::KafkaCgroup() +{ + this->ptr = new kafka_cgroup_t; + kafka_cgroup_init(this->ptr); + kafka_group_protocol_t *protocol = new kafka_group_protocol_t; + protocol->protocol_name = new char[strlen("range") + 1]; + memcpy(protocol->protocol_name, "range", strlen("range") + 1); + protocol->assignor = kafka_range_assignor; + list_add_tail(&protocol->list, &this->ptr->group_protocol_list); + protocol = new kafka_group_protocol_t; + protocol->protocol_name = new char[strlen("roundrobin") + 1]; + memcpy(protocol->protocol_name, "roundrobin", strlen("roundrobin") + 1); + protocol->assignor = kafka_roundrobin_assignor; + list_add_tail(&protocol->list, &this->ptr->group_protocol_list); + this->ref = new std::atomic(1); + this->coordinator = NULL; +} + +KafkaCgroup::~KafkaCgroup() +{ + if (--*this->ref == 0) + { + for (int i = 0; i < this->ptr->member_elements; ++i) + { + kafka_member_t *member = this->ptr->members[i]; + KafkaToppar *toppar; + struct list_head *pos, *tmp; + + list_for_each_safe(pos, tmp, &member->toppar_list) + { + toppar = list_entry(pos, KafkaToppar, list); + list_del(pos); + delete toppar; + } + + list_for_each_safe(pos, tmp, &member->assigned_toppar_list) + { + toppar = list_entry(pos, KafkaToppar, list); + list_del(pos); + delete toppar; + } + } + + kafka_cgroup_deinit(this->ptr); + + struct list_head *tmp, *pos; + KafkaToppar *toppar; + + list_for_each_safe(pos, tmp, &this->ptr->assigned_toppar_list) + { + toppar = list_entry(pos, KafkaToppar, list); + list_del(pos); + delete toppar; + } + + kafka_group_protocol_t *protocol; + list_for_each_safe(pos, tmp, &this->ptr->group_protocol_list) + { + protocol = list_entry(pos, kafka_group_protocol_t, list); + list_del(pos); + delete []protocol->protocol_name; + delete protocol; + } + + delete []this->ptr->group_name; + + delete this->ptr; + delete this->ref; + } + + delete this->coordinator; +} + +KafkaCgroup::KafkaCgroup(KafkaCgroup&& move) +{ + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new kafka_cgroup_t; + kafka_cgroup_init(move.ptr); + move.ref = new std::atomic(1); + this->coordinator = move.coordinator; + move.coordinator = NULL; +} + +KafkaCgroup& KafkaCgroup::operator= (KafkaCgroup&& move) +{ + if (this != &move) + { + this->~KafkaCgroup(); + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new kafka_cgroup_t; + kafka_cgroup_init(move.ptr); + move.ref = new std::atomic(1); + this->coordinator = move.coordinator; + move.coordinator = NULL; + } + + return *this; +} + +KafkaCgroup::KafkaCgroup(const KafkaCgroup& copy) +{ + this->ptr = copy.ptr; + this->ref = copy.ref; + ++*this->ref; + + if (copy.coordinator) + this->coordinator = new KafkaBroker(copy.coordinator->get_raw_ptr()); + else + this->coordinator = NULL; +} + +KafkaCgroup& KafkaCgroup::operator= (const KafkaCgroup& copy) +{ + this->~KafkaCgroup(); + this->ptr = copy.ptr; + this->ref = copy.ref; + ++*this->ref; + + if (copy.coordinator) + this->coordinator = new KafkaBroker(copy.coordinator->get_raw_ptr()); + else + this->coordinator = NULL; + + return *this; +} + +bool KafkaCgroup::create_members(int member_cnt) +{ + if (member_cnt == 0) + return true; + + kafka_member_t **members; + members = (kafka_member_t **)malloc(sizeof(void *) * member_cnt); + if (!members) + return false; + + int i; + + for (i = 0; i < member_cnt; ++i) + { + members[i] = (kafka_member_t *)malloc(sizeof(kafka_member_t)); + if (!members[i]) + break; + + kafka_member_init(members[i]); + INIT_LIST_HEAD(&members[i]->toppar_list); + INIT_LIST_HEAD(&members[i]->assigned_toppar_list); + } + + if (i != member_cnt) + { + while (--i >= 0) + { + KafkaToppar *toppar; + struct list_head *pos, *tmp; + list_for_each_safe(pos, tmp, &members[i]->toppar_list) + { + toppar = list_entry(pos, KafkaToppar, list); + list_del(pos); + delete toppar; + } + + list_for_each_safe(pos, tmp, &members[i]->assigned_toppar_list) + { + toppar = list_entry(pos, KafkaToppar, list); + list_del(pos); + delete toppar; + } + + kafka_member_deinit(members[i]); + free(members[i]); + } + + free(members); + return false; + } + + for (i = 0; i < this->ptr->member_elements; ++i) + { + KafkaToppar *toppar; + struct list_head *pos, *tmp; + + list_for_each_safe(pos, tmp, &this->ptr->members[i]->toppar_list) + { + toppar = list_entry(pos, KafkaToppar, list); + list_del(pos); + delete toppar; + } + + list_for_each_safe(pos, tmp, &this->ptr->members[i]->assigned_toppar_list) + { + toppar = list_entry(pos, KafkaToppar, list); + list_del(pos); + delete toppar; + } + + kafka_member_deinit(this->ptr->members[i]); + free(this->ptr->members[i]); + } + + free(this->ptr->members); + + this->ptr->members = members; + this->ptr->member_elements = member_cnt; + return true; +} + +void KafkaCgroup::add_assigned_toppar(KafkaToppar *toppar) +{ + list_add_tail(toppar->get_list(), &this->ptr->assigned_toppar_list); +} + +void KafkaCgroup::assigned_toppar_rewind() +{ + this->curpos = &this->ptr->assigned_toppar_list; +} + +KafkaToppar *KafkaCgroup::get_assigned_toppar_next() +{ + if (this->curpos->next == &this->ptr->assigned_toppar_list) + return NULL; + + this->curpos = this->curpos->next; + return list_entry(this->curpos, KafkaToppar, list); +} + +void KafkaCgroup::del_assigned_toppar_cur() +{ + assert(this->curpos != &this->ptr->assigned_toppar_list); + this->curpos = this->curpos->prev; + list_del(this->curpos->next); +} + +bool KafkaRecord::add_header_pair(const void *key, size_t key_len, + const void *val, size_t val_len) +{ + kafka_record_header_t *header; + + header = (kafka_record_header_t *)malloc(sizeof(kafka_record_header_t)); + if (!header) + return false; + + kafka_record_header_init(header); + if (kafka_record_header_set_kv(key, key_len, val, val_len, header) < 0) + { + free(header); + return false; + } + + list_add_tail(&header->list, &this->ptr->header_list); + return true; +} + +bool KafkaRecord::add_header_pair(const std::string& key, + const std::string& val) +{ + return add_header_pair(key.c_str(), key.size(), val.c_str(), val.size()); +} + +KafkaToppar::~KafkaToppar() +{ + if (--*this->ref == 0) + { + kafka_topic_partition_deinit(this->ptr); + + struct list_head *tmp, *pos; + KafkaRecord *record; + list_for_each_safe(pos, tmp, &this->ptr->record_list) + { + record = list_entry(pos, KafkaRecord, list); + list_del(pos); + delete record; + } + + delete this->ptr; + delete this->ref; + } +} + +void KafkaBuffer::list_splice(KafkaBuffer *buffer) +{ + struct list_head *pre_insert; + struct list_head *pre_tail; + + this->buf_size -= this->insert_buf_size; + + pre_insert = this->insert_pos->next; + __list_splice(buffer->get_head(), this->insert_pos, pre_insert); + + pre_tail = this->block_list.get_tail(); + buffer->get_head()->prev->next = this->block_list.get_head(); + this->block_list.get_head()->prev = buffer->get_head()->prev; + + buffer->get_head()->next = pre_insert; + buffer->get_head()->prev = pre_tail; + pre_tail->next = buffer->get_head(); + + pre_insert->prev = buffer->get_head(); + + this->buf_size += buffer->get_size(); +} + +size_t KafkaBuffer::peek(const char **buf) +{ + if (!this->inited) + { + this->inited = true; + this->cur_pos = std::make_pair(this->block_list.get_next(), 0); + } + + if (this->cur_pos.first == this->block_list.get_tail_entry() && + this->cur_pos.second == this->block_list.get_tail_entry()->get_len()) + { + *buf = NULL; + return 0; + } + + KafkaBlock *block = this->cur_pos.first; + + if (this->cur_pos.second >= block->get_len()) + { + block = this->block_list.get_next(); + this->cur_pos = std::make_pair(block, 0); + } + + *buf = (char *)block->get_block() + this->cur_pos.second; + + return block->get_len() - this->cur_pos.second; +} + +KafkaToppar *get_toppar(const char *topic, int partition, + KafkaTopparList *toppar_list) +{ + struct list_head *pos; + KafkaToppar *toppar; + list_for_each(pos, toppar_list->get_head()) + { + toppar = list_entry(pos, KafkaToppar, list); + if (strcmp(toppar->get_topic(), topic) == 0 && + toppar->get_partition() == partition) + return toppar; + } + + return NULL; +} + +const KafkaMeta *get_meta(const char *topic, KafkaMetaList *meta_list) +{ + struct list_head *pos; + const KafkaMeta *meta; + list_for_each(pos, meta_list->get_head()) + { + meta = list_entry(pos, KafkaMeta, list); + if (strcmp(meta->get_topic(), topic) == 0) + return meta; + } + + return NULL; +} + +} /* namespace protocol */ diff --git a/src/protocol/KafkaDataTypes.h b/src/protocol/KafkaDataTypes.h new file mode 100644 index 0000000..325b332 --- /dev/null +++ b/src/protocol/KafkaDataTypes.h @@ -0,0 +1,1650 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhulei (wangzhulei@sogou-inc.com) +*/ + +#ifndef _KAFKA_DATATYPES_H_ +#define _KAFKA_DATATYPES_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "list.h" +#include "rbtree.h" +#include "kafka_parser.h" + +namespace protocol +{ + +template +class KafkaList +{ +public: + KafkaList() + { + this->t_list = new struct list_head; + + INIT_LIST_HEAD(this->t_list); + this->ref = new std::atomic(1); + this->curpos = this->t_list; + } + + ~KafkaList() + { + if (--*this->ref == 0) + { + struct list_head *pos, *tmp; + T *t; + + list_for_each_safe(pos, tmp, this->t_list) + { + t = list_entry(pos, T, list); + list_del(pos); + delete t; + } + + delete this->t_list; + delete this->ref; + } + } + + KafkaList(KafkaList&& move) + { + this->t_list = move.t_list; + move.t_list = new struct list_head; + INIT_LIST_HEAD(move.t_list); + this->ref = move.ref; + this->curpos = this->t_list; + move.ref = new std::atomic(1); + } + + KafkaList& operator= (KafkaList&& move) + { + if (this != &move) + { + this->~KafkaList(); + this->t_list = move.t_list; + move.t_list = new struct list_head; + INIT_LIST_HEAD(move.t_list); + this->ref = move.ref; + this->curpos = this->t_list; + move.ref = new std::atomic(1); + } + + return *this; + } + + KafkaList(const KafkaList& copy) + { + this->ref = copy.ref; + ++*this->ref; + this->t_list = copy.t_list; + this->curpos = copy.curpos; + } + + KafkaList& operator= (const KafkaList& copy) + { + if (this != ©) + { + this->~KafkaList(); + this->ref = copy.ref; + ++*this->ref; + this->t_list = copy.t_list; + this->curpos = copy.curpos; + } + return *this; + } + + T *add_item(T&& move) + { + T *t = new T; + + *t = std::move(move); + list_add_tail(t->get_list(), this->t_list); + return t; + } + + void add_item(T& obj) + { + T *t = new T; + + *t = obj; + list_add_tail(t->get_list(), this->t_list); + } + + struct list_head *get_head() { return this->t_list; } + + struct list_head *get_tail() { return this->t_list->prev; } + + T *get_first_entry() + { + if (this->t_list == this->t_list->next) + return NULL; + + return list_entry(this->t_list->next, T, list); + } + + T *get_tail_entry() + { + if (this->t_list == this->get_tail()) + return NULL; + + return list_entry(this->get_tail(), T, list); + } + + T *get_entry(struct list_head *pos) + { + return list_entry(pos, T, list); + } + + void rewind() + { + this->curpos = this->t_list; + } + + T *get_next() + { + if (this->curpos->next == this->t_list) + return NULL; + + this->curpos = this->curpos->next; + return list_entry(this->curpos, T, list); + } + + void insert_pos(struct list_head *list, struct list_head *pos) + { + __list_add(list, pos, pos->next); + } + + void del_cur() + { + assert(this->curpos != this->t_list); + this->curpos = this->curpos->prev; + list_del(this->curpos->next); + } + +private: + struct list_head *t_list; + std::atomic *ref; + struct list_head *curpos; +}; + +template +class KafkaMap +{ +public: + KafkaMap() + { + this->t_map = new struct rb_root; + this->t_map->rb_node = NULL; + + this->ref = new std::atomic(1); + } + + ~KafkaMap() + { + if (--*this->ref == 0) + { + T *t; + while (this->t_map->rb_node) + { + t = rb_entry(this->t_map->rb_node, T, rb); + rb_erase(this->t_map->rb_node, this->t_map); + delete t; + } + + delete this->t_map; + delete this->ref; + } + } + + KafkaMap(const KafkaMap& copy) + { + this->ref = copy.ref; + ++*this->ref; + this->t_map = copy.t_map; + } + + KafkaMap& operator= (const KafkaMap& copy) + { + if (this != ©) + { + this->~KafkaMap(); + this->ref = copy.ref; + ++*this->ref; + this->t_map = copy.t_map; + } + return *this; + } + + T *find_item(const T& v) const + { + rb_node **p = &this->t_map->rb_node; + T *t; + + while (*p) + { + t = rb_entry(*p, T, rb); + + if (v < *t) + p = &(*p)->rb_left; + else if (v > *t) + p = &(*p)->rb_right; + else + break; + } + + return *p ? t : NULL; + } + + void add_item(T& obj) + { + rb_node **p = &this->t_map->rb_node; + rb_node *parent = NULL; + T *t; + + while (*p) + { + parent = *p; + t = rb_entry(*p, T, rb); + + if (obj < *t) + p = &(*p)->rb_left; + else if (obj > *t) + p = &(*p)->rb_right; + else + break; + } + + if (*p == NULL) + { + T *nt = new T; + + *nt = obj; + rb_link_node(nt->get_rb(), parent, p); + rb_insert_color(nt->get_rb(), this->t_map); + } + } + + T *find_item(int id) const + { + rb_node **p = &this->t_map->rb_node; + T *t; + + while (*p) + { + t = rb_entry(*p, T, rb); + + if (id < t->get_id()) + p = &(*p)->rb_left; + else if (id > t->get_id()) + p = &(*p)->rb_right; + else + break; + } + + return *p ? t : NULL; + } + + void add_item(T& obj, int id) + { + rb_node **p = &this->t_map->rb_node; + rb_node *parent = NULL; + T *t; + + while (*p) + { + parent = *p; + t = rb_entry(*p, T, rb); + + if (id < t->get_id()) + p = &(*p)->rb_left; + else if (id > t->get_id()) + p = &(*p)->rb_right; + else + break; + } + + if (*p == NULL) + { + T *nt = new T; + + *nt = obj; + rb_link_node(nt->get_rb(), parent, p); + rb_insert_color(nt->get_rb(), this->t_map); + } + } + + T *get_first_entry() + { + struct rb_node *p = rb_first(this->t_map); + return rb_entry(p, T, rb); + } + + T *get_tail_entry() + { + struct rb_node *p = rb_last(this->t_map); + return rb_entry(p, T, rb); + } + +private: + struct rb_root *t_map; + std::atomic *ref; +}; + +class KafkaConfig +{ +public: + void set_produce_timeout(int ms) { this->ptr->produce_timeout = ms; } + int get_produce_timeout() const { return this->ptr->produce_timeout; } + + void set_produce_msg_max_bytes(int bytes) + { + this->ptr->produce_msg_max_bytes = bytes; + } + int get_produce_msg_max_bytes() const + { + return this->ptr->produce_msg_max_bytes; + } + + void set_produce_msgset_cnt(int cnt) + { + this->ptr->produce_msgset_cnt = cnt; + } + int get_produce_msgset_cnt() const + { + return this->ptr->produce_msgset_cnt; + } + + void set_produce_msgset_max_bytes(int bytes) + { + this->ptr->produce_msgset_max_bytes = bytes; + } + int get_produce_msgset_max_bytes() const + { + return this->ptr->produce_msgset_max_bytes; + } + + void set_fetch_timeout(int ms) { this->ptr->fetch_timeout = ms; } + int get_fetch_timeout() const { return this->ptr->fetch_timeout; } + + void set_fetch_min_bytes(int bytes) { this->ptr->fetch_min_bytes = bytes; } + int get_fetch_min_bytes() const { return this->ptr->fetch_min_bytes; } + + void set_fetch_max_bytes(int bytes) { this->ptr->fetch_max_bytes = bytes; } + int get_fetch_max_bytes() const { return this->ptr->fetch_max_bytes; } + + void set_fetch_msg_max_bytes(int bytes) { this->ptr->fetch_msg_max_bytes = bytes; } + int get_fetch_msg_max_bytes() const { return this->ptr->fetch_msg_max_bytes; } + + void set_offset_timestamp(long long tm) + { + this->ptr->offset_timestamp = tm; + } + long long get_offset_timestamp() const + { + return this->ptr->offset_timestamp; + } + + void set_commit_timestamp(long long commit_timestamp) + { + this->ptr->commit_timestamp = commit_timestamp; + } + long long get_commit_timestamp() const { return this->ptr->commit_timestamp; } + + void set_session_timeout(int ms) { this->ptr->session_timeout = ms; } + int get_session_timeout() const { return this->ptr->session_timeout; } + + void set_rebalance_timeout(int ms) { this->ptr->rebalance_timeout = ms; } + int get_rebalance_timeout() const { return this->ptr->rebalance_timeout; } + + void set_retention_time_period(long long ms) + { + this->ptr->retention_time_period = ms; + } + long long get_retention_time_period() const + { + return this->ptr->retention_time_period; + } + + void set_produce_acks(int acks) { this->ptr->produce_acks = acks; } + int get_produce_acks() const { return this->ptr->produce_acks; } + + void set_allow_auto_topic_creation(bool allow_auto_topic_creation) + { + this->ptr->allow_auto_topic_creation = allow_auto_topic_creation; + } + bool get_allow_auto_topic_creation() const + { + return this->ptr->allow_auto_topic_creation; + } + + void set_api_version_request(int api_ver) + { + this->ptr->api_version_request = api_ver; + } + int get_api_version_request() const + { + return this->ptr->api_version_request; + } + + bool set_broker_version(const char *version) + { + char *p = strdup(version); + + if (!p) + return false; + + free(this->ptr->broker_version); + this->ptr->broker_version = p; + return true; + } + const char *get_broker_version() const + { + return this->ptr->broker_version; + } + + void set_compress_type(int type) { this->ptr->compress_type = type; } + int get_compress_type() const { return this->ptr->compress_type; } + + const char *get_client_id() { return this->ptr->client_id; } + bool set_client_id(const char *client_id) + { + char *p = strdup(client_id); + + if (!p) + return false; + + free(this->ptr->client_id); + this->ptr->client_id = p; + return true; + } + + bool get_check_crcs() const + { + return this->ptr->check_crcs != 0; + } + void set_check_crcs(bool check_crcs) + { + this->ptr->check_crcs = check_crcs; + } + + int get_offset_store() const + { + return this->ptr->offset_store; + } + void set_offset_store(int offset_store) + { + this->ptr->offset_store = offset_store; + } + + const char *get_rack_id() const + { + return this->ptr->rack_id; + } + bool set_rack_id(const char *rack_id) + { + char *p = strdup(rack_id); + if (!p) + return false; + + free(this->ptr->rack_id); + this->ptr->rack_id = p; + return true; + } + + const char *get_sasl_mech() const + { + return this->ptr->mechanisms; + } + bool set_sasl_mech(const char *mechanisms) + { + char *p = strdup(mechanisms); + + if (!p) + return false; + + free(this->ptr->mechanisms); + this->ptr->mechanisms = p; + if (kafka_sasl_set_mechanisms(this->ptr) != 0) + return false; + + return true; + } + + const char *get_sasl_username() const + { + return this->ptr->username; + } + bool set_sasl_username(const char *username) + { + return kafka_sasl_set_username(username, this->ptr) == 0; + } + + const char *get_sasl_password() const + { + return this->ptr->password; + } + bool set_sasl_password(const char *password) + { + return kafka_sasl_set_password(password, this->ptr) == 0; + } + + std::string get_sasl_info() const; + + bool new_client(kafka_sasl_t *sasl) + { + return this->ptr->client_new(this->ptr, sasl) == 0; + } + +public: + KafkaConfig() + { + this->ptr = new kafka_config_t; + kafka_config_init(this->ptr); + this->ref = new std::atomic(1); + } + + virtual ~KafkaConfig() + { + if (--*this->ref == 0) + { + kafka_config_deinit(this->ptr); + delete this->ptr; + delete this->ref; + } + } + + KafkaConfig(KafkaConfig&& move) + { + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new kafka_config_t; + kafka_config_init(move.ptr); + move.ref = new std::atomic(1); + } + + KafkaConfig& operator= (KafkaConfig&& move) + { + if (this != &move) + { + this->~KafkaConfig(); + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new kafka_config_t; + kafka_config_init(move.ptr); + move.ref = new std::atomic(1); + } + + return *this; + } + + KafkaConfig(const KafkaConfig& copy) + { + this->ptr = copy.ptr; + this->ref = copy.ref; + ++*this->ref; + } + + KafkaConfig& operator= (const KafkaConfig& copy) + { + if (this != ©) + { + this->~KafkaConfig(); + this->ptr = copy.ptr; + this->ref = copy.ref; + ++*this->ref; + } + + return *this; + } + + kafka_config_t *get_raw_ptr() { return this->ptr; } + +private: + kafka_config_t *ptr; + std::atomic *ref; +}; + +class KafkaRecord +{ +public: + bool set_key(const void *key, size_t key_len) + { + return kafka_record_set_key(key, key_len, this->ptr) == 0; + } + void get_key(const void **key, size_t *key_len) const + { + *key = this->ptr->key; + *key_len = this->ptr->key_len; + } + size_t get_key_len() const { return this->ptr->key_len; } + + bool set_value(const void *value, size_t value_len) + { + return kafka_record_set_value(value, value_len, this->ptr) == 0; + } + void get_value(const void **value, size_t *value_len) const + { + *value = this->ptr->value; + *value_len = this->ptr->value_len; + } + size_t get_value_len() const { return this->ptr->value_len; } + + bool add_header_pair(const void *key, size_t key_len, + const void *val, size_t val_len); + + bool add_header_pair(const std::string& key, const std::string& val); + + struct list_head *get_list() { return &this->list; } + + const char *get_topic() const { return this->ptr->toppar->topic_name; } + + void set_status(short err) { this->ptr->status = err; } + short get_status() const { return this->ptr->status; } + + int get_partition() const { return this->ptr->toppar->partition; } + + long long get_offset() const { return this->ptr->offset; } + void set_offset(long long offset) { this->ptr->offset = offset; } + + long long get_timestamp() const { return this->ptr->timestamp; } + void set_timestamp(long long timestamp) { this->ptr->timestamp = timestamp; } + +public: + KafkaRecord() + { + this->ptr = new kafka_record_t; + kafka_record_init(this->ptr); + this->ref = new std::atomic(1); + } + + ~KafkaRecord() + { + if (--*this->ref == 0) + { + kafka_record_deinit(this->ptr); + delete this->ptr; + delete this->ref; + } + } + + KafkaRecord(KafkaRecord&& move) + { + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new kafka_record_t; + kafka_record_init(move.ptr); + move.ref = new std::atomic(1); + } + + KafkaRecord& operator= (KafkaRecord&& move) + { + if (this != &move) + { + this->~KafkaRecord(); + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new kafka_record_t; + kafka_record_init(move.ptr); + move.ref = new std::atomic(1); + } + + return *this; + } + + KafkaRecord(const KafkaRecord& copy) + { + this->ptr = copy.ptr; + this->ref = copy.ref; + ++*this->ref; + } + + KafkaRecord& operator= (const KafkaRecord& copy) + { + if (this != ©) + { + this->~KafkaRecord(); + this->ptr = copy.ptr; + this->ref = copy.ref; + ++*this->ref; + } + + return *this; + } + + kafka_record_t *get_raw_ptr() const { return this->ptr; } + + struct list_head *get_header_list() const { return &this->ptr->header_list; } + +private: + struct list_head list; + kafka_record_t *ptr; + std::atomic *ref; + + friend class KafkaMessage; + friend class KafkaResponse; + friend class KafkaToppar; +}; + + +class KafkaMeta; +class KafkaBroker; +class KafkaToppar; + +using KafkaMetaList = KafkaList; +using KafkaBrokerList = KafkaList; +using KafkaBrokerMap = KafkaMap; +using KafkaTopparList = KafkaList; +using KafkaRecordList = KafkaList; + +extern KafkaToppar *get_toppar(const char *topic, int partition, + KafkaTopparList *toppar_list); + +extern const KafkaMeta *get_meta(const char *topic, KafkaMetaList *meta_list); + +class KafkaToppar +{ +public: + bool set_topic_partition(const std::string& topic, int partition) + { + return kafka_topic_partition_set_tp(topic.c_str(), partition, + this->ptr) == 0; + } + + int get_preferred_read_replica() const + { + return this->ptr->preferred_read_replica; + } + + bool set_topic(const char *topic) + { + this->ptr->topic_name = strdup(topic); + return this->ptr->topic_name != NULL; + } + + const char *get_topic() const { return this->ptr->topic_name; } + + int get_partition() const { return this->ptr->partition; } + + long long get_offset() const { return this->ptr->offset; } + void set_offset(long long offset) { this->ptr->offset = offset; } + + long long get_offset_timestamp() const { return this->ptr->offset_timestamp; } + void set_offset_timestamp(long long tm) { this->ptr->offset_timestamp = tm; } + + long long get_high_watermark() const { return this->ptr->high_watermark; } + void set_high_watermark(long long offset) const { this->ptr->high_watermark = offset; } + + long long get_low_watermark() const { return this->ptr->low_watermark; } + void set_low_watermark(long long offset) { this->ptr->low_watermark = offset; } + + void clear_records() + { + INIT_LIST_HEAD(&this->ptr->record_list); + this->curpos = &this->ptr->record_list; + this->startpos = this->endpos = this->curpos; + } + +public: + KafkaToppar() + { + this->ptr = new kafka_topic_partition_t; + kafka_topic_partition_init(this->ptr); + this->ref = new std::atomic(1); + this->curpos = &this->ptr->record_list; + this->startpos = this->endpos = this->curpos; + } + + ~KafkaToppar(); + + KafkaToppar(KafkaToppar&& move) + { + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new kafka_topic_partition_t; + kafka_topic_partition_init(move.ptr); + move.ref = new std::atomic(1); + this->curpos = &this->ptr->record_list; + this->startpos = this->endpos = this->curpos; + } + + KafkaToppar& operator= (KafkaToppar&& move) + { + if (this != &move) + { + this->~KafkaToppar(); + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new kafka_topic_partition_t; + kafka_topic_partition_init(move.ptr); + move.ref = new std::atomic(1); + this->curpos = &this->ptr->record_list; + this->startpos = this->endpos = this->curpos; + } + + return *this; + } + + KafkaToppar(const KafkaToppar& copy) + { + this->ptr = copy.ptr; + this->ref = copy.ref; + ++*this->ref; + this->curpos = copy.curpos; + this->startpos = copy.startpos; + this->endpos = copy.endpos; + } + + KafkaToppar& operator= (const KafkaToppar& copy) + { + if (this != ©) + { + this->~KafkaToppar(); + this->ptr = copy.ptr; + this->ref = copy.ref; + ++*this->ref; + this->curpos = copy.curpos; + this->startpos = copy.startpos; + this->endpos = copy.endpos; + } + + return *this; + } + + kafka_topic_partition_t *get_raw_ptr() { return this->ptr; } + + struct list_head *get_list() { return &this->list; } + + struct list_head *get_record() { return &this->ptr->record_list; } + + void set_error(short error) { this->ptr->error = error; } + int get_error() const { return this->ptr->error; } + + void add_record(KafkaRecord&& record) + { + KafkaRecord *tmp = new KafkaRecord; + + *tmp =std::move(record); + list_add_tail(tmp->get_list(), &this->ptr->record_list); + } + + void record_rewind() + { + this->curpos = &this->ptr->record_list; + } + + KafkaRecord *get_record_next() + { + if (this->curpos->next == &this->ptr->record_list) + return NULL; + + this->curpos = this->curpos->next; + return list_entry(this->curpos, KafkaRecord, list); + } + + void del_record_cur() + { + assert(this->curpos != &this->ptr->record_list); + this->curpos = this->curpos->prev; + list_del(this->curpos->next); + } + + struct list_head *get_record_startpos() + { + return this->startpos; + } + + struct list_head *get_record_endpos() + { + return this->endpos; + } + + void restore_record_curpos() + { + this->curpos = this->startpos; + this->endpos = NULL; + } + + void save_record_startpos() + { + this->startpos = this->curpos; + } + + void save_record_endpos() + { + this->endpos = this->curpos->next; + } + + bool record_reach_end() + { + return this->endpos == &this->ptr->record_list; + } + + void record_rollback() + { + this->curpos = this->curpos->prev; + } + + KafkaRecord *get_tail_record() + { + if (&this->ptr->record_list != this->ptr->record_list.prev) + { + return (KafkaRecord *)list_entry(this->ptr->record_list.prev, + KafkaRecord, list); + } + else + { + return NULL; + } + } + +private: + struct list_head list; + kafka_topic_partition_t *ptr; + std::atomic *ref; + struct list_head *curpos; + struct list_head *startpos; + struct list_head *endpos; + + friend class KafkaMessage; + friend class KafkaRequest; + friend class KafkaResponse; + friend class KafkaList; + friend class KafkaCgroup; + + friend KafkaToppar *get_toppar(const char *topic, int partition, + KafkaTopparList *toppar_list); +}; + +class KafkaBroker +{ +public: + const char *get_host() const + { + return this->ptr->host; + } + + int get_port() const + { + return this->ptr->port; + } + + std::string get_host_port() const + { + std::string host_port(this->ptr->host); + host_port += ":"; + host_port += std::to_string(this->ptr->port); + return host_port; + } + + int get_error() + { + return this->ptr->error; + } + +public: + KafkaBroker() + { + this->ptr = new kafka_broker_t; + kafka_broker_init(this->ptr); + this->ref = new std::atomic(1); + } + + ~KafkaBroker() + { + if (this->ref && --*this->ref == 0) + { + kafka_broker_deinit(this->ptr); + delete this->ptr; + delete this->ref; + } + } + + KafkaBroker(kafka_broker_t *ptr) + { + this->ptr = ptr; + this->ref = NULL; + } + + KafkaBroker(KafkaBroker&& move) + { + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new kafka_broker_t; + kafka_broker_init(move.ptr); + move.ref = new std::atomic(1); + } + + KafkaBroker& operator= (KafkaBroker&& move) + { + if (this != &move) + { + this->~KafkaBroker(); + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new kafka_broker_t; + kafka_broker_init(move.ptr); + move.ref = new std::atomic(1); + } + + return *this; + } + + KafkaBroker(const KafkaBroker& copy) + { + this->ptr = copy.ptr; + this->ref = copy.ref; + if (this->ref) + ++*this->ref; + } + + KafkaBroker& operator= (const KafkaBroker& copy) + { + if (this != ©) + { + this->~KafkaBroker(); + this->ptr = copy.ptr; + this->ref = copy.ref; + if (this->ref) + ++*this->ref; + } + + return *this; + } + + bool operator< (const KafkaBroker& broker) const + { + return this->get_host_port() < broker.get_host_port(); + } + + bool operator> (const KafkaBroker& broker) const + { + return this->get_host_port() > broker.get_host_port(); + } + + kafka_broker_t *get_raw_ptr() const { return this->ptr; } + + struct list_head *get_list() { return &this->list; } + + struct rb_node *get_rb() { return &this->rb; } + + int get_node_id() const { return this->ptr->node_id; } + + int get_id () const { return this->ptr->node_id; } + +private: + struct list_head list; + struct rb_node rb; + kafka_broker_t *ptr; + std::atomic *ref; + + friend class KafkaList; + friend class KafkaMap; +}; + +class KafkaMeta +{ +public: + const char *get_topic() const { return this->ptr->topic_name; } + + const kafka_broker_t *get_broker(int partition) const + { + if (partition >= this->ptr->partition_elements) + return NULL; + + for (int i = 0; i < this->ptr->partition_elements; ++i) + { + if (partition == this->ptr->partitions[i]->partition_index) + return &this->ptr->partitions[i]->leader; + } + + return NULL; + } + + kafka_partition_t **get_partitions() const { return this->ptr->partitions; } + + int get_partition_elements() const { return this->ptr->partition_elements; } + +public: + KafkaMeta() + { + this->ptr = new kafka_meta_t; + kafka_meta_init(this->ptr); + this->ref = new std::atomic(1); + } + + ~KafkaMeta() + { + if (--*this->ref == 0) + { + kafka_meta_deinit(this->ptr); + delete this->ptr; + delete this->ref; + } + } + + KafkaMeta(KafkaMeta&& move) + { + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new kafka_meta_t; + kafka_meta_init(move.ptr); + move.ref = new std::atomic(1); + } + + KafkaMeta& operator= (KafkaMeta&& move) + { + if (this != &move) + { + this->~KafkaMeta(); + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new kafka_meta_t; + kafka_meta_init(move.ptr); + move.ref = new std::atomic(1); + } + + return *this; + } + + KafkaMeta(const KafkaMeta& copy) + { + this->ptr = copy.ptr; + this->ref = copy.ref; + ++*this->ref; + } + + KafkaMeta& operator= (const KafkaMeta& copy) + { + if (this != ©) + { + this->~KafkaMeta(); + this->ptr = copy.ptr; + this->ref = copy.ref; + ++*this->ref; + } + + return *this; + } + + kafka_meta_t *get_raw_ptr() { return this->ptr; } + + bool set_topic(const std::string& topic) + { + return kafka_meta_set_topic(topic.c_str(), this->ptr) == 0; + } + + struct list_head *get_list() { return &this->list; } + + int get_error() const { return this->ptr->error; } + + bool create_partitions(int partition_cnt); + + bool create_replica_nodes(int partition_idx, int replica_cnt) + { + int *replica_nodes = (int *)malloc(replica_cnt * 4); + + if (!replica_nodes) + return false; + + this->ptr->partitions[partition_idx]->replica_nodes = replica_nodes; + this->ptr->partitions[partition_idx]->replica_node_elements = replica_cnt; + return true; + } + + bool create_isr_nodes(int partition_idx, int isr_cnt) + { + int *isr_nodes = (int *)malloc(isr_cnt * 4); + + if (!isr_nodes) + return false; + + this->ptr->partitions[partition_idx]->isr_nodes = isr_nodes; + this->ptr->partitions[partition_idx]->isr_node_elements = isr_cnt; + return true; + } + +private: + struct list_head list; + kafka_meta_t *ptr; + std::atomic *ref; + + friend class KafkaList; + + friend const KafkaMeta *get_meta(const char *topic, KafkaMetaList *meta_list); +}; + +class KafkaMetaSubscriber +{ +public: + void set_meta(KafkaMeta *meta) + { + this->meta = meta; + } + + const KafkaMeta *get_meta() const + { + return this->meta; + } + + void add_member(kafka_member_t *member) + { + this->member_vec.push_back(member); + } + + const std::vector *get_member() const + { + return &this->member_vec; + } + + void sort_by_member(); + +private: + KafkaMeta *meta; + std::vector member_vec; +}; + +class KafkaCgroup +{ +public: + const char *get_group() const { return this->ptr->group_name; } + + const char *get_protocol_type() const { return this->ptr->protocol_type; } + + const char *get_protocol_name() const { return this->ptr->protocol_name; } + + int get_generation_id() const { return this->ptr->generation_id; } + + const char *get_member_id() const { return this->ptr->member_id; } + +public: + KafkaCgroup(); + + ~KafkaCgroup(); + + KafkaCgroup(KafkaCgroup&& move); + + KafkaCgroup& operator= (KafkaCgroup&& move); + + KafkaCgroup(const KafkaCgroup& copy); + + KafkaCgroup& operator= (const KafkaCgroup& copy); + + kafka_cgroup_t *get_raw_ptr() { return this->ptr; } + + void set_group(const std::string& group) + { + char *p = new char[group.size() + 1]; + strncpy(p, group.c_str(), group.size()); + p[group.size()] = 0; + this->ptr->group_name = p; + } + + struct list_head *get_list() { return &this->list; } + + int get_error() const { return this->ptr->error; } + + bool is_leader() const + { + return strcmp(this->ptr->leader_id, this->ptr->member_id) == 0; + } + + struct list_head *get_group_protocol() + { + return &this->ptr->group_protocol_list; + } + + void set_member_id(const char *p) + { + free(this->ptr->member_id); + this->ptr->member_id = strdup(p); + } + + void set_error(short error) { this->ptr->error = error; } + + bool create_members(int member_cnt); + + kafka_member_t **get_members() const { return this->ptr->members; } + + int get_member_elements() { return this->ptr->member_elements; } + + void add_assigned_toppar(KafkaToppar *toppar); + + struct list_head *get_assigned_toppar_list() + { + return &this->ptr->assigned_toppar_list; + } + + KafkaToppar *get_assigned_toppar_by_pos(struct list_head *pos) + { + return list_entry(pos, KafkaToppar, list); + } + + void assigned_toppar_rewind(); + + KafkaToppar *get_assigned_toppar_next(); + + void del_assigned_toppar_cur(); + + KafkaBroker *get_coordinator() + { + if (!this->coordinator) + this->coordinator = new KafkaBroker(&this->ptr->coordinator); + + return this->coordinator; + } + + int run_assignor(KafkaMetaList *meta_list, const char *protocol_name); + + void add_subscriber(KafkaMetaList *meta_list, + std::vector *subscribers); + + static int kafka_range_assignor(kafka_member_t **members, + int member_elements, + void *meta_topic); + + static int kafka_roundrobin_assignor(kafka_member_t **members, + int member_elements, + void *meta_topic); + +private: + struct list_head list; + kafka_cgroup_t *ptr; + std::atomic *ref; + struct list_head *curpos; + KafkaBroker *coordinator; +}; + +class KafkaBlock +{ +public: + KafkaBlock() + { + this->ptr = new kafka_block_t; + kafka_block_init(this->ptr); + } + + ~KafkaBlock() + { + kafka_block_deinit(this->ptr); + delete this->ptr; + } + + KafkaBlock(KafkaBlock&& move) + { + this->ptr = move.ptr; + move.ptr = new kafka_block_t; + kafka_block_init(move.ptr); + } + + KafkaBlock& operator= (KafkaBlock&& move) + { + if (this != &move) + { + this->~KafkaBlock(); + this->ptr = move.ptr; + move.ptr = new kafka_block_t; + kafka_block_init(move.ptr); + } + + return *this; + } + + kafka_block_t *get_raw_ptr() const { return this->ptr; } + + struct list_head *get_list() { return &this->list; } + + void *get_block() const { return this->ptr->buf; } + + size_t get_len() const { return this->ptr->len; } + + bool allocate(size_t len) + { + void *p = malloc(len); + + if (!p) + return false; + + free(this->ptr->buf); + this->ptr->buf = p; + this->ptr->len = len; + return true; + } + + bool reallocate(size_t len) + { + void *p = realloc(this->ptr->buf, len); + + if (p) + { + this->ptr->buf = p; + this->ptr->len = len; + return true; + } + else + return false; + } + + bool set_block(void *buf, size_t len) + { + if (!this->allocate(len)) + return false; + + memcpy(this->ptr->buf, buf, len); + return true; + } + + void set_block_nocopy(void *buf, size_t len) + { + this->ptr->buf = buf; + this->ptr->len = len; + } + + void set_len(size_t len) { this->ptr->len = len; } + +private: + struct list_head list; + kafka_block_t *ptr; + + friend class KafkaBuffer; + friend class KafkaList; +}; + +class KafkaBuffer +{ +public: + KafkaBuffer() + { + this->insert_pos = NULL; + this->insert_curpos = NULL; + this->buf_size = 0; + this->inited = false; + this->insert_buf_size = 0; + this->insert_flag = false; + } + + void backup(size_t n) + { + this->buf_size -= n; + } + + void list_splice(KafkaBuffer *buffer); + + void add_item(KafkaBlock block) + { + if (this->insert_flag) + this->insert_buf_size += block.get_len(); + + this->buf_size += block.get_len(); + this->block_list.add_item(std::move(block)); + } + + void set_insert_pos() + { + this->insert_pos = this->block_list.get_tail(); + this->insert_flag = true; + this->insert_buf_size = 0; + } + + void block_insert_rewind() + { + this->insert_flag = false; + this->insert_curpos = this->insert_pos; + } + + KafkaBlock *get_block_insert_next() + { + if (this->insert_curpos->next == this->block_list.get_head()) + return NULL; + + this->insert_curpos = this->insert_curpos->next; + return list_entry(this->insert_curpos, KafkaBlock, list); + } + + KafkaBlock *get_block_tail() + { + return this->block_list.get_tail_entry(); + } + + void insert_list(KafkaBlock *block) + { + this->buf_size += block->get_len(); + this->block_list.insert_pos(block->get_list(), this->insert_pos); + this->insert_pos = this->insert_pos->next; + } + + KafkaBlock *get_block_first() + { + this->block_list.rewind(); + return this->block_list.get_next(); + } + + KafkaBlock *get_block_next() + { + return this->block_list.get_next(); + } + + void append(const char *bytes, size_t n) + { + KafkaBlock block; + + block.set_block((void *)bytes, n); + this->block_list.add_item(std::move(block)); + this->buf_size += n; + } + + size_t get_size() const + { + return this->buf_size; + } + + size_t peek(const char **buf); + + long seek(long offset) + { + this->cur_pos.second += offset; + return offset; + } + + struct list_head *get_head() + { + return this->block_list.get_head(); + } + +private: + KafkaList block_list; + std::pair cur_pos; + struct list_head *insert_pos; + struct list_head *insert_curpos; + size_t buf_size; + size_t insert_buf_size; + bool inited; + bool insert_flag; +}; + +class KafkaSnappySink : public snappy::Sink +{ +public: + KafkaSnappySink(KafkaBuffer *buffer) + { + this->buffer = buffer; + } + + virtual void Append(const char *bytes, size_t n) + { + this->buffer->append(bytes, n); + } + + size_t size() const + { + return this->buffer->get_size(); + } + + KafkaBuffer *get_buffer() const + { + return buffer; + } + +private: + KafkaBuffer *buffer; +}; + +class KafkaSnappySource : public snappy::Source +{ +public: + KafkaSnappySource(KafkaBuffer *buffer) + { + this->buffer = buffer; + this->buf_size = this->buffer->get_size(); + this->pos = 0; + } + + virtual size_t Available() const + { + return this->buf_size - this->pos; + } + + virtual const char *Peek(size_t *len) + { + const char *pos; + + *len = this->buffer->peek(&pos); + return pos; + } + + virtual void Skip(size_t n) + { + this->pos += this->buffer->seek(n); + } + +private: + KafkaBuffer *buffer; + size_t buf_size; + size_t pos; +}; + +} + +#endif diff --git a/src/protocol/KafkaMessage.cc b/src/protocol/KafkaMessage.cc new file mode 100644 index 0000000..af06e7e --- /dev/null +++ b/src/protocol/KafkaMessage.cc @@ -0,0 +1,3788 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhulei (wangzhulei@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "crc32c.h" +#include "EncodeStream.h" +#include "KafkaMessage.h" + +namespace protocol +{ + +#define CHECK_RET(exp) \ +do { \ + int tmp = exp; \ + if (tmp < 0) \ + return tmp; \ +} while (0) + +#ifndef htonll +static uint64_t htonll(uint64_t x) +{ + if (1 == htonl(1)) + return x; + else + return ((uint64_t)htonl(x & 0xFFFFFFFF) << 32) + htonl(x >> 32); +} +#endif + +static size_t append_bool(std::string& buf, bool val) +{ + unsigned char v = 0; + + if (val) + v = 1; + + buf.append((char *)&v, 1); + return 1; +} + +static size_t append_i8(std::string& buf, int8_t val) +{ + buf.append((char *)&val, 1); + return 1; +} + +static size_t append_i8(void **buf, int8_t val) +{ + *(char *)*buf = val; + *buf = (char *)*buf + 1; + return 1; +} + +static size_t append_i16(std::string& buf, int16_t val) +{ + int16_t v = htons(val); + + buf.append((char *)&v, 2); + return 2; +} + +static size_t append_i32(std::string& buf, int32_t val) +{ + int32_t v = htonl(val); + + buf.append((char *)&v, 4); + return 4; +} + +static size_t append_i32(void **buf, int32_t val) +{ + int32_t v = htonl(val); + + *(int32_t *)*buf = v; + *buf = (int32_t *)*buf + 1; + return 4; +} + +static size_t append_i64(std::string& buf, int64_t val) +{ + int64_t v = htonll(val); + + buf.append((char *)&v, 8); + return 8; +} + +static size_t append_i64(void **buf, int64_t val) +{ + int64_t v = htonll(val); + + *(int64_t *)*buf = v; + *buf = (int64_t *)*buf + 1; + return 8; +} + +static size_t append_string(std::string& buf, const char *str, size_t len) +{ + append_i16(buf, len); + buf.append(str, len); + return len + 2; +} + +static size_t append_string(std::string& buf, const char *str) +{ + if (!str) + return append_string(buf, "", 0); + + return append_string(buf, str, strlen(str)); +} + +static size_t append_string_raw(std::string& buf, const char *str, size_t len) +{ + buf.append(str, len); + return len; +} + +static size_t append_nullable_string(std::string& buf, const char *str, size_t len) +{ + if (len == 0) + return append_i16(buf, -1); + else + return append_string(buf, str, len); +} + +static size_t append_string_raw(void **buf, const char *str, size_t len) +{ + memcpy(*buf, str, len); + *buf = (char *)*buf + len; + return len; +} + +static size_t append_string_raw(void **buf, const std::string& str) +{ + return append_string_raw(buf, str.c_str(), str.size()); +} + +static size_t append_bytes(std::string& buf, const char *str, size_t len) +{ + append_i32(buf, len); + buf.append(str, len); + return 4 + len; +} + +static size_t append_bytes(std::string& buf, const std::string& str) +{ + return append_bytes(buf, str.c_str(), str.size()); +} + +static size_t append_bytes(void **buf, const char *str, size_t len) +{ + *((int32_t *)*buf) = htonl(len); + *buf = (int32_t *)*buf + 1; + + memcpy(*buf, str, len); + *buf = (char *)*buf + len; + + return len + 2; +} + +static size_t append_nullable_bytes(void **buf, const char *str, size_t len) +{ + if (len == 0) + return append_i32(buf, -1); + else + return append_bytes(buf, str, len); +} + +static size_t append_varint_u64(std::string& buf, uint64_t num) +{ + size_t len = 0; + + do + { + unsigned char v = (num & 0x7f) | (num > 0x7f ? 0x80 : 0); + buf.append((char *)&v, 1); + num >>= 7; + ++len; + } while (num); + return len; +} + +static inline size_t append_varint_i64(std::string& buf, int64_t num) +{ + return append_varint_u64(buf, (num << 1) ^ (num >> 63)); +} + +static inline size_t append_varint_i32(std::string& buf, int32_t num) +{ + return append_varint_i64(buf, num); +} + +static size_t append_compact_string(std::string& buf, const char *str) +{ + if (!str || str[0] == '\0') + append_string(buf, ""); + + size_t len = strlen(str); + size_t r = append_varint_u64(buf, len + 1); + append_string_raw(buf, str, len); + return r + len; +} + +static inline int parse_i8(void **buf, size_t *size, int8_t *val) +{ + if (*size >= 1) + { + *val = *(int8_t *)*buf; + *size -= sizeof(int8_t); + *buf = (int8_t *)*buf + 1; + return 0; + } + + errno = EBADMSG; + return -1; +} + +static inline int parse_i16(void **buf, size_t *size, int16_t *val) +{ + if (*size >= 2) + { + *val = ntohs(*(int16_t *)*buf); + *size -= sizeof(int16_t); + *buf = (int16_t *)*buf + 1; + return 0; + } + + errno = EBADMSG; + return -1; +} + +static inline int parse_i32(void **buf, size_t *size, int32_t *val) +{ + if (*size >= 4) + { + *val = ntohl(*(int32_t *)*buf); + *size -= sizeof(int32_t); + *buf = (int32_t *)*buf + 1; + return 0; + } + + errno = EBADMSG; + return -1; +} + +static inline int parse_i64(void **buf, size_t *size, int64_t *val) +{ + if (*size >= 8) + { + *val = htonll(*(int64_t *)*buf); + *size -= sizeof(int64_t); + *buf = (int64_t *)*buf + 1; + return 0; + } + + errno = EBADMSG; + return -1; +} + +static int parse_string(void **buf, size_t *size, std::string& str); + +static int parse_string(void **buf, size_t *size, char **str); + +static int parse_bytes(void **buf, size_t *size, std::string& str); + +static int parse_bytes(void **buf, size_t *size, + void **str, size_t *str_len); + +static int parse_varint_u64(void **buf, size_t *size, uint64_t *val); + +static int parse_varint_i64(void **buf, size_t *size, int64_t *val) +{ + uint64_t n; + int ret = parse_varint_u64(buf, size, &n); + + if (ret == 0) + *val = (int64_t)(n >> 1) ^ -(int64_t)(n & 1); + + return ret; +} + +static int parse_varint_i32(void **buf, size_t *size, int32_t *val) +{ + int64_t v = 0; + + if (parse_varint_i64(buf, size, &v) < 0) + return -1; + + *val = (int32_t)v; + return 0; +} + +static const LZ4F_preferences_t kPrefs = +{ + .frameInfo = {LZ4F_default, LZ4F_blockIndependent, }, + .compressionLevel = 0, +}; + +static int compress_buf(KafkaBlock *block, int compress_type, void *env) +{ + z_stream *c_stream; + size_t total_in = 0, gzip_in, bound_size; + KafkaBuffer *snappy_buffer; + KafkaBlock nblock; + LZ4F_errorCode_t lz4_r; + LZ4F_cctx *lz4_cctx; + ZSTD_CStream *zstd_cctx; + size_t zstd_r; + ZSTD_outBuffer out; + ZSTD_inBuffer in; + + switch (compress_type) + { + case Kafka_Gzip: + c_stream = static_cast(env); + gzip_in = c_stream->total_in; + while (total_in < block->get_len()) + { + if (c_stream->avail_in == 0) + { + c_stream->next_in = (Bytef *)block->get_block(); + c_stream->avail_in = block->get_len() - total_in; + } + + if (c_stream->avail_out == 0) + { + bound_size = compressBound(c_stream->avail_in); + if (!nblock.allocate(bound_size)) + { + delete c_stream; + return -1; + } + + c_stream->next_out = (Bytef *)nblock.get_block(); + c_stream->avail_out = bound_size; + } + + if (deflate(c_stream, Z_NO_FLUSH) != Z_OK) + { + delete c_stream; + errno = EBADMSG; + return -1; + } + + total_in += c_stream->total_in - gzip_in; + gzip_in = c_stream->total_in; + } + + *block = std::move(nblock); + + break; + + case Kafka_Snappy: + snappy_buffer = static_cast(env); + snappy_buffer->append((const char *)block->get_block(), block->get_len()); + + break; + + case Kafka_Lz4: + lz4_cctx = static_cast(env); + bound_size = LZ4F_compressBound(block->get_len(), &kPrefs); + if (!nblock.allocate(bound_size)) + { + LZ4F_freeCompressionContext(lz4_cctx); + return -1; + } + + lz4_r = LZ4F_compressUpdate(lz4_cctx, + nblock.get_block(), nblock.get_len(), + block->get_block(), block->get_len(), + NULL); + + if (LZ4F_isError(lz4_r)) + { + LZ4F_freeCompressionContext(lz4_cctx); + errno = EBADMSG; + return -1; + } + + nblock.set_len(lz4_r); + *block = std::move(nblock); + break; + + case Kafka_Zstd: + zstd_cctx = static_cast(env); + bound_size = ZSTD_compressBound(block->get_len()); + if (!nblock.allocate(bound_size)) + { + ZSTD_freeCStream(zstd_cctx); + return -1; + } + + in.src = block->get_block(); + in.pos = 0; + in.size = block->get_len(); + out.dst = nblock.get_block(); + out.pos = 0; + out.size = nblock.get_len(); + zstd_r = ZSTD_compressStream(zstd_cctx, &out, &in); + if (ZSTD_isError(zstd_r) || in.pos < in.size) + { + ZSTD_freeCStream(zstd_cctx); + errno = EBADMSG; + return -1; + } + + nblock.set_len(out.pos); + *block = std::move(nblock); + break; + + default: + return 0; + } + + return 0; +} + +static int gzip_decompress(void *compressed, size_t n, KafkaBlock *block) +{ + for (int pass = 1; pass <= 2; pass++) + { + z_stream strm = {0}; + gz_header hdr; + char buf[512]; + char *p; + int len; + int r; + + if ((r = inflateInit2(&strm, 15 | 32)) != Z_OK) + { + errno = EBADMSG; + return -1; + } + + strm.next_in = (Bytef *)compressed; + strm.avail_in = n; + + if ((r = inflateGetHeader(&strm, &hdr)) != Z_OK) + { + inflateEnd(&strm); + errno = EBADMSG; + return -1; + } + + if (pass == 1) + { + p = buf; + len = sizeof(buf); + } + else + { + p = (char *)block->get_block(); + len = block->get_len(); + } + + do + { + strm.next_out = (unsigned char *)p; + strm.avail_out = len; + + r = inflate(&strm, Z_NO_FLUSH); + switch (r) + { + case Z_STREAM_ERROR: + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: + inflateEnd(&strm); + errno = EBADMSG; + return -1; + } + + if (pass == 2) + { + p += len - strm.avail_out; + len -= len - strm.avail_out; + } + + } while (strm.avail_out == 0 && r != Z_STREAM_END); + + + if (pass == 1) + { + if (!block->allocate(strm.total_out)) + { + inflateEnd(&strm); + return -1; + } + } + + inflateEnd(&strm); + + if (strm.total_in != n || r != Z_STREAM_END) + { + errno = EBADMSG; + return -1; + } + } + + return 0; +} + +static int kafka_snappy_java_uncompress(const char *inbuf, size_t inlen, KafkaBlock *block) +{ + char *obuf = NULL; + + for (int pass = 1; pass <= 2; pass++) + { + ssize_t off = 0; + ssize_t uoff = 0; + + while (off + 4 <= (ssize_t)inlen) + { + uint32_t clen; + size_t ulen; + + memcpy(&clen, inbuf + off, 4); + clen = ntohl(clen); + off += 4; + + if (clen > inlen - off) + { + errno = EBADMSG; + return -1; + } + + if (snappy_uncompressed_length(inbuf + off, clen, &ulen) != SNAPPY_OK) + { + errno = EBADMSG; + return -1; + } + + if (pass == 1) + { + off += clen; + uoff += ulen; + continue; + } + + size_t n = block->get_len() - uoff; + + if (snappy_uncompress(inbuf + off, clen, obuf + uoff, &n) != SNAPPY_OK) + { + errno = EBADMSG; + return -1; + } + + off += clen; + uoff += ulen; + } + + if (off != (ssize_t)inlen) + { + errno = EBADMSG; + return -1; + } + + if (pass == 1) + { + if (uoff <= 0) + { + errno = EBADMSG; + return -1; + } + + if (!block->allocate(uoff)) + return -1; + + obuf = (char *)block->get_block(); + } + else + block->set_len(uoff); + } + + return 0; +} + +static int snappy_decompress(void *buf, size_t n, KafkaBlock *block) +{ + const char *inbuf = (const char *)buf; + size_t inlen = n; + static const unsigned char snappy_java_magic[] = { + 0x82, 'S','N','A','P','P','Y', 0 + }; + static const size_t snappy_java_hdrlen = 8 + 4 + 4; + + if (!memcmp(buf, snappy_java_magic, 8)) + { + inbuf = inbuf + snappy_java_hdrlen; + inlen -= snappy_java_hdrlen; + return kafka_snappy_java_uncompress(inbuf, inlen, block); + } + else + { + size_t uncompressed_len; + + if (snappy_uncompressed_length(inbuf, n, &uncompressed_len) != SNAPPY_OK) + { + errno = EBADMSG; + return -1; + } + + if (!block->allocate(uncompressed_len)) + return -1; + + size_t nn = block->get_len(); + + return snappy_uncompress(inbuf, n, (char *)block->get_block(), &nn); + } +} + +static int lz4_decompress(void *buf, size_t n, KafkaBlock *block) +{ + LZ4F_errorCode_t code; + LZ4F_decompressionContext_t dctx; + LZ4F_frameInfo_t fi; + size_t in_sz, out_sz; + size_t in_off, out_off; + size_t r; + size_t uncompressed_size; + size_t outlen; + char *out = NULL; + size_t inlen = n; + + const char *inbuf = (const char *)buf; + + code = LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION); + if (LZ4F_isError(code)) + { + code = LZ4F_freeDecompressionContext(dctx); + errno = EBADMSG; + return -1; + } + + in_sz = n; + r = LZ4F_getFrameInfo(dctx, &fi, (const void *)buf, &in_sz); + if (LZ4F_isError(r)) + { + code = LZ4F_freeDecompressionContext(dctx); + errno = EBADMSG; + return -1; + } + + if (fi.contentSize == 0 || fi.contentSize > inlen * 255) + uncompressed_size = inlen * 4; + else + uncompressed_size = (size_t)fi.contentSize; + + if (!block->allocate(uncompressed_size)) + { + code = LZ4F_freeDecompressionContext(dctx); + return -1; + } + + out = (char *)block->get_block(); + outlen = block->get_len(); + in_off = in_sz; + out_off = 0; + while (in_off < inlen) + { + out_sz = outlen - out_off; + in_sz = inlen - in_off; + r = LZ4F_decompress(dctx, out + out_off, &out_sz, + inbuf + in_off, &in_sz, NULL); + if (LZ4F_isError(r)) + { + code = LZ4F_freeDecompressionContext(dctx); + errno = EBADMSG; + return -1; + } + + if (!(out_off + out_sz <= outlen && in_off + in_sz <= inlen)) + { + code = LZ4F_freeDecompressionContext(dctx); + errno = EBADMSG; + return -1; + } + + out_off += out_sz; + in_off += in_sz; + if (r == 0) + break; + + if (out_off == outlen) + { + size_t extra = outlen * 3 / 4; + + if (!block->reallocate(outlen + extra)) + { + code = LZ4F_freeDecompressionContext(dctx); + errno = EBADMSG; + return -1; + } + + out = (char *)block->get_block(); + outlen += extra; + } + } + + + if (in_off < inlen) + { + code = LZ4F_freeDecompressionContext(dctx); + errno = EBADMSG; + return -1; + } + + LZ4F_freeDecompressionContext(dctx); + return 0; +} + +static int zstd_decompress(void *buf, size_t n, KafkaBlock *block) +{ + unsigned long long out_bufsize = ZSTD_getFrameContentSize(buf, n); + + switch (out_bufsize) + { + case ZSTD_CONTENTSIZE_UNKNOWN: + out_bufsize = n * 2; + break; + + case ZSTD_CONTENTSIZE_ERROR: + errno = EBADMSG; + return -1; + + default: + break; + } + + while (1) + { + size_t ret; + + if (!block->allocate(out_bufsize)) + return -1; + + ret = ZSTD_decompress(block->get_block(), out_bufsize, + buf, n); + + if (!ZSTD_isError(ret)) + return 0; + + if (ZSTD_getErrorCode(ret) == ZSTD_error_dstSize_tooSmall) + { + out_bufsize += out_bufsize * 2; + } + else + { + errno = EBADMSG; + return -1; + } + } +} + +static int uncompress_buf(void *buf, size_t size, KafkaBlock *block, + int compress_type) +{ + switch(compress_type) + { + case Kafka_Gzip: + return gzip_decompress(buf, size, block); + + case Kafka_Snappy: + return snappy_decompress(buf, size, block); + + case Kafka_Lz4: + return lz4_decompress(buf, size, block); + + case Kafka_Zstd: + return zstd_decompress(buf, size, block); + + default: + errno = EBADMSG; + return -1; + } +} + +static int append_message_set(KafkaBlock *block, + const KafkaRecord *record, + int offset, int msg_version, + const KafkaConfig& config, void *env, + int cur_msg_size) +{ + const void *key; + size_t key_len; + record->get_key(&key, &key_len); + + const void *value; + size_t value_len; + record->get_value(&value, &value_len); + + int message_size = 4 + 1 + 1 + 4 + 4 + key_len + value_len; + + if (msg_version == 1) + message_size += 8; + + int max_msg_size = std::min(config.get_produce_msgset_max_bytes(), + config.get_produce_msg_max_bytes()); + if (message_size + 8 + 4 + cur_msg_size > max_msg_size) + return 1; + + if (!block->allocate(message_size + 8 + 4)) + return -1; + + void *cur = block->get_block(); + + append_i64(&cur, offset); + append_i32(&cur, message_size); + + int crc_32 = crc32(0, NULL, 0); + + append_i32(&cur, crc_32); //need update + append_i8(&cur, msg_version); + append_i8(&cur, 0); + + if (msg_version == 1) + append_i64(&cur, record->get_timestamp()); + + append_bytes(&cur, (const char *)key, key_len); + append_nullable_bytes(&cur, (const char *)value, value_len); + + char *crc_buf = (char *)block->get_block() + 8 + 4; + + crc_32 = crc32(crc_32, (Bytef *)(crc_buf + 4), message_size - 4); + *(uint32_t *)crc_buf = htonl(crc_32); + + if (compress_buf(block, config.get_compress_type(), env) < 0) + return -1; + + return 0; +} + +static int append_batch_record(KafkaBlock *block, + const KafkaRecord *record, + int offset, const KafkaConfig& config, + int64_t first_timestamp, void *env, + int cur_msg_size) +{ + const void *key; + size_t key_len; + record->get_key(&key, &key_len); + + const void *value; + size_t value_len; + record->get_value(&value, &value_len); + + std::string klen_str; + std::string vlen_str; + std::string timestamp_delta_str; + int64_t timestamp_delta = record->get_timestamp() - first_timestamp; + + append_varint_i64(timestamp_delta_str, timestamp_delta); + + std::string offset_delta_str; + append_varint_i64(offset_delta_str, offset); + + if (key_len > 0) + append_varint_i32(klen_str, (int32_t)key_len); + else + append_varint_i32(klen_str, (int32_t)-1); + + if (value) + append_varint_i32(vlen_str, (int32_t)value_len); + else + append_varint_i32(vlen_str, -1); + + struct list_head *pos; + kafka_record_header_t *header; + std::string hdr_str; + int hdr_cnt = 0; + + list_for_each(pos, record->get_header_list()) + { + header = list_entry(pos, kafka_record_header_t, list); + append_varint_i32(hdr_str, (int32_t)header->key_len); + append_string_raw(hdr_str, (const char *)header->key, header->key_len); + append_varint_i32(hdr_str, (int32_t)header->value_len); + append_string_raw(hdr_str, (const char *)header->value, header->value_len); + ++hdr_cnt; + } + + std::string hdr_cnt_str; + append_varint_i32(hdr_cnt_str, hdr_cnt); + + int length = 1 + timestamp_delta_str.size() + offset_delta_str.size() + + klen_str.size() + key_len + vlen_str.size() + value_len + + hdr_cnt_str.size() + hdr_str.size(); + + std::string length_str; + append_varint_i32(length_str, length); + + int max_msg_size = std::min(config.get_produce_msgset_max_bytes(), + config.get_produce_msg_max_bytes()); + if ((int)(length + length_str.size() + cur_msg_size) > max_msg_size) + return 1; + + if (!block->allocate(length + length_str.size())) + return false; + + void *cur = block->get_block(); + + append_string_raw(&cur, length_str); + append_i8(&cur, 0); + append_string_raw(&cur, timestamp_delta_str); + append_string_raw(&cur, offset_delta_str); + append_string_raw(&cur, klen_str); + + if (key_len > 0) + append_string_raw(&cur, (const char *)key, key_len); + + append_string_raw(&cur, vlen_str); + if (value_len > 0) + append_string_raw(&cur, (const char *)value, value_len); + + append_string_raw(&cur, hdr_cnt_str); + if (hdr_cnt > 0) + append_string_raw(&cur, hdr_str); + + if (compress_buf(block, config.get_compress_type(), env) < 0) + return -1; + + return 0; +} + +static int append_record(KafkaBlock *block, + const KafkaRecord *record, + int offset, int msg_version, + const KafkaConfig& config, + int64_t first_timestamp, void *env, + int cur_msg_size) +{ + if (config.get_produce_msgset_cnt() < offset) + return 1; + + int ret = 0; + + switch (msg_version) + { + case 0: + case 1: + ret = append_message_set(block, record, offset, msg_version, + config, env, cur_msg_size); + break; + + case 2: + ret = append_batch_record(block, record, offset, config, + first_timestamp, env, cur_msg_size); + break; + + default: + break; + } + + return ret; +} + +static int parse_string(void **buf, size_t *size, std::string& str) +{ + if (*size >= 2) + { + int16_t len; + + if (parse_i16(buf, size, &len) >= 0) + { + if (len >= -1) + { + if (len == -1) + len = 0; + + if (*size >= (size_t)len) + { + str.assign((char *)*buf, len); + *size -= len; + *buf = (char *)*buf + len; + return 0; + } + else + { + *buf = (char *)*buf - 2; + *size += 2; + } + } + } + } + + errno = EBADMSG; + return -1; +} + +static int parse_string(void **buf, size_t *size, char **str) +{ + if (*size >= 2) + { + int16_t len; + + if (parse_i16(buf, size, &len) >= 0) + { + if (len >= -1) + { + if (len == -1) + len = 0; + + if (*size >= (size_t)len) + { + char *p = (char *)malloc(len + 1); + + if (!p) + { + *buf = (char *)*buf - 2; + *size += 2; + return -1; + } + + free(*str); + memcpy((void *)p, *buf, len); + p[len] = 0; + *size -= len; + *buf = (char *)*buf + len; + *str = p; + + return 0; + } + else + { + *buf = (char *)*buf - 2; + *size += 2; + } + } + } + } + + errno = EBADMSG; + return -1; +} + +static int parse_bytes(void **buf, size_t *size, std::string& str) +{ + if (*size >= 4) + { + int32_t len; + + if (parse_i32(buf, size, &len) >= 0) + { + if (len == -1) + len = 0; + + if (*size >= (size_t)len) + { + str.assign((char *)*buf, len); + *size -= len; + *buf = (char *)*buf + len; + + return 0; + } + else + { + *buf = (char *)*buf - 4; + *size += 4; + } + } + } + + errno = EBADMSG; + return -1; +} + +static int parse_bytes(void **buf, size_t *size, + void **str, size_t *str_len) +{ + if (*size >= 4) + { + int32_t len; + + if (parse_i32(buf, size, &len) >= 0) + { + if (len == -1) + len = 0; + + if (*size >= (size_t)len) + { + *str = *buf; + *str_len = len; + *size -= len; + *buf = (char *)*buf + len; + + return 0; + } + else + { + *buf = (char *)*buf - 4; + *size += 4; + } + } + } + + errno = EBADMSG; + return -1; +} + +static int parse_varint_u64(void **buf, size_t *size, uint64_t *val) +{ + size_t off = 0; + uint64_t num = 0; + int shift = 0; + size_t org_size = *size; + + do + { + if (*size == 0) + { + *size = org_size; + errno = EBADMSG; + return -1; /* Underflow */ + } + + num |= (uint64_t)(((char *)(*buf))[(int)off] & 0x7f) << shift; + shift += 7; + } while (((char *)(*buf))[(int)off++] & 0x80); + + *val = num; + *buf = (char *)(*buf) + off; + *size -= off; + return 0; +} + +int KafkaMessage::parse_message_set(void **buf, size_t *size, + bool check_crcs, int msg_vers, + struct list_head *record_list, + KafkaBuffer *uncompressed, + KafkaToppar *toppar) +{ + int64_t offset; + int32_t message_size; + int32_t crc; + + if (parse_i64(buf, size, &offset) < 0) + return -1; + + if (parse_i32(buf, size, &message_size) < 0) + return -1; + + if (*size < (size_t)(message_size - 8)) + return 1; + + if (parse_i32(buf, size, &crc) < 0) + return -1; + + if (check_crcs) + { + int32_t crc_32 = crc32(0, NULL, 0); + crc_32 = crc32(crc_32, (Bytef *)*buf, message_size - 4); + if (crc_32 != crc) + { + errno = EBADMSG; + return -1; + } + } + + int8_t magic; + int8_t attributes; + + if (parse_i8(buf, size, &magic) < 0) + return -1; + + if (parse_i8(buf, size, &attributes) < 0) + return -1; + + int64_t timestamp = -1; + if (msg_vers == 1 && parse_i64(buf, size, ×tamp) < 0) + return -1; + + void *key; + size_t key_len; + if (parse_bytes(buf, size, &key, &key_len) < 0) + return -1; + + void *payload; + size_t payload_len; + + if (parse_bytes(buf, size, &payload, &payload_len) < 0) + return -1; + + if (offset >= toppar->get_offset()) + { + int compress_type = attributes & 3; + if (compress_type == 0) + { + KafkaRecord *kafka_record = new KafkaRecord; + kafka_record_t *record = kafka_record->get_raw_ptr(); + record->key = key; + record->key_len = key_len; + record->timestamp = timestamp; + record->offset = offset; + record->toppar = toppar->get_raw_ptr(); + record->key_is_moved = 1; + record->value_is_moved = 1; + record->value = payload; + record->value_len = payload_len; + list_add_tail(kafka_record->get_list(), record_list); + } + else + { + KafkaBlock block; + if (uncompress_buf(payload, payload_len, &block, compress_type) < 0) + return -1; + + struct list_head *record_head = record_list->prev; + void *uncompressed_ptr = block.get_block(); + size_t uncompressed_len = block.get_len(); + parse_message_set(&uncompressed_ptr, &uncompressed_len, check_crcs, + msg_vers, record_list, uncompressed, toppar); + + uncompressed->add_item(std::move(block)); + + if (msg_vers == 1) + { + struct list_head *pos; + KafkaRecord *record; + int n = 0; + + for (pos = record_head->next; pos != record_list; pos = pos->next) + n++; + + for (pos = record_head->next; pos != record_list; pos = pos->next) + { + int64_t fix_offset; + + record = list_entry(pos, KafkaRecord, list); + fix_offset = offset + record->get_offset() - n + 1; + record->set_offset(fix_offset); + } + } + } + } + + if (*size > 0) + { + return parse_message_set(buf, size, check_crcs, msg_vers, + record_list, uncompressed, toppar); + } + + return 0; +} + +static int parse_varint_bytes(void **buf, size_t *size, + void **str, size_t *str_len) +{ + int64_t len = 0; + + if (parse_varint_i64(buf, size, &len) >= 0) + { + if (len >= -1) + { + if (len <= 0) + { + *str = NULL; + *str_len = 0; + return 0; + } + + if ((int64_t)*size >= len) + { + *str = *buf; + *str_len = (size_t)len; + *size -= len; + *buf = (char *)*buf + len; + return 0; + } + } + } + + errno = EBADMSG; + return -1; +} + +struct KafkaBatchRecordHeader +{ + int64_t base_offset; + int32_t length; + int32_t partition_leader_epoch; + int8_t magic; + int32_t crc; + int16_t attributes; + int32_t last_offset_delta; + int64_t base_timestamp; + int64_t max_timestamp; + int64_t produce_id; + int16_t producer_epoch; + int32_t base_sequence; + int32_t record_count; +}; + +int KafkaMessage::parse_message_record(void **buf, size_t *size, + kafka_record_t *record) +{ + int64_t length; + int8_t attributes; + int64_t timestamp_delta; + int64_t offset_delta; + int32_t hdr_size; + + if (parse_varint_i64(buf, size, &length) < 0) + return -1; + + if (parse_i8(buf, size, &attributes) < 0) + return -1; + + if (parse_varint_i64(buf, size, ×tamp_delta) < 0) + return -1; + + if (parse_varint_i64(buf, size, &offset_delta) < 0) + return -1; + + record->timestamp += timestamp_delta; + record->offset += offset_delta; + + if (parse_varint_bytes(buf, size, &record->key, &record->key_len) < 0) + return -1; + + if (parse_varint_bytes(buf, size, &record->value, &record->value_len) < 0) + return -1; + + if (parse_varint_i32(buf, size, &hdr_size) < 0) + return -1; + + for (int i = 0; i < hdr_size; ++i) + { + kafka_record_header_t *header; + + header = (kafka_record_header_t *)malloc(sizeof(kafka_record_header_t)); + if (!header) + return -1; + + kafka_record_header_init(header); + if (parse_varint_bytes(buf, size, &header->key, &header->key_len) < 0) + { + free(header); + return -1; + } + + if (parse_varint_bytes(buf, size, &header->value, &header->value_len) < 0) + { + kafka_record_header_deinit(header); + free(header); + return -1; + } + + header->key_is_moved = 1; + header->value_is_moved = 1; + + list_add_tail(&header->list, &record->header_list); + } + + return record->offset < record->toppar->offset ? 1 : 0; +} + +int KafkaMessage::parse_record_batch(void **buf, size_t *size, + bool check_crcs, + struct list_head *record_list, + KafkaBuffer *uncompressed, + KafkaToppar *toppar) +{ + KafkaBatchRecordHeader hdr; + + if (parse_i64(buf, size, &hdr.base_offset) < 0) + return -1; + + if (parse_i32(buf, size, &hdr.length) < 0) + return -1; + + if (parse_i32(buf, size, &hdr.partition_leader_epoch) < 0) + return -1; + + if (parse_i8(buf, size, &hdr.magic) < 0) + return -1; + + if (parse_i32(buf, size, &hdr.crc) < 0) + return -1; + + if (check_crcs) + { + if (hdr.length > (int)*size + 9) + { + errno = EBADMSG; + return -1; + } + + if ((int)crc32c(0, (const void *)*buf, hdr.length - 9) != hdr.crc) + { + errno = EBADMSG; + return -1; + } + } + + if (parse_i16(buf, size, &hdr.attributes) < 0) + return -1; + + if (parse_i32(buf, size, &hdr.last_offset_delta) < 0) + return -1; + + if (parse_i64(buf, size, &hdr.base_timestamp) < 0) + return -1; + + if (parse_i64(buf, size, &hdr.max_timestamp) < 0) + return -1; + + if (parse_i64(buf, size, &hdr.produce_id) < 0) + return -1; + + if (parse_i16(buf, size, &hdr.producer_epoch) < 0) + return -1; + + if (parse_i32(buf, size, &hdr.base_sequence) < 0) + return -1; + + if (parse_i32(buf, size, &hdr.record_count) < 0) + return -1; + + if (*size < (size_t)(hdr.length - 61 + 12)) + return 1; + + KafkaBlock block; + + if (hdr.attributes & 7) + { + if (uncompress_buf(*buf, hdr.length - 61 + 12, &block, hdr.attributes & 7) < 0) + return -1; + + *buf = (char *)*buf + hdr.length - 61 + 12; + *size -= hdr.length - 61 + 12; + } + + void *p = *buf; + size_t n = *size; + + if (block.get_len() > 0) + { + p = block.get_block(); + n = block.get_len(); + } + + for (int i = 0; i < hdr.record_count; ++i) + { + KafkaRecord *record = new KafkaRecord; + record->set_offset(hdr.base_offset); + record->set_timestamp(hdr.base_timestamp); + record->get_raw_ptr()->key_is_moved = 1; + record->get_raw_ptr()->value_is_moved = 1; + record->get_raw_ptr()->toppar = toppar->get_raw_ptr(); + + switch (parse_message_record(&p, &n, record->get_raw_ptr())) + { + case -1: + delete record; + return -1; + case 0: + list_add_tail(record->get_list(), record_list); + break; + default: + delete record; + break; + } + } + + if (hdr.attributes == 0) + { + *buf = p; + *size = n; + } + + if (block.get_len() > 0) + uncompressed->add_item(std::move(block)); + + return 0; +} + +int KafkaMessage::parse_records(void **buf, size_t *size, bool check_crcs, + KafkaBuffer *uncompressed, KafkaToppar *toppar) +{ + struct list_head *record_list = toppar->get_record(); + int msg_set_size = 0; + + if (parse_i32(buf, size, &msg_set_size) < 0) + return -1; + + if (msg_set_size == 0) + return 0; + + if (msg_set_size < 0) + return -1; + + if (*size < 17) + return -1; + + size_t msg_size = msg_set_size; + + while (msg_size > 16) + { + int ret = -1; + char magic = ((char *)(*buf))[16]; + + switch(magic) + { + case 0: + case 1: + ret = parse_message_set(buf, &msg_size, check_crcs, + magic, record_list, + uncompressed, toppar); + break; + + case 2: + ret = parse_record_batch(buf, &msg_size, check_crcs, + record_list, uncompressed, toppar); + break; + + default: + break; + } + + if (ret > 0) + { + *size -= msg_set_size; + *buf = (char *)*buf + msg_size; + return 0; + } + else if (ret < 0) + break; + } + + *size -= msg_set_size; + *buf = (char *)*buf + msg_size; + return 0; +} + +KafkaMessage::KafkaMessage() +{ + static struct Crc32cInitializer + { + Crc32cInitializer() + { + crc32c_global_init(); + } + } initializer; + + this->parser = new kafka_parser_t; + kafka_parser_init(this->parser); + this->stream = new EncodeStream; + this->api_type = Kafka_Unknown; + this->correlation_id = 0; + this->cur_size = 0; +} + +KafkaMessage::~KafkaMessage() +{ + if (this->parser) + { + kafka_parser_deinit(this->parser); + delete this->parser; + delete this->stream; + } +} + +KafkaMessage::KafkaMessage(KafkaMessage&& msg) : + ProtocolMessage(std::move(msg)) +{ + this->parser = msg.parser; + this->stream = msg.stream; + msg.parser = NULL; + msg.stream = NULL; + + this->msgbuf = std::move(msg.msgbuf); + this->headbuf = std::move(msg.headbuf); + + this->toppar_list = std::move(msg.toppar_list); + this->serialized = std::move(msg.serialized); + this->uncompressed = std::move(msg.uncompressed); + + this->api_type = msg.api_type; + msg.api_type = Kafka_Unknown; + + this->compress_env = msg.compress_env; + msg.compress_env = NULL; + + this->cur_size = msg.cur_size; + msg.cur_size = 0; +} + +KafkaMessage& KafkaMessage::operator= (KafkaMessage &&msg) +{ + if (this != &msg) + { + *(ProtocolMessage *)this = std::move(msg); + + if (this->parser) + { + kafka_parser_deinit(this->parser); + delete this->parser; + delete this->stream; + } + + this->parser = msg.parser; + this->stream = msg.stream; + msg.parser = NULL; + msg.stream = NULL; + + this->msgbuf = std::move(msg.msgbuf); + this->headbuf = std::move(msg.headbuf); + + this->toppar_list = std::move(msg.toppar_list); + this->serialized = std::move(msg.serialized); + this->uncompressed = std::move(msg.uncompressed); + + this->api_type = msg.api_type; + msg.api_type = Kafka_Unknown; + + this->compress_env = msg.compress_env; + msg.compress_env = NULL; + + this->cur_size = msg.cur_size; + msg.cur_size = 0; + } + + return *this; +} + +int KafkaMessage::encode_message(int api_type, struct iovec vectors[], int max) +{ + const auto it = this->encode_func_map.find(api_type); + + if (it == this->encode_func_map.cend()) + return -1; + + return it->second(vectors, max); +} + +static int kafka_api_get_max_ver(int api_type) +{ + switch (api_type) + { + case Kafka_Metadata: + return 4; + case Kafka_Produce: + return 7; + case Kafka_Fetch: + return 11; + case Kafka_FindCoordinator: + return 2; + case Kafka_JoinGroup: + return 5; + case Kafka_SyncGroup: + return 3; + case Kafka_Heartbeat: + return 3; + case Kafka_OffsetFetch: + return 1; + case Kafka_OffsetCommit: + return 7; + case Kafka_ListOffsets: + return 1; + case Kafka_LeaveGroup: + return 1; + case Kafka_ApiVersions: + return 0; + case Kafka_SaslHandshake: + return 1; + case Kafka_SaslAuthenticate: + return 0; + case Kafka_DescribeGroups: + return 0; + default: + return 0; + } +} + +static int kafka_get_api_version(const kafka_api_t *api, const KafkaConfig& conf, + int api_type, int max_ver, int message_version) +{ + int min_ver = 0; + + if (api_type == Kafka_Produce) + { + if (message_version == 2) + min_ver = 3; + else if (message_version == 1) + min_ver = 1; + + if (conf.get_compress_type() == Kafka_Zstd) + min_ver = 7; + } + + return kafka_broker_get_api_version(api, api_type, min_ver, max_ver); +} + +int KafkaMessage::encode_head() +{ + if (this->api_type == Kafka_ApiVersions) + this->api_version = 0; + else + { + int max_ver = kafka_api_get_max_ver(this->api_type); + + if (this->api->features & KAFKA_FEATURE_MSGVER2) + this->message_version = 2; + else if (this->api->features & KAFKA_FEATURE_MSGVER1) + this->message_version = 1; + else + this->message_version = 0; + + if (this->config.get_compress_type() == Kafka_Lz4 && + !(this->api->features & KAFKA_FEATURE_LZ4)) + { + this->config.set_compress_type(Kafka_NoCompress); + } + + if (this->config.get_compress_type() == Kafka_Zstd && + !(this->api->features & KAFKA_FEATURE_ZSTD)) + { + this->config.set_compress_type(Kafka_NoCompress); + } + + this->api_version = kafka_get_api_version(this->api, this->config, + this->api_type, max_ver, + this->message_version); + } + + if (this->api_version < 0) + return -1; + + append_i32(this->headbuf, 0); + append_i16(this->headbuf, this->api_type); + append_i16(this->headbuf, this->api_version); + append_i32(this->headbuf, this->correlation_id); + append_string(this->headbuf, this->config.get_client_id()); + + return 0; +} + +int KafkaMessage::encode(struct iovec vectors[], int max) +{ + if (encode_head() < 0) + return -1; + + int n = encode_message(this->api_type, vectors + 1, max - 1); + if (n < 0) + return -1; + + int msg_size = this->headbuf.size() + this->cur_size - 4; + *(int32_t *)this->headbuf.c_str() = htonl(msg_size); + + vectors[0].iov_base = (void *)this->headbuf.c_str(); + vectors[0].iov_len = this->headbuf.size(); + + return n + 1; +} + +int KafkaMessage::append(const void *buf, size_t *size) +{ + int ret = kafka_parser_append_message(buf, size, this->parser); + + if (ret >= 0) + { + this->cur_size += *size; + if (this->cur_size > this->size_limit) + { + errno = EMSGSIZE; + ret = -1; + } + } + else if (ret == -2) + { + errno = EBADMSG; + ret = -1; + } + + return ret; +} + +static int kafka_compress_prepare(int compress_type, void **env, + KafkaBlock *block) +{ + z_stream *c_stream; + KafkaBuffer *snappy_buffer; + size_t lz4_out_len; + LZ4F_errorCode_t lz4_r; + LZ4F_cctx *lz4_cctx = NULL; + ZSTD_CStream *zstd_cctx; + size_t zstd_r; + + switch (compress_type) + { + case Kafka_Gzip: + c_stream = new z_stream; + c_stream->zalloc = (alloc_func)0; + c_stream->zfree = (free_func)0; + c_stream->opaque = (voidpf)0; + if (deflateInit2(c_stream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 15 | 16, + 8, Z_DEFAULT_STRATEGY) != Z_OK) + { + delete c_stream; + errno = EBADMSG; + return -1; + } + + c_stream->avail_in = 0; + c_stream->avail_out = 0; + c_stream->total_in = 0; + + *env = (void *)c_stream; + break; + + case Kafka_Snappy: + snappy_buffer = new KafkaBuffer; + *env = (void *)snappy_buffer; + break; + + case Kafka_Lz4: + lz4_r = LZ4F_createCompressionContext(&lz4_cctx, LZ4F_VERSION); + if (LZ4F_isError(lz4_r)) + { + LZ4F_freeCompressionContext(lz4_cctx); + errno = EBADMSG; + return -1; + } + + lz4_out_len = LZ4F_HEADER_SIZE_MAX; + + if (!block->allocate(lz4_out_len)) + { + LZ4F_freeCompressionContext(lz4_cctx); + return -1; + } + + lz4_r = LZ4F_compressBegin(lz4_cctx, block->get_block(), + block->get_len(), &kPrefs); + if (LZ4F_isError(lz4_r)) + { + LZ4F_freeCompressionContext(lz4_cctx); + errno = EBADMSG; + return -1; + } + + block->set_len(lz4_r); + + *env = (void *)lz4_cctx; + break; + + case Kafka_Zstd: + zstd_cctx = ZSTD_createCStream(); + if (!zstd_cctx) + return -1; + + zstd_r = ZSTD_initCStream(zstd_cctx, ZSTD_CLEVEL_DEFAULT); + if (ZSTD_isError(zstd_r)) + { + ZSTD_freeCStream(zstd_cctx); + errno = EBADMSG; + return -1; + } + + *env = (void *)zstd_cctx; + break; + + default: + return 0; + } + + return 0; +} + +static int kafka_compress_finish(int compress_type, void *env, + KafkaBuffer *buffer, int *addon) +{ + int gzip_err; + LZ4F_cctx *lz4_cctx; + z_stream *c_stream; + size_t out_len; + KafkaBuffer *snappy_buffer; + LZ4F_errorCode_t lz4_r; + ZSTD_CStream *zstd_cctx; + size_t zstd_r; + ZSTD_outBuffer out; + KafkaBlock block; + size_t zstd_end_bufsize = ZSTD_compressBound(buffer->get_size()); + + switch (compress_type) + { + case Kafka_Gzip: + c_stream = static_cast(env); + out_len = c_stream->total_out; + for(;;) + { + if (c_stream->avail_out == 0) + { + block.allocate(1024); + c_stream->next_out = (Bytef *)block.get_block(); + c_stream->avail_out = 1024; + } + + gzip_err = deflate(c_stream, Z_FINISH); + + if (gzip_err == Z_STREAM_END) + break; + + if (gzip_err != Z_OK) + { + delete c_stream; + errno = EBADMSG; + return -1; + } + else if (block.get_len() > 0) + { + size_t use_bytes = block.get_len() - c_stream->avail_out; + + block.set_len(use_bytes); + buffer->add_item(std::move(block)); + *addon += use_bytes; + } + } + + if (deflateEnd(c_stream) != Z_OK) + { + delete c_stream; + errno = EBADMSG; + return -1; + } + + if (block.get_len() > 0) + { + size_t use_bytes = block.get_len() - c_stream->avail_out; + block.set_len(use_bytes); + buffer->add_item(std::move(block)); + *addon += use_bytes; + } + else + { + KafkaBlock *b = buffer->get_block_tail(); + size_t use_bytes = b->get_len() - c_stream->avail_out; + int remainer = b->get_len() - use_bytes; + + b->set_len(use_bytes); + *addon += -remainer; + } + + delete c_stream; + break; + + case Kafka_Snappy: + snappy_buffer = static_cast(env); + { + KafkaBuffer kafka_buffer_sink; + KafkaSnappySource source(snappy_buffer); + KafkaSnappySink sink(&kafka_buffer_sink); + + if (snappy::Compress(&source, &sink) < 0) + { + delete snappy_buffer; + errno = EBADMSG; + return -1; + } + + size_t pre_n = buffer->get_size(); + + buffer->list_splice(&kafka_buffer_sink); + *addon = buffer->get_size() - pre_n; + } + + delete snappy_buffer; + break; + + case Kafka_Lz4: + lz4_cctx = static_cast(env); + out_len = LZ4F_compressBound(0, &kPrefs); + if (!block.allocate(out_len)) + { + LZ4F_freeCompressionContext(lz4_cctx); + return -1; + } + + lz4_r = LZ4F_compressEnd(lz4_cctx, block.get_block(), block.get_len(), NULL); + if (LZ4F_isError(lz4_r)) + { + LZ4F_freeCompressionContext(lz4_cctx); + errno = EBADMSG; + return -1; + } + + block.set_len(lz4_r); + buffer->add_item(std::move(block)); + *addon = lz4_r; + LZ4F_freeCompressionContext(lz4_cctx); + break; + + case Kafka_Zstd: + zstd_cctx = static_cast(env); + if (!block.allocate(zstd_end_bufsize)) + return -1; + + out.dst = block.get_block(); + out.pos = 0; + out.size = 1024000; + zstd_r = ZSTD_endStream(zstd_cctx, &out); + if (ZSTD_isError(zstd_r) || zstd_r > 0) + { + ZSTD_freeCStream(zstd_cctx); + errno = EBADMSG; + return -1; + } + + block.set_len(out.pos); + buffer->add_item(std::move(block)); + *addon = out.pos; + ZSTD_freeCStream(zstd_cctx); + break; + + default: + return 0; + } + + return 0; +} + +KafkaRequest::KafkaRequest() +{ + using namespace std::placeholders; + this->encode_func_map[Kafka_Metadata] = std::bind(&KafkaRequest::encode_metadata, this, _1, _2); + this->encode_func_map[Kafka_Produce] = std::bind(&KafkaRequest::encode_produce, this, _1, _2); + this->encode_func_map[Kafka_Fetch] = std::bind(&KafkaRequest::encode_fetch, this, _1, _2); + this->encode_func_map[Kafka_FindCoordinator] = std::bind(&KafkaRequest::encode_findcoordinator, this, _1, _2); + this->encode_func_map[Kafka_JoinGroup] = std::bind(&KafkaRequest::encode_joingroup, this, _1, _2); + this->encode_func_map[Kafka_SyncGroup] = std::bind(&KafkaRequest::encode_syncgroup, this, _1, _2); + this->encode_func_map[Kafka_Heartbeat] = std::bind(&KafkaRequest::encode_heartbeat, this, _1, _2); + this->encode_func_map[Kafka_OffsetFetch] = std::bind(&KafkaRequest::encode_offsetfetch, this, _1, _2); + this->encode_func_map[Kafka_OffsetCommit] = std::bind(&KafkaRequest::encode_offsetcommit, this, _1, _2); + this->encode_func_map[Kafka_ListOffsets] = std::bind(&KafkaRequest::encode_listoffset, this, _1, _2); + this->encode_func_map[Kafka_LeaveGroup] = std::bind(&KafkaRequest::encode_leavegroup, this, _1, _2); + this->encode_func_map[Kafka_ApiVersions] = std::bind(&KafkaRequest::encode_apiversions, this, _1, _2); + this->encode_func_map[Kafka_SaslHandshake] = std::bind(&KafkaRequest::encode_saslhandshake, this, _1, _2); + this->encode_func_map[Kafka_SaslAuthenticate] = std::bind(&KafkaRequest::encode_saslauthenticate, this, _1, _2); +} + +int KafkaRequest::encode_produce(struct iovec vectors[], int max) +{ + this->stream->reset(vectors, max); + + //transaction_id + if (this->api_version >= 3) + append_nullable_string(this->msgbuf, "", 0); + + append_i16(this->msgbuf, this->config.get_produce_acks()); + append_i32(this->msgbuf, this->config.get_produce_timeout()); + + int topic_cnt = 0; + this->toppar_list.rewind(); + KafkaToppar *toppar; + + while ((toppar = this->toppar_list.get_next()) != NULL) + { + std::string topic_header; + KafkaBlock header_block; + int record_flag = -1; + + append_string(topic_header, toppar->get_topic()); + append_i32(topic_header, 1); + append_i32(topic_header, toppar->get_partition()); + append_i32(topic_header, 0); // recordset length + + if (!header_block.set_block((void *)topic_header.c_str(), + topic_header.size())) + { + return -1; + } + + void *recordset_size_ptr = (void *)((char *)header_block.get_block() + + header_block.get_len() - 4); + + int64_t first_timestamp = 0; + int64_t max_timestamp = 0; + const int MSGV2HSize = (8 + 4 + 4 + 1 + 4 + 2 + 4 + 8 + 8 + 8 + 2 + 4 + 4); + int batch_length = 0; + + if (this->message_version == 2) + batch_length = MSGV2HSize - (8 + 4); + + size_t cur_serialized_len = this->serialized.get_size(); + int batch_cnt = 0; + + toppar->save_record_startpos(); + KafkaRecord *record; + while ((record = toppar->get_record_next()) != NULL) + { + KafkaBlock compress_block; + KafkaBlock record_block; + struct timespec ts; + + if (record->get_timestamp() == 0) + { + clock_gettime(CLOCK_REALTIME, &ts); + record->set_timestamp((ts.tv_sec * 1000000000 + + ts.tv_nsec) / 1000 / 1000); + } + + if (batch_cnt == 0) + { + if (kafka_compress_prepare(this->config.get_compress_type(), + &this->compress_env, + &compress_block) < 0) + { + return -1; + } + + first_timestamp = record->get_timestamp(); + } + + int ret = append_record(&record_block, record, batch_cnt, + this->message_version, this->config, + first_timestamp, this->compress_env, + batch_length); + + if (ret < 0) + return -1; + + if (ret > 0) + { + toppar->record_rollback(); + toppar->save_record_endpos(); + if (record_flag < 0) + { + errno = EMSGSIZE; + return -1; + } + else + record_flag = 1; + + break; + } + + if (batch_cnt == 0) + { + this->serialized.add_item(std::move(header_block)); + cur_serialized_len = this->serialized.get_size(); + this->serialized.set_insert_pos(); + + if (compress_block.get_len() > 0) + this->serialized.add_item(std::move(compress_block)); + } + + if (record_block.get_len() > 0) + this->serialized.add_item(std::move(record_block)); + + record_flag = 0; + toppar->save_record_endpos(); + + max_timestamp = record->get_timestamp(); + ++batch_cnt; + + batch_length += this->serialized.get_size() - cur_serialized_len; + cur_serialized_len = this->serialized.get_size(); + } + + if (record_flag < 0) + continue; + + if (this->message_version == 2) + { + if (this->config.get_compress_type() != Kafka_NoCompress) + { + int addon = 0; + + if (kafka_compress_finish(this->config.get_compress_type(), + this->compress_env, &this->serialized, &addon) < 0) + { + return -1; + } + + batch_length += addon; + } + + std::string record_header; + + append_i64(record_header, 0); + append_i32(record_header, batch_length); + append_i32(record_header, 0); + append_i8(record_header, 2); //magic + + uint32_t crc_32 = 0; + size_t crc32_offset = record_header.size(); + + append_i32(record_header, crc_32); + append_i16(record_header, this->config.get_compress_type()); + append_i32(record_header, batch_cnt - 1); + append_i64(record_header, first_timestamp); + append_i64(record_header, max_timestamp); + append_i64(record_header, -1); //produce_id + append_i16(record_header, -1); + append_i32(record_header, -1); + append_i32(record_header, batch_cnt); + + KafkaBlock *header_block = new KafkaBlock; + + if (!header_block->set_block((void *)record_header.c_str(), + record_header.size())) + { + delete header_block; + return -1; + } + + char *crc_ptr = (char *)header_block->get_block() + crc32_offset; + + this->serialized.insert_list(header_block); + + crc_32 = crc32c(crc_32, (const void *)(crc_ptr + 4), + header_block->get_len() - crc32_offset - 4); + + this->serialized.block_insert_rewind(); + KafkaBlock *block; + while ((block = this->serialized.get_block_insert_next()) != NULL) + crc_32 = crc32c(crc_32, block->get_block(), block->get_len()); + + *(uint32_t *)crc_ptr = htonl(crc_32); + *(uint32_t *)recordset_size_ptr = htonl(batch_length + 4 + 8); + } + else + { + if (this->config.get_compress_type() != Kafka_NoCompress) + { + int addon = 0; + + if (kafka_compress_finish(this->config.get_compress_type(), + this->compress_env, &this->serialized, &addon) < 0) + { + return -1; + } + + batch_length += addon; + + int message_size = 4 + 1 + 1 + 4 + 4 + batch_length; + + if (this->message_version == 1) + message_size += 8; + + std::string wrap_header; + + append_i64(wrap_header, 0); + append_i32(wrap_header, message_size); + + int crc_32 = crc32(0, NULL, 0); + size_t crc32_offset = wrap_header.size(); + + append_i32(wrap_header, crc_32); + append_i8(wrap_header, this->message_version); + append_i8(wrap_header, this->config.get_compress_type()); + + if (this->message_version == 1) + append_i64(wrap_header, first_timestamp); + + append_bytes(wrap_header, ""); + append_i32(wrap_header, batch_length); + const char *crc_ptr = (const char *)wrap_header.c_str() + crc32_offset; + + crc_32 = crc32(crc_32, (Bytef *)(crc_ptr + 4), + wrap_header.size() - crc32_offset - 4); + + this->serialized.block_insert_rewind(); + KafkaBlock *block; + + while ((block = this->serialized.get_block_insert_next()) != NULL) + crc_32 = crc32(crc_32, (Bytef *)block->get_block(), block->get_len()); + + *(uint32_t *)crc_ptr = htonl(crc_32); + + KafkaBlock *wrap_block = new KafkaBlock; + + if (!wrap_block->set_block((void *)wrap_header.c_str(), + wrap_header.size())) + { + delete wrap_block; + return -1; + } + + this->serialized.insert_list(wrap_block); + *(uint32_t *)recordset_size_ptr = htonl(message_size + 8 + 4); + } + else + *(uint32_t *)recordset_size_ptr = htonl(batch_length); + } + + ++topic_cnt; + } + + append_i32(this->msgbuf, topic_cnt); + this->cur_size += this->msgbuf.size(); + this->stream->append_nocopy(this->msgbuf.c_str(), this->msgbuf.size()); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + + KafkaBlock *block = this->serialized.get_block_first(); + + while (block) + { + this->stream->append_nocopy((const char *)block->get_block(), + block->get_len()); + this->cur_size += block->get_len(); + block = this->serialized.get_block_next(); + } + + return this->stream->size(); +} + +int KafkaRequest::encode_fetch(struct iovec vectors[], int max) +{ + append_i32(this->msgbuf, -1); + append_i32(this->msgbuf, this->config.get_fetch_timeout()); + append_i32(this->msgbuf, this->config.get_fetch_min_bytes()); + + if (this->api_version >= 3) + append_i32(this->msgbuf, this->config.get_fetch_max_bytes()); + + //isolation_level + if (this->api_version >= 4) + append_i8(this->msgbuf, 0); + + if (this->api_version >= 7) + { + //sessionid + append_i32(this->msgbuf, 0); + //epoch + append_i32(this->msgbuf, -1); + } + + int topic_cnt_pos = this->msgbuf.size(); + + append_i32(this->msgbuf, 0); + + int topic_cnt = 0; + this->toppar_list.rewind(); + KafkaToppar *toppar; + + while ((toppar = this->toppar_list.get_next()) != NULL) + { + append_string(this->msgbuf, toppar->get_topic()); + append_i32(this->msgbuf, 1); + append_i32(this->msgbuf, toppar->get_partition()); + + //CurrentLeaderEpoch + if (this->api_version >= 9) + append_i32(this->msgbuf, -1); + + append_i64(this->msgbuf, toppar->get_offset()); + + //LogStartOffset + if (this->api_version >= 5) + append_i64(this->msgbuf, -1); + + append_i32(this->msgbuf, this->config.get_fetch_msg_max_bytes()); + ++topic_cnt; + } + + *(uint32_t *)(this->msgbuf.c_str() + topic_cnt_pos) = htonl(topic_cnt); + + //Length of the ForgottenTopics list + if (this->api_version >= 7) + append_i32(this->msgbuf, 0); + + //rackid + if (this->api_version >= 11) + { + if (this->config.get_rack_id()) + append_compact_string(this->msgbuf, this->config.get_rack_id()); + else + append_string(this->msgbuf, ""); + } + + this->cur_size = this->msgbuf.size(); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + return 1; +} + +int KafkaRequest::encode_metadata(struct iovec vectors[], int max) +{ + int topic_cnt_pos = this->msgbuf.size(); + + if (this->api_version >= 1) + append_i32(this->msgbuf, -1); + else + append_i32(this->msgbuf, 0); + + this->meta_list.rewind(); + KafkaMeta *meta; + int topic_cnt = 0; + + while ((meta = this->meta_list.get_next()) != NULL) + { + append_string(this->msgbuf, meta->get_topic()); + ++topic_cnt; + } + + if (this->api_version >= 4) + { + append_bool(this->msgbuf, + this->config.get_allow_auto_topic_creation()); + } + + *(uint32_t *)(this->msgbuf.c_str() + topic_cnt_pos) = htonl(topic_cnt); + this->cur_size = this->msgbuf.size(); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + return 1; +} + +int KafkaRequest::encode_findcoordinator(struct iovec vectors[], int max) +{ + append_string(this->msgbuf, this->cgroup.get_group()); + + //coordinator key type + if (this->api_version >= 1) + append_i8(this->msgbuf, 0); + + this->cur_size = this->msgbuf.size(); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + return 1; +} + +static std::string kafka_cgroup_gen_metadata(KafkaMetaList& meta_list) +{ + std::string metadata; + int meta_pos; + int meta_cnt = 0; + meta_list.rewind(); + KafkaMeta *meta; + + append_i16(metadata, 2); // version + meta_pos = metadata.size(); + append_i32(metadata, 0); + + while ((meta = meta_list.get_next()) != NULL) + { + append_string(metadata, meta->get_topic()); + meta_cnt++; + } + + *(uint32_t *)(metadata.c_str() + meta_pos) = htonl(meta_cnt); + + //UserData empty + append_bytes(metadata, ""); + + return metadata; +} + +int KafkaRequest::encode_joingroup(struct iovec vectors[], int max) +{ + append_string(this->msgbuf, this->cgroup.get_group()); + append_i32(this->msgbuf, this->config.get_session_timeout()); + + if (this->api_version >= 1) + append_i32(this->msgbuf, this->config.get_rebalance_timeout()); + + //member_id + append_string(this->msgbuf, this->cgroup.get_member_id()); + + //group_instance_id + if (this->api_version >= 5) + append_nullable_string(this->msgbuf, "", 0); + + append_string(this->msgbuf, this->cgroup.get_protocol_type()); + int protocol_pos = this->msgbuf.size(); + append_i32(this->msgbuf, 0); + + int protocol_cnt = 0; + struct list_head *pos; + kafka_group_protocol_t *group_protocol; + list_for_each(pos, this->cgroup.get_group_protocol()) + { + ++protocol_cnt; + group_protocol = list_entry(pos, kafka_group_protocol_t, list); + append_string(this->msgbuf, group_protocol->protocol_name); + append_bytes(this->msgbuf, + kafka_cgroup_gen_metadata(this->meta_list)); + } + + *(uint32_t *)(this->msgbuf.c_str() + protocol_pos) = htonl(protocol_cnt); + + this->cur_size = this->msgbuf.size(); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + return 1; +} + +std::string KafkaMessage::get_member_assignment(kafka_member_t *member) +{ + std::string assignment; + //version + append_i16(assignment, 2); + + size_t topic_cnt_pos = assignment.size(); + append_i32(assignment, 0); + + struct list_head *pos; + KafkaToppar *toppar; + int topic_cnt = 0; + + list_for_each(pos, &member->assigned_toppar_list) + { + toppar = list_entry(pos, KafkaToppar, list); + append_string(assignment, toppar->get_topic()); + append_i32(assignment, 1); + append_i32(assignment, toppar->get_partition()); + ++topic_cnt; + } + + //userdata + append_bytes(assignment, ""); + + *(uint32_t *)(assignment.c_str() + topic_cnt_pos) = htonl(topic_cnt); + + return assignment; +} + +int KafkaRequest::encode_syncgroup(struct iovec vectors[], int max) +{ + append_string(this->msgbuf, this->cgroup.get_group()); + append_i32(this->msgbuf, this->cgroup.get_generation_id()); + append_string(this->msgbuf, this->cgroup.get_member_id()); + + //group_instance_id + if (this->api_version >= 3) + append_nullable_string(this->msgbuf, "", 0); + + if (this->cgroup.is_leader()) + { + append_i32(this->msgbuf, this->cgroup.get_member_elements()); + for (int i = 0; i < this->cgroup.get_member_elements(); ++i) + { + kafka_member_t *member = this->cgroup.get_members()[i]; + append_string(this->msgbuf, member->member_id); + append_bytes(this->msgbuf, std::move(get_member_assignment(member))); + } + } + else + append_i32(this->msgbuf, 0); + + this->cur_size = this->msgbuf.size(); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + return 1; +} + +int KafkaRequest::encode_leavegroup(struct iovec vectors[], int max) +{ + append_string(this->msgbuf, this->cgroup.get_group()); + append_string(this->msgbuf, this->cgroup.get_member_id()); + + this->cur_size = this->msgbuf.size(); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + return 1; +} + +int KafkaRequest::encode_listoffset(struct iovec vectors[], int max) +{ + append_i32(this->msgbuf, -1); + + int topic_cnt = 0; + int topic_cnt_pos = this->msgbuf.size(); + append_i32(this->msgbuf, 0); + + struct list_head *pos; + KafkaToppar *toppar; + + list_for_each(pos, this->toppar_list.get_head()) + { + toppar = this->toppar_list.get_entry(pos); + append_string(this->msgbuf, toppar->get_topic()); + + append_i32(this->msgbuf, 1); + append_i32(this->msgbuf, toppar->get_partition()); + append_i64(this->msgbuf, toppar->get_offset_timestamp()); + + if (this->api_version == 0) + append_i32(this->msgbuf, 1); + + ++topic_cnt; + } + + *(uint32_t *)(this->msgbuf.c_str() + topic_cnt_pos) = htonl(topic_cnt); + this->cur_size = this->msgbuf.size(); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + return 1; +} + +int KafkaRequest::encode_offsetfetch(struct iovec vectors[], int max) +{ + append_string(this->msgbuf, this->cgroup.get_group()); + + int topic_cnt = 0; + int topic_cnt_pos = this->msgbuf.size(); + append_i32(this->msgbuf, 0); + + this->cgroup.assigned_toppar_rewind(); + KafkaToppar *toppar; + + while ((toppar = this->cgroup.get_assigned_toppar_next()) != NULL) + { + append_string(this->msgbuf, toppar->get_topic()); + append_i32(this->msgbuf, 1); + append_i32(this->msgbuf, toppar->get_partition()); + ++topic_cnt; + } + + *(uint32_t *)(this->msgbuf.c_str() + topic_cnt_pos) = htonl(topic_cnt); + + this->cur_size = this->msgbuf.size(); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + return 1; +} + +int KafkaRequest::encode_offsetcommit(struct iovec vectors[], int max) +{ + append_string(this->msgbuf, this->cgroup.get_group()); + + if (this->api_version >= 1) + { + append_i32(this->msgbuf, this->cgroup.get_generation_id()); + append_string(this->msgbuf, this->cgroup.get_member_id()); + } + + //GroupInstanceId + if (this->api_version >= 7) + append_nullable_string(this->msgbuf, "", 0); + + //RetentionTime + if (this->api_version >= 2 && this->api_version <= 4) + append_i64(this->msgbuf, -1); + + int toppar_cnt = 0; + int toppar_cnt_pos = this->msgbuf.size(); + append_i32(this->msgbuf, 0); + + this->toppar_list.rewind(); + KafkaToppar *toppar; + + while ((toppar = this->toppar_list.get_next()) != NULL) + { + append_string(this->msgbuf, toppar->get_topic()); + append_i32(this->msgbuf, 1); + append_i32(this->msgbuf, toppar->get_partition()); + append_i64(this->msgbuf, toppar->get_offset() + 1); + + if (this->api_version >= 6) + append_i32(this->msgbuf, -1); + + if (this->api_version == 1) + append_i64(this->msgbuf, -1); + + append_nullable_string(this->msgbuf, "", 0); + ++toppar_cnt; + } + + *(uint32_t *)(this->msgbuf.c_str() + toppar_cnt_pos) = htonl(toppar_cnt); + + this->cur_size = this->msgbuf.size(); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + return 1; +} + +int KafkaRequest::encode_heartbeat(struct iovec vectors[], int max) +{ + append_string(this->msgbuf, this->cgroup.get_group()); + append_i32(this->msgbuf, this->cgroup.get_generation_id()); + append_string(this->msgbuf, this->cgroup.get_member_id()); + + //group_instance_id + if (this->api_version >= 3) + append_nullable_string(this->msgbuf, "", 0); + + this->cur_size = this->msgbuf.size(); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + return 1; +} + +int KafkaRequest::encode_apiversions(struct iovec vectors[], int max) +{ + return 0; +} + +int KafkaRequest::encode_saslhandshake(struct iovec vectors[], int max) +{ + append_string(this->msgbuf, this->config.get_sasl_mech()); + + this->cur_size = this->msgbuf.size(); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + return 1; +} + +int KafkaRequest::encode_saslauthenticate(struct iovec vectors[], int max) +{ + append_bytes(this->msgbuf, this->sasl->buf, this->sasl->bsize); + + this->cur_size = this->msgbuf.size(); + + vectors[0].iov_base = (void *)this->msgbuf.c_str(); + vectors[0].iov_len = this->msgbuf.size(); + return 1; +} + +KafkaResponse::KafkaResponse() +{ + using namespace std::placeholders; + this->parse_func_map[Kafka_Metadata] = std::bind(&KafkaResponse::parse_metadata, this, _1, _2); + this->parse_func_map[Kafka_Produce] = std::bind(&KafkaResponse::parse_produce, this, _1, _2); + this->parse_func_map[Kafka_Fetch] = std::bind(&KafkaResponse::parse_fetch, this, _1, _2); + this->parse_func_map[Kafka_FindCoordinator] = std::bind(&KafkaResponse::parse_findcoordinator, this, _1, _2); + this->parse_func_map[Kafka_JoinGroup] = std::bind(&KafkaResponse::parse_joingroup, this, _1, _2); + this->parse_func_map[Kafka_SyncGroup] = std::bind(&KafkaResponse::parse_syncgroup, this, _1, _2); + this->parse_func_map[Kafka_Heartbeat] = std::bind(&KafkaResponse::parse_heartbeat, this, _1, _2); + this->parse_func_map[Kafka_OffsetFetch] = std::bind(&KafkaResponse::parse_offsetfetch, this, _1, _2); + this->parse_func_map[Kafka_OffsetCommit] = std::bind(&KafkaResponse::parse_offsetcommit, this, _1, _2); + this->parse_func_map[Kafka_ListOffsets] = std::bind(&KafkaResponse::parse_listoffset, this, _1, _2); + this->parse_func_map[Kafka_LeaveGroup] = std::bind(&KafkaResponse::parse_leavegroup, this, _1, _2); + this->parse_func_map[Kafka_ApiVersions] = std::bind(&KafkaResponse::parse_apiversions, this, _1, _2); + this->parse_func_map[Kafka_SaslHandshake] = std::bind(&KafkaResponse::parse_saslhandshake, this, _1, _2); + this->parse_func_map[Kafka_SaslAuthenticate] = std::bind(&KafkaResponse::parse_saslauthenticate, this, _1, _2); +} + +int KafkaResponse::parse_response() +{ + auto it = this->parse_func_map.find(this->api_type); + + if (it == this->parse_func_map.end()) + { + errno = EPROTO; + return -1; + } + + void *buf = this->parser->msgbuf; + size_t size = this->parser->message_size; + int32_t correlation_id; + + if (parse_i32(&buf, &size, &correlation_id) < 0) + return -1; + + this->correlation_id = correlation_id; + + int ret = it->second(&buf, &size); + + if (ret < 0) + return -1; + + if (size != 0) + { + errno = EBADMSG; + return -1; + } + + return ret; +} + +static int kafka_meta_parse_broker(void **buf, size_t *size, + int api_version, + KafkaBrokerList *broker_list) +{ + int32_t broker_cnt; + + CHECK_RET(parse_i32(buf, size, &broker_cnt)); + + if (broker_cnt < 0) + { + errno = EBADMSG; + return -1; + } + + for (int i = 0; i < broker_cnt; ++i) + { + KafkaBroker broker; + kafka_broker_t *ptr = broker.get_raw_ptr(); + + CHECK_RET(parse_i32(buf, size, &ptr->node_id)); + CHECK_RET(parse_string(buf, size, &ptr->host)); + CHECK_RET(parse_i32(buf, size, &ptr->port)); + + if (api_version >= 1) + CHECK_RET(parse_string(buf, size, &ptr->rack)); + + broker_list->rewind(); + KafkaBroker *last; + + while ((last = broker_list->get_next()) != NULL) + { + if (last->get_node_id() == broker.get_node_id()) + { + broker_list->del_cur(); + delete last; + break; + } + } + + broker_list->add_item(std::move(broker)); + } + + return 0; +} + +static bool kafka_broker_get_leader(int leader_id, KafkaBrokerList *broker_list, + kafka_broker_t *leader) +{ + KafkaBroker *bbroker; + + broker_list->rewind(); + while ((bbroker = broker_list->get_next()) != NULL) + { + if (bbroker->get_node_id() == leader_id) + { + kafka_broker_t *broker = bbroker->get_raw_ptr(); + + char *host = strdup(broker->host); + if (host) + { + char *rack = NULL; + + if (broker->rack) + rack = strdup(broker->rack); + + if (!broker->rack || rack) + { + kafka_broker_deinit(leader); + *leader = *broker; + leader->host = host; + leader->rack = rack; + return true; + } + + free(host); + } + + return false; + } + } + + errno = EBADMSG; + return false; +} + +static int kafka_meta_parse_partition(void **buf, size_t *size, + KafkaMeta *meta, + KafkaBrokerList *broker_list) +{ + int32_t leader_id; + int32_t replica_cnt, isr_cnt; + int32_t partition_cnt; + int32_t i, j; + + CHECK_RET(parse_i32(buf, size, &partition_cnt)); + + if (partition_cnt < 0) + { + errno = EBADMSG; + return -1; + } + + if (!meta->create_partitions(partition_cnt)) + return -1; + + kafka_partition_t **partition = meta->get_partitions(); + + for (i = 0; i < partition_cnt; ++i) + { + int16_t error; + int32_t index; + + if (parse_i16(buf, size, &error) < 0) + break; + + partition[i]->error = error; + + if (parse_i32(buf, size, &index) < 0) + break; + + partition[i]->partition_index = index; + + if (parse_i32(buf, size, &leader_id) < 0) + break; + + if (!kafka_broker_get_leader(leader_id, broker_list, &partition[i]->leader)) + break; + + if (parse_i32(buf, size, &replica_cnt) < 0) + break; + + if (!meta->create_replica_nodes(i, replica_cnt)) + break; + + for (j = 0; j < replica_cnt; ++j) + { + int32_t replica_node; + + if (parse_i32(buf, size, &replica_node) < 0) + break; + + partition[i]->replica_nodes[j] = replica_node; + } + + if (j != replica_cnt) + break; + + if (parse_i32(buf, size, &isr_cnt) < 0) + break; + + if (!meta->create_isr_nodes(i, isr_cnt)) + break; + + for (j = 0; j < isr_cnt; ++j) + { + int32_t isr_node; + + if (parse_i32(buf, size, &isr_node) < 0) + break; + + partition[i]->isr_nodes[j] = isr_node; + } + + if (j != isr_cnt) + break; + } + + if (i != partition_cnt) + return -1; + + return 0; +} + +static KafkaMeta *find_meta_by_name(const std::string& topic, KafkaMetaList *meta_list) +{ + meta_list->rewind(); + KafkaMeta *meta; + + while ((meta = meta_list->get_next()) != NULL) + { + if (meta->get_topic() == topic) + return meta; + } + + errno = EBADMSG; + return NULL; +} + +static int kafka_meta_parse_topic(void **buf, size_t *size, + int api_version, + KafkaMetaList *meta_list, + KafkaBrokerList *broker_list) +{ + KafkaMetaList lst; + int32_t topic_cnt; + + CHECK_RET(parse_i32(buf, size, &topic_cnt)); + for (int32_t topic_idx = 0; topic_idx < topic_cnt; ++topic_idx) + { + int16_t error; + CHECK_RET(parse_i16(buf, size, &error)); + + std::string topic_name; + CHECK_RET(parse_string(buf, size, topic_name)); + KafkaMeta *meta = find_meta_by_name(topic_name, meta_list); + if (!meta) + return -1; + + KafkaMeta new_mta; + new_mta.set_topic(topic_name); + + kafka_meta_t *ptr = new_mta.get_raw_ptr(); + ptr->error = error; + + if (api_version >= 1) + CHECK_RET(parse_i8(buf, size, &ptr->is_internal)); + + CHECK_RET(kafka_meta_parse_partition(buf, size, &new_mta, broker_list)); + + lst.add_item(std::move(new_mta)); + } + + *meta_list = std::move(lst); + return 0; +} + +int KafkaResponse::parse_metadata(void **buf, size_t *size) +{ + int32_t throttle_time, controller_id; + std::string cluster_id; + + if (this->api_version >= 3) + CHECK_RET(parse_i32(buf, size, &throttle_time)); + + CHECK_RET(kafka_meta_parse_broker(buf, size, this->api_version, + &this->broker_list)); + + if (this->api_version >= 2) + CHECK_RET(parse_string(buf, size, cluster_id)); + + if (this->api_version >= 1) + CHECK_RET(parse_i32(buf, size, &controller_id)); + + CHECK_RET(kafka_meta_parse_topic(buf, size, this->api_version, + &this->meta_list, &this->broker_list)); + + return 0; +} + +KafkaToppar *KafkaMessage::find_toppar_by_name(const std::string& topic, int partition, + struct list_head *toppar_list) +{ + KafkaToppar *toppar; + struct list_head *pos; + + list_for_each(pos, toppar_list) + { + toppar = list_entry(pos, KafkaToppar, list); + if (toppar->get_topic() == topic && toppar->get_partition() == partition) + return toppar; + } + + errno = EBADMSG; + return NULL; +} + +KafkaToppar *KafkaMessage::find_toppar_by_name(const std::string& topic, int partition, + KafkaTopparList *toppar_list) +{ + toppar_list->rewind(); + KafkaToppar *toppar; + + while ((toppar = toppar_list->get_next()) != NULL) + { + if (toppar->get_topic() == topic && + toppar->get_partition() == partition) + return toppar; + } + + errno = EBADMSG; + return NULL; +} + +int KafkaResponse::parse_produce(void **buf, size_t *size) +{ + int32_t topic_cnt; + std::string topic_name; + int32_t partition_cnt; + int32_t partition; + int64_t base_offset, log_append_time, log_start_offset; + int32_t throttle_time; + int produce_timeout = this->config.get_produce_timeout() * 2; + + CHECK_RET(parse_i32(buf, size, &topic_cnt)); + for (int32_t topic_idx = 0; topic_idx < topic_cnt; ++topic_idx) + { + CHECK_RET(parse_string(buf, size, topic_name)); + + CHECK_RET(parse_i32(buf, size, &partition_cnt)); + for (int32_t i = 0; i < partition_cnt; ++i) + { + CHECK_RET(parse_i32(buf, size, &partition)); + + KafkaToppar *toppar = find_toppar_by_name(topic_name, partition, + &this->toppar_list); + if (!toppar) + return -1; + + kafka_topic_partition_t *ptr = toppar->get_raw_ptr(); + + CHECK_RET(parse_i16(buf, size, &ptr->error)); + + log_append_time = -1; + CHECK_RET(parse_i64(buf, size, &base_offset)); + + if (this->api_version >= 2) + CHECK_RET(parse_i64(buf, size, &log_append_time)); + + if (this->api_version >=5) + CHECK_RET(parse_i64(buf, size, &log_start_offset)); + + struct list_head *pos; + KafkaRecord *record; + + if (ptr->error == KAFKA_REQUEST_TIMED_OUT) + { + toppar->restore_record_curpos(); + this->config.set_produce_timeout(produce_timeout); + continue; + } + + for (pos = toppar->get_record_startpos()->next; + pos != toppar->get_record_endpos(); pos = pos->next) + { + record = list_entry(pos, KafkaRecord, list); + record->set_status(ptr->error); + + if (ptr->error) + continue; + + record->set_offset(base_offset++); + + if (log_append_time != -1) + record->set_timestamp(log_append_time); + } + } + + } + + if (this->api_version >= 1) + CHECK_RET(parse_i32(buf, size, &throttle_time)); + + return 0; +} + +int KafkaResponse::parse_fetch(void **buf, size_t *size) +{ + int32_t throttle_time; + + this->toppar_list.rewind(); + KafkaToppar *toppar; + while ((toppar = this->toppar_list.get_next()) != NULL) + toppar->clear_records(); + + if (this->api_version >= 1) + CHECK_RET(parse_i32(buf, size, &throttle_time)); + + if (this->api_version >= 7) + { + int16_t error; + int32_t sessionid; + parse_i16(buf, size, &error); + parse_i32(buf, size, &sessionid); + } + + int32_t topic_cnt; + std::string topic_name; + int32_t partition_cnt; + int32_t partition; + int32_t aborted_cnt; + int32_t preferred_read_replica; + int64_t producer_id, first_offset; + int64_t high_watermark; + + CHECK_RET(parse_i32(buf, size, &topic_cnt)); + for (int32_t topic_idx = 0; topic_idx < topic_cnt; ++topic_idx) + { + CHECK_RET(parse_string(buf, size, topic_name)); + CHECK_RET(parse_i32(buf, size, &partition_cnt)); + + for (int i = 0; i < partition_cnt; ++i) + { + CHECK_RET(parse_i32(buf, size, &partition)); + + KafkaToppar *toppar = find_toppar_by_name(topic_name, partition, + &this->toppar_list); + + if (!toppar) + return -1; + + kafka_topic_partition_t *ptr = toppar->get_raw_ptr(); + + CHECK_RET(parse_i16(buf, size, &ptr->error)); + CHECK_RET(parse_i64(buf, size, &high_watermark)); + + if (high_watermark > ptr->low_watermark) + ptr->high_watermark = high_watermark; + + if (this->api_version >= 4) + { + CHECK_RET(parse_i64(buf, size, (int64_t *)&ptr->last_stable_offset)); + + if (this->api_version >= 5) + CHECK_RET(parse_i64(buf, size, (int64_t *)&ptr->log_start_offset)); + + CHECK_RET(parse_i32(buf, size, &aborted_cnt)); + for (int32_t j = 0; j < aborted_cnt; ++j) + { + CHECK_RET(parse_i64(buf, size, &producer_id)); + CHECK_RET(parse_i64(buf, size, &first_offset)); + } + } + + if (this->api_version >= 11) + { + CHECK_RET(parse_i32(buf, size, &preferred_read_replica)); + ptr->preferred_read_replica = preferred_read_replica; + } + + if (parse_records(buf, size, this->config.get_check_crcs(), + &this->uncompressed, toppar) != 0) + { + ptr->error = KAFKA_CORRUPT_MESSAGE; + return -1; + } + } + } + + return 0; +} + +int KafkaResponse::parse_listoffset(void **buf, size_t *size) +{ + int32_t throttle_time; + int32_t topic_cnt; + std::string topic_name; + int32_t partition_cnt; + int32_t partition; + int64_t offset_timestamp, offset; + int32_t offset_cnt; + + if (this->api_version >= 2) + CHECK_RET(parse_i32(buf, size, &throttle_time)); + + CHECK_RET(parse_i32(buf, size, &topic_cnt)); + for (int32_t topic_idx = 0; topic_idx < topic_cnt; ++topic_idx) + { + CHECK_RET(parse_string(buf, size, topic_name)); + CHECK_RET(parse_i32(buf, size, &partition_cnt)); + + for (int32_t i = 0; i < partition_cnt; ++i) + { + CHECK_RET(parse_i32(buf, size, &partition)); + + KafkaToppar *toppar = find_toppar_by_name(topic_name, partition, + &this->toppar_list); + if (!toppar) + return -1; + + kafka_topic_partition_t *ptr = toppar->get_raw_ptr(); + + CHECK_RET(parse_i16(buf, size, &ptr->error)); + + if (this->api_version == 1) + { + CHECK_RET(parse_i64(buf, size, &offset_timestamp)); + CHECK_RET(parse_i64(buf, size, &offset)); + if (ptr->offset_timestamp == -1) + ptr->high_watermark = offset; + else if (ptr->offset_timestamp == -2) + ptr->low_watermark = offset; + else + ptr->offset = offset; + } + else if (this->api_version == 0) + { + CHECK_RET(parse_i32(buf, size, &offset_cnt)); + for (int32_t j = 0; j < offset_cnt; ++j) + { + CHECK_RET(parse_i64(buf, size, &offset)); + ptr->offset = offset; + } + + ptr->low_watermark = 0; + } + } + } + + return 0; +} + +int KafkaResponse::parse_findcoordinator(void **buf, size_t *size) +{ + int32_t throttle_time; + + if (this->api_version >= 1) + CHECK_RET(parse_i32(buf, size, &throttle_time)); + + kafka_cgroup_t *cgroup = this->cgroup.get_raw_ptr(); + CHECK_RET(parse_i16(buf, size, &cgroup->error)); + + if (this->api_version >= 1) + CHECK_RET(parse_string(buf, size, &cgroup->error_msg)); + + CHECK_RET(parse_i32(buf, size, &cgroup->coordinator.node_id)); + CHECK_RET(parse_string(buf, size, &cgroup->coordinator.host)); + CHECK_RET(parse_i32(buf, size, &cgroup->coordinator.port)); + + return 0; +} + +static bool kafka_meta_find_or_add_topic(const std::string& topic_name, + KafkaMetaList *meta_list) +{ + meta_list->rewind(); + bool find = false; + KafkaMeta *meta; + + while ((meta = meta_list->get_next()) != NULL) + { + if (topic_name == meta->get_topic()) + { + find = true; + break; + } + } + + if (!find) + { + KafkaMeta tmp; + if (!tmp.set_topic(topic_name)) + return false; + + meta_list->add_item(tmp); + } + + return true; +} + +static int kafka_cgroup_parse_member(void **buf, size_t *size, + KafkaCgroup *cgroup, + KafkaMetaList *meta_list, + int api_version) +{ + int32_t member_cnt = 0; + CHECK_RET(parse_i32(buf, size, &member_cnt)); + + if (member_cnt < 0) + { + errno = EBADMSG; + return -1; + } + + if (!cgroup->create_members(member_cnt)) + return -1; + + kafka_member_t **member = cgroup->get_members(); + int32_t i; + + for (i = 0; i < member_cnt; ++i) + { + if (parse_string(buf, size, &member[i]->member_id) < 0) + break; + + if (api_version >= 5) + { + std::string group_instance_id; + parse_string(buf, size, group_instance_id); + } + + if (parse_bytes(buf, size, &member[i]->member_metadata, + &member[i]->member_metadata_len) < 0) + break; + + void *metadata = member[i]->member_metadata; + size_t metadata_len = member[i]->member_metadata_len; + int16_t version; + int32_t topic_cnt; + std::string topic_name; + int32_t j; + + if (parse_i16(&metadata, &metadata_len, &version) < 0) + break; + + if (parse_i32(&metadata, &metadata_len, &topic_cnt) < 0) + break; + + for (j = 0; j < topic_cnt; ++j) + { + if (parse_string(&metadata, &metadata_len, topic_name) < 0) + break; + + KafkaToppar * toppar = new KafkaToppar; + if (!toppar->set_topic(topic_name.c_str())) + { + delete toppar; + break; + } + + list_add_tail(toppar->get_list(), &member[i]->toppar_list); + + if (!kafka_meta_find_or_add_topic(topic_name, meta_list)) + return -1; + } + + if (j != topic_cnt) + break; + } + + if (i != member_cnt) + return -1; + + return 0; +} + +int KafkaResponse::parse_joingroup(void **buf, size_t *size) +{ + int32_t throttle_time; + + if (this->api_version >= 2) + CHECK_RET(parse_i32(buf, size, &throttle_time)); + + kafka_cgroup_t *cgroup = this->cgroup.get_raw_ptr(); + CHECK_RET(parse_i16(buf, size, &cgroup->error)); + CHECK_RET(parse_i32(buf, size, &cgroup->generation_id)); + + CHECK_RET(parse_string(buf, size, &cgroup->protocol_name)); + CHECK_RET(parse_string(buf, size, &cgroup->leader_id)); + CHECK_RET(parse_string(buf, size, &cgroup->member_id)); + CHECK_RET(kafka_cgroup_parse_member(buf, size, &this->cgroup, + &this->meta_list, + this->api_version)); + + return 0; +} + +int KafkaMessage::kafka_parse_member_assignment(const char *bbuf, size_t n, + KafkaCgroup *cgroup) +{ + void **buf = (void **)&bbuf; + size_t *size = &n; + int32_t topic_cnt; + int32_t partition_cnt; + int16_t version; + struct list_head *pos, *tmp; + std::string topic_name; + int32_t partition; + + list_for_each_safe(pos, tmp, cgroup->get_assigned_toppar_list()) + { + KafkaToppar *toppar = list_entry(pos, KafkaToppar, list); + list_del(pos); + delete toppar; + } + + CHECK_RET(parse_i16(buf, size, &version)); + CHECK_RET(parse_i32(buf, size, &topic_cnt)); + for (int32_t i = 0; i < topic_cnt; ++i) + { + CHECK_RET(parse_string(buf, size, topic_name)); + CHECK_RET(parse_i32(buf, size, &partition_cnt)); + + for (int32_t j = 0; j < partition_cnt; ++j) + { + CHECK_RET(parse_i32(buf, size, &partition)); + KafkaToppar *toppar = new KafkaToppar; + + if (!toppar->set_topic_partition(topic_name, partition)) + { + delete toppar; + return -1; + } + cgroup->add_assigned_toppar(toppar); + } + } + + return 0; +} + +int KafkaResponse::parse_syncgroup(void **buf, size_t *size) +{ + int32_t throttle_time; + int16_t error; + std::string member_assignment; + + if (this->api_version >= 1) + CHECK_RET(parse_i32(buf, size, &throttle_time)); + + CHECK_RET(parse_i16(buf, size, &error)); + this->cgroup.set_error(error); + + CHECK_RET(parse_bytes(buf, size, member_assignment)); + if (!member_assignment.empty()) + { + CHECK_RET(kafka_parse_member_assignment(member_assignment.c_str(), + member_assignment.size(), + &this->cgroup)); + } + + return 0; +} + +int KafkaResponse::parse_leavegroup(void **buf, size_t *size) +{ + int32_t throttle_time; + int16_t error; + + if (this->api_version >= 1) + CHECK_RET(parse_i32(buf, size, &throttle_time)); + + CHECK_RET(parse_i16(buf, size, &error)); + this->cgroup.set_error(error); + + return 0; +} + +int KafkaResponse::parse_offsetfetch(void **buf, size_t *size) +{ + int32_t topic_cnt; + std::string topic_name; + int32_t partition_cnt; + int32_t partition; + + CHECK_RET(parse_i32(buf, size, &topic_cnt)); + for (int32_t topic_idx = 0; topic_idx < topic_cnt; ++topic_idx) + { + CHECK_RET(parse_string(buf, size, topic_name)); + + CHECK_RET(parse_i32(buf, size, &partition_cnt)); + for (int32_t i = 0; i < partition_cnt; ++i) + { + CHECK_RET(parse_i32(buf, size, &partition)); + KafkaToppar *toppar = find_toppar_by_name(topic_name, partition, + this->cgroup.get_assigned_toppar_list()); + if (!toppar) + return -1; + + kafka_topic_partition_t *ptr = toppar->get_raw_ptr(); + + int64_t offset; + CHECK_RET(parse_i64(buf, size, &offset)); + if (this->config.get_offset_store() != KAFKA_OFFSET_ASSIGN) + ptr->offset = offset; + + CHECK_RET(parse_string(buf, size, &ptr->committed_metadata)); + CHECK_RET(parse_i16(buf, size, &ptr->error)); + } + } + + return 0; +} + +int KafkaResponse::parse_offsetcommit(void **buf, size_t *size) +{ + int32_t throttle_time; + int32_t topic_cnt; + std::string topic_name; + int32_t partition_cnt; + int32_t partition; + + if (this->api_version >= 3) + CHECK_RET(parse_i32(buf, size, &throttle_time)); + + CHECK_RET(parse_i32(buf, size, &topic_cnt)); + for (int32_t topic_idx = 0; topic_idx < topic_cnt; ++topic_idx) + { + CHECK_RET(parse_string(buf, size, topic_name)); + CHECK_RET(parse_i32(buf, size, &partition_cnt)); + + for (int32_t i = 0 ; i < partition_cnt; ++i) + { + CHECK_RET(parse_i32(buf, size, &partition)); + CHECK_RET(parse_i16(buf, size, &this->cgroup.get_raw_ptr()->error)); + } + } + + return 0; +} + +int KafkaResponse::parse_heartbeat(void **buf, size_t *size) +{ + int32_t throttle_time; + int16_t error; + + if (this->api_version >= 1) + CHECK_RET(parse_i32(buf, size, &throttle_time)); + + CHECK_RET(parse_i16(buf, size, &error)); + + this->cgroup.set_error(error); + return 0; +} + +static bool kafka_api_version_cmp(const kafka_api_version_t& api_ver1, + const kafka_api_version_t& api_ver2) +{ + return api_ver1.api_key < api_ver2.api_key; +} + +int KafkaResponse::parse_apiversions(void **buf, size_t *size) +{ + int16_t error; + int32_t api_cnt; + int32_t throttle_time; + + CHECK_RET(parse_i16(buf, size, &error)); + CHECK_RET(parse_i32(buf, size, &api_cnt)); + if (api_cnt < 0) + { + errno = EBADMSG; + return -1; + } + + void *p = malloc(api_cnt * sizeof(kafka_api_version_t)); + if (!p) + return -1; + + this->api->api = (kafka_api_version_t *)p; + this->api->elements = api_cnt; + + for (int32_t i = 0; i < api_cnt; ++i) + { + CHECK_RET(parse_i16(buf, size, &this->api->api[i].api_key)); + CHECK_RET(parse_i16(buf, size, &this->api->api[i].min_ver)); + CHECK_RET(parse_i16(buf, size, &this->api->api[i].max_ver)); + } + + if (this->api_version >= 1) + CHECK_RET(parse_i32(buf, size, &throttle_time)); + + std::sort(this->api->api, this->api->api + api_cnt, kafka_api_version_cmp); + this->api->features = kafka_get_features(this->api->api, api_cnt); + return 0; +} + +int KafkaResponse::parse_saslhandshake(void **buf, size_t *size) +{ + std::string mechanism; + int16_t error = 0; + int32_t cnt, i; + + CHECK_RET(parse_i16(buf, size, &error)); + if (error != 0) + { + this->broker.get_raw_ptr()->error = error; + return 1; + } + + CHECK_RET(parse_i32(buf, size, &cnt)); + + for (i = 0; i < cnt; i++) + { + CHECK_RET(parse_string(buf, size, mechanism)); + + if (strcasecmp(mechanism.c_str(), this->config.get_sasl_mech()) == 0) + break; + } + + if (i == cnt) + { + this->broker.get_raw_ptr()->error = KAFKA_SASL_AUTHENTICATION_FAILED; + return 1; + } + + for (i++; i < cnt; i++) + CHECK_RET(parse_string(buf, size, mechanism)); + + errno = 0; + if (!this->config.new_client(this->sasl)) + { + if (errno) + return -1; + + this->broker.get_raw_ptr()->error = KAFKA_SASL_AUTHENTICATION_FAILED; + return 1; + } + + return 0; +} + +int KafkaResponse::parse_saslauthenticate(void **buf, size_t *size) +{ + std::string error_message; + std::string auth_bytes; + int16_t error = 0; + + CHECK_RET(parse_i16(buf, size, &error)); + CHECK_RET(parse_string(buf, size, error_message)); + CHECK_RET(parse_bytes(buf, size, auth_bytes)); + + if (error != 0) + { + this->broker.get_raw_ptr()->error = error; + return 1; + } + + errno = 0; + if (this->config.get_raw_ptr()->recv(auth_bytes.c_str(), + auth_bytes.size(), + this->config.get_raw_ptr(), + this->sasl) != 0) + { + if (errno) + return -1; + + this->broker.get_raw_ptr()->error = KAFKA_SASL_AUTHENTICATION_FAILED; + return 1; + } + + return 0; +} + +int KafkaResponse::handle_sasl_continue() +{ + struct iovec iovecs[64]; + int ret; + int cnt = this->encode(iovecs, 64); + if ((unsigned int)cnt > 64) + { + if (cnt > 64) + errno = EOVERFLOW; + return -1; + } + + for (int i = 0; i < cnt; i++) + { + ret = this->feedback(iovecs[i].iov_base, iovecs[i].iov_len); + if (ret != (int)iovecs[i].iov_len) + { + if (ret >= 0) + errno = ENOBUFS; + return -1; + } + } + + return 0; +} + +int KafkaResponse::append(const void *buf, size_t *size) +{ + int ret = KafkaMessage::append(buf, size); + + if (ret <= 0) + return ret; + + ret = this->parse_response(); + if (ret != 0) + return ret; + + if (this->api_type == Kafka_SaslHandshake) + { + this->api_type = Kafka_SaslAuthenticate; + this->clear_buf(); + return this->handle_sasl_continue(); + } + else if (this->api_type == Kafka_SaslAuthenticate) + { + if (strncasecmp(this->config.get_sasl_mech(), "SCRAM", 5) == 0) + { + this->clear_buf(); + if (this->sasl->scram.state != + KAFKA_SASL_SCRAM_STATE_CLIENT_FINISHED) + return this->handle_sasl_continue(); + else + this->sasl->status = 1; + } + } + + return 1; +} + +} + diff --git a/src/protocol/KafkaMessage.h b/src/protocol/KafkaMessage.h new file mode 100644 index 0000000..d9ef982 --- /dev/null +++ b/src/protocol/KafkaMessage.h @@ -0,0 +1,266 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Wang Zhulei(wangzhulei@sogou-inc.com) +*/ + +#ifndef _KAFKAMESSAGE_H_ +#define _KAFKAMESSAGE_H_ + +#include +#include +#include +#include +#include +#include +#include "kafka_parser.h" +#include "ProtocolMessage.h" +#include "KafkaDataTypes.h" + +namespace protocol +{ + +class KafkaMessage : public ProtocolMessage +{ +public: + KafkaMessage(); + + virtual ~KafkaMessage(); + +protected: + virtual int encode(struct iovec vectors[], int max); + virtual int append(const void *buf, size_t *size); + +private: + int encode_head(); + +public: + KafkaMessage(KafkaMessage&& msg); + KafkaMessage& operator= (KafkaMessage&& msg); + +public: + int encode_message(int api_type, struct iovec vectors[], int max); + + void set_api_type(int api_type) { this->api_type = api_type; } + int get_api_type() const { return this->api_type; } + + void set_api_version(int ver) { this->api_version = ver; } + int get_api_version() const { return this->api_version; } + + void set_correlation_id(int id) { this->correlation_id = id; } + int get_correlation_id() const { return this->correlation_id; } + + void set_config(const KafkaConfig& conf) + { + this->config = conf; + } + const KafkaConfig *get_config() const { return &this->config; } + + void set_cgroup(const KafkaCgroup& cgroup) + { + this->cgroup = cgroup; + } + KafkaCgroup *get_cgroup() + { + return &this->cgroup; + } + + void set_broker(const KafkaBroker& broker) + { + this->broker = broker; + } + KafkaBroker *get_broker() + { + return &this->broker; + } + + void set_meta_list(const KafkaMetaList& meta_list) + { + this->meta_list = meta_list; + } + KafkaMetaList *get_meta_list() + { + return &this->meta_list; + } + + void set_toppar_list(const KafkaTopparList& toppar_list) + { + this->toppar_list = toppar_list; + } + KafkaTopparList *get_toppar_list() + { + return &this->toppar_list; + } + + void set_broker_list(const KafkaBrokerList& broker_list) + { + this->broker_list = broker_list; + } + KafkaBrokerList *get_broker_list() + { + return &this->broker_list; + } + + void set_sasl(kafka_sasl_t *sasl) + { + this->sasl = sasl; + } + + void set_api(kafka_api_t *api) + { + this->api = api; + } + + void duplicate(const KafkaMessage& msg) + { + this->config = msg.config; + this->cgroup = msg.cgroup; + this->broker = msg.broker; + this->meta_list = msg.meta_list; + this->broker_list = msg.broker_list; + this->toppar_list = msg.toppar_list; + this->sasl = msg.sasl; + this->api = msg.api; + } + + void clear_buf() + { + this->msgbuf.clear(); + this->headbuf.clear(); + kafka_parser_deinit(this->parser); + kafka_parser_init(this->parser); + this->cur_size = 0; + this->serialized = KafkaBuffer(); + this->uncompressed = KafkaBuffer(); + } + +protected: + static int parse_message_set(void **buf, size_t *size, + bool check_crcs, int msg_vers, + struct list_head *record_list, + KafkaBuffer *uncompressed, + KafkaToppar *toppar); + + static int parse_message_record(void **buf, size_t *size, + kafka_record_t *kafka_record); + + static int parse_record_batch(void **buf, size_t *size, + bool check_crcs, + struct list_head *record_list, + KafkaBuffer *uncompressed, + KafkaToppar *toppar); + + static int parse_records(void **buf, size_t *size, bool check_crcs, + KafkaBuffer *uncompressed, KafkaToppar *toppar); + + static std::string get_member_assignment(kafka_member_t *member); + + static KafkaToppar *find_toppar_by_name(const std::string& topic, int partition, + struct list_head *toppar_list); + + static KafkaToppar *find_toppar_by_name(const std::string& topic, int partition, + KafkaTopparList *toppar_list); + + static int kafka_parse_member_assignment(const char *bbuf, size_t n, + KafkaCgroup *cgroup); + +protected: + kafka_parser_t *parser; + using encode_func = std::function; + std::map encode_func_map; + + using parse_func = std::function; + std::map parse_func_map; + + class EncodeStream *stream; + std::string msgbuf; + std::string headbuf; + + KafkaConfig config; + KafkaCgroup cgroup; + KafkaBroker broker; + KafkaMetaList meta_list; + KafkaBrokerList broker_list; + KafkaTopparList toppar_list; + KafkaBuffer serialized; + KafkaBuffer uncompressed; + + int api_type; + int api_version; + int correlation_id; + int message_version; + + void *compress_env; + size_t cur_size; + + kafka_sasl_t *sasl; + kafka_api_t *api; +}; + +class KafkaRequest : public KafkaMessage +{ +public: + KafkaRequest(); + +private: + int encode_produce(struct iovec vectors[], int max); + int encode_fetch(struct iovec vectors[], int max); + int encode_metadata(struct iovec vectors[], int max); + int encode_findcoordinator(struct iovec vectors[], int max); + int encode_listoffset(struct iovec vectors[], int max); + int encode_joingroup(struct iovec vectors[], int max); + int encode_syncgroup(struct iovec vectors[], int max); + int encode_leavegroup(struct iovec vectors[], int max); + int encode_heartbeat(struct iovec vectors[], int max); + int encode_offsetcommit(struct iovec vectors[], int max); + int encode_offsetfetch(struct iovec vectors[], int max); + int encode_apiversions(struct iovec vectors[], int max); + int encode_saslhandshake(struct iovec vectors[], int max); + int encode_saslauthenticate(struct iovec vectors[], int max); +}; + +class KafkaResponse : public KafkaRequest +{ +public: + KafkaResponse(); + + int parse_response(); + +protected: + virtual int append(const void *buf, size_t *size); + +private: + int parse_produce(void **buf, size_t *size); + int parse_fetch(void **buf, size_t *size); + int parse_metadata(void **buf, size_t *size); + int parse_findcoordinator(void **buf, size_t *size); + int parse_joingroup(void **buf, size_t *size); + int parse_syncgroup(void **buf, size_t *size); + int parse_leavegroup(void **buf, size_t *size); + int parse_listoffset(void **buf, size_t *size); + int parse_offsetcommit(void **buf, size_t *size); + int parse_offsetfetch(void **buf, size_t *size); + int parse_heartbeat(void **buf, size_t *size); + int parse_apiversions(void **buf, size_t *size); + int parse_saslhandshake(void **buf, size_t *size); + int parse_saslauthenticate(void **buf, size_t *size); + + int handle_sasl_continue(); +}; + +} + +#endif + diff --git a/src/protocol/KafkaResult.cc b/src/protocol/KafkaResult.cc new file mode 100644 index 0000000..36a266a --- /dev/null +++ b/src/protocol/KafkaResult.cc @@ -0,0 +1,118 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhulei (wangzhulei@sogou-inc.com) +*/ + +#include "KafkaResult.h" + +namespace protocol +{ + +enum +{ + KAFKA_STATUS_GET_RESULT, + KAFKA_STATUS_END, +}; + +KafkaResult::KafkaResult() +{ + this->resp_vec = NULL; + this->resp_num = 0; +} + +KafkaResult& KafkaResult::operator= (KafkaResult&& move) +{ + if (this != &move) + { + delete []this->resp_vec; + + this->resp_vec = move.resp_vec; + move.resp_vec = NULL; + + this->resp_num = move.resp_num; + move.resp_num = 0; + } + + return *this; +} + +KafkaResult::KafkaResult(KafkaResult&& move) +{ + this->resp_vec = move.resp_vec; + move.resp_vec = NULL; + + this->resp_num = move.resp_num; + move.resp_num = 0; +} + +void KafkaResult::create(size_t n) +{ + delete []this->resp_vec; + this->resp_vec = new KafkaResponse[n]; + this->resp_num = n; +} + +void KafkaResult::set_resp(KafkaResponse&& resp, size_t i) +{ + assert(i < this->resp_num); + this->resp_vec[i] = std::move(resp); +} + +void KafkaResult::fetch_toppars(std::vector& toppars) +{ + toppars.clear(); + + KafkaToppar *toppar = NULL; + for (size_t i = 0; i < this->resp_num; ++i) + { + this->resp_vec[i].get_toppar_list()->rewind(); + + while ((toppar = this->resp_vec[i].get_toppar_list()->get_next()) != NULL) + toppars.push_back(toppar); + } +} + +void KafkaResult::fetch_records(std::vector>& records) +{ + records.clear(); + + KafkaToppar *toppar = NULL; + KafkaRecord *record = NULL; + + for (size_t i = 0; i < this->resp_num; ++i) + { + if (this->resp_vec[i].get_api_type() != Kafka_Produce && + this->resp_vec[i].get_api_type() != Kafka_Fetch) + continue; + + this->resp_vec[i].get_toppar_list()->rewind(); + + while ((toppar = this->resp_vec[i].get_toppar_list()->get_next()) != NULL) + { + std::vector tmp; + toppar->record_rewind(); + + while ((record = toppar->get_record_next()) != NULL) + tmp.push_back(record); + + if (!tmp.empty()) + records.emplace_back(std::move(tmp)); + } + } +} + +} + diff --git a/src/protocol/KafkaResult.h b/src/protocol/KafkaResult.h new file mode 100644 index 0000000..4eac839 --- /dev/null +++ b/src/protocol/KafkaResult.h @@ -0,0 +1,65 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhulei (wangzhulei@sogou-inc.com) +*/ + +#ifndef _KAFKARESULT_H_ +#define _KAFKARESULT_H_ + +#include +#include +#include +#include "KafkaMessage.h" +#include "KafkaDataTypes.h" + +namespace protocol +{ + +class KafkaResult +{ +public: + // for offsetcommit + void fetch_toppars(std::vector& toppars); + + // for produce, fetch + void fetch_records(std::vector>& records); + +public: + void create(size_t n); + + void set_resp(KafkaResponse&& resp, size_t i); + +public: + KafkaResult(); + + virtual ~KafkaResult() + { + delete []this->resp_vec; + } + + KafkaResult& operator= (KafkaResult&& move); + + KafkaResult(KafkaResult&& move); + +private: + KafkaResponse *resp_vec; + size_t resp_num; +}; + +} + +#endif + diff --git a/src/protocol/MySQLMessage.cc b/src/protocol/MySQLMessage.cc new file mode 100644 index 0000000..2dd1a51 --- /dev/null +++ b/src/protocol/MySQLMessage.cc @@ -0,0 +1,689 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "SSLWrapper.h" +#include "mysql_byteorder.h" +#include "mysql_types.h" +#include "MySQLResult.h" +#include "MySQLMessage.h" + +namespace protocol +{ + +#define MYSQL_PAYLOAD_MAX ((1 << 24) - 1) + +#define MYSQL_NATIVE_PASSWORD "mysql_native_password" +#define CACHING_SHA2_PASSWORD "caching_sha2_password" +#define MYSQL_CLEAR_PASSWORD "mysql_clear_password" + +MySQLMessage::~MySQLMessage() +{ + if (parser_) + { + mysql_parser_deinit(parser_); + mysql_stream_deinit(stream_); + delete parser_; + delete stream_; + } +} + +MySQLMessage::MySQLMessage(MySQLMessage&& move) : + ProtocolMessage(std::move(move)) +{ + parser_ = move.parser_; + stream_ = move.stream_; + seqid_ = move.seqid_; + cur_size_ = move.cur_size_; + + move.parser_ = NULL; + move.stream_ = NULL; + move.seqid_ = 0; + move.cur_size_ = 0; +} + +MySQLMessage& MySQLMessage::operator= (MySQLMessage&& move) +{ + if (this != &move) + { + *(ProtocolMessage *)this = std::move(move); + + if (parser_) + { + mysql_parser_deinit(parser_); + mysql_stream_deinit(stream_); + delete parser_; + delete stream_; + } + + parser_ = move.parser_; + stream_ = move.stream_; + seqid_ = move.seqid_; + cur_size_ = move.cur_size_; + + move.parser_ = NULL; + move.stream_ = NULL; + move.seqid_ = 0; + move.cur_size_ = 0; + } + + return *this; +} + +int MySQLMessage::append(const void *buf, size_t *size) +{ + const void *stream_buf; + size_t stream_len; + size_t nleft = *size; + size_t n; + int ret; + + cur_size_ += *size; + if (cur_size_ > this->size_limit) + { + errno = EMSGSIZE; + return -1; + } + + while (nleft > 0) + { + n = nleft; + ret = mysql_stream_write(buf, &n, stream_); + if (ret > 0) + { + seqid_ = mysql_stream_get_seq(stream_); + mysql_stream_get_buf(&stream_buf, &stream_len, stream_); + ret = decode_packet((const unsigned char *)stream_buf, stream_len); + if (ret == -2) + errno = EBADMSG; + } + + if (ret < 0) + return -1; + + nleft -= n; + buf = (const char *)buf + n; + } + + return ret; +} + +int MySQLMessage::encode(struct iovec vectors[], int max) +{ + const unsigned char *p = (unsigned char *)buf_.c_str(); + size_t nleft = buf_.size(); + uint8_t seqid_start = seqid_; + uint8_t seqid = seqid_; + unsigned char *head; + uint32_t length; + int i = 0; + + do + { + length = (nleft >= MYSQL_PAYLOAD_MAX ? MYSQL_PAYLOAD_MAX + : (uint32_t)nleft); + head = heads_[seqid]; + int3store(head, length); + head[3] = seqid++; + vectors[i].iov_base = head; + vectors[i].iov_len = 4; + i++; + vectors[i].iov_base = const_cast(p); + vectors[i].iov_len = length; + i++; + + if (i > max)//overflow + break; + + if (nleft < MYSQL_PAYLOAD_MAX) + return i; + + nleft -= MYSQL_PAYLOAD_MAX; + p += length; + } while (seqid != seqid_start); + + errno = EOVERFLOW; + return -1; +} + +void MySQLRequest::set_query(const char *query, size_t length) +{ + set_command(MYSQL_COM_QUERY); + buf_.resize(length + 1); + char *buffer = const_cast(buf_.c_str()); + + buffer[0] = MYSQL_COM_QUERY; + if (length > 0) + memcpy(buffer + 1, query, length); +} + +std::string MySQLRequest::get_query() const +{ + size_t len = buf_.size(); + if (len <= 1 || buf_[0] != MYSQL_COM_QUERY) + return ""; + + return std::string(buf_.c_str() + 1); +} + +#define MYSQL_CAPFLAG_CLIENT_SSL 0x00000800 +#define MYSQL_CAPFLAG_CLIENT_PROTOCOL_41 0x00000200 +#define MYSQL_CAPFLAG_CLIENT_SECURE_CONNECTION 0x00008000 +#define MYSQL_CAPFLAG_CLIENT_CONNECT_WITH_DB 0x00000008 +#define MYSQL_CAPFLAG_CLIENT_MULTI_STATEMENTS 0x00010000 +#define MYSQL_CAPFLAG_CLIENT_MULTI_RESULTS 0x00020000 +#define MYSQL_CAPFLAG_CLIENT_PS_MULTI_RESULTS 0x00040000 +#define MYSQL_CAPFLAG_CLIENT_PLUGIN_AUTH 0x00080000 +#define MYSQL_CAPFLAG_CLIENT_LOCAL_FILES 0x00000080 + +int MySQLHandshakeResponse::encode(struct iovec vectors[], int max) +{ + const char empty[10] = {0}; + uint16_t cap_flags_lower = capability_flags_ & 0xffffffff; + uint16_t cap_flags_upper = capability_flags_ >> 16; + + buf_.clear(); + buf_.append((const char *)&protocol_version_, 1); + buf_.append(server_version_.c_str(), server_version_.size() + 1); + buf_.append((const char *)&connection_id_, 4); + buf_.append((const char *)auth_plugin_data_, 8); + buf_.append(empty, 1); + buf_.append((const char *)&cap_flags_lower, 2); + buf_.append((const char *)&character_set_, 1); + buf_.append((const char *)&status_flags_, 2); + buf_.append((const char *)&cap_flags_upper, 2); + buf_.push_back(21); + buf_.append(empty, 10); + buf_.append((const char *)auth_plugin_data_ + 8, 12); + buf_.push_back(0); + if (capability_flags_ & MYSQL_CAPFLAG_CLIENT_PLUGIN_AUTH) + buf_.append(MYSQL_NATIVE_PASSWORD, strlen(MYSQL_NATIVE_PASSWORD) + 1); + + return MySQLMessage::encode(vectors, max); +} + +int MySQLHandshakeResponse::decode_packet(const unsigned char *buf, size_t buflen) +{ + const unsigned char *end = buf + buflen; + const unsigned char *pos; + uint16_t cap_flags_lower; + uint16_t cap_flags_upper; + + if (buflen == 0) + return -2; + + protocol_version_ = *buf; + if (protocol_version_ == 255) + { + if (buflen >= 4) + { + const_cast(buf)[3] = '#'; + if (mysql_parser_parse(buf, buflen, parser_) == 1) + { + disallowed_ = true; + return 1; + } + } + + errno = EBADMSG; + return -1; + } + + pos = ++buf; + while (pos < end && *pos) + pos++; + + if (pos >= end || end - pos < 45) + return -2; + + server_version_.assign((const char *)buf, pos - buf); + buf = pos + 1; + + connection_id_ = uint4korr(buf); + buf += 4; + memcpy(auth_plugin_data_, buf, 8); + buf += 9; + cap_flags_lower = uint2korr(buf); + buf += 2; + character_set_ = *buf++; + status_flags_ = uint2korr(buf); + buf += 2; + cap_flags_upper = uint2korr(buf); + buf += 2; + capability_flags_ = (cap_flags_upper << 16U) + cap_flags_lower; + auth_plugin_data_len_ = *buf++; + // 10 bytes reserved. All 0s. + buf += 10; + // auth_plugin_data always 20 bytes + if (auth_plugin_data_len_ > 21) + return -2; + + memcpy(auth_plugin_data_ + 8, buf, 12); + buf += 13; + if (capability_flags_ & MYSQL_CAPFLAG_CLIENT_PLUGIN_AUTH) + { + if (buf == end || *(end - 1) != '\0') + return -2; + + auth_plugin_name_.assign((const char *)buf, end - 1 - buf); + } + + return 1; +} + +static std::string __native_password_encrypt(const std::string& password, + unsigned char seed[20]) +{ + unsigned char buf1[20]; + unsigned char buf2[40]; + int i; + + // SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) ) + SHA1((unsigned char *)password.c_str(), password.size(), buf1); + SHA1(buf1, 20, buf2 + 20); + memcpy(buf2, seed, 20); + SHA1(buf2, 40, buf2); + for (i = 0; i < 20; i++) + buf1[i] ^= buf2[i]; + + return std::string((const char *)buf1, 20); +} + +static std::string __caching_sha2_password_encrypt(const std::string& password, + unsigned char seed[20]) +{ + unsigned char buf1[32]; + unsigned char buf2[52]; + int i; + + // SHA256( password ) ^ SHA256( SHA256( SHA256( password ) ) + seed) + SHA256((unsigned char *)password.c_str(), password.size(), buf1); + SHA256(buf1, 32, buf2); + memcpy(buf2 + 32, seed, 20); + SHA256(buf2, 52, buf2); + for (i = 0; i < 32; i++) + buf1[i] ^= buf2[i]; + + return std::string((const char *)buf1, 32); +} + +int MySQLSSLRequest::encode(struct iovec vectors[], int max) +{ + unsigned char header[32] = {0}; + unsigned char *pos = header; + int ret; + + int4store(pos, MYSQL_CAPFLAG_CLIENT_SSL | + MYSQL_CAPFLAG_CLIENT_PROTOCOL_41 | + MYSQL_CAPFLAG_CLIENT_SECURE_CONNECTION | + MYSQL_CAPFLAG_CLIENT_CONNECT_WITH_DB | + MYSQL_CAPFLAG_CLIENT_MULTI_RESULTS| + MYSQL_CAPFLAG_CLIENT_LOCAL_FILES | + MYSQL_CAPFLAG_CLIENT_MULTI_STATEMENTS | + MYSQL_CAPFLAG_CLIENT_PS_MULTI_RESULTS | + MYSQL_CAPFLAG_CLIENT_PLUGIN_AUTH); + pos += 4; + int4store(pos, 0); + pos += 4; + *pos = (uint8_t)character_set_; + + buf_.clear(); + buf_.append((char *)header, 32); + ret = MySQLMessage::encode(vectors, max); + if (ret >= 0) + { + max -= ret; + if (max >= 8) /* Indeed SSL handshaker needs only 1 vector. */ + { + max = ssl_handshaker_.encode(vectors + ret, max); + if (max >= 0) + return max + ret; + } + else + errno = EOVERFLOW; + } + + return -1; +} + +int MySQLAuthRequest::encode(struct iovec vectors[], int max) +{ + unsigned char header[32] = {0}; + unsigned char *pos = header; + std::string str; + + int4store(pos, MYSQL_CAPFLAG_CLIENT_PROTOCOL_41 | + MYSQL_CAPFLAG_CLIENT_SECURE_CONNECTION | + MYSQL_CAPFLAG_CLIENT_CONNECT_WITH_DB | + MYSQL_CAPFLAG_CLIENT_MULTI_RESULTS| + MYSQL_CAPFLAG_CLIENT_LOCAL_FILES | + MYSQL_CAPFLAG_CLIENT_MULTI_STATEMENTS | + MYSQL_CAPFLAG_CLIENT_PS_MULTI_RESULTS | + MYSQL_CAPFLAG_CLIENT_PLUGIN_AUTH); + pos += 4; + int4store(pos, 0); + pos += 4; + *pos = (uint8_t)character_set_; + + if (password_.empty()) + str.push_back(0); + else if (auth_plugin_name_ == CACHING_SHA2_PASSWORD) + { + str.push_back(32); + str += __caching_sha2_password_encrypt(password_, seed_); + } + else + { + str.push_back(20); + str += __native_password_encrypt(password_, seed_); + } + + buf_.clear(); + buf_.append((char *)header, 32); + buf_.append(username_.c_str(), username_.size() + 1); + buf_.append(str); + buf_.append(db_.c_str(), db_.size() + 1); + if (auth_plugin_name_.size() != 0) + buf_.append(auth_plugin_name_.c_str(), auth_plugin_name_.size() + 1); + + return MySQLMessage::encode(vectors, max); +} + +int MySQLAuthRequest::decode_packet(const unsigned char *buf, size_t buflen) +{ + const unsigned char *end = buf + buflen; + const unsigned char *pos; + + if (buflen < 32) + return -2; + + uint32_t flags = uint4korr(buf); + + if (!(flags & MYSQL_CAPFLAG_CLIENT_PROTOCOL_41)) + return -2; + + buf += 8; + character_set_ = *buf++; + buf += 23; + + pos = buf; + while (pos < end && *pos) + pos++; + + if (pos >= end) + return -2; + + username_.assign((const char *)buf, pos - buf); + buf = pos + 1; + + return 1; +} + +int MySQLAuthResponse::decode_packet(const unsigned char *buf, size_t buflen) +{ + const unsigned char *end = buf + buflen; + const unsigned char *pos; + const unsigned char *str; + unsigned long long len; + + if (end == buf) + return -2; + + switch (*buf) + { + case 0x00: + case 0xff: + return MySQLResponse::decode_packet(buf, buflen); + + case 0xfe: + pos = ++buf; + while (pos < end && *pos) + pos++; + + if (pos >= end) + return -2; + + auth_plugin_name_.assign((const char *)buf, pos - buf); + buf = pos + 1; + if (buf == end || *(end - 1) != '\0') + return -2; + + if (end - 1 - buf != 20) + return -2; + + memcpy(seed_, buf, 20); + return 1; + + default: + pos = buf; + if (decode_string(&str, &len, &pos, end) > 0 && len == 1) + { + if (*str == 0x03) + { + if (end > pos) + return MySQLResponse::decode_packet(pos, end - pos); + else + return 0; + } + else if (*str == 0x04) + { + continue_ = true; + return 1; + } + } + + return -2; + } +} + +int MySQLAuthSwitchRequest::encode(struct iovec vectors[], int max) +{ + if (password_.empty()) + { + buf_ = "\0"; + } + else if (auth_plugin_name_ == MYSQL_NATIVE_PASSWORD) + { + buf_ = __native_password_encrypt(password_, seed_); + } + else if (auth_plugin_name_ == CACHING_SHA2_PASSWORD) + { + buf_ = __caching_sha2_password_encrypt(password_, seed_); + } + else if (auth_plugin_name_ == MYSQL_CLEAR_PASSWORD) + { + buf_ = password_; + buf_.push_back('\0'); + } + else + { + errno = EINVAL; + return -1; + } + + return MySQLMessage::encode(vectors, max); +} + +int MySQLPublicKeyResponse::decode_packet(const unsigned char *buf, + size_t buflen) +{ + if (buflen == 0 || *buf != 0x01) + return -2; + + if (buflen == 1) + return 0; + + public_key_.assign((const char *)buf + 1, buflen - 1); + return 1; +} + +int MySQLPublicKeyResponse::encode(struct iovec vectors[], int max) +{ + buf_.clear(); + buf_.push_back(0x01); + buf_ += public_key_; + return MySQLMessage::encode(vectors, max); +} + +int MySQLRSAAuthRequest::rsa_encrypt(void *ctx) +{ + EVP_PKEY_CTX *pkey_ctx = (EVP_PKEY_CTX *)ctx; + unsigned char out[256]; + size_t outlen = 256; + std::string pass; + unsigned char *p; + size_t i; + + if (EVP_PKEY_encrypt_init(pkey_ctx) > 0 && + EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_OAEP_PADDING) > 0) + { + pass.reserve(password_.size() + 1); + p = (unsigned char *)pass.c_str(); + for (i = 0; i <= password_.size(); i++) + p[i] = (unsigned char)password_[i] ^ seed_[i % 20]; + + if (EVP_PKEY_encrypt(pkey_ctx, out, &outlen, p, i) > 0) + { + buf_.assign((char *)out, 256); + return 0; + } + } + + return -1; +} + +int MySQLRSAAuthRequest::encode(struct iovec vectors[], int max) +{ + BIO *bio; + EVP_PKEY *pkey; + EVP_PKEY_CTX *pkey_ctx; + int ret = -1; + + bio = BIO_new_mem_buf(public_key_.c_str(), public_key_.size()); + if (bio) + { + pkey = PEM_read_bio_PUBKEY(bio, NULL, NULL, NULL); + if (pkey) + { + pkey_ctx = EVP_PKEY_CTX_new(pkey, NULL); + if (pkey_ctx) + { + ret = rsa_encrypt(pkey_ctx); + EVP_PKEY_CTX_free(pkey_ctx); + } + + EVP_PKEY_free(pkey); + } + + BIO_free(bio); + } + + if (ret < 0) + return ret; + + return MySQLMessage::encode(vectors, max); +} + +void MySQLResponse::set_ok_packet() +{ + uint16_t zero16 = 0; + buf_.clear(); + buf_.push_back(0x00); + buf_.append((const char *)&zero16, 2); + buf_.append((const char *)&zero16, 2); + buf_.append((const char *)&zero16, 2); +} + +int MySQLResponse::decode_packet(const unsigned char *buf, size_t buflen) +{ + return mysql_parser_parse(buf, buflen, parser_); +} + +unsigned long long MySQLResponse::get_affected_rows() const +{ + unsigned long long affected_rows = 0; + MySQLResultCursor cursor(this); + + do { + affected_rows += cursor.get_affected_rows(); + } while (cursor.next_result_set()); + + return affected_rows; +} + +// return array api +unsigned long long MySQLResponse::get_last_insert_id() const +{ + unsigned long long insert_id = 0; + MySQLResultCursor cursor(this); + + do { + if (cursor.get_insert_id()) + insert_id = cursor.get_insert_id(); + } while (cursor.next_result_set()); + + return insert_id; +} + +int MySQLResponse::get_warnings() const +{ + int warning_count = 0; + MySQLResultCursor cursor(this); + + do { + warning_count += cursor.get_warnings(); + } while (cursor.next_result_set()); + + return warning_count; +} + +std::string MySQLResponse::get_info() const +{ + std::string info; + MySQLResultCursor cursor(this); + + do { + if (info.length() > 0) + info += " "; + info += cursor.get_info(); + } while (cursor.next_result_set()); + + return info; +} + +bool MySQLResponse::is_ok_packet() const +{ + return parser_->packet_type == MYSQL_PACKET_OK; +} + +bool MySQLResponse::is_error_packet() const +{ + return parser_->packet_type == MYSQL_PACKET_ERROR; +} + +} + diff --git a/src/protocol/MySQLMessage.h b/src/protocol/MySQLMessage.h new file mode 100644 index 0000000..866209e --- /dev/null +++ b/src/protocol/MySQLMessage.h @@ -0,0 +1,122 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _MYSQLMESSAGE_H_ +#define _MYSQLMESSAGE_H_ + +#include +#include +#include "ProtocolMessage.h" +#include "mysql_stream.h" +#include "mysql_parser.h" + +/** + * @file MySQLMessage.h + * @brief MySQL Protocol Interface + */ + +namespace protocol +{ + +class MySQLMessage : public ProtocolMessage +{ +public: + mysql_parser_t *get_parser() const; + int get_seqid() const; + void set_seqid(int seqid); + int get_command() const; + +protected: + virtual int append(const void *buf, size_t *size); + virtual int encode(struct iovec vectors[], int max); + virtual int decode_packet(const unsigned char *buf, size_t buflen) { return 1; } + + void set_command(int cmd) const; + + //append + mysql_stream_t *stream_; + mysql_parser_t *parser_; + + //encode + unsigned char heads_[256][4]; + uint8_t seqid_; + std::string buf_; + size_t cur_size_; + +public: + MySQLMessage(); + virtual ~MySQLMessage(); + //move constructor + MySQLMessage(MySQLMessage&& move); + //move operator + MySQLMessage& operator= (MySQLMessage&& move); +}; + +class MySQLRequest : public MySQLMessage +{ +public: + void set_query(const char *query); + void set_query(const std::string& query); + void set_query(const char *query, size_t length); + + std::string get_query() const; + bool query_is_unset() const; + +public: + MySQLRequest() = default; + //move constructor + MySQLRequest(MySQLRequest&& move) = default; + //move operator + MySQLRequest& operator= (MySQLRequest&& move) = default; +}; + +class MySQLResponse : public MySQLMessage +{ +public: + bool is_ok_packet() const; + bool is_error_packet() const; + int get_packet_type() const; + + unsigned long long get_affected_rows() const; + unsigned long long get_last_insert_id() const; + int get_warnings() const; + int get_error_code() const; + std::string get_error_msg() const; + std::string get_sql_state() const; + std::string get_info() const; + + void set_ok_packet(); + +public: + MySQLResponse() = default; + //move constructor + MySQLResponse(MySQLResponse&& move) = default; + //move operator + MySQLResponse& operator= (MySQLResponse&& move) = default; + +protected: + virtual int decode_packet(const unsigned char *buf, size_t buflen); +}; + +} + +//impl. not for user +#include "MySQLMessage.inl" + +#endif + diff --git a/src/protocol/MySQLMessage.inl b/src/protocol/MySQLMessage.inl new file mode 100644 index 0000000..b4caf49 --- /dev/null +++ b/src/protocol/MySQLMessage.inl @@ -0,0 +1,417 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include "SSLWrapper.h" + +namespace protocol +{ + +class MySQLHandshakeRequest : public MySQLRequest +{ +private: + virtual int encode(struct iovec vectors[], int max) { return 0; } +}; + +class MySQLHandshakeResponse : public MySQLResponse +{ +public: + std::string get_server_version() const { return server_version_; } + std::string get_auth_plugin_name() const { return auth_plugin_name_; } + + void get_seed(unsigned char seed[20]) const + { + memcpy(seed, auth_plugin_data_, 20); + } + + virtual int encode(struct iovec vectors[], int max); + + void server_set(uint8_t protocol_version, const std::string server_version, + uint32_t connection_id, const unsigned char seed[20], + uint32_t capability_flags, uint8_t character_set, + uint16_t status_flags) + { + protocol_version_ = protocol_version; + server_version_ = server_version; + connection_id_ = connection_id; + memcpy(auth_plugin_data_, seed, 20); + capability_flags_ = capability_flags; + character_set_ = character_set; + status_flags_ = status_flags; + } + + bool host_disallowed() const { return disallowed_; } + uint32_t get_capability_flags() const { return capability_flags_; } + uint16_t get_status_flags() const { return status_flags_; } + +private: + virtual int decode_packet(const unsigned char *buf, size_t buflen); + + std::string server_version_; + std::string auth_plugin_name_; + unsigned char auth_plugin_data_[20]; + uint32_t connection_id_; + uint32_t capability_flags_; + uint16_t status_flags_; + uint8_t character_set_; + uint8_t auth_plugin_data_len_; + uint8_t protocol_version_; + bool disallowed_; + +public: + MySQLHandshakeResponse() : disallowed_(false) { } + //move constructor + MySQLHandshakeResponse(MySQLHandshakeResponse&& move) = default; + //move operator + MySQLHandshakeResponse& operator= (MySQLHandshakeResponse&& move) = default; +}; + +class MySQLSSLRequest : public MySQLRequest +{ +private: + virtual int encode(struct iovec vectors[], int max); + + /* Do not support server side with SSL currently. */ + virtual int decode_packet(const unsigned char *buf, size_t buflen) + { + return -2; + } + +private: + int character_set_; + SSLHandshaker ssl_handshaker_; + +public: + MySQLSSLRequest(int character_set, SSL *ssl) : ssl_handshaker_(ssl) + { + character_set_ = character_set; + } + + MySQLSSLRequest(MySQLSSLRequest&& move) = default; + MySQLSSLRequest& operator= (MySQLSSLRequest&& move) = default; +}; + +class MySQLAuthRequest : public MySQLRequest +{ +public: + void set_auth(const std::string username, + const std::string password, + const std::string db, + int character_set) + { + username_ = std::move(username); + password_ = std::move(password); + db_ = std::move(db); + character_set_ = character_set; + } + + void set_auth_plugin_name(std::string name) + { + auth_plugin_name_ = std::move(name); + } + + void set_seed(const unsigned char seed[20]) + { + memcpy(seed_, seed, 20); + } + +private: + virtual int encode(struct iovec vectors[], int max); + virtual int decode_packet(const unsigned char *buf, size_t buflen); + + std::string username_; + std::string password_; + std::string db_; + std::string auth_plugin_name_; + unsigned char seed_[20]; + int character_set_; + +public: + MySQLAuthRequest() : character_set_(33) { } + //move constructor + MySQLAuthRequest(MySQLAuthRequest&& move) = default; + //move operator + MySQLAuthRequest& operator= (MySQLAuthRequest&& move) = default; +}; + +class MySQLAuthResponse : public MySQLResponse +{ +public: + std::string get_auth_plugin_name() const { return auth_plugin_name_; } + + void get_seed(unsigned char seed[20]) const + { + memcpy(seed, seed_, 20); + } + + bool is_continue() const + { + return continue_; + } + +private: + virtual int decode_packet(const unsigned char *buf, size_t buflen); + +private: + std::string auth_plugin_name_; + unsigned char seed_[20]; + bool continue_; + +public: + MySQLAuthResponse() : continue_(false) { } + //move constructor + MySQLAuthResponse(MySQLAuthResponse&& move) = default; + //move operator + MySQLAuthResponse& operator= (MySQLAuthResponse&& move) = default; +}; + +class MySQLAuthSwitchRequest : public MySQLRequest +{ +public: + void set_password(std::string password) + { + password_ = std::move(password); + } + + void set_auth_plugin_name(std::string name) + { + auth_plugin_name_ = std::move(name); + } + + void set_seed(const unsigned char seed[20]) + { + memcpy(seed_, seed, 20); + } + +private: + virtual int encode(struct iovec vectors[], int max); + + /* Not implemented. */ + virtual int decode_packet(const unsigned char *buf, size_t buflen) + { + return -2; + } + + std::string password_; + std::string auth_plugin_name_; + unsigned char seed_[20]; + +public: + MySQLAuthSwitchRequest() { } + //move constructor + MySQLAuthSwitchRequest(MySQLAuthSwitchRequest&& move) = default; + //move operator + MySQLAuthSwitchRequest& operator= (MySQLAuthSwitchRequest&& move) = default; +}; + +class MySQLPublicKeyRequest : public MySQLRequest +{ +public: + void set_caching_sha2() { byte_ = 0x02; } + void set_sha256() { byte_ = 0x01; } + +private: + virtual int encode(struct iovec vectors[], int max) + { + buf_.assign(&byte_, 1); + return MySQLRequest::encode(vectors, max); + } + + /* Not implemented. */ + virtual int decode_packet(const unsigned char *buf, size_t buflen) + { + return -2; + } + + char byte_; + +public: + MySQLPublicKeyRequest() : byte_(0x01) { } + //move constructor + MySQLPublicKeyRequest(MySQLPublicKeyRequest&& move) = default; + //move operator + MySQLPublicKeyRequest& operator= (MySQLPublicKeyRequest&& move) = default; +}; + +class MySQLPublicKeyResponse : public MySQLResponse +{ +public: + std::string get_public_key() const + { + return public_key_; + } + + void set_public_key(std::string key) + { + public_key_ = std::move(key); + } + +private: + virtual int encode(struct iovec vectors[], int max); + virtual int decode_packet(const unsigned char *buf, size_t buflen); + + std::string public_key_; + +public: + MySQLPublicKeyResponse() { } + //move constructor + MySQLPublicKeyResponse(MySQLPublicKeyResponse&& move) = default; + //move operator + MySQLPublicKeyResponse& operator= (MySQLPublicKeyResponse&& move) = default; +}; + +class MySQLRSAAuthRequest : public MySQLRequest +{ +public: + void set_password(std::string password) + { + password_ = std::move(password); + } + + void set_public_key(std::string key) + { + public_key_ = std::move(key); + } + + void set_seed(const unsigned char seed[20]) + { + memcpy(seed_, seed, 20); + } + +private: + virtual int encode(struct iovec vectors[], int max); + + /* Not implemented. */ + virtual int decode_packet(const unsigned char *buf, size_t buflen) + { + return -2; + } + + int rsa_encrypt(void *ctx); + + std::string password_; + std::string public_key_; + unsigned char seed_[20]; + +public: + MySQLRSAAuthRequest() { } + //move constructor + MySQLRSAAuthRequest(MySQLRSAAuthRequest&& move) = default; + //move operator + MySQLRSAAuthRequest& operator= (MySQLRSAAuthRequest&& move) = default; +}; + +////////// + +inline mysql_parser_t *MySQLMessage::get_parser() const +{ + return parser_; +} + +inline int MySQLMessage::get_seqid() const +{ + return seqid_; +} + +inline void MySQLMessage::set_seqid(int seqid) +{ + seqid_ = seqid; +} + +inline int MySQLMessage::get_command() const +{ + return parser_->cmd; +} + +inline void MySQLMessage::set_command(int cmd) const +{ + mysql_parser_set_command(cmd, parser_); +} + +inline MySQLMessage::MySQLMessage(): + stream_(new mysql_stream_t), + parser_(new mysql_parser_t), + seqid_(0), + cur_size_(0) +{ + mysql_stream_init(stream_); + mysql_parser_init(parser_); +} + +inline bool MySQLRequest::query_is_unset() const +{ + return buf_.empty(); +} + +inline void MySQLRequest::set_query(const char *query) +{ + set_query(query, strlen(query)); +} + +inline void MySQLRequest::set_query(const std::string& query) +{ + set_query(query.c_str(), query.size()); +} + +inline int MySQLResponse::get_packet_type() const +{ + return parser_->packet_type; +} + +inline int MySQLResponse::get_error_code() const +{ + return is_error_packet() ? parser_->error : 0; +} + +inline std::string MySQLResponse::get_error_msg() const +{ + if (is_error_packet()) + { + const char *s; + size_t slen; + + mysql_parser_get_err_msg(&s, &slen, parser_); + if (slen > 0) + return std::string(s, slen); + } + + return std::string(); +} + +inline std::string MySQLResponse::get_sql_state() const +{ + if (is_error_packet()) + { + const char *s; + size_t slen; + + mysql_parser_get_net_state(&s, &slen, parser_); + if (slen > 0) + return std::string(s, slen); + } + + return std::string(); +} + +} + diff --git a/src/protocol/MySQLResult.cc b/src/protocol/MySQLResult.cc new file mode 100644 index 0000000..8aeef65 --- /dev/null +++ b/src/protocol/MySQLResult.cc @@ -0,0 +1,349 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) +*/ + +#include +#include "mysql_types.h" +#include "mysql_byteorder.h" +#include "MySQLMessage.h" +#include "MySQLResult.h" + +namespace protocol +{ + +MySQLField::MySQLField(const void *buf, mysql_field_t *field) +{ + const char *p = (const char *)buf; + + this->name = p + field->name_offset; + this->name_length = field->name_length; + + this->org_name = p + field->org_name_offset; + this->org_name_length = field->org_name_length; + + this->table = p + field->table_offset; + this->table_length = field->table_length; + + this->org_table = p + field->org_table_offset; + this->org_table_length = field->org_table_length; + + this->db = p + field->db_offset; + this->db_length = field->db_length; + + this->catalog = p + field->catalog_offset; + this->catalog_length = field->catalog_length; + + if (field->def_offset == (size_t)-1 && field->def_length == 0) + { + this->def = NULL; + this->def_length = 0; + } else { + this->def = p + field->def_offset; + this->def_length = field->def_length; + } + + this->flags = field->flags; + this->length = field->length; + this->decimals = field->decimals; + this->charsetnr = field->charsetnr; + this->data_type = field->data_type; +} + +MySQLResultCursor::MySQLResultCursor() +{ + this->init(); +} + +void MySQLResultCursor::init() +{ + this->current_field = 0; + this->current_row = 0; + this->field_count = 0; + this->fields = NULL; + this->parser = NULL; + this->status = MYSQL_STATUS_NOT_INIT; +} + +MySQLResultCursor::MySQLResultCursor(const MySQLResponse *resp) +{ + this->init(resp); +} + +void MySQLResultCursor::reset(MySQLResponse *resp) +{ + this->clear(); + this->init(resp); +} + +void MySQLResultCursor::fetch_result_set(const struct __mysql_result_set *result_set) +{ + const char *buf = (const char *)this->parser->buf; + this->server_status = result_set->server_status; + + switch (result_set->type) + { + case MYSQL_PACKET_GET_RESULT: + this->status = MYSQL_STATUS_GET_RESULT; + this->field_count = result_set->field_count; + this->start = buf + result_set->rows_begin_offset; + this->pos = this->start; + this->end = buf + result_set->rows_end_offset; + this->row_count = result_set->row_count; + + this->fields = new MySQLField *[this->field_count]; + for (int i = 0; i < this->field_count; i++) + this->fields[i] = new MySQLField(this->parser->buf, result_set->fields[i]); + break; + + case MYSQL_PACKET_OK: + this->status = MYSQL_STATUS_OK; + this->affected_rows = result_set->affected_rows; + this->insert_id = result_set->insert_id; + this->warning_count = result_set->warning_count; + this->start = buf + result_set->info_offset; + this->info_len = result_set->info_len; + this->field_count = 0; + this->fields = NULL; + break; + + default: + this->status = MYSQL_STATUS_ERROR; + break; + } +} + +void MySQLResultCursor::init(const MySQLResponse *resp) +{ + this->current_field = 0; + this->current_row = 0; + this->field_count = 0; + this->fields = NULL; + this->parser = resp->get_parser(); + this->status = MYSQL_STATUS_NOT_INIT; + + if (!list_empty(&this->parser->result_set_list)) + { + struct __mysql_result_set *result_set; + + mysql_result_set_cursor_init(&this->cursor, this->parser); + mysql_result_set_cursor_next(&result_set, &this->cursor); + + this->fetch_result_set(result_set); + } +} + +bool MySQLResultCursor::next_result_set() +{ + if (this->status == MYSQL_STATUS_NOT_INIT || + this->status == MYSQL_STATUS_ERROR) + { + return false; + } + + struct __mysql_result_set *result_set; + if (mysql_result_set_cursor_next(&result_set, &this->cursor) == 0) + { + for (int i = 0; i < this->field_count; i++) + delete this->fields[i]; + + delete []this->fields; + + this->current_field = 0; + this->current_row = 0; + + this->fetch_result_set(result_set); + return true; + } + else + { + this->status = MYSQL_STATUS_END; + return false; + } +} + +bool MySQLResultCursor::fetch_row(std::vector& row_arr) +{ + if (this->status != MYSQL_STATUS_GET_RESULT) + return false; + + unsigned long long len; + const unsigned char *data; + int data_type; + + const unsigned char *p = (const unsigned char *)this->pos; + const unsigned char *end = (const unsigned char *)this->end; + + row_arr.clear(); + + for (int i = 0; i < this->field_count; i++) + { + data_type = this->fields[i]->get_data_type(); + if (*p == MYSQL_PACKET_HEADER_NULL) + { + data = NULL; + len = 0; + p++; + data_type = MYSQL_TYPE_NULL; + } else if (decode_string(&data, &len, &p, end) == 0) { + this->status = MYSQL_STATUS_ERROR; + return false; + } + + row_arr.emplace_back(data, len, data_type); + } + + if (++this->current_row == this->row_count) + this->status = MYSQL_STATUS_END; + + this->pos = p; + + return true; +} + +bool MySQLResultCursor::fetch_row(std::map& row_map) +{ + return this->fetch_row>(row_map); +} + +bool MySQLResultCursor::fetch_row(std::unordered_map& row_map) +{ + return this->fetch_row>(row_map); +} + +bool MySQLResultCursor::fetch_row_nocopy(const void **data, size_t *len, int *data_type) +{ + if (this->status != MYSQL_STATUS_GET_RESULT) + return false; + + unsigned long long cell_len; + const unsigned char *cell_data; + + const unsigned char *p = (const unsigned char *)this->pos; + const unsigned char *end = (const unsigned char *)this->end; + + for (int i = 0; i < this->field_count; i++) + { + if (*p == MYSQL_PACKET_HEADER_NULL) + { + cell_data = NULL; + cell_len = 0; + p++; + } + else if (decode_string(&cell_data, &cell_len, &p, end) == 0) + { + this->status = MYSQL_STATUS_ERROR; + return false; + } + + data[i] = cell_data; + len[i] = cell_len; + data_type[i] = this->fields[i]->get_data_type(); + } + + this->pos = p; + + if (++this->current_row == this->row_count) + this->status = MYSQL_STATUS_END; + + return true; +} + +bool MySQLResultCursor::fetch_all(std::vector>& rows) +{ + if (this->status != MYSQL_STATUS_GET_RESULT) + return false; + + unsigned long long len; + const unsigned char *data; + int data_type; + + const unsigned char *p = (const unsigned char *)this->pos; + const unsigned char *end = (const unsigned char *)this->end; + + rows.clear(); + + for (int i = this->current_row; i < this->row_count; i++) + { + std::vector tmp; + for (int j = 0; j < this->field_count; j++) + { + data_type = this->fields[j]->get_data_type(); + if (*p == MYSQL_PACKET_HEADER_NULL) + { + data = NULL; + len = 0; + p++; + data_type = MYSQL_TYPE_NULL; + } + else if (decode_string(&data, &len, &p, end) == 0) + { + this->status = MYSQL_STATUS_ERROR; + return false; + } + + tmp.emplace_back(data, len, data_type); + } + rows.emplace_back(std::move(tmp)); + } + + this->current_row = this->row_count; + this->status = MYSQL_STATUS_END; + this->pos = p; + + return true; +} + +void MySQLResultCursor::first_result_set() +{ + if (this->status == MYSQL_STATUS_NOT_INIT || + this->status == MYSQL_STATUS_ERROR) + { + return; + } + + mysql_result_set_cursor_rewind(&this->cursor); + struct __mysql_result_set *result_set; + + if (mysql_result_set_cursor_next(&result_set, &this->cursor) == 0) + { + for (int i = 0; i < this->field_count; i++) + delete this->fields[i]; + + delete []this->fields; + + this->current_field = 0; + this->current_row = 0; + + this->fetch_result_set(result_set); + } +} + +void MySQLResultCursor::rewind() +{ + if (this->status != MYSQL_STATUS_GET_RESULT && + this->status != MYSQL_STATUS_END) + { + return; + } + + this->current_field = 0; + this->current_row = 0; + this->pos = this->start; +} + +} + diff --git a/src/protocol/MySQLResult.h b/src/protocol/MySQLResult.h new file mode 100644 index 0000000..c81825e --- /dev/null +++ b/src/protocol/MySQLResult.h @@ -0,0 +1,196 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) +*/ + +#ifndef _MYSQLRESULT_H_ +#define _MYSQLRESULT_H_ + +#include +#include +#include +#include +#include "mysql_types.h" +#include "mysql_parser.h" +#include "MySQLMessage.h" + +/** + * @file MySQLResult.h + * @brief MySQL toolbox for visit result + */ + +namespace protocol +{ + +class MySQLCell +{ +public: + MySQLCell(); + + MySQLCell(MySQLCell&& move); + MySQLCell& operator=(MySQLCell&& move); + + MySQLCell(const void *data, size_t len, int data_type); + + int get_data_type() const; + + bool is_null() const; + bool is_int() const; + bool is_string() const; + bool is_float() const; + bool is_double() const; + bool is_ulonglong() const; + bool is_date() const; + bool is_time() const; + bool is_datetime() const; + + // for copy + int as_int() const; + std::string as_string() const; + std::string as_binary_string() const; + float as_float() const; + double as_double() const; + unsigned long long as_ulonglong() const; + std::string as_date() const; + std::string as_time() const; + std::string as_datetime() const; + + // for nocopy + void get_cell_nocopy(const void **data, size_t *len, int *data_type) const; + +private: + int data_type; + void *data; + size_t len; +}; + +class MySQLField +{ +public: + MySQLField(const void *buf, mysql_field_t *field); + + std::string get_name() const; + std::string get_org_name() const; + std::string get_table() const; + std::string get_org_table() const; + std::string get_db() const; + std::string get_catalog() const; + std::string get_def() const; + int get_charsetnr() const; + int get_length() const; + int get_flags() const; + int get_decimals() const; + int get_data_type() const; + +private: + const char *name; /* Name of column */ + const char *org_name; /* Original column name, if an alias */ + const char *table; /* Table of column if column was a field */ + const char *org_table; /* Org table name, if table was an alias */ + const char *db; /* Database for table */ + const char *catalog; /* Catalog for table */ + const char *def; /* Default value (set by mysql_list_fields) */ + int length; /* Width of column (create length) */ + int name_length; + int org_name_length; + int table_length; + int org_table_length; + int db_length; + int catalog_length; + int def_length; + int flags; /* Div flags */ + int decimals; /* Number of decimals in field */ + int charsetnr; /* Character set */ + int data_type; /* Type of field. See mysql_types.h for types */ +}; + +class MySQLResultCursor +{ +public: + MySQLResultCursor(const MySQLResponse *resp); + + MySQLResultCursor(MySQLResultCursor&& move); + MySQLResultCursor& operator=(MySQLResultCursor&& move); + + virtual ~MySQLResultCursor(); + + bool next_result_set(); + void first_result_set(); + + const MySQLField *fetch_field(); + const MySQLField *const *fetch_fields() const; + + bool fetch_row(std::vector& row_arr); + bool fetch_row(std::map& row_map); + bool fetch_row(std::unordered_map& row_map); + + bool fetch_row_nocopy(const void **data, size_t *len, int *data_type); + bool fetch_all(std::vector>& rows); + + int get_cursor_status() const; + int get_server_status() const; + + int get_field_count() const; + int get_rows_count() const; + unsigned long long get_affected_rows() const; + unsigned long long get_insert_id() const; + int get_warnings() const; + std::string get_info() const; + + void rewind(); + +public: + MySQLResultCursor(); + void reset(MySQLResponse *resp); + +private: + void init(const MySQLResponse *resp); + void init(); + void clear(); + + void fetch_result_set(const struct __mysql_result_set *result_set); + + template + bool fetch_row(T& row_map); + + int status; + int server_status; + const void *start; + const void *end; + const void *pos; + + const void **row_data; + MySQLField **fields; + int row_count; + int field_count; + int current_row; + int current_field; + + unsigned long long affected_rows; + unsigned long long insert_id; + int warning_count; + int info_len; + + mysql_result_set_cursor_t cursor; + mysql_parser_t *parser; +}; + +} + +#include "MySQLResult.inl" + +#endif + diff --git a/src/protocol/MySQLResult.inl b/src/protocol/MySQLResult.inl new file mode 100644 index 0000000..4acbfd1 --- /dev/null +++ b/src/protocol/MySQLResult.inl @@ -0,0 +1,486 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) +*/ + +#include +#include +#include +#include "mysql_byteorder.h" + +namespace protocol +{ + +inline std::string MySQLField::get_name() const +{ + if (this->data_type == MYSQL_TYPE_NULL) + return ""; + return std::string(this->name, this->name_length); +} + +inline std::string MySQLField::get_org_name() const +{ + if (this->data_type == MYSQL_TYPE_NULL) + return ""; + return std::string(this->org_name, this->org_name_length); +} + +inline std::string MySQLField::get_table() const +{ + if (this->data_type == MYSQL_TYPE_NULL) + return ""; + return std::string(this->table, this->table_length); +} + +inline std::string MySQLField::get_org_table() const +{ + if (this->data_type == MYSQL_TYPE_NULL) + return ""; + return std::string(this->org_table, this->org_table_length); +} + +inline std::string MySQLField::get_db() const +{ + if (this->data_type == MYSQL_TYPE_NULL) + return ""; + return std::string(this->db, this->db_length); +} + +inline std::string MySQLField::get_catalog() const +{ + if (this->data_type == MYSQL_TYPE_NULL) + return ""; + return std::string(this->catalog, this->catalog_length); +} + +inline std::string MySQLField::get_def() const +{ + if (this->data_type == MYSQL_TYPE_NULL) + return ""; + return std::string(this->def, this->def_length); +} + +inline int MySQLField::get_charsetnr() const +{ + if (this->data_type == MYSQL_TYPE_NULL) + return 0; + return this->charsetnr; +} + +inline int MySQLField::get_length() const +{ + if (this->data_type == MYSQL_TYPE_NULL) + return 0; + return this->length; +} + +inline int MySQLField::get_flags() const +{ + if (this->data_type == MYSQL_TYPE_NULL) + return 0; + return this->flags; +} + +inline int MySQLField::get_decimals() const +{ + if (this->data_type == MYSQL_TYPE_NULL) + return 0; + return this->decimals; +} + +inline int MySQLField::get_data_type() const +{ + return this->data_type; +} + +inline MySQLCell::MySQLCell(MySQLCell&& move) +{ + this->operator=(std::move(move)); +} + +inline MySQLCell& MySQLCell::operator=(MySQLCell&& move) +{ + if (this != &move) + { + this->data = move.data; + this->len = move.len; + this->data_type = move.data_type; + + move.data = NULL; + move.len = 0; + } + + return *this; +} + +inline MySQLCell::MySQLCell(const void *data, size_t len, int data_type) +{ + this->data_type = data_type; + this->data = const_cast(data); + this->len = len; +} + +inline MySQLCell::MySQLCell() +{ + this->data = NULL; + this->len = 0; + this->data_type = MYSQL_TYPE_NULL; +} + +inline int MySQLCell::get_data_type() const +{ + return this->data_type; +} + +inline void MySQLCell::get_cell_nocopy(const void **data, size_t *len, + int *data_type) const +{ + *data = this->data; + *len = this->len; + *data_type = this->data_type; +} + +inline bool MySQLCell::is_null() const +{ + return (this->data_type == MYSQL_TYPE_NULL); +} + +inline std::string MySQLCell::as_binary_string() const +{ + return std::string((char *)this->data, this->len); +} + +inline bool MySQLCell::is_int() const +{ + return (this->data_type == MYSQL_TYPE_TINY || + this->data_type == MYSQL_TYPE_SHORT || + this->data_type == MYSQL_TYPE_INT24 || + this->data_type == MYSQL_TYPE_LONG); +} + +inline int MySQLCell::as_int() const +{ + if (!this->is_int()) + return 0; + + std::string num((char *)this->data, this->len); + return atoi(num.c_str()); +} + +inline bool MySQLCell::is_float() const +{ + return (this->data_type == MYSQL_TYPE_FLOAT); +} + +inline float MySQLCell::as_float() const +{ + if (!this->is_float()) + return NAN; + + std::string num((char *)this->data, this->len); + return strtof(num.c_str(), NULL); +} + +inline bool MySQLCell::is_double() const +{ + return (this->data_type == MYSQL_TYPE_DOUBLE); +} + +inline double MySQLCell::as_double() const +{ + if (!this->is_double()) + return NAN; + + std::string num((char *)this->data, this->len); + return strtod(num.c_str(), NULL); +} + +inline bool MySQLCell::is_ulonglong() const +{ + return (this->data_type == MYSQL_TYPE_LONGLONG); +} + +inline unsigned long long MySQLCell::as_ulonglong() const +{ + if (!this->is_ulonglong()) + return (unsigned long long)-1; + + std::string num((char *)this->data, this->len); + return strtoull(num.c_str(), NULL, 10); +} + +inline bool MySQLCell::is_date() const +{ + return (this->data_type == MYSQL_TYPE_DATE); +} + +inline std::string MySQLCell::as_date() const +{ + if (!this->is_date()) + return ""; + + return std::string((char *)this->data, this->len); +} + +inline bool MySQLCell::is_time() const +{ + return (this->data_type == MYSQL_TYPE_TIME); +} + +inline std::string MySQLCell::as_time() const +{ + if (!this->is_time()) + return ""; + + return std::string((char *)this->data, this->len); +} + +inline bool MySQLCell::is_datetime() const +{ + return (this->data_type == MYSQL_TYPE_DATETIME || + this->data_type == MYSQL_TYPE_TIMESTAMP); +} + +inline std::string MySQLCell::as_datetime() const +{ + if (!this->is_datetime()) + return ""; + + return std::string((char *)this->data, this->len); +} + +inline bool MySQLCell::is_string() const +{ + return (this->data_type == MYSQL_TYPE_DECIMAL || + this->data_type == MYSQL_TYPE_NEWDECIMAL || + this->data_type == MYSQL_TYPE_STRING || + this->data_type == MYSQL_TYPE_VARCHAR || + this->data_type == MYSQL_TYPE_VAR_STRING || + this->data_type == MYSQL_TYPE_JSON); +} + +inline std::string MySQLCell::as_string() const +{ + if (!this->is_string() && !this->is_time() && + !this->is_date() && !this->is_datetime()) + return ""; + + return std::string((char *)this->data, this->len); +} + +template +bool MySQLResultCursor::fetch_row(T& row_map) +{ + if (this->status != MYSQL_STATUS_GET_RESULT) + return false; + + unsigned long long len; + const unsigned char *data; + int data_type; + + const unsigned char *p = (const unsigned char *)this->pos; + const unsigned char *end = (const unsigned char *)this->end; + + row_map.clear(); + + for (int i = 0; i < this->field_count; i++) + { + data_type = this->fields[i]->get_data_type(); + if (*p == MYSQL_PACKET_HEADER_NULL) + { + data = NULL; + len = 0; + p++; + data_type = MYSQL_TYPE_NULL; + } + else if (decode_string(&data, &len, &p, end) == 0) + { + this->status = MYSQL_STATUS_ERROR; + return false; + } + row_map.emplace(this->fields[i]->get_name(), MySQLCell(data, len, data_type)); + } + + this->pos = p; + + if (++this->current_row == this->row_count) + this->status = MYSQL_STATUS_END; + + return true; +} + +inline const MySQLField *MySQLResultCursor::fetch_field() +{ + if (this->status != MYSQL_STATUS_GET_RESULT && + this->status != MYSQL_STATUS_END) + { + return NULL; + } + + if (this->current_field >= this->field_count) + return NULL; + + return this->fields[this->current_field++]; +} + +inline const MySQLField *const *MySQLResultCursor::fetch_fields() const +{ + if (this->status != MYSQL_STATUS_GET_RESULT && + this->status != MYSQL_STATUS_END) + { + return NULL; + } + + return this->fields; +} + +inline int MySQLResultCursor::get_cursor_status() const +{ + return this->status; +} + +inline int MySQLResultCursor::get_server_status() const +{ + if (this->status != MYSQL_STATUS_GET_RESULT && + this->status != MYSQL_STATUS_END && + this->status != MYSQL_STATUS_OK) + { + return 0; + } + + return this->server_status; +} + +inline int MySQLResultCursor::get_field_count() const +{ + if (this->status != MYSQL_STATUS_GET_RESULT && + this->status != MYSQL_STATUS_END) + { + return 0; + } + + return this->field_count; +} + +inline int MySQLResultCursor::get_rows_count() const +{ + if (this->status != MYSQL_STATUS_GET_RESULT && + this->status != MYSQL_STATUS_END) + { + return 0; + } + + return this->row_count; +} + +inline unsigned long long MySQLResultCursor::get_affected_rows() const +{ + if (this->status != MYSQL_PACKET_OK) + return 0; + + return this->affected_rows; +} + +inline int MySQLResultCursor::get_warnings() const +{ + if (this->status != MYSQL_PACKET_OK) + return 0; + + return this->warning_count; +} + +inline unsigned long long MySQLResultCursor::get_insert_id() const +{ + if (this->status != MYSQL_PACKET_OK) + return 0; + + return this->insert_id; +} + +inline std::string MySQLResultCursor::get_info() const +{ + if (this->status != MYSQL_PACKET_OK) + return ""; + + return std::string((char *)this->start, this->info_len); +} + +inline void MySQLResultCursor::clear() +{ + for (int i = 0; i < this->field_count; i++) + delete this->fields[i]; + + delete []this->fields; +} + +inline MySQLResultCursor::~MySQLResultCursor() +{ + this->clear(); +} + +inline MySQLResultCursor::MySQLResultCursor(MySQLResultCursor&& move) +{ + this->start = move.start; + this->end = move.end; + this->pos = move.pos; + this->status = move.status; + this->row_data = move.row_data; + this->fields = move.fields; + this->row_count = move.row_count; + this->field_count = move.field_count; + this->current_row = move.current_row; + this->current_field = move.current_field; + this->affected_rows = move.affected_rows; + this->insert_id = move.insert_id; + this->warning_count = move.warning_count; + this->info_len = move.info_len; + this->cursor = move.cursor; + this->parser = move.parser; + + move.init(); +} + +inline MySQLResultCursor& MySQLResultCursor::operator=(MySQLResultCursor&& move) +{ + if (this != &move) + { + this->clear(); + + this->start = move.start; + this->end = move.end; + this->pos = move.pos; + this->status = move.status; + this->row_data = move.row_data; + this->fields = move.fields; + this->row_count = move.row_count; + this->field_count = move.field_count; + this->current_row = move.current_row; + this->current_field = move.current_field; + this->affected_rows = move.affected_rows; + this->insert_id = move.insert_id; + this->warning_count = move.warning_count; + this->info_len = move.info_len; + this->cursor = move.cursor; + this->parser = move.parser; + + move.init(); + } + + return *this; +} + +} + diff --git a/src/protocol/MySQLUtil.cc b/src/protocol/MySQLUtil.cc new file mode 100644 index 0000000..d661fec --- /dev/null +++ b/src/protocol/MySQLUtil.cc @@ -0,0 +1,85 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include "MySQLUtil.h" + +namespace protocol +{ + +std::string MySQLUtil::escape_string(const std::string& str) +{ + std::string res; + char escape; + size_t i; + + for (i = 0; i < str.size(); i++) + { + switch (str[i]) + { + case '\0': + escape = '0'; + break; + case '\n': + escape = 'n'; + break; + case '\r': + escape = 'r'; + break; + case '\\': + escape = '\\'; + break; + case '\'': + escape = '\''; + break; + case '\"': + escape = '\"'; + break; + case '\032': + escape = 'Z'; + break; + default: + res.push_back(str[i]); + continue; + } + + res.push_back('\\'); + res.push_back(escape); + } + + return res; +} + +std::string MySQLUtil::escape_string_quote(const std::string& str, char quote) +{ + std::string res; + size_t i; + + for (i = 0; i < str.size(); i++) + { + if (str[i] == quote) + res.push_back(quote); + + res.push_back(str[i]); + } + + return res; +} + +} + diff --git a/src/protocol/MySQLUtil.h b/src/protocol/MySQLUtil.h new file mode 100644 index 0000000..f0916fe --- /dev/null +++ b/src/protocol/MySQLUtil.h @@ -0,0 +1,37 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _MYSQLUTIL_H_ +#define _MYSQLUTIL_H_ + +#include + +namespace protocol +{ + +class MySQLUtil +{ +public: + static std::string escape_string(const std::string& str); + static std::string escape_string_quote(const std::string& str, char quote); +}; + +} + +#endif + diff --git a/src/protocol/PackageWrapper.cc b/src/protocol/PackageWrapper.cc new file mode 100644 index 0000000..a8fae11 --- /dev/null +++ b/src/protocol/PackageWrapper.cc @@ -0,0 +1,72 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include "PackageWrapper.h" + +namespace protocol +{ + +int PackageWrapper::encode(struct iovec vectors[], int max) +{ + int cnt = 0; + int ret; + + while (max >= 8) + { + ret = this->ProtocolWrapper::encode(vectors, max); + if ((unsigned int)ret > (unsigned int)max) + { + if (ret < 0) + return ret; + + break; + } + + cnt += ret; + this->set_message(this->next_out(this->message)); + if (!this->message) + return cnt; + + vectors += ret; + max -= ret; + } + + errno = EOVERFLOW; + return -1; +} + +int PackageWrapper::append(const void *buf, size_t *size) +{ + int ret = this->ProtocolWrapper::append(buf, size); + + if (ret > 0) + { + this->set_message(this->next_in(this->message)); + if (this->message) + { + this->renew(); + ret = 0; + } + } + + return ret; +} + +} + diff --git a/src/protocol/PackageWrapper.h b/src/protocol/PackageWrapper.h new file mode 100644 index 0000000..75d9a25 --- /dev/null +++ b/src/protocol/PackageWrapper.h @@ -0,0 +1,62 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _PACKAGEWRAPPER_H_ +#define _PACKAGEWRAPPER_H_ + +#include "ProtocolMessage.h" + +namespace protocol +{ + +class PackageWrapper : public ProtocolWrapper +{ +private: + virtual ProtocolMessage *next_in(ProtocolMessage *message) + { + return this->next(message); + } + + virtual ProtocolMessage *next_out(ProtocolMessage *message) + { + return this->next(message); + } + + virtual ProtocolMessage *next(ProtocolMessage *message) + { + return NULL; + } + +protected: + virtual int encode(struct iovec vectors[], int max); + virtual int append(const void *buf, size_t *size); + +public: + PackageWrapper(ProtocolMessage *message) : ProtocolWrapper(message) + { + } + +public: + PackageWrapper(PackageWrapper&& wrapper) = default; + PackageWrapper& operator = (PackageWrapper&& wrapper) = default; +}; + +} + +#endif + diff --git a/src/protocol/ProtocolMessage.h b/src/protocol/ProtocolMessage.h new file mode 100644 index 0000000..1f845ec --- /dev/null +++ b/src/protocol/ProtocolMessage.h @@ -0,0 +1,196 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _PROTOCOLMESSAGE_H_ +#define _PROTOCOLMESSAGE_H_ + +#include +#include +#include +#include "Communicator.h" + +/** + * @file ProtocolMessage.h + * @brief General Protocol Interface + */ + +namespace protocol +{ + +class ProtocolMessage : public CommMessageOut, public CommMessageIn +{ +protected: + virtual int encode(struct iovec vectors[], int max) + { + errno = ENOSYS; + return -1; + } + + /* You have to implement one of the 'append' functions, and the first one + * with arguement 'size_t *size' is recommmended. */ + + /* Argument 'size' indicates bytes to append, and returns bytes used. */ + virtual int append(const void *buf, size_t *size) + { + return this->append(buf, *size); + } + + /* When implementing this one, all bytes are consumed. Cannot support + * streaming protocol. */ + virtual int append(const void *buf, size_t size) + { + errno = ENOSYS; + return -1; + } + +public: + void set_size_limit(size_t limit) { this->size_limit = limit; } + size_t get_size_limit() const { return this->size_limit; } + +public: + class Attachment + { + public: + virtual ~Attachment() { } + }; + + void set_attachment(Attachment *att) { this->attachment = att; } + Attachment *get_attachment() const { return this->attachment; } + +protected: + virtual int feedback(const void *buf, size_t size) + { + if (this->wrapper) + return this->wrapper->feedback(buf, size); + else + return this->CommMessageIn::feedback(buf, size); + } + + virtual void renew() + { + if (this->wrapper) + return this->wrapper->renew(); + else + return this->CommMessageIn::renew(); + } + + virtual ProtocolMessage *inner() { return this; } + +protected: + size_t size_limit; + +private: + Attachment *attachment; + ProtocolMessage *wrapper; + +public: + ProtocolMessage() + { + this->size_limit = (size_t)-1; + this->attachment = NULL; + this->wrapper = NULL; + } + + virtual ~ProtocolMessage() { delete this->attachment; } + +public: + ProtocolMessage(ProtocolMessage&& message) + { + this->size_limit = message.size_limit; + this->attachment = message.attachment; + message.attachment = NULL; + this->wrapper = NULL; + } + + ProtocolMessage& operator = (ProtocolMessage&& message) + { + if (&message != this) + { + this->size_limit = message.size_limit; + delete this->attachment; + this->attachment = message.attachment; + message.attachment = NULL; + } + + return *this; + } + + friend class ProtocolWrapper; +}; + +class ProtocolWrapper : public ProtocolMessage +{ +protected: + virtual int encode(struct iovec vectors[], int max) + { + return this->message->encode(vectors, max); + } + + virtual int append(const void *buf, size_t *size) + { + return this->message->append(buf, size); + } + +protected: + virtual ProtocolMessage *inner() + { + return this->message->inner(); + } + +protected: + void set_message(ProtocolMessage *message) + { + this->message = message; + if (message) + message->wrapper = this; + } + +protected: + ProtocolMessage *message; + +public: + ProtocolWrapper(ProtocolMessage *message) + { + this->set_message(message); + } + +public: + ProtocolWrapper(ProtocolWrapper&& wrapper) : + ProtocolMessage(std::move(wrapper)) + { + this->set_message(wrapper.message); + wrapper.message = NULL; + } + + ProtocolWrapper& operator = (ProtocolWrapper&& wrapper) + { + if (&wrapper != this) + { + *(ProtocolMessage *)this = std::move(wrapper); + this->set_message(wrapper.message); + wrapper.message = NULL; + } + + return *this; + } +}; + +} + +#endif + diff --git a/src/protocol/RedisMessage.cc b/src/protocol/RedisMessage.cc new file mode 100644 index 0000000..842c86f --- /dev/null +++ b/src/protocol/RedisMessage.cc @@ -0,0 +1,699 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) + Liu Kai (liukaidx@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include "EncodeStream.h" +#include "RedisMessage.h" + +namespace protocol +{ + +typedef int64_t Rint; +typedef std::string Rstr; +typedef std::vector Rarr; + +RedisValue& RedisValue::operator= (const RedisValue& copy) +{ + if (this != ©) + { + free_data(); + + switch (copy.type_) + { + case REDIS_REPLY_TYPE_INTEGER: + type_ = copy.type_; + data_ = new Rint(*((Rint*)(copy.data_))); + break; + + case REDIS_REPLY_TYPE_ERROR: + case REDIS_REPLY_TYPE_STATUS: + case REDIS_REPLY_TYPE_STRING: + type_ = copy.type_; + data_ = new Rstr(*((Rstr*)(copy.data_))); + break; + + case REDIS_REPLY_TYPE_ARRAY: + type_ = copy.type_; + data_ = new Rarr(*((Rarr*)(copy.data_))); + break; + + default: + type_ = REDIS_REPLY_TYPE_NIL; + data_ = NULL; + } + } + + return *this; +} + +RedisValue& RedisValue::operator= (RedisValue&& move) +{ + if (this != &move) + { + free_data(); + + type_ = move.type_; + data_ = move.data_; + + move.type_ = REDIS_REPLY_TYPE_NIL; + move.data_ = NULL; + } + + return *this; +} + +void RedisValue::free_data() +{ + if (data_) + { + switch (type_) + { + case REDIS_REPLY_TYPE_INTEGER: + delete (Rint *)data_; + break; + + case REDIS_REPLY_TYPE_ERROR: + case REDIS_REPLY_TYPE_STATUS: + case REDIS_REPLY_TYPE_STRING: + delete (Rstr *)data_; + break; + + case REDIS_REPLY_TYPE_ARRAY: + delete (Rarr *)data_; + break; + } + + data_ = NULL; + } +} + +void RedisValue::only_set_string_data(const std::string& strv) +{ + Rstr *p = (Rstr *)(data_); + p->assign(strv); +} + +void RedisValue::only_set_string_data(const char *str, size_t len) +{ + Rstr *p = (Rstr *)(data_); + if (str == NULL || len == 0) + p->clear(); + else + p->assign(str, len); +} + +void RedisValue::only_set_string_data(const char *str) +{ + Rstr *p = (Rstr *)(data_); + if (str == NULL) + p->clear(); + else + p->assign(str); +} + +void RedisValue::set_int(int64_t intv) +{ + if (type_ == REDIS_REPLY_TYPE_INTEGER) + *((Rint *)data_) = intv; + else + { + free_data(); + data_ = new Rint(intv); + type_ = REDIS_REPLY_TYPE_INTEGER; + } +} + +void RedisValue::set_string(const std::string& strv) +{ + if (type_ == REDIS_REPLY_TYPE_STRING || + type_ == REDIS_REPLY_TYPE_STATUS || + type_ == REDIS_REPLY_TYPE_ERROR) + only_set_string_data(strv); + else + { + free_data(); + data_ = new Rstr(strv); + } + + type_ = REDIS_REPLY_TYPE_STRING; +} + +void RedisValue::set_status(const std::string& strv) +{ + if (type_ == REDIS_REPLY_TYPE_STRING || + type_ == REDIS_REPLY_TYPE_STATUS || + type_ == REDIS_REPLY_TYPE_ERROR) + only_set_string_data(strv); + else + { + free_data(); + data_ = new Rstr(strv); + } + + type_ = REDIS_REPLY_TYPE_STATUS; +} + +void RedisValue::set_error(const std::string& strv) +{ + if (type_ == REDIS_REPLY_TYPE_STRING || + type_ == REDIS_REPLY_TYPE_STATUS || + type_ == REDIS_REPLY_TYPE_ERROR) + only_set_string_data(strv); + else + { + free_data(); + data_ = new Rstr(strv); + } + + type_ = REDIS_REPLY_TYPE_ERROR; +} + +void RedisValue::set_string(const char *str, size_t len) +{ + if (type_ == REDIS_REPLY_TYPE_STRING || + type_ == REDIS_REPLY_TYPE_STATUS || + type_ == REDIS_REPLY_TYPE_ERROR) + only_set_string_data(str, len); + else + { + free_data(); + data_ = new Rstr(str, len); + } + + type_ = REDIS_REPLY_TYPE_STRING; +} + +void RedisValue::set_status(const char *str, size_t len) +{ + if (type_ == REDIS_REPLY_TYPE_STRING || + type_ == REDIS_REPLY_TYPE_STATUS || + type_ == REDIS_REPLY_TYPE_ERROR) + only_set_string_data(str, len); + else + { + free_data(); + data_ = new Rstr(str, len); + } + + type_ = REDIS_REPLY_TYPE_STATUS; +} + +void RedisValue::set_error(const char *str, size_t len) +{ + if (type_ == REDIS_REPLY_TYPE_STRING || + type_ == REDIS_REPLY_TYPE_STATUS || + type_ == REDIS_REPLY_TYPE_ERROR) + only_set_string_data(str, len); + else + { + free_data(); + data_ = new Rstr(str, len); + } + + type_ = REDIS_REPLY_TYPE_ERROR; +} + +void RedisValue::set_string(const char *str) +{ + if (type_ == REDIS_REPLY_TYPE_STRING || + type_ == REDIS_REPLY_TYPE_STATUS || + type_ == REDIS_REPLY_TYPE_ERROR) + only_set_string_data(str); + else + { + free_data(); + data_ = new Rstr(str); + } + + type_ = REDIS_REPLY_TYPE_STRING; +} + +void RedisValue::set_status(const char *str) +{ + if (type_ == REDIS_REPLY_TYPE_STRING || + type_ == REDIS_REPLY_TYPE_STATUS || + type_ == REDIS_REPLY_TYPE_ERROR) + only_set_string_data(str); + else + { + free_data(); + data_ = new Rstr(str); + } + + type_ = REDIS_REPLY_TYPE_STATUS; +} + +void RedisValue::set_error(const char *str) +{ + if (type_ == REDIS_REPLY_TYPE_STRING || + type_ == REDIS_REPLY_TYPE_STATUS || + type_ == REDIS_REPLY_TYPE_ERROR) + only_set_string_data(str); + else + { + free_data(); + data_ = new Rstr(str); + } + + type_ = REDIS_REPLY_TYPE_ERROR; +} + +void RedisValue::set_array(size_t new_size) +{ + if (type_ == REDIS_REPLY_TYPE_ARRAY) + ((Rarr *)data_)->resize(new_size); + else + { + free_data(); + data_ = new Rarr(new_size); + type_ = REDIS_REPLY_TYPE_ARRAY; + } +} + +void RedisValue::set(const redis_reply_t *reply) +{ + set_nil(); + switch (reply->type) + { + case REDIS_REPLY_TYPE_INTEGER: + set_int(reply->integer); + break; + + case REDIS_REPLY_TYPE_ERROR: + set_error(reply->str, reply->len); + break; + + case REDIS_REPLY_TYPE_STATUS: + set_status(reply->str, reply->len); + break; + + case REDIS_REPLY_TYPE_STRING: + set_string(reply->str, reply->len); + break; + + case REDIS_REPLY_TYPE_ARRAY: + set_array(reply->elements); + + if (reply->elements > 0) + { + Rarr *parr = (Rarr *)data_; + for (size_t i = 0; i < reply->elements; i++) + (*parr)[i].set(reply->element[i]); + } + + break; + } +} + +void RedisValue::arr_clear() +{ + if (type_ == REDIS_REPLY_TYPE_ARRAY) + ((Rarr *)data_)->clear(); +} + +void RedisValue::arr_resize(size_t new_size) +{ + if (type_ == REDIS_REPLY_TYPE_ARRAY) + ((Rarr *)data_)->resize(new_size); +} + +bool RedisValue::transform(redis_reply_t *reply) const +{ + //todo risk of stack overflow + Rarr *parr; + Rstr *pstr; + + redis_reply_set_null(reply); + switch (type_) + { + case REDIS_REPLY_TYPE_INTEGER: + redis_reply_set_integer(*((Rint *)data_), reply); + break; + + case REDIS_REPLY_TYPE_ARRAY: + parr = (Rarr *)data_; + if (redis_reply_set_array(parr->size(), reply) < 0) + return false; + + for (size_t i = 0; i < reply->elements; i++) + { + if (!(*parr)[i].transform(reply->element[i])) + return false; + } + + break; + + case REDIS_REPLY_TYPE_STATUS: + pstr = (Rstr *)data_; + redis_reply_set_status(pstr->c_str(), pstr->size(), reply); + + break; + + case REDIS_REPLY_TYPE_ERROR: + pstr = (Rstr *)data_; + redis_reply_set_error(pstr->c_str(), pstr->size(), reply); + + break; + + case REDIS_REPLY_TYPE_STRING: + pstr = (Rstr *)data_; + redis_reply_set_string(pstr->c_str(), pstr->size(), reply); + + break; + } + + return true; +} + +std::string RedisValue::debug_string() const +{ + std::string ret; + + if (is_error()) + { + ret += "ERROR: "; + ret += string_view()->c_str(); + } + else if (is_int()) + { + std::ostringstream oss; + oss << int_value(); + ret += oss.str(); + } + else if (is_nil()) + { + ret += "nil"; + } + else if (is_string()) + { + ret += '\"'; + ret += string_view()->c_str(); + ret += '\"'; + } + else if (is_array()) + { + ret += '['; + size_t l = arr_size(); + for (size_t i = 0; i < l; i++) + { + if (i) + ret += ", "; + + ret += (*this)[i].debug_string(); + } + ret += ']'; + } + + return ret; +} + +RedisValue::~RedisValue() +{ + free_data(); +} + +RedisMessage::RedisMessage(): + parser_(new redis_parser_t), + stream_(new EncodeStream), + cur_size_(0), + asking_(false) +{ + redis_parser_init(parser_); +} + +RedisMessage::~RedisMessage() +{ + if (parser_) + { + redis_parser_deinit(parser_); + delete parser_; + delete stream_; + } +} + +RedisMessage::RedisMessage(RedisMessage&& move) : + ProtocolMessage(std::move(move)) +{ + parser_ = move.parser_; + stream_ = move.stream_; + cur_size_ = move.cur_size_; + asking_ = move.asking_; + + move.parser_ = NULL; + move.stream_ = NULL; + move.cur_size_ = 0; + move.asking_ = false; +} + +RedisMessage& RedisMessage::operator= (RedisMessage &&move) +{ + if (this != &move) + { + *(ProtocolMessage *)this = std::move(move); + + if (parser_) + { + redis_parser_deinit(parser_); + delete parser_; + delete stream_; + } + + parser_ = move.parser_; + stream_ = move.stream_; + cur_size_ = move.cur_size_; + asking_ = move.asking_; + + move.parser_ = NULL; + move.stream_ = NULL; + move.cur_size_ = 0; + move.asking_ = false; + } + + return *this; +} + +bool RedisMessage::encode_reply(redis_reply_t *reply) +{ + EncodeStream& stream = *stream_; + switch (reply->type) + { + case REDIS_REPLY_TYPE_STATUS: + stream << "+" << std::make_pair(reply->str, reply->len) << "\r\n"; + break; + + case REDIS_REPLY_TYPE_ERROR: + stream << "-" << std::make_pair(reply->str, reply->len) << "\r\n"; + break; + + case REDIS_REPLY_TYPE_NIL: + stream << "$-1\r\n"; + break; + + case REDIS_REPLY_TYPE_INTEGER: + stream << ":" << reply->integer << "\r\n"; + break; + + case REDIS_REPLY_TYPE_STRING: + stream << "$" << reply->len << "\r\n"; + stream << std::make_pair(reply->str, reply->len) << "\r\n"; + break; + + case REDIS_REPLY_TYPE_ARRAY: + stream << "*" << reply->elements << "\r\n"; + for (size_t i = 0; i < reply->elements; i++) + if (!encode_reply(reply->element[i])) + return false; + + break; + + default: + return false; + } + + return true; +} + +int RedisMessage::encode(struct iovec vectors[], int max) +{ + stream_->reset(vectors, max); + + if (encode_reply(&parser_->reply)) + return stream_->size(); + + return 0; +} + +int RedisMessage::append(const void *buf, size_t *size) +{ + int ret = redis_parser_append_message(buf, size, parser_); + + if (ret >= 0) + { + cur_size_ += *size; + if (cur_size_ > this->size_limit) + { + errno = EMSGSIZE; + ret = -1; + } + } + else if (ret == -2) + { + errno = EBADMSG; + ret = -1; + } + + return ret; +} + +void RedisRequest::set_request(const std::string& command, + const std::vector& params) +{ + size_t n = params.size() + 1; + user_request_.reserve(n); + user_request_.clear(); + user_request_.push_back(command); + for (size_t i = 0; i < params.size(); i++) + user_request_.push_back(params[i]); + + redis_reply_t *reply = &parser_->reply; + redis_reply_set_array(n, reply); + for (size_t i = 0; i < n; i++) + { + redis_reply_set_string(user_request_[i].c_str(), + user_request_[i].size(), + reply->element[i]); + } +} + +bool RedisRequest::get_command(std::string& command) const +{ + const redis_reply_t *reply = &parser_->reply; + if (reply->type == REDIS_REPLY_TYPE_ARRAY && reply->elements > 0) + { + reply = reply->element[0]; + if (reply->type == REDIS_REPLY_TYPE_STRING) + { + command.assign(reply->str, reply->len); + return true; + } + } + + return false; +} + +bool RedisRequest::get_params(std::vector& params) const +{ + const redis_reply_t *reply = &parser_->reply; + if (reply->type == REDIS_REPLY_TYPE_ARRAY && reply->elements > 0) + { + for (size_t i = 1; i < reply->elements; i++) + { + if (reply->element[i]->type != REDIS_REPLY_TYPE_STRING && + reply->element[i]->type != REDIS_REPLY_TYPE_NIL) + { + return false; + } + } + + params.reserve(reply->elements - 1); + params.clear(); + for (size_t i = 1; i < reply->elements; i++) + params.emplace_back(reply->element[i]->str, reply->element[i]->len); + + return true; + } + + return false; +} + +#define REDIS_ASK_COMMAND "ASKING" +#define REDIS_ASK_REQUEST "*1\r\n$6\r\nASKING\r\n" +#define REDIS_OK_RESPONSE "+OK\r\n" + +int RedisRequest::encode(struct iovec vectors[], int max) +{ + stream_->reset(vectors, max); + + if (is_asking()) + (*stream_) << REDIS_ASK_REQUEST; + if (encode_reply(&parser_->reply)) + return stream_->size(); + + return 0; +} + +int RedisRequest::append(const void *buf, size_t *size) +{ + int ret = RedisMessage::append(buf, size); + + if (ret > 0) + { + std::string command; + + if (get_command(command) && + strcasecmp(command.c_str(), REDIS_ASK_COMMAND) == 0) + { + redis_parser_deinit(parser_); + redis_parser_init(parser_); + set_asking(true); + + ret = this->feedback(REDIS_OK_RESPONSE, strlen(REDIS_OK_RESPONSE)); + if (ret != strlen(REDIS_OK_RESPONSE)) + { + errno = ENOBUFS; + ret = -1; + } + else + ret = 0; + } + } + + return ret; +} + +int RedisResponse::append(const void *buf, size_t *size) +{ + int ret = RedisMessage::append(buf, size); + + if (ret > 0 && is_asking()) + { + redis_parser_deinit(parser_); + redis_parser_init(parser_); + ret = 0; + set_asking(false); + } + + return ret; +} + +bool RedisResponse::set_result(const RedisValue& value) +{ + redis_reply_t *reply = &parser_->reply; + redis_reply_deinit(reply); + redis_reply_init(reply); + + value_ = value; + return value_.transform(reply); +} + +} + diff --git a/src/protocol/RedisMessage.h b/src/protocol/RedisMessage.h new file mode 100644 index 0000000..043acc0 --- /dev/null +++ b/src/protocol/RedisMessage.h @@ -0,0 +1,323 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) + Liu Kai (liukaidx@sogou-inc.com) +*/ + +#ifndef _REDISMESSAGE_H_ +#define _REDISMESSAGE_H_ + +#include +#include +#include +#include "ProtocolMessage.h" +#include "redis_parser.h" + +/** + * @file RedisMessage.h + * @brief Redis Protocol Interface + */ + +namespace protocol +{ + +class RedisValue +{ +public: + // nil + RedisValue(); + virtual ~RedisValue(); + + //copy constructor + RedisValue(const RedisValue& copy); + //copy operator + RedisValue& operator= (const RedisValue& copy); + //move constructor + RedisValue(RedisValue&& move); + //move operator + RedisValue& operator= (RedisValue&& move); + + // release memory and change type to nil + void set_nil(); + void set_int(int64_t intv); + void set_string(const std::string& strv); + void set_status(const std::string& strv); + void set_error(const std::string& strv); + void set_string(const char *str, size_t len); + void set_status(const char *str, size_t len); + void set_error(const char *str, size_t len); + void set_string(const char *str); + void set_status(const char *str); + void set_error(const char *str); + // array(resize) + void set_array(size_t new_size); + // set data by C style data struct + void set(const redis_reply_t *reply); + + // Return true if not error + bool is_ok() const; + // Return true if error + bool is_error() const; + // Return true if nil + bool is_nil() const; + // Return true if integer + bool is_int() const; + // Return true if array + bool is_array() const; + // Return true if string/status + bool is_string() const; + // Return type of C style data struct + int get_type() const; + + // Copy. If type isnot string/status/error, returns an empty std::string + std::string string_value() const; + // No copy. If type isnot string/status/error, returns NULL. + const std::string *string_view() const; + // If type isnot integer, returns 0 + int64_t int_value() const; + // If type isnot array, returns 0 + size_t arr_size() const; + // If type isnot array, do nothing + void arr_clear(); + // If type isnot array, do nothing + void arr_resize(size_t new_size); + // Always return std::vector.at(pos); notice overflow exception + RedisValue& arr_at(size_t pos) const; + // Always return std::vector[pos]; notice overflow exception + RedisValue& operator[] (size_t pos) const; + + // transform data into C style data struct + bool transform(redis_reply_t *reply) const; + // equal to set_nil(); + void clear(); + // format data to text + std::string debug_string() const; + +private: + void free_data(); + void only_set_string_data(const std::string& strv); + void only_set_string_data(const char *str, size_t len); + void only_set_string_data(const char *str); + + int type_; + void *data_; +}; + +class RedisMessage : public ProtocolMessage +{ +public: + RedisMessage(); + virtual ~RedisMessage(); + //move constructor + RedisMessage(RedisMessage&& move); + //move operator + RedisMessage& operator= (RedisMessage&& move); + +public: + //peek after CommMessageIn append + //not for users. + bool parse_success() const; + bool is_asking() const; + void set_asking(bool asking); + +protected: + redis_parser_t *parser_; + + virtual int encode(struct iovec vectors[], int max); + virtual int append(const void *buf, size_t *size); + bool encode_reply(redis_reply_t *reply); + + class EncodeStream *stream_; + +private: + size_t cur_size_; + bool asking_; +}; + +class RedisRequest : public RedisMessage +{ +public: + RedisRequest() = default; + //move constructor + RedisRequest(RedisRequest&& move) = default; + //move operator + RedisRequest& operator= (RedisRequest&& move) = default; + +public:// C++ style + // Usually, client use set_request to (prepare)send request to server + // Usually, server use get_command/get_params to get client request + + // set_request("HSET", {"keyname", "hashkey", "somevalue"}); + void set_request(const std::string& command, + const std::vector& params); + + bool get_command(std::string& command) const; + bool get_params(std::vector& params) const; + +protected: + virtual int encode(struct iovec vectors[], int max); + virtual int append(const void *buf, size_t *size); + +private: + std::vector user_request_; +}; + +class RedisResponse : public RedisMessage +{ +public: + RedisResponse() = default; + //move constructor + RedisResponse(RedisResponse&& move) = default; + //move operator + RedisResponse& operator= (RedisResponse&& move) = default; + +public:// C++ style + // client use get_result to get result from server, copy + void get_result(RedisValue& value) const; + + // server use set_result to (prepare)send result to client, copy + bool set_result(const RedisValue& value); + +public:// C style + // redis_parser_t is absolutely same as hiredis-redisReply in memory + // If you include hiredis.h, redisReply* can cast to redis_reply_t* safely + // BUT this function return not a copy, DONOT free the pointer by yourself + + // client read data from redis_reply_t by pointer of result_ptr + // server write data into redis_reply_t by pointer of result_ptr + redis_reply_t *result_ptr(); + +protected: + virtual int append(const void *buf, size_t *size); + +private: + RedisValue value_; +}; + +//////////////////// + +inline RedisValue::RedisValue(): + type_(REDIS_REPLY_TYPE_NIL), + data_(NULL) +{ +} + +inline RedisValue::RedisValue(const RedisValue& copy): + type_(REDIS_REPLY_TYPE_NIL), + data_(NULL) +{ + this->operator= (copy); +} + +inline RedisValue::RedisValue(RedisValue&& move): + type_(REDIS_REPLY_TYPE_NIL), + data_(NULL) +{ + this->operator= (std::move(move)); +} + +inline bool RedisValue::is_ok() const { return type_ != REDIS_REPLY_TYPE_ERROR; } +inline bool RedisValue::is_error() const { return type_ == REDIS_REPLY_TYPE_ERROR; } +inline bool RedisValue::is_nil() const { return type_ == REDIS_REPLY_TYPE_NIL; } +inline bool RedisValue::is_int() const { return type_ == REDIS_REPLY_TYPE_INTEGER; } +inline bool RedisValue::is_array() const { return type_ == REDIS_REPLY_TYPE_ARRAY; } +inline int RedisValue::get_type() const { return type_; } + +inline bool RedisValue::is_string() const +{ + return type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS; +} + +inline std::string RedisValue::string_value() const +{ + if (type_ == REDIS_REPLY_TYPE_STRING || + type_ == REDIS_REPLY_TYPE_STATUS || + type_ == REDIS_REPLY_TYPE_ERROR) + return *((std::string *)data_); + else + return ""; +} + +inline const std::string *RedisValue::string_view() const +{ + if (type_ == REDIS_REPLY_TYPE_STRING || + type_ == REDIS_REPLY_TYPE_STATUS || + type_ == REDIS_REPLY_TYPE_ERROR) + return ((std::string *)data_); + else + return NULL; +} + +inline int64_t RedisValue::int_value() const +{ + if (type_ == REDIS_REPLY_TYPE_INTEGER) + return *((int64_t *)data_); + else + return 0; +} + +inline size_t RedisValue::arr_size() const +{ + if (type_ == REDIS_REPLY_TYPE_ARRAY) + return ((std::vector *)data_)->size(); + else + return 0; +} + +inline RedisValue& RedisValue::arr_at(size_t pos) const +{ + return ((std::vector *)data_)->at(pos); +} + +inline RedisValue& RedisValue::operator[] (size_t pos) const +{ + return (*((std::vector *)data_))[pos]; +} + +inline void RedisValue::set_nil() +{ + free_data(); + type_ = REDIS_REPLY_TYPE_NIL; +} + +inline void RedisValue::clear() +{ + set_nil(); +} + +inline bool RedisMessage::parse_success() const { return parser_->parse_succ; } + +inline bool RedisMessage::is_asking() const { return asking_; } + +inline void RedisMessage::set_asking(bool asking) { asking_ = asking; } + +inline redis_reply_t *RedisResponse::result_ptr() +{ + return &parser_->reply; +} + +inline void RedisResponse::get_result(RedisValue& value) const +{ + if (parser_->parse_succ) + value.set(&parser_->reply); + else + value.set_nil(); +} + +} + +#endif + diff --git a/src/protocol/SSLWrapper.cc b/src/protocol/SSLWrapper.cc new file mode 100644 index 0000000..977eb10 --- /dev/null +++ b/src/protocol/SSLWrapper.cc @@ -0,0 +1,301 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include "SSLWrapper.h" + +namespace protocol +{ + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +static inline BIO *__get_wbio(SSL *ssl) +{ + BIO *wbio = SSL_get_wbio(ssl); + BIO *next = BIO_next(wbio); + return next ? next : wbio; +} + +# define SSL_get_wbio(ssl) __get_wbio(ssl) +#endif + +int SSLHandshaker::encode(struct iovec vectors[], int max) +{ + BIO *wbio = SSL_get_wbio(this->ssl); + char *ptr; + long len; + int ret; + + ret = SSL_do_handshake(this->ssl); + if (ret <= 0) + { + ret = SSL_get_error(this->ssl, ret); + if (ret != SSL_ERROR_WANT_READ) + { + if (ret != SSL_ERROR_SYSCALL) + errno = -ret; + + return -1; + } + } + + len = BIO_get_mem_data(wbio, &ptr); + if (len > 0) + { + vectors[0].iov_base = ptr; + vectors[0].iov_len = len; + return 1; + } + else if (len == 0) + return 0; + else + return -1; +} + +static int __ssl_handshake(const void *buf, size_t *size, SSL *ssl, + char **ptr, long *len) +{ + BIO *wbio = SSL_get_wbio(ssl); + BIO *rbio = SSL_get_rbio(ssl); + int ret; + + ret = BIO_write(rbio, buf, *size); + if (ret <= 0) + return -1; + + *size = ret; + ret = SSL_do_handshake(ssl); + if (ret <= 0) + { + ret = SSL_get_error(ssl, ret); + if (ret != SSL_ERROR_WANT_READ) + { + if (ret != SSL_ERROR_SYSCALL) + errno = -ret; + + return -1; + } + + ret = 0; + } + + *len = BIO_get_mem_data(wbio, ptr); + if (*len < 0) + return -1; + + return ret; +} + +int SSLHandshaker::append(const void *buf, size_t *size) +{ + BIO *wbio = SSL_get_wbio(this->ssl); + char *ptr; + long len; + long n; + int ret; + + BIO_reset(wbio); + ret = __ssl_handshake(buf, size, this->ssl, &ptr, &len); + if (ret != 0) + return ret; + + if (len > 0) + { + n = this->feedback(ptr, len); + BIO_reset(wbio); + } + else + n = 0; + + if (n == len) + return ret; + + if (n >= 0) + errno = ENOBUFS; + + return -1; +} + +int SSLWrapper::encode(struct iovec vectors[], int max) +{ + BIO *wbio = SSL_get_wbio(this->ssl); + struct iovec *iov; + char *ptr; + long len; + int ret; + + ret = this->ProtocolWrapper::encode(vectors, max); + if ((unsigned int)ret > (unsigned int)max) + return ret; + + max = ret; + for (iov = vectors; iov < vectors + max; iov++) + { + if (iov->iov_len > 0) + { + ret = SSL_write(this->ssl, iov->iov_base, iov->iov_len); + if (ret <= 0) + { + ret = SSL_get_error(this->ssl, ret); + if (ret != SSL_ERROR_SYSCALL) + errno = -ret; + + return -1; + } + } + } + + len = BIO_get_mem_data(wbio, &ptr); + if (len > 0) + { + vectors[0].iov_base = ptr; + vectors[0].iov_len = len; + return 1; + } + else if (len == 0) + return 0; + else + return -1; +} + +#define BUFSIZE 8192 + +int SSLWrapper::append_message() +{ + char buf[BUFSIZE]; + int ret; + + while ((ret = SSL_read(this->ssl, buf, BUFSIZE)) > 0) + { + size_t nleft = ret; + char *p = buf; + size_t n; + + do + { + n = nleft; + ret = this->ProtocolWrapper::append(p, &n); + if (ret == 0) + { + nleft -= n; + p += n; + } + else + return ret; + + } while (nleft > 0); + } + + if (ret < 0) + { + ret = SSL_get_error(this->ssl, ret); + if (ret != SSL_ERROR_WANT_READ) + { + if (ret != SSL_ERROR_SYSCALL) + errno = -ret; + + return -1; + } + } + + return 0; +} + +int SSLWrapper::append(const void *buf, size_t *size) +{ + BIO *wbio = SSL_get_wbio(this->ssl); + BIO *rbio = SSL_get_rbio(this->ssl); + int ret; + + BIO_reset(wbio); + ret = BIO_write(rbio, buf, *size); + if (ret <= 0) + return -1; + + *size = ret; + return this->append_message(); +} + +int SSLWrapper::feedback(const void *buf, size_t size) +{ + BIO *wbio = SSL_get_wbio(this->ssl); + char *ptr; + long len; + long n; + int ret; + + if (size == 0) + return 0; + + ret = SSL_write(this->ssl, buf, size); + if (ret <= 0) + { + ret = SSL_get_error(this->ssl, ret); + if (ret != SSL_ERROR_SYSCALL) + errno = -ret; + + return -1; + } + + len = BIO_get_mem_data(wbio, &ptr); + if (len >= 0) + { + n = this->ProtocolWrapper::feedback(ptr, len); + BIO_reset(wbio); + if (n == len) + return size; + + if (ret > 0) + errno = ENOBUFS; + } + + return -1; +} + +int ServerSSLWrapper::append(const void *buf, size_t *size) +{ + BIO *wbio = SSL_get_wbio(this->ssl); + char *ptr; + long len; + long n; + + BIO_reset(wbio); + if (__ssl_handshake(buf, size, this->ssl, &ptr, &len) < 0) + return -1; + + if (len > 0) + { + n = this->ProtocolMessage::feedback(ptr, len); + BIO_reset(wbio); + } + else + n = 0; + + if (n == len) + return this->append_message(); + + if (n >= 0) + errno = ENOBUFS; + + return -1; +} + +} + diff --git a/src/protocol/SSLWrapper.h b/src/protocol/SSLWrapper.h new file mode 100644 index 0000000..44b9736 --- /dev/null +++ b/src/protocol/SSLWrapper.h @@ -0,0 +1,94 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _SSLWRAPPER_H_ +#define _SSLWRAPPER_H_ + +#include +#include "ProtocolMessage.h" + +namespace protocol +{ + +class SSLHandshaker : public ProtocolMessage +{ +public: + virtual int encode(struct iovec vectors[], int max); + virtual int append(const void *buf, size_t *size); + +protected: + SSL *ssl; + +public: + SSLHandshaker(SSL *ssl) + { + this->ssl = ssl; + } + +public: + SSLHandshaker(SSLHandshaker&& handshaker) = default; + SSLHandshaker& operator = (SSLHandshaker&& handshaker) = default; +}; + +class SSLWrapper : public ProtocolWrapper +{ +protected: + virtual int encode(struct iovec vectors[], int max); + virtual int append(const void *buf, size_t *size); + +protected: + virtual int feedback(const void *buf, size_t size); + +protected: + int append_message(); + +protected: + SSL *ssl; + +public: + SSLWrapper(ProtocolMessage *message, SSL *ssl) : + ProtocolWrapper(message) + { + this->ssl = ssl; + } + +public: + SSLWrapper(SSLWrapper&& wrapper) = default; + SSLWrapper& operator = (SSLWrapper&& wrapper) = default; +}; + +class ServerSSLWrapper : public SSLWrapper +{ +protected: + virtual int append(const void *buf, size_t *size); + +public: + ServerSSLWrapper(ProtocolMessage *message, SSL *ssl) : + SSLWrapper(message, ssl) + { + } + +public: + ServerSSLWrapper(ServerSSLWrapper&& wrapper) = default; + ServerSSLWrapper& operator = (ServerSSLWrapper&& wrapper) = default; +}; + +} + +#endif + diff --git a/src/protocol/TLVMessage.cc b/src/protocol/TLVMessage.cc new file mode 100644 index 0000000..a21a7e9 --- /dev/null +++ b/src/protocol/TLVMessage.cc @@ -0,0 +1,85 @@ +/* + Copyright (c) 2023 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include "TLVMessage.h" + +namespace protocol +{ + +int TLVMessage::encode(struct iovec vectors[], int max) +{ + this->head[0] = htonl((uint32_t)this->type); + this->head[1] = htonl(this->value.size()); + + vectors[0].iov_base = this->head; + vectors[0].iov_len = 8; + vectors[1].iov_base = (char *)this->value.data(); + vectors[1].iov_len = this->value.size(); + return 2; +} + +int TLVMessage::append(const void *buf, size_t *size) +{ + size_t n = *size; + size_t head_left; + + head_left = 8 - this->head_received; + if (head_left > 0) + { + void *p = (char *)this->head + this->head_received; + + if (n < head_left) + { + memcpy(p, buf, n); + this->head_received += n; + return 0; + } + + memcpy(p, buf, head_left); + this->head_received = 8; + buf = (const char *)buf + head_left; + n -= head_left; + + this->type = (int)ntohl(this->head[0]); + *this->head = ntohl(this->head[1]); + if (*this->head > this->size_limit) + { + errno = EMSGSIZE; + return -1; + } + + this->value.reserve(*this->head); + } + + if (this->value.size() + n > *this->head) + { + n = *this->head - this->value.size(); + *size = n + head_left; + } + + this->value.append((const char *)buf, n); + return this->value.size() == *this->head; +} + +} + diff --git a/src/protocol/TLVMessage.h b/src/protocol/TLVMessage.h new file mode 100644 index 0000000..f78159a --- /dev/null +++ b/src/protocol/TLVMessage.h @@ -0,0 +1,69 @@ +/* + Copyright (c) 2023 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _TLVMESSAGE_H_ +#define _TLVMESSAGE_H_ + +#include +#include +#include +#include "ProtocolMessage.h" + +namespace protocol +{ + +class TLVMessage : public ProtocolMessage +{ +public: + int get_type() const { return this->type; } + void set_type(int type) { this->type = type; } + + std::string *get_value() { return &this->value; } + void set_value(std::string value) { this->value = std::move(value); } + +protected: + virtual int encode(struct iovec vectors[], int max); + virtual int append(const void *buf, size_t *size); + +protected: + int type; + std::string value; + +private: + uint32_t head[2]; + size_t head_received; + +public: + TLVMessage() + { + this->type = 0; + this->head_received = 0; + } + +public: + TLVMessage(TLVMessage&& msg) = default; + TLVMessage& operator = (TLVMessage&& msg) = default; +}; + +using TLVRequest = TLVMessage; +using TLVResponse = TLVMessage; + +} + +#endif + diff --git a/src/protocol/dns_parser.c b/src/protocol/dns_parser.c new file mode 100644 index 0000000..2a698fa --- /dev/null +++ b/src/protocol/dns_parser.c @@ -0,0 +1,1204 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#include +#include +#include +#include "dns_parser.h" + +#define DNS_LABELS_MAX 63 +#define DNS_NAMES_MAX 256 +#define DNS_MSGBASE_INIT_SIZE 514 // 512 + 2(leading length) +#define MAX(x, y) ((x) <= (y) ? (y) : (x)) + +struct __dns_record_entry +{ + struct list_head entry_list; + struct dns_record record; +}; + + +static inline uint8_t __dns_parser_uint8(const char *ptr) +{ + return (unsigned char)ptr[0]; +} + +static inline uint16_t __dns_parser_uint16(const char *ptr) +{ + const unsigned char *p = (const unsigned char *)ptr; + return ((uint16_t)p[0] << 8) + + ((uint16_t)p[1]); +} + +static inline uint32_t __dns_parser_uint32(const char *ptr) +{ + const unsigned char *p = (const unsigned char *)ptr; + return ((uint32_t)p[0] << 24) + + ((uint32_t)p[1] << 16) + + ((uint32_t)p[2] << 8) + + ((uint32_t)p[3]); +} + +/* + * Parse a single . + * is a domain name represented as a series of labels, and + * terminated by a label with zero length. + * + * phost must point to an char array with at least DNS_NAMES_MAX+1 size + */ +static int __dns_parser_parse_host(char *phost, dns_parser_t *parser) +{ + uint8_t len; + uint16_t pointer; + size_t hcur; + const char *msgend; + const char **cur; + const char *curbackup; // backup cur when host label is pointer + + msgend = (const char *)parser->msgbuf + parser->msgsize; + cur = &(parser->cur); + curbackup = NULL; + hcur = 0; + + if (*cur >= msgend) + return -2; + + while (*cur < msgend) + { + len = __dns_parser_uint8(*cur); + + if ((len & 0xC0) == 0) + { + (*cur)++; + if (len == 0) + break; + if (len > DNS_LABELS_MAX || *cur + len > msgend || + hcur + len + 1 > DNS_NAMES_MAX) + return -2; + + memcpy(phost + hcur, *cur, len); + *cur += len; + hcur += len; + phost[hcur++] = '.'; + } + // RFC 1035, 4.1.4 Message compression + else if ((len & 0xC0) == 0xC0) + { + pointer = __dns_parser_uint16(*cur) & 0x3FFF; + + if (pointer >= parser->msgsize) + return -2; + + // pointer must point to a prior position + if ((const char *)parser->msgbase + pointer >= *cur) + return -2; + + *cur += 2; + + // backup cur only when the first pointer occurs + if (curbackup == NULL) + curbackup = *cur; + *cur = (const char *)parser->msgbase + pointer; + } + else + return -2; + } + if (curbackup != NULL) + *cur = curbackup; + + if (hcur > 1 && phost[hcur - 1] == '.') + hcur--; + + if (hcur == 0) + phost[hcur++] = '.'; + phost[hcur++] = '\0'; + + return 0; +} + +static void __dns_parser_free_record(struct __dns_record_entry *r) +{ + switch (r->record.type) + { + case DNS_TYPE_SOA: + { + struct dns_record_soa *soa; + soa = (struct dns_record_soa *)(r->record.rdata); + free(soa->mname); + free(soa->rname); + break; + } + case DNS_TYPE_SRV: + { + struct dns_record_srv *srv; + srv = (struct dns_record_srv *)(r->record.rdata); + free(srv->target); + break; + } + case DNS_TYPE_MX: + { + struct dns_record_mx *mx; + mx = (struct dns_record_mx *)(r->record.rdata); + free(mx->exchange); + break; + } + } + free(r->record.name); + free(r); +} + +static void __dns_parser_free_record_list(struct list_head *head) +{ + struct list_head *pos, *tmp; + struct __dns_record_entry *entry; + + list_for_each_safe(pos, tmp, head) + { + entry = list_entry(pos, struct __dns_record_entry, entry_list); + list_del(pos); + __dns_parser_free_record(entry); + } +} + +/* + * A RDATA format, from RFC 1035: + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | ADDRESS | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + * ADDRESS: A 32 bit Internet address. + * Hosts that have multiple Internet addresses will have multiple A records. + */ +static int __dns_parser_parse_a(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char **cur; + struct __dns_record_entry *entry; + size_t entry_size; + + if (sizeof (struct in_addr) != rdlength) + return -2; + + cur = &(parser->cur); + entry_size = sizeof (struct __dns_record_entry) + sizeof (struct in_addr); + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + memcpy(entry->record.rdata, *cur, rdlength); + *cur += rdlength; + *r = entry; + + return 0; +} + +/* + * AAAA RDATA format, from RFC 3596: + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | ADDRESS | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + * ADDRESS: A 128 bit Internet address. + * Hosts that have multiple addresses will have multiple AAAA records. + */ +static int __dns_parser_parse_aaaa(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char **cur; + struct __dns_record_entry *entry; + size_t entry_size; + + if (sizeof (struct in6_addr) != rdlength) + return -2; + + cur = &(parser->cur); + entry_size = sizeof (struct __dns_record_entry) + sizeof (struct in6_addr); + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + memcpy(entry->record.rdata, *cur, rdlength); + *cur += rdlength; + *r = entry; + + return 0; +} + +/* + * Parse any record. + */ +static int __dns_parser_parse_names(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char *rcdend; + const char **cur; + struct __dns_record_entry *entry; + size_t entry_size; + size_t name_len; + char name[DNS_NAMES_MAX + 2]; + int ret; + + cur = &(parser->cur); + rcdend = *cur + rdlength; + ret = __dns_parser_parse_host(name, parser); + if (ret < 0) + return ret; + + if (*cur != rcdend) + return -2; + + name_len = strlen(name); + entry_size = sizeof (struct __dns_record_entry) + name_len + 1; + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + memcpy(entry->record.rdata, name, name_len + 1); + *r = entry; + + return 0; +} + +/* + * SOA RDATA format, from RFC 1035: + * 1 1 1 1 1 1 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * / MNAME / + * / / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * / RNAME / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | SERIAL | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | REFRESH | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | RETRY | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | EXPIRE | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | MINIMUM | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + * MNAME: + * RNAME: + * SERIAL: The unsigned 32 bit version number. + * REFRESH: A 32 bit time interval. + * RETRY: A 32 bit time interval. + * EXPIRE: A 32 bit time value. + * MINIMUM: The unsigned 32 bit integer. + */ +static int __dns_parser_parse_soa(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char *rcdend; + const char **cur; + struct __dns_record_entry *entry; + struct dns_record_soa *soa; + size_t entry_size; + char mname[DNS_NAMES_MAX + 2]; + char rname[DNS_NAMES_MAX + 2]; + int ret; + + cur = &(parser->cur); + rcdend = *cur + rdlength; + ret = __dns_parser_parse_host(mname, parser); + if (ret < 0) + return ret; + ret = __dns_parser_parse_host(rname, parser); + if (ret < 0) + return ret; + + if (*cur + 20 != rcdend) + return -2; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_soa); + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + soa = (struct dns_record_soa *)(entry->record.rdata); + + soa->mname = strdup(mname); + soa->rname = strdup(rname); + soa->serial = __dns_parser_uint32(*cur); + soa->refresh = __dns_parser_uint32(*cur + 4); + soa->retry = __dns_parser_uint32(*cur + 8); + soa->expire = __dns_parser_uint32(*cur + 12); + soa->minimum = __dns_parser_uint32(*cur + 16); + + if (!soa->mname || !soa->rname) + { + free(soa->mname); + free(soa->rname); + free(entry); + return -1; + } + + *cur += 20; + *r = entry; + + return 0; +} + +/* + * SRV RDATA format, from RFC 2782: + * 1 1 1 1 1 1 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | PRIORITY | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | WEIGHT | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | PORT | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * / TARGET / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + * PRIORITY: A 16 bit unsigned integer in network byte order. + * WEIGHT: A 16 bit unsigned integer in network byte order. + * PORT: A 16 bit unsigned integer in network byte order. + * TARGET: + */ +static int __dns_parser_parse_srv(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char *rcdend; + const char **cur; + struct __dns_record_entry *entry; + struct dns_record_srv *srv; + size_t entry_size; + char target[DNS_NAMES_MAX + 2]; + uint16_t priority; + uint16_t weight; + uint16_t port; + int ret; + + cur = &(parser->cur); + rcdend = *cur + rdlength; + + if (*cur + 6 > rcdend) + return -2; + + priority = __dns_parser_uint16(*cur); + weight = __dns_parser_uint16(*cur + 2); + port = __dns_parser_uint16(*cur + 4); + *cur += 6; + + ret = __dns_parser_parse_host(target, parser); + if (ret < 0) + return ret; + if (*cur != rcdend) + return -2; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_srv); + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + srv = (struct dns_record_srv *)(entry->record.rdata); + + srv->priority = priority; + srv->weight = weight; + srv->port = port; + srv->target = strdup(target); + + if (!srv->target) + { + free(entry); + return -1; + } + + *r = entry; + + return 0; +} + +static int __dns_parser_parse_mx(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char *rcdend; + const char **cur; + struct __dns_record_entry *entry; + struct dns_record_mx *mx; + size_t entry_size; + char exchange[DNS_NAMES_MAX + 2]; + int16_t preference; + int ret; + + cur = &(parser->cur); + rcdend = *cur + rdlength; + + if (*cur + 2 > rcdend) + return -2; + preference = __dns_parser_uint16(*cur); + *cur += 2; + + ret = __dns_parser_parse_host(exchange, parser); + if (ret < 0) + return ret; + if (*cur != rcdend) + return -2; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_mx); + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + mx = (struct dns_record_mx *)(entry->record.rdata); + mx->exchange = strdup(exchange); + mx->preference = preference; + + if (!mx->exchange) + { + free(entry); + return -1; + } + + *r = entry; + + return 0; +} + +static int __dns_parser_parse_others(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char **cur; + struct __dns_record_entry *entry; + size_t entry_size; + + cur = &(parser->cur); + entry_size = sizeof (struct __dns_record_entry) + rdlength; + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + memcpy(entry->record.rdata, *cur, rdlength); + *cur += rdlength; + *r = entry; + + return 0; +} + +/* + * RR format, from RFC 1035: + * 1 1 1 1 1 1 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | | + * / NAME / + * / / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | TYPE | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | CLASS | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | TTL | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | RDLENGTH | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * / RDATA / + * / / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + */ +static int __dns_parser_parse_record(int idx, dns_parser_t *parser) +{ + uint16_t i; + uint16_t type; + uint16_t rclass; + uint32_t ttl; + uint16_t rdlength; + uint16_t count; + const char *msgend; + const char **cur; + int ret; + struct __dns_record_entry *entry; + char host[DNS_NAMES_MAX + 2]; + struct list_head *list; + + switch (idx) + { + case 0: + count = parser->header.ancount; + list = &parser->answer_list; + break; + case 1: + count = parser->header.nscount; + list = &parser->authority_list; + break; + case 2: + count = parser->header.arcount; + list = &parser->additional_list; + break; + default: + return -2; + } + + msgend = (const char *)parser->msgbuf + parser->msgsize; + cur = &(parser->cur); + + for (i = 0; i < count; i++) + { + ret = __dns_parser_parse_host(host, parser); + if (ret < 0) + return ret; + + // TYPE(2) + CLASS(2) + TTL(4) + RDLENGTH(2) = 10 + if (*cur + 10 > msgend) + return -2; + type = __dns_parser_uint16(*cur); + rclass = __dns_parser_uint16(*cur + 2); + ttl = __dns_parser_uint32(*cur + 4); + rdlength = __dns_parser_uint16(*cur + 8); + *cur += 10; + if (*cur + rdlength > msgend) + return -2; + + entry = NULL; + switch (type) + { + case DNS_TYPE_A: + ret = __dns_parser_parse_a(&entry, rdlength, parser); + break; + case DNS_TYPE_AAAA: + ret = __dns_parser_parse_aaaa(&entry, rdlength, parser); + break; + case DNS_TYPE_NS: + case DNS_TYPE_CNAME: + case DNS_TYPE_PTR: + ret = __dns_parser_parse_names(&entry, rdlength, parser); + break; + case DNS_TYPE_SOA: + ret = __dns_parser_parse_soa(&entry, rdlength, parser); + break; + case DNS_TYPE_SRV: + ret = __dns_parser_parse_srv(&entry, rdlength, parser); + break; + case DNS_TYPE_MX: + ret = __dns_parser_parse_mx(&entry, rdlength, parser); + break; + default: + ret = __dns_parser_parse_others(&entry, rdlength, parser); + } + + if (ret < 0) + return ret; + + entry->record.name = strdup(host); + if (!entry->record.name) + { + __dns_parser_free_record(entry); + return -1; + } + + entry->record.type = type; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + entry->record.rdlength = rdlength; + list_add_tail(&entry->entry_list, list); + } + + return 0; +} + +/* + * Question format, from RFC 1035: + * 1 1 1 1 1 1 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | | + * / QNAME / + * / / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | QTYPE | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | QCLASS | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + * The query name is encoded as a series of labels, each represented + * as a one-byte length (maximum 63) followed by the text of the + * label. The list is terminated by a label of length zero (which can + * be thought of as the root domain). + */ +static int __dns_parser_parse_question(dns_parser_t *parser) +{ + uint16_t qtype; + uint16_t qclass; + const char *msgend; + const char **cur; + int ret; + char host[DNS_NAMES_MAX + 2]; + + msgend = (const char *)parser->msgbuf + parser->msgsize; + cur = &(parser->cur); + + // question count != 1 is an error + if (parser->header.qdcount != 1) + return -2; + + // parse qname + ret = __dns_parser_parse_host(host, parser); + if (ret < 0) + return ret; + + // parse qtype and qclass + if (*cur + 4 > msgend) + return -2; + + qtype = __dns_parser_uint16(*cur); + qclass = __dns_parser_uint16(*cur + 2); + *cur += 4; + + if (parser->question.qname) + free(parser->question.qname); + + parser->question.qname = strdup(host); + if (parser->question.qname == NULL) + return -1; + + parser->question.qtype = qtype; + parser->question.qclass = qclass; + + return 0; +} + +void dns_parser_init(dns_parser_t *parser) +{ + parser->msgbuf = NULL; + parser->msgbase = NULL; + parser->cur = NULL; + parser->msgsize = 0; + parser->bufsize = 0; + parser->complete = 0; + parser->single_packet = 0; + memset(&parser->header, 0, sizeof (struct dns_header)); + memset(&parser->question, 0, sizeof (struct dns_question)); + INIT_LIST_HEAD(&parser->answer_list); + INIT_LIST_HEAD(&parser->authority_list); + INIT_LIST_HEAD(&parser->additional_list); +} + +int dns_parser_set_question(const char *name, + uint16_t qtype, + uint16_t qclass, + dns_parser_t *parser) +{ + int ret; + + ret = dns_parser_set_question_name(name, parser); + if (ret < 0) + return ret; + + parser->question.qtype = qtype; + parser->question.qclass = qclass; + parser->header.qdcount = 1; + + return 0; +} + +int dns_parser_set_question_name(const char *name, dns_parser_t *parser) +{ + char *newname; + size_t len; + + len = strlen(name); + newname = (char *)malloc(len + 1); + + if (!newname) + return -1; + + memcpy(newname, name, len + 1); + // Remove trailing dot, except name is "." + if (len > 1 && newname[len - 1] == '.') + newname[len - 1] = '\0'; + + if (parser->question.qname) + free(parser->question.qname); + parser->question.qname = newname; + + return 0; +} + +void dns_parser_set_id(uint16_t id, dns_parser_t *parser) +{ + parser->header.id = id; +} + +int dns_parser_parse_all(dns_parser_t *parser) +{ + struct dns_header *h; + int ret; + int i; + + parser->complete = 1; + parser->cur = (const char *)parser->msgbase; + h = &parser->header; + + if (parser->msgsize < sizeof (struct dns_header)) + return -2; + + memcpy(h, parser->msgbase, sizeof (struct dns_header)); + h->id = ntohs(h->id); + h->qdcount = ntohs(h->qdcount); + h->ancount = ntohs(h->ancount); + h->nscount = ntohs(h->nscount); + h->arcount = ntohs(h->arcount); + parser->cur += sizeof (struct dns_header); + + ret = __dns_parser_parse_question(parser); + if (ret < 0) + return ret; + + for (i = 0; i < 3; i++) + { + ret = __dns_parser_parse_record(i, parser); + if (ret < 0) + return ret; + } + + return 0; +} + +int dns_parser_append_message(const void *buf, + size_t *n, + dns_parser_t *parser) +{ + int ret; + size_t total; + size_t new_size; + size_t msgsize_bak; + void *new_buf; + + if (parser->complete) + { + *n = 0; + return 1; + } + + if (!parser->single_packet) + { + msgsize_bak = parser->msgsize; + if (parser->msgsize + *n > parser->bufsize) + { + new_size = MAX(DNS_MSGBASE_INIT_SIZE, 2 * parser->bufsize); + + while (new_size < parser->msgsize + *n) + new_size *= 2; + + new_buf = realloc(parser->msgbuf, new_size); + if (!new_buf) + return -1; + + parser->msgbuf = new_buf; + parser->bufsize = new_size; + } + + memcpy((char*)parser->msgbuf + parser->msgsize, buf, *n); + parser->msgsize += *n; + + if (parser->msgsize < 2) + return 0; + + total = __dns_parser_uint16((char*)parser->msgbuf); + if (parser->msgsize < total + 2) + return 0; + + *n = total + 2 - msgsize_bak; + parser->msgsize = total + 2; + parser->msgbase = (char*)parser->msgbuf + 2; + } + else + { + parser->msgbuf = malloc(*n); + memcpy(parser->msgbuf, buf, *n); + parser->msgbase = parser->msgbuf; + parser->msgsize = *n; + parser->bufsize = *n; + } + + ret = dns_parser_parse_all(parser); + if (ret < 0) + return ret; + + return 1; +} + +void dns_parser_deinit(dns_parser_t *parser) +{ + free(parser->msgbuf); + free(parser->question.qname); + + __dns_parser_free_record_list(&parser->answer_list); + __dns_parser_free_record_list(&parser->authority_list); + __dns_parser_free_record_list(&parser->additional_list); +} + +int dns_record_cursor_next(struct dns_record **record, + dns_record_cursor_t *cursor) +{ + struct __dns_record_entry *e; + + if (cursor->next->next != cursor->head) + { + cursor->next = cursor->next->next; + e = list_entry(cursor->next, struct __dns_record_entry, entry_list); + *record = &e->record; + return 0; + } + + return 1; +} + +int dns_record_cursor_find_cname(const char *name, + const char **cname, + dns_record_cursor_t *cursor) +{ + struct __dns_record_entry *e; + + if (!name || !cname) + return 1; + + cursor->next = cursor->head; + while (cursor->next->next != cursor->head) + { + cursor->next = cursor->next->next; + e = list_entry(cursor->next, struct __dns_record_entry, entry_list); + + if (e->record.type == DNS_TYPE_CNAME && + strcasecmp(name, e->record.name) == 0) + { + *cname = (const char *)e->record.rdata; + return 0; + } + } + + return 1; +} + +int dns_add_raw_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, uint16_t rlen, const void *rdata, + struct list_head *list) +{ + struct __dns_record_entry *entry; + size_t entry_size = sizeof (struct __dns_record_entry) + rlen; + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.name = strdup(name); + if (!entry->record.name) + { + free(entry); + return -1; + } + + entry->record.type = type; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + entry->record.rdlength = rlen; + entry->record.rdata = (void *)(entry + 1); + memcpy(entry->record.rdata, rdata, rlen); + list_add_tail(&entry->entry_list, list); + + return 0; +} + +int dns_add_str_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, const char *rdata, + struct list_head *list) +{ + size_t rlen = strlen(rdata); + // record.rdlength has no meaning for parsed record types, ignore its + // correctness, same for soa/srv/mx record + return dns_add_raw_record(name, type, rclass, ttl, rlen+1, rdata, list); +} + +int dns_add_soa_record(const char *name, uint16_t rclass, uint32_t ttl, + const char *mname, const char *rname, + uint32_t serial, int32_t refresh, + int32_t retry, int32_t expire, uint32_t minimum, + struct list_head *list) +{ + struct __dns_record_entry *entry; + struct dns_record_soa *soa; + size_t entry_size; + char *pname, *pmname, *prname; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_soa); + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + entry->record.rdlength = 0; + soa = (struct dns_record_soa *)(entry->record.rdata); + + pname = strdup(name); + pmname = strdup(mname); + prname = strdup(rname); + + if (!pname || !pmname || !prname) + { + free(pname); + free(pmname); + free(prname); + free(entry); + return -1; + } + + soa->mname = pmname; + soa->rname = prname; + soa->serial = serial; + soa->refresh = refresh; + soa->retry = retry; + soa->expire = expire; + soa->minimum = minimum; + + entry->record.name = pname; + entry->record.type = DNS_TYPE_SOA; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + list_add_tail(&entry->entry_list, list); + + return 0; +} + +int dns_add_srv_record(const char *name, uint16_t rclass, uint32_t ttl, + uint16_t priority, uint16_t weight, + uint16_t port, const char *target, + struct list_head *list) +{ + struct __dns_record_entry *entry; + struct dns_record_srv *srv; + size_t entry_size; + char *pname, *ptarget; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_srv); + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + entry->record.rdlength = 0; + srv = (struct dns_record_srv *)(entry->record.rdata); + + pname = strdup(name); + ptarget = strdup(target); + + if (!pname || !ptarget) + { + free(pname); + free(ptarget); + free(entry); + return -1; + } + + srv->priority = priority; + srv->weight = weight; + srv->port = port; + srv->target = ptarget; + + entry->record.name = pname; + entry->record.type = DNS_TYPE_SRV; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + list_add_tail(&entry->entry_list, list); + + return 0; +} + +int dns_add_mx_record(const char *name, uint16_t rclass, uint32_t ttl, + int16_t preference, const char *exchange, + struct list_head *list) +{ + struct __dns_record_entry *entry; + struct dns_record_mx *mx; + size_t entry_size; + char *pname, *pexchange; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_mx); + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + entry->record.rdlength = 0; + mx = (struct dns_record_mx *)(entry->record.rdata); + + pname = strdup(name); + pexchange = strdup(exchange); + + if (!pname || !pexchange) + { + free(pname); + free(pexchange); + free(entry); + return -1; + } + + mx->preference = preference; + mx->exchange = pexchange; + + entry->record.name = pname; + entry->record.type = DNS_TYPE_MX; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + list_add_tail(&entry->entry_list, list); + + return 0; +} + +const char *dns_type2str(int type) +{ + switch (type) + { + case DNS_TYPE_A: + return "A"; + case DNS_TYPE_NS: + return "NS"; + case DNS_TYPE_MD: + return "MD"; + case DNS_TYPE_MF: + return "MF"; + case DNS_TYPE_CNAME: + return "CNAME"; + case DNS_TYPE_SOA: + return "SOA"; + case DNS_TYPE_MB: + return "MB"; + case DNS_TYPE_MG: + return "MG"; + case DNS_TYPE_MR: + return "MR"; + case DNS_TYPE_NULL: + return "NULL"; + case DNS_TYPE_WKS: + return "WKS"; + case DNS_TYPE_PTR: + return "PTR"; + case DNS_TYPE_HINFO: + return "HINFO"; + case DNS_TYPE_MINFO: + return "MINFO"; + case DNS_TYPE_MX: + return "MX"; + case DNS_TYPE_AAAA: + return "AAAA"; + case DNS_TYPE_SRV: + return "SRV"; + case DNS_TYPE_TXT: + return "TXT"; + case DNS_TYPE_AXFR: + return "AXFR"; + case DNS_TYPE_MAILB: + return "MAILB"; + case DNS_TYPE_MAILA: + return "MAILA"; + case DNS_TYPE_ALL: + return "ALL"; + default: + return "Unknown"; + } +} + +const char *dns_class2str(int dnsclass) +{ + switch (dnsclass) + { + case DNS_CLASS_IN: + return "IN"; + case DNS_CLASS_CS: + return "CS"; + case DNS_CLASS_CH: + return "CH"; + case DNS_CLASS_HS: + return "HS"; + case DNS_CLASS_ALL: + return "ALL"; + default: + return "Unknown"; + } +} + +const char *dns_opcode2str(int opcode) +{ + switch (opcode) + { + case DNS_OPCODE_QUERY: + return "QUERY"; + case DNS_OPCODE_IQUERY: + return "IQUERY"; + case DNS_OPCODE_STATUS: + return "STATUS"; + default: + return "Unknown"; + } +} + +const char *dns_rcode2str(int rcode) +{ + switch (rcode) + { + case DNS_RCODE_NO_ERROR: + return "NO_ERROR"; + case DNS_RCODE_FORMAT_ERROR: + return "FORMAT_ERROR"; + case DNS_RCODE_SERVER_FAILURE: + return "SERVER_FAILURE"; + case DNS_RCODE_NAME_ERROR: + return "NAME_ERROR"; + case DNS_RCODE_NOT_IMPLEMENTED: + return "NOT_IMPLEMENTED"; + case DNS_RCODE_REFUSED: + return "REFUSED"; + default: + return "Unknown"; + } +} + diff --git a/src/protocol/dns_parser.h b/src/protocol/dns_parser.h new file mode 100644 index 0000000..e950767 --- /dev/null +++ b/src/protocol/dns_parser.h @@ -0,0 +1,275 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#ifndef _DNS_PARSER_H_ +#define _DNS_PARSER_H_ + +#include +#include +#include "list.h" + +enum +{ + DNS_TYPE_A = 1, + DNS_TYPE_NS, + DNS_TYPE_MD, + DNS_TYPE_MF, + DNS_TYPE_CNAME, + DNS_TYPE_SOA = 6, + DNS_TYPE_MB, + DNS_TYPE_MG, + DNS_TYPE_MR, + DNS_TYPE_NULL, + DNS_TYPE_WKS = 11, + DNS_TYPE_PTR, + DNS_TYPE_HINFO, + DNS_TYPE_MINFO, + DNS_TYPE_MX, + DNS_TYPE_TXT = 16, + + DNS_TYPE_AAAA = 28, + DNS_TYPE_SRV = 33, + + DNS_TYPE_AXFR = 252, + DNS_TYPE_MAILB = 253, + DNS_TYPE_MAILA = 254, + DNS_TYPE_ALL = 255 +}; + +enum +{ + DNS_CLASS_IN = 1, + DNS_CLASS_CS, + DNS_CLASS_CH, + DNS_CLASS_HS, + + DNS_CLASS_ALL = 255 +}; + +enum +{ + DNS_OPCODE_QUERY = 0, + DNS_OPCODE_IQUERY, + DNS_OPCODE_STATUS, +}; + +enum +{ + DNS_RCODE_NO_ERROR = 0, + DNS_RCODE_FORMAT_ERROR, + DNS_RCODE_SERVER_FAILURE, + DNS_RCODE_NAME_ERROR, + DNS_RCODE_NOT_IMPLEMENTED, + DNS_RCODE_REFUSED +}; + +enum +{ + DNS_ANSWER_SECTION = 1, + DNS_AUTHORITY_SECTION = 2, + DNS_ADDITIONAL_SECTION = 3, +}; + +/** + * dns_header_t is a struct to describe the header of a dns + * request or response packet, but the byte order is not + * transformed. + */ +#pragma pack(1) +struct dns_header +{ + uint16_t id; +#if __BYTE_ORDER == __LITTLE_ENDIAN + uint8_t rd : 1; + uint8_t tc : 1; + uint8_t aa : 1; + uint8_t opcode : 4; + uint8_t qr : 1; + uint8_t rcode : 4; + uint8_t z : 3; + uint8_t ra : 1; +#elif __BYTE_ORDER == __BIG_ENDIAN + uint8_t qr : 1; + uint8_t opcode : 4; + uint8_t aa : 1; + uint8_t tc : 1; + uint8_t rd : 1; + uint8_t ra : 1; + uint8_t z : 3; + uint8_t rcode : 4; +#else +# error "unknown byte order" +#endif + uint16_t qdcount; + uint16_t ancount; + uint16_t nscount; + uint16_t arcount; +}; +#pragma pack() + +struct dns_question +{ + char *qname; + uint16_t qtype; + uint16_t qclass; +}; + +struct dns_record_soa +{ + char *mname; + char *rname; + uint32_t serial; + int32_t refresh; + int32_t retry; + int32_t expire; + uint32_t minimum; +}; + +struct dns_record_srv +{ + uint16_t priority; + uint16_t weight; + uint16_t port; + char *target; +}; + +struct dns_record_mx +{ + int16_t preference; + char *exchange; +}; + +struct dns_record +{ + char *name; + uint16_t type; + uint16_t rclass; + uint32_t ttl; + uint16_t rdlength; + void *rdata; +}; + +typedef struct __dns_parser +{ + void *msgbuf; // Message with leading length (TCP) + void *msgbase; // Message without leading length + const char *cur; // Current parser position + size_t msgsize; + size_t bufsize; + char complete; // Whether parse completed + char single_packet; // Response without leading length When UDP + struct dns_header header; + struct dns_question question; + struct list_head answer_list; + struct list_head authority_list; + struct list_head additional_list; +} dns_parser_t; + +typedef struct __dns_record_cursor +{ + const struct list_head *head; + const struct list_head *next; +} dns_record_cursor_t; + +#ifdef __cplusplus +extern "C" +{ +#endif + +void dns_parser_init(dns_parser_t *parser); + +void dns_parser_set_id(uint16_t id, dns_parser_t *parser); +int dns_parser_set_question(const char *name, + uint16_t qtype, + uint16_t qclass, + dns_parser_t *parser); +int dns_parser_set_question_name(const char *name, + dns_parser_t *parser); + +int dns_parser_parse_all(dns_parser_t *parser); +int dns_parser_append_message(const void *buf, size_t *n, + dns_parser_t *parser); + +void dns_parser_deinit(dns_parser_t *parser); + +int dns_record_cursor_next(struct dns_record **record, + dns_record_cursor_t *cursor); + +int dns_record_cursor_find_cname(const char *name, + const char **cname, + dns_record_cursor_t *cursor); + +int dns_add_raw_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, uint16_t rlen, const void *rdata, + struct list_head *list); + +int dns_add_str_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, const char *rdata, + struct list_head *list); + +int dns_add_soa_record(const char *name, uint16_t rclass, uint32_t ttl, + const char *mname, const char *rname, + uint32_t serial, int32_t refresh, + int32_t retry, int32_t expire, uint32_t minimum, + struct list_head *list); + +int dns_add_srv_record(const char *name, uint16_t rclass, uint32_t ttl, + uint16_t priority, uint16_t weight, + uint16_t port, const char *target, + struct list_head *list); + +int dns_add_mx_record(const char *name, uint16_t rclass, uint32_t ttl, + int16_t preference, const char *exchange, + struct list_head *list); + +const char *dns_type2str(int type); +const char *dns_class2str(int dnsclass); +const char *dns_opcode2str(int opcode); +const char *dns_rcode2str(int rcode); + +#ifdef __cplusplus +} +#endif + +static inline void dns_answer_cursor_init(dns_record_cursor_t *cursor, + const dns_parser_t *parser) +{ + cursor->head = &parser->answer_list; + cursor->next = cursor->head; +} + +static inline void dns_authority_cursor_init(dns_record_cursor_t *cursor, + const dns_parser_t *parser) +{ + cursor->head = &parser->authority_list; + cursor->next = cursor->head; +} + +static inline void dns_additional_cursor_init(dns_record_cursor_t *cursor, + const dns_parser_t *parser) +{ + cursor->head = &parser->additional_list; + cursor->next = cursor->head; +} + +static inline void dns_record_cursor_deinit(dns_record_cursor_t *cursor) +{ +} + +#endif + diff --git a/src/protocol/http_parser.c b/src/protocol/http_parser.c new file mode 100644 index 0000000..1bb551b --- /dev/null +++ b/src/protocol/http_parser.c @@ -0,0 +1,853 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include "list.h" +#include "http_parser.h" + +#define MIN(x, y) ((x) <= (y) ? (x) : (y)) +#define MAX(x, y) ((x) >= (y) ? (x) : (y)) + +#define HTTP_START_LINE_MAX 8192 +#define HTTP_HEADER_VALUE_MAX 8192 +#define HTTP_CHUNK_LINE_MAX 1024 +#define HTTP_TRAILER_LINE_MAX 8192 +#define HTTP_MSGBUF_INIT_SIZE 2048 + +enum +{ + HPS_START_LINE, + HPS_HEADER_NAME, + HPS_HEADER_VALUE, + HPS_HEADER_COMPLETE +}; + +enum +{ + CPS_CHUNK_DATA, + CPS_TRAILER_PART, + CPS_CHUNK_COMPLETE +}; + +struct __header_line +{ + struct list_head list; + int name_len; + int value_len; + char *buf; +}; + +static int __add_message_header(const void *name, size_t name_len, + const void *value, size_t value_len, + http_parser_t *parser) +{ + size_t size = sizeof (struct __header_line) + name_len + value_len + 4; + struct __header_line *line; + + line = (struct __header_line *)malloc(size); + if (line) + { + line->buf = (char *)(line + 1); + memcpy(line->buf, name, name_len); + line->buf[name_len] = ':'; + line->buf[name_len + 1] = ' '; + memcpy(line->buf + name_len + 2, value, value_len); + line->buf[name_len + 2 + value_len] = '\r'; + line->buf[name_len + 2 + value_len + 1] = '\n'; + line->name_len = name_len; + line->value_len = value_len; + list_add_tail(&line->list, &parser->header_list); + return 0; + } + + return -1; +} + +static int __set_message_header(const void *name, size_t name_len, + const void *value, size_t value_len, + http_parser_t *parser) +{ + struct __header_line *line; + struct list_head *pos; + char *buf; + + list_for_each(pos, &parser->header_list) + { + line = list_entry(pos, struct __header_line, list); + if (line->name_len == name_len && + strncasecmp(line->buf, name, name_len) == 0) + { + if (value_len > line->value_len) + { + buf = (char *)malloc(name_len + value_len + 4); + if (!buf) + return -1; + + if (line->buf != (char *)(line + 1)) + free(line->buf); + + line->buf = buf; + memcpy(buf, name, name_len); + buf[name_len] = ':'; + buf[name_len + 1] = ' '; + } + + memcpy(line->buf + name_len + 2, value, value_len); + line->buf[name_len + 2 + value_len] = '\r'; + line->buf[name_len + 2 + value_len + 1] = '\n'; + line->value_len = value_len; + return 0; + } + } + + return __add_message_header(name, name_len, value, value_len, parser); +} + +static int __match_request_line(const char *method, + const char *uri, + const char *version, + http_parser_t *parser) +{ + if (strcmp(version, "HTTP/1.0") == 0 || strncmp(version, "HTTP/0", 6) == 0) + parser->keep_alive = 0; + + method = strdup(method); + if (method) + { + uri = strdup(uri); + if (uri) + { + version = strdup(version); + if (version) + { + free(parser->method); + free(parser->uri); + free(parser->version); + parser->method = (char *)method; + parser->uri = (char *)uri; + parser->version = (char *)version; + return 0; + } + + free((char *)uri); + } + + free((char *)method); + } + + return -1; +} + +static int __match_status_line(const char *version, + const char *code, + const char *phrase, + http_parser_t *parser) +{ + if (strcmp(version, "HTTP/1.0") == 0 || strncmp(version, "HTTP/0", 6) == 0) + parser->keep_alive = 0; + + if (*code == '1' || strcmp(code, "204") == 0 || strcmp(code, "304") == 0) + parser->transfer_length = 0; + + version = strdup(version); + if (version) + { + code = strdup(code); + if (code) + { + phrase = strdup(phrase); + if (phrase) + { + free(parser->version); + free(parser->code); + free(parser->phrase); + parser->version = (char *)version; + parser->code = (char *)code; + parser->phrase = (char *)phrase; + return 0; + } + + free((char *)code); + } + + free((char *)version); + } + + return -1; +} + +static void __check_message_header(const char *name, size_t name_len, + const char *value, size_t value_len, + http_parser_t *parser) +{ + switch (name_len) + { + case 6: + if (strncasecmp(name, "Expect", 6) == 0) + { + if (value_len == 12 && strncasecmp(value, "100-continue", 12) == 0) + parser->expect_continue = 1; + } + + break; + + case 10: + if (strncasecmp(name, "Connection", 10) == 0) + { + parser->has_connection = 1; + if (value_len == 10 && strncasecmp(value, "Keep-Alive", 10) == 0) + parser->keep_alive = 1; + else if (value_len == 5 && strncasecmp(value, "close", 5) == 0) + parser->keep_alive = 0; + } + else if (strncasecmp(name, "Keep-Alive", 10) == 0) + parser->has_keep_alive = 1; + + break; + + case 14: + if (strncasecmp(name, "Content-Length", 14) == 0) + { + parser->has_content_length = 1; + if (*value >= '0' && *value <= '9' && value_len <= 15) + { + char buf[16]; + memcpy(buf, value, value_len); + buf[value_len] = '\0'; + parser->content_length = atol(buf); + } + } + + break; + + case 17: + if (strncasecmp(name, "Transfer-Encoding", 17) == 0) + { + if (value_len != 8 || strncasecmp(value, "identity", 8) != 0) + parser->chunked = 1; + else + parser->chunked = 0; + } + + break; + } +} + +static int __parse_start_line(const char *ptr, size_t len, + http_parser_t *parser) +{ + char start_line[HTTP_START_LINE_MAX]; + size_t min = MIN(HTTP_START_LINE_MAX, len); + char *p1, *p2, *p3; + size_t i; + int ret; + + if (len >= 2 && ptr[0] == '\r' && ptr[1] == '\n') + { + parser->header_offset += 2; + return 1; + } + + for (i = 0; i < min; i++) + { + start_line[i] = ptr[i]; + if (start_line[i] == '\r') + { + if (i == len - 1) + return 0; + + if (ptr[i + 1] != '\n') + return -2; + + start_line[i] = '\0'; + p1 = start_line; + p2 = strchr(p1, ' '); + if (p2) + *p2++ = '\0'; + else + return -2; + + p3 = strchr(p2, ' '); + if (p3) + *p3++ = '\0'; + else + return -2; + + if (parser->is_resp) + ret = __match_status_line(p1, p2, p3, parser); + else + ret = __match_request_line(p1, p2, p3, parser); + + if (ret < 0) + return -1; + + parser->header_offset += i + 2; + parser->header_state = HPS_HEADER_NAME; + return 1; + } + + if (start_line[i] == 0) + return -2; + } + + if (i == HTTP_START_LINE_MAX) + return -2; + + return 0; +} + +static int __parse_header_name(const char *ptr, size_t len, + http_parser_t *parser) +{ + size_t min = MIN(HTTP_HEADER_NAME_MAX, len); + size_t i; + + if (len >= 2 && ptr[0] == '\r' && ptr[1] == '\n') + { + parser->header_offset += 2; + parser->header_state = HPS_HEADER_COMPLETE; + return 1; + } + + for (i = 0; i < min; i++) + { + if (ptr[i] == ':') + { + parser->namebuf[i] = '\0'; + parser->header_offset += i + 1; + parser->header_state = HPS_HEADER_VALUE; + return 1; + } + + if ((signed char)ptr[i] <= 0) + return -2; + + parser->namebuf[i] = ptr[i]; + } + + if (i == HTTP_HEADER_NAME_MAX) + return -2; + + return 0; +} + +static int __parse_header_value(const char *ptr, size_t len, + http_parser_t *parser) +{ + char header_value[HTTP_HEADER_VALUE_MAX]; + const char *end = ptr + len; + const char *begin = ptr; + size_t i = 0; + + while (1) + { + while (1) + { + if (ptr == end) + return 0; + + if (*ptr == ' ' || *ptr == '\t') + ptr++; + else + break; + } + + while (1) + { + if (i == HTTP_HEADER_VALUE_MAX) + return -2; + + header_value[i] = *ptr++; + if (ptr == end) + return 0; + + if (header_value[i] == '\r') + break; + + if ((signed char)header_value[i] <= 0) + return -2; + + i++; + } + + if (*ptr == '\n') + ptr++; + else + return -2; + + if (ptr == end) + return 0; + + while (i > 0) + { + if (header_value[i - 1] == ' ' || header_value[i - 1] == '\t') + i--; + else + break; + } + + if (*ptr != ' ' && *ptr != '\t') + break; + + ptr++; + header_value[i++] = ' '; + } + + header_value[i] = '\0'; + if (http_parser_add_header(parser->namebuf, strlen(parser->namebuf), + header_value, i, parser) < 0) + return -1; + + parser->header_offset += ptr - begin; + parser->header_state = HPS_HEADER_NAME; + return 1; +} + +static int __parse_message_header(const void *message, size_t size, + http_parser_t *parser) +{ + const char *ptr; + size_t len; + int ret; + + do + { + ptr = (const char *)message + parser->header_offset; + len = size - parser->header_offset; + if (parser->header_state == HPS_START_LINE) + ret = __parse_start_line(ptr, len, parser); + else if (parser->header_state == HPS_HEADER_VALUE) + ret = __parse_header_value(ptr, len, parser); + else /* if (parser->header_state == HPS_HEADER_NAME) */ + { + ret = __parse_header_name(ptr, len, parser); + if (parser->header_state == HPS_HEADER_COMPLETE) + return 1; + } + } while (ret > 0); + + return ret; +} + +#define CHUNK_SIZE_MAX (2 * 1024 * 1024 * 1024U - HTTP_CHUNK_LINE_MAX - 4) + +static int __parse_chunk_data(const char *ptr, size_t len, + http_parser_t *parser) +{ + char chunk_line[HTTP_CHUNK_LINE_MAX]; + size_t min = MIN(HTTP_CHUNK_LINE_MAX, len); + long chunk_size; + char *end; + size_t i; + + for (i = 0; i < min; i++) + { + chunk_line[i] = ptr[i]; + if (chunk_line[i] == '\r') + { + if (i == len - 1) + return 0; + + if (ptr[i + 1] != '\n') + return -2; + + chunk_line[i] = '\0'; + chunk_size = strtol(chunk_line, &end, 16); + if (end == chunk_line) + return -2; + + if (chunk_size == 0) + { + chunk_size = i + 2; + parser->chunk_state = CPS_TRAILER_PART; + } + else if ((unsigned long)chunk_size < CHUNK_SIZE_MAX) + { + chunk_size += i + 4; + if (len < (size_t)chunk_size) + return 0; + } + else + return -2; + + parser->chunk_offset += chunk_size; + return 1; + } + } + + if (i == HTTP_CHUNK_LINE_MAX) + return -2; + + return 0; +} + +static int __parse_trailer_part(const char *ptr, size_t len, + http_parser_t *parser) +{ + size_t min = MIN(HTTP_TRAILER_LINE_MAX, len); + size_t i; + + for (i = 0; i < min; i++) + { + if (ptr[i] == '\r') + { + if (i == len - 1) + return 0; + + if (ptr[i + 1] != '\n') + return -2; + + parser->chunk_offset += i + 2; + if (i == 0) + parser->chunk_state = CPS_CHUNK_COMPLETE; + + return 1; + } + } + + if (i == HTTP_TRAILER_LINE_MAX) + return -2; + + return 0; +} + +static int __parse_chunk(const void *message, size_t size, + http_parser_t *parser) +{ + const char *ptr; + size_t len; + int ret; + + do + { + ptr = (const char *)message + parser->chunk_offset; + len = size - parser->chunk_offset; + if (parser->chunk_state == CPS_CHUNK_DATA) + ret = __parse_chunk_data(ptr, len, parser); + else /* if (parser->chunk_state == CPS_TRAILER_PART) */ + { + ret = __parse_trailer_part(ptr, len, parser); + if (parser->chunk_state == CPS_CHUNK_COMPLETE) + return 1; + } + } while (ret > 0); + + return ret; +} + +void http_parser_init(int is_resp, http_parser_t *parser) +{ + parser->header_state = HPS_START_LINE; + parser->header_offset = 0; + parser->transfer_length = (size_t)-1; + parser->content_length = is_resp ? (size_t)-1 : 0; + parser->version = NULL; + parser->method = NULL; + parser->uri = NULL; + parser->code = NULL; + parser->phrase = NULL; + INIT_LIST_HEAD(&parser->header_list); + parser->msgbuf = NULL; + parser->msgsize = 0; + parser->bufsize = 0; + parser->has_connection = 0; + parser->has_content_length = 0; + parser->has_keep_alive = 0; + parser->expect_continue = 0; + parser->keep_alive = 1; + parser->chunked = 0; + parser->complete = 0; + parser->is_resp = is_resp; +} + +int http_parser_append_message(const void *buf, size_t *n, + http_parser_t *parser) +{ + int ret; + + if (parser->complete) + { + *n = 0; + return 1; + } + + if (parser->msgsize + *n + 1 > parser->bufsize) + { + size_t new_size = MAX(HTTP_MSGBUF_INIT_SIZE, 2 * parser->bufsize); + void *new_base; + + while (new_size < parser->msgsize + *n + 1) + new_size *= 2; + + new_base = realloc(parser->msgbuf, new_size); + if (!new_base) + return -1; + + parser->msgbuf = new_base; + parser->bufsize = new_size; + } + + memcpy((char *)parser->msgbuf + parser->msgsize, buf, *n); + parser->msgsize += *n; + if (parser->header_state != HPS_HEADER_COMPLETE) + { + ret = __parse_message_header(parser->msgbuf, parser->msgsize, parser); + if (ret <= 0) + return ret; + + if (parser->chunked) + { + parser->chunk_offset = parser->header_offset; + parser->chunk_state = CPS_CHUNK_DATA; + } + else if (parser->transfer_length == (size_t)-1) + parser->transfer_length = parser->content_length; + } + + if (parser->transfer_length != (size_t)-1) + { + size_t total = parser->header_offset + parser->transfer_length; + + if (parser->msgsize >= total) + { + *n -= parser->msgsize - total; + parser->msgsize = total; + parser->complete = 1; + return 1; + } + + return 0; + } + + if (!parser->chunked) + return 0; + + if (parser->chunk_state != CPS_CHUNK_COMPLETE) + { + ret = __parse_chunk(parser->msgbuf, parser->msgsize, parser); + if (ret <= 0) + return ret; + } + + *n -= parser->msgsize - parser->chunk_offset; + parser->msgsize = parser->chunk_offset; + parser->complete = 1; + return 1; +} + +int http_parser_header_complete(const http_parser_t *parser) +{ + return parser->header_state == HPS_HEADER_COMPLETE; +} + +int http_parser_get_body(const void **body, size_t *size, + const http_parser_t *parser) +{ + if (parser->complete && parser->header_state == HPS_HEADER_COMPLETE) + { + *body = (char *)parser->msgbuf + parser->header_offset; + *size = parser->msgsize - parser->header_offset; + ((char *)parser->msgbuf)[parser->msgsize] = '\0'; + return 0; + } + + return 1; +} + +int http_parser_set_method(const char *method, http_parser_t *parser) +{ + method = strdup(method); + if (method) + { + free(parser->method); + parser->method = (char *)method; + return 0; + } + + return -1; +} + +int http_parser_set_uri(const char *uri, http_parser_t *parser) +{ + uri = strdup(uri); + if (uri) + { + free(parser->uri); + parser->uri = (char *)uri; + return 0; + } + + return -1; +} + +int http_parser_set_version(const char *version, http_parser_t *parser) +{ + version = strdup(version); + if (version) + { + free(parser->version); + parser->version = (char *)version; + return 0; + } + + return -1; +} + +int http_parser_set_code(const char *code, http_parser_t *parser) +{ + code = strdup(code); + if (code) + { + free(parser->code); + parser->code = (char *)code; + return 0; + } + + return -1; +} + +int http_parser_set_phrase(const char *phrase, http_parser_t *parser) +{ + phrase = strdup(phrase); + if (phrase) + { + free(parser->phrase); + parser->phrase = (char *)phrase; + return 0; + } + + return -1; +} + +int http_parser_add_header(const void *name, size_t name_len, + const void *value, size_t value_len, + http_parser_t *parser) +{ + if (__add_message_header(name, name_len, value, value_len, parser) >= 0) + { + __check_message_header((const char *)name, name_len, + (const char *)value, value_len, + parser); + return 0; + } + + return -1; +} + +int http_parser_set_header(const void *name, size_t name_len, + const void *value, size_t value_len, + http_parser_t *parser) +{ + if (__set_message_header(name, name_len, value, value_len, parser) >= 0) + { + __check_message_header((const char *)name, name_len, + (const char *)value, value_len, + parser); + return 0; + } + + return -1; +} + +void http_parser_deinit(http_parser_t *parser) +{ + struct __header_line *line; + struct list_head *pos, *tmp; + + list_for_each_safe(pos, tmp, &parser->header_list) + { + line = list_entry(pos, struct __header_line, list); + list_del(pos); + if (line->buf != (char *)(line + 1)) + free(line->buf); + + free(line); + } + + free(parser->version); + free(parser->method); + free(parser->uri); + free(parser->code); + free(parser->phrase); + free(parser->msgbuf); +} + +int http_header_cursor_next(const void **name, size_t *name_len, + const void **value, size_t *value_len, + http_header_cursor_t *cursor) +{ + struct __header_line *line; + + if (cursor->next->next != cursor->head) + { + cursor->next = cursor->next->next; + line = list_entry(cursor->next, struct __header_line, list); + *name = line->buf; + *name_len = line->name_len; + *value = line->buf + line->name_len + 2; + *value_len = line->value_len; + return 0; + } + + return 1; +} + +int http_header_cursor_find(const void *name, size_t name_len, + const void **value, size_t *value_len, + http_header_cursor_t *cursor) +{ + struct __header_line *line; + + while (cursor->next->next != cursor->head) + { + cursor->next = cursor->next->next; + line = list_entry(cursor->next, struct __header_line, list); + if (line->name_len == name_len) + { + if (strncasecmp(line->buf, name, name_len) == 0) + { + *value = line->buf + name_len + 2; + *value_len = line->value_len; + return 0; + } + } + } + + return 1; +} + +int http_header_cursor_erase(http_header_cursor_t *cursor) +{ + struct __header_line *line; + + if (cursor->next != cursor->head) + { + line = list_entry(cursor->next, struct __header_line, list); + cursor->next = cursor->next->prev; + list_del(&line->list); + if (line->buf != (char *)(line + 1)) + free(line->buf); + + free(line); + return 0; + } + + return 1; +} + diff --git a/src/protocol/http_parser.h b/src/protocol/http_parser.h new file mode 100644 index 0000000..0897ede --- /dev/null +++ b/src/protocol/http_parser.h @@ -0,0 +1,169 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _HTTP_PARSER_H_ +#define _HTTP_PARSER_H_ + +#include +#include "list.h" + +#define HTTP_HEADER_NAME_MAX 64 + +typedef struct __http_parser +{ + int header_state; + int chunk_state; + size_t header_offset; + size_t chunk_offset; + size_t content_length; + size_t transfer_length; + char *version; + char *method; + char *uri; + char *code; + char *phrase; + struct list_head header_list; + char namebuf[HTTP_HEADER_NAME_MAX]; + void *msgbuf; + size_t msgsize; + size_t bufsize; + char has_connection; + char has_content_length; + char has_keep_alive; + char expect_continue; + char keep_alive; + char chunked; + char complete; + char is_resp; +} http_parser_t; + +typedef struct __http_header_cursor +{ + const struct list_head *head; + const struct list_head *next; +} http_header_cursor_t; + +#ifdef __cplusplus +extern "C" +{ +#endif + +void http_parser_init(int is_resp, http_parser_t *parser); +int http_parser_append_message(const void *buf, size_t *n, + http_parser_t *parser); +int http_parser_get_body(const void **body, size_t *size, + const http_parser_t *parser); +int http_parser_header_complete(const http_parser_t *parser); +int http_parser_set_method(const char *method, http_parser_t *parser); +int http_parser_set_uri(const char *uri, http_parser_t *parser); +int http_parser_set_version(const char *version, http_parser_t *parser); +int http_parser_set_code(const char *code, http_parser_t *parser); +int http_parser_set_phrase(const char *phrase, http_parser_t *parser); +int http_parser_add_header(const void *name, size_t name_len, + const void *value, size_t value_len, + http_parser_t *parser); +int http_parser_set_header(const void *name, size_t name_len, + const void *value, size_t value_len, + http_parser_t *parser); +void http_parser_deinit(http_parser_t *parser); + +int http_header_cursor_next(const void **name, size_t *name_len, + const void **value, size_t *value_len, + http_header_cursor_t *cursor); +int http_header_cursor_find(const void *name, size_t name_len, + const void **value, size_t *value_len, + http_header_cursor_t *cursor); +int http_header_cursor_erase(http_header_cursor_t *cursor); + +#ifdef __cplusplus +} +#endif + +static inline const char *http_parser_get_method(const http_parser_t *parser) +{ + return parser->method; +} + +static inline const char *http_parser_get_uri(const http_parser_t *parser) +{ + return parser->uri; +} + +static inline const char *http_parser_get_version(const http_parser_t *parser) +{ + return parser->version; +} + +static inline const char *http_parser_get_code(const http_parser_t *parser) +{ + return parser->code; +} + +static inline const char *http_parser_get_phrase(const http_parser_t *parser) +{ + return parser->phrase; +} + +static inline int http_parser_chunked(const http_parser_t *parser) +{ + return parser->chunked; +} + +static inline int http_parser_keep_alive(const http_parser_t *parser) +{ + return parser->keep_alive; +} + +static inline int http_parser_has_connection(const http_parser_t *parser) +{ + return parser->has_connection; +} + +static inline int http_parser_has_content_length(const http_parser_t *parser) +{ + return parser->has_content_length; +} + +static inline int http_parser_has_keep_alive(const http_parser_t *parser) +{ + return parser->has_keep_alive; +} + +static inline void http_parser_close_message(http_parser_t *parser) +{ + parser->complete = 1; +} + +static inline void http_header_cursor_init(http_header_cursor_t *cursor, + const http_parser_t *parser) +{ + cursor->head = &parser->header_list; + cursor->next = cursor->head; +} + +static inline void http_header_cursor_rewind(http_header_cursor_t *cursor) +{ + cursor->next = cursor->head; +} + +static inline void http_header_cursor_deinit(http_header_cursor_t *cursor) +{ +} + +#endif + diff --git a/src/protocol/kafka_parser.c b/src/protocol/kafka_parser.c new file mode 100644 index 0000000..b05c070 --- /dev/null +++ b/src/protocol/kafka_parser.c @@ -0,0 +1,1244 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhulei (wangzhulei@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "kafka_parser.h" + +static kafka_api_version_t kafka_api_version_queryable[] = { + { Kafka_ApiVersions, 0, 0 } +}; + +static kafka_api_version_t kafka_api_version_0_9_0[] = { + { Kafka_Produce, 0, 1 }, + { Kafka_Fetch, 0, 1 }, + { Kafka_ListOffsets, 0, 0 }, + { Kafka_Metadata, 0, 0 }, + { Kafka_OffsetCommit, 0, 2 }, + { Kafka_OffsetFetch, 0, 1 }, + { Kafka_FindCoordinator, 0, 0 }, + { Kafka_JoinGroup, 0, 0 }, + { Kafka_Heartbeat, 0, 0 }, + { Kafka_LeaveGroup, 0, 0 }, + { Kafka_SyncGroup, 0, 0 }, + { Kafka_DescribeGroups, 0, 0 }, + { Kafka_ListGroups, 0, 0 } +}; + +static kafka_api_version_t kafka_api_version_0_8_2[] = { + { Kafka_Produce, 0, 0 }, + { Kafka_Fetch, 0, 0 }, + { Kafka_ListOffsets, 0, 0 }, + { Kafka_Metadata, 0, 0 }, + { Kafka_OffsetCommit, 0, 1 }, + { Kafka_OffsetFetch, 0, 1 }, + { Kafka_FindCoordinator, 0, 0 } +}; + +static kafka_api_version_t kafka_api_version_0_8_1[] = { + { Kafka_Produce, 0, 0 }, + { Kafka_Fetch, 0, 0 }, + { Kafka_ListOffsets, 0, 0 }, + { Kafka_Metadata, 0, 0 }, + { Kafka_OffsetCommit, 0, 1 }, + { Kafka_OffsetFetch, 0, 0 } +}; + +static kafka_api_version_t kafka_api_version_0_8_0[] = { + { Kafka_Produce, 0, 0 }, + { Kafka_Fetch, 0, 0 }, + { Kafka_ListOffsets, 0, 0 }, + { Kafka_Metadata, 0, 0 } +}; + +static const struct kafka_feature_map { + unsigned feature; + kafka_api_version_t depends[Kafka_ApiNums]; +} kafka_feature_map[] = { + { + .feature = KAFKA_FEATURE_MSGVER1, + .depends = { + { Kafka_Produce, 2, 2 }, + { Kafka_Fetch, 2, 2 }, + { Kafka_Unknown, 0, 0 }, + }, + }, + { + .feature = KAFKA_FEATURE_MSGVER2, + .depends = { + { Kafka_Produce, 3, 3 }, + { Kafka_Fetch, 4, 4 }, + { Kafka_Unknown, 0, 0 }, + }, + }, + { + .feature = KAFKA_FEATURE_APIVERSION, + .depends = { + { Kafka_ApiVersions, 0, 0 }, + { Kafka_Unknown, 0, 0 }, + }, + }, + { + .feature = KAFKA_FEATURE_BROKER_GROUP_COORD, + .depends = { + { Kafka_FindCoordinator, 0, 0 }, + { Kafka_Unknown, 0, 0 }, + }, + }, + { + .feature = KAFKA_FEATURE_BROKER_BALANCED_CONSUMER, + .depends = { + { Kafka_FindCoordinator, 0, 0 }, + { Kafka_OffsetCommit, 1, 2 }, + { Kafka_OffsetFetch, 1, 1 }, + { Kafka_JoinGroup, 0, 0 }, + { Kafka_SyncGroup, 0, 0 }, + { Kafka_Heartbeat, 0, 0 }, + { Kafka_LeaveGroup, 0, 0 }, + { Kafka_Unknown, 0, 0 }, + }, + }, + { + .feature = KAFKA_FEATURE_THROTTLETIME, + .depends = { + { Kafka_Produce, 1, 2 }, + { Kafka_Fetch, 1, 2 }, + { Kafka_Unknown, 0, 0 }, + }, + + }, + { + .feature = KAFKA_FEATURE_LZ4, + .depends = { + { Kafka_FindCoordinator, 0, 0 }, + { Kafka_Unknown, 0, 0 }, + }, + }, + { + .feature = KAFKA_FEATURE_OFFSET_TIME, + .depends = { + { Kafka_ListOffsets, 1, 1 }, + { Kafka_Unknown, 0, 0 }, + } + }, + { + .feature = KAFKA_FEATURE_ZSTD, + .depends = { + { Kafka_Produce, 7, 7 }, + { Kafka_Fetch, 10, 10 }, + { Kafka_Unknown, 0, 0 }, + }, + }, + { + .feature = KAFKA_FEATURE_SASL_GSSAPI, + .depends = { + { Kafka_JoinGroup, 0, 0}, + { Kafka_Unknown, 0, 0 }, + }, + }, + { + .feature = KAFKA_FEATURE_SASL_HANDSHAKE, + .depends = { + { Kafka_SaslHandshake, 0, 0}, + { Kafka_Unknown, 0, 0 }, + }, + }, + { + .feature = KAFKA_FEATURE_SASL_AUTH_REQ, + .depends = { + { Kafka_SaslHandshake, 1, 1}, + { Kafka_SaslAuthenticate, 0, 0}, + { Kafka_Unknown, 0, 0 }, + }, + }, + { + .feature = 0, + }, +}; + +static int kafka_get_legacy_api_version(const char *broker_version, + kafka_api_version_t **api, + size_t *api_cnt) +{ + static const struct { + const char *pfx; + kafka_api_version_t *api; + size_t api_cnt; + } vermap[] = { + { "0.9.0", + kafka_api_version_0_9_0, + sizeof(kafka_api_version_0_9_0) / sizeof(kafka_api_version_t) + }, + { "0.8.2", + kafka_api_version_0_8_2, + sizeof(kafka_api_version_0_8_2) / sizeof(kafka_api_version_t) + }, + { "0.8.1", + kafka_api_version_0_8_1, + sizeof(kafka_api_version_0_8_1) / sizeof(kafka_api_version_t) + }, + { "0.8.0", + kafka_api_version_0_8_0, + sizeof(kafka_api_version_0_8_0) / sizeof(kafka_api_version_t) + }, + { "0.7.", NULL, 0 }, + { "0.6", NULL, 0 }, + { "", kafka_api_version_queryable, 1 }, + { NULL, NULL, 0 } + }; + int i; + + for (i = 0 ; vermap[i].pfx ; i++) + { + if (!strncmp(vermap[i].pfx, broker_version, strlen(vermap[i].pfx))) + { + if (!vermap[i].api) + return -1; + *api = vermap[i].api; + *api_cnt = vermap[i].api_cnt; + break; + } + } + + return 0; +} + +int kafka_api_version_is_queryable(const char *broker_version, + kafka_api_version_t **api, + size_t *api_cnt) +{ + return kafka_get_legacy_api_version(broker_version, api, api_cnt); +} + +static int kafka_api_version_key_cmp(const void *_a, const void *_b) +{ + const kafka_api_version_t *a = _a, *b = _b; + + if (a->api_key > b->api_key) + return 1; + else if (a->api_key == b->api_key) + return 0; + else + return -1; +} + +static int kafka_api_version_check(const kafka_api_version_t *apis, + size_t api_cnt, + const kafka_api_version_t *match) +{ + const kafka_api_version_t *api; + + api = bsearch(match, apis, api_cnt, sizeof(*apis), + kafka_api_version_key_cmp); + if (!api) + return 0; + + return match->min_ver <= api->max_ver && api->min_ver <= match->max_ver; +} + +unsigned kafka_get_features(kafka_api_version_t *api, size_t api_cnt) +{ + unsigned features = 0; + int i, fails, r; + const kafka_api_version_t *match; + + for (i = 0; kafka_feature_map[i].feature != 0; i++) + { + fails = 0; + for (match = &kafka_feature_map[i].depends[0]; + match->api_key != -1 ; match++) + { + r = kafka_api_version_check(api, api_cnt, match); + fails += !r; + } + + if (!fails) + features |= kafka_feature_map[i].feature; + } + + return features; +} + +int kafka_broker_get_api_version(const kafka_api_t *api, int api_key, + int min_ver, int max_ver) +{ + kafka_api_version_t sk = { .api_key = api_key }; + kafka_api_version_t *retp; + + retp = bsearch(&sk, api->api, api->elements, + sizeof(*api->api), kafka_api_version_key_cmp); + if (!retp) + return -1; + + if (retp->max_ver < max_ver) + { + if (retp->max_ver < min_ver) + return -1; + else + return retp->max_ver; + } + else if (retp->min_ver > min_ver) + return -1; + else + return max_ver; +} + +void kafka_parser_init(kafka_parser_t *parser) +{ + parser->complete = 0; + parser->message_size = 0; + parser->msgbuf = NULL; + parser->cur_size = 0; + parser->hsize = 0; +} + +void kafka_parser_deinit(kafka_parser_t *parser) +{ + free(parser->msgbuf); +} + +void kafka_config_init(kafka_config_t *conf) +{ + conf->produce_timeout = 100; + conf->produce_msg_max_bytes = 1000000; + conf->produce_msgset_cnt = 10000; + conf->produce_msgset_max_bytes = 1000000; + conf->fetch_timeout = 100; + conf->fetch_min_bytes = 1; + conf->fetch_max_bytes = 50 * 1024 * 1024; + conf->fetch_msg_max_bytes = 10 * 1024 * 1024; + conf->offset_timestamp = KAFKA_TIMESTAMP_LATEST; + conf->commit_timestamp = 0; + conf->session_timeout = 10*1000; + conf->rebalance_timeout = 10000; + conf->retention_time_period = 20000; + conf->produce_acks = -1; + conf->allow_auto_topic_creation = 1; + conf->api_version_request = 0; + conf->api_version_timeout = 10000; + conf->broker_version = NULL; + conf->compress_type = Kafka_NoCompress; + conf->compress_level = 0; + conf->client_id = NULL; + conf->check_crcs = 0; + conf->offset_store = KAFKA_OFFSET_AUTO; + conf->rack_id = NULL; + conf->mechanisms = NULL; + conf->username = NULL; + conf->password = NULL; + conf->recv = NULL; + conf->client_new = NULL; +} + +void kafka_config_deinit(kafka_config_t *conf) +{ + free(conf->broker_version); + free(conf->client_id); + free(conf->rack_id); + free(conf->mechanisms); + free(conf->username); + free(conf->password); +} + +void kafka_partition_init(kafka_partition_t *partition) +{ + partition->error = 0; + partition->partition_index = -1; + kafka_broker_init(&partition->leader); + partition->replica_nodes = NULL; + partition->replica_node_elements = 0; + partition->isr_nodes = NULL; + partition->isr_node_elements = 0; +} + +void kafka_partition_deinit(kafka_partition_t *partition) +{ + kafka_broker_deinit(&partition->leader); + free(partition->replica_nodes); + free(partition->isr_nodes); +} + +void kafka_api_init(kafka_api_t *api) +{ + api->features = 0; + api->api = NULL; + api->elements = 0; +} + +void kafka_api_deinit(kafka_api_t *api) +{ + free(api->api); +} + +void kafka_broker_init(kafka_broker_t *broker) +{ + broker->node_id = -1; + broker->port = 0; + broker->host = NULL; + broker->rack = NULL; + broker->error = 0; + broker->status = KAFKA_BROKER_UNINIT; +} + +void kafka_broker_deinit(kafka_broker_t *broker) +{ + free(broker->host); + free(broker->rack); +} + +void kafka_meta_init(kafka_meta_t *meta) +{ + meta->error = 0; + meta->topic_name = NULL; + meta->error_message = NULL; + meta->is_internal = 0; + meta->partitions = NULL; + meta->partition_elements = 0; +} + +void kafka_meta_deinit(kafka_meta_t *meta) +{ + int i; + + free(meta->topic_name); + free(meta->error_message); + + for (i = 0; i < meta->partition_elements; ++i) + { + kafka_partition_deinit(meta->partitions[i]); + free(meta->partitions[i]); + } + free(meta->partitions); +} + +void kafka_topic_partition_init(kafka_topic_partition_t *toppar) +{ + toppar->error = 0; + toppar->topic_name = NULL; + toppar->partition = -1; + toppar->preferred_read_replica = -1; + toppar->offset = KAFKA_OFFSET_UNINIT; + toppar->high_watermark = KAFKA_OFFSET_UNINIT; + toppar->low_watermark = KAFKA_OFFSET_UNINIT; + toppar->last_stable_offset = -1; + toppar->log_start_offset = -1; + toppar->offset_timestamp = KAFKA_TIMESTAMP_UNINIT; + toppar->committed_metadata = NULL; + INIT_LIST_HEAD(&toppar->record_list); +} + +void kafka_topic_partition_deinit(kafka_topic_partition_t *toppar) +{ + free(toppar->topic_name); + free(toppar->committed_metadata); +} + +void kafka_record_header_init(kafka_record_header_t *header) +{ + header->key = NULL; + header->key_len = 0; + header->key_is_moved = 0; + header->value = NULL; + header->value_len = 0; + header->value_is_moved = 0; +} + +void kafka_record_header_deinit(kafka_record_header_t *header) +{ + if (!header->key_is_moved) + free(header->key); + + if (!header->value_is_moved) + free(header->value); +} + +void kafka_record_init(kafka_record_t *record) +{ + record->key = NULL; + record->key_len = 0; + record->key_is_moved = 0; + record->value = NULL; + record->value_len = 0; + record->value_is_moved = 0; + record->timestamp = 0; + record->offset = 0; + INIT_LIST_HEAD(&record->header_list); + record->status = 0; + record->toppar = NULL; +} + +void kafka_record_deinit(kafka_record_t *record) +{ + struct list_head *tmp, *pos; + kafka_record_header_t *header; + + if (!record->key_is_moved) + free(record->key); + + if (!record->value_is_moved) + free(record->value); + + list_for_each_safe(pos, tmp, &record->header_list) + { + header = list_entry(pos, kafka_record_header_t, list); + list_del(pos); + kafka_record_header_deinit(header); + free(header); + } +} + +void kafka_member_init(kafka_member_t *member) +{ + member->member_id = NULL; + member->client_id = NULL; + member->client_host = NULL; + member->member_metadata = NULL; + member->member_metadata_len = 0; +} + +void kafka_member_deinit(kafka_member_t *member) +{ + free(member->member_id); + free(member->client_id); + free(member->client_host); + //do not need free! + //free(member->member_metadata); +} + +void kafka_cgroup_init(kafka_cgroup_t *cgroup) +{ + INIT_LIST_HEAD(&cgroup->assigned_toppar_list); + cgroup->error = 0; + cgroup->error_msg = NULL; + kafka_broker_init(&cgroup->coordinator); + cgroup->leader_id = NULL; + cgroup->member_id = NULL; + cgroup->members = NULL; + cgroup->member_elements = 0; + cgroup->generation_id = -1; + cgroup->group_name = NULL; + cgroup->protocol_type = "consumer"; + cgroup->protocol_name = NULL; + INIT_LIST_HEAD(&cgroup->group_protocol_list); +} + +void kafka_cgroup_deinit(kafka_cgroup_t *cgroup) +{ + int i; + + free(cgroup->error_msg); + kafka_broker_deinit(&cgroup->coordinator); + free(cgroup->leader_id); + free(cgroup->member_id); + + for (i = 0; i < cgroup->member_elements; ++i) + { + kafka_member_deinit(cgroup->members[i]); + free(cgroup->members[i]); + } + + free(cgroup->members); + free(cgroup->protocol_name); +} + +void kafka_block_init(kafka_block_t *block) +{ + block->buf = NULL; + block->len = 0; + block->is_moved = 0; +} + +void kafka_block_deinit(kafka_block_t *block) +{ + if (!block->is_moved) + free(block->buf); +} + +int kafka_parser_append_message(const void *buf, size_t *size, + kafka_parser_t *parser) +{ + size_t s = *size; + int totaln; + + if (parser->complete) + { + *size = 0; + return 1; + } + + if (parser->hsize + s < 4) + { + memcpy(parser->headbuf + parser->hsize, buf, s); + parser->hsize += s; + return 0; + } + else if (!parser->msgbuf) + { + memcpy(parser->headbuf + parser->hsize, buf, 4 - parser->hsize); + buf = (const char *)buf + 4 - parser->hsize; + s -= 4 - parser->hsize; + parser->hsize = 4; + memcpy(&totaln, parser->headbuf, 4); + parser->message_size = ntohl(totaln); + parser->msgbuf = malloc(parser->message_size); + if (!parser->msgbuf) + return -1; + + parser->cur_size = 0; + } + + if (s > parser->message_size - parser->cur_size) + { + memcpy((char *)parser->msgbuf + parser->cur_size, buf, + parser->message_size - parser->cur_size); + parser->cur_size = parser->message_size; + } + else + { + memcpy((char *)parser->msgbuf + parser->cur_size, buf, s); + parser->cur_size += s; + } + + if (parser->cur_size < parser->message_size) + return 0; + + *size -= parser->message_size - parser->cur_size; + return 1; +} + +int kafka_topic_partition_set_tp(const char *topic_name, int partition, + kafka_topic_partition_t *toppar) +{ + char *p = strdup(topic_name); + + if (!p) + return -1; + + free(toppar->topic_name); + toppar->topic_name = p; + toppar->partition = partition; + return 0; +} + +int kafka_record_set_key(const void *key, size_t key_len, + kafka_record_t *record) +{ + void *k = malloc(key_len); + + if (!k) + return -1; + + free(record->key); + memcpy(k, key, key_len); + record->key = k; + record->key_len = key_len; + return 0; +} + +int kafka_record_set_value(const void *val, size_t val_len, + kafka_record_t *record) +{ + void *v = malloc(val_len); + + if (!v) + return -1; + + free(record->value); + memcpy(v, val, val_len); + record->value = v; + record->value_len = val_len; + return 0; +} + +int kafka_record_header_set_kv(const void *key, size_t key_len, + const void *val, size_t val_len, + kafka_record_header_t *header) +{ + void *k = malloc(key_len); + void *v = malloc(val_len); + + if (!k || !v) + { + free(k); + free(v); + return -1; + } + + memcpy(k, key, key_len); + memcpy(v, val, val_len); + header->key = k; + header->key_len = key_len; + header->value = v; + header->value_len = val_len; + return 0; +} + +int kafka_meta_set_topic(const char *topic, kafka_meta_t *meta) +{ + char *t = strdup(topic); + + if (!t) + return -1; + + free(meta->topic_name); + meta->topic_name = t; + return 0; +} + +int kafka_cgroup_set_group(const char *group, kafka_cgroup_t *cgroup) +{ + char *t = strdup(group); + + if (!t) + return -1; + + free(cgroup->group_name); + cgroup->group_name = t; + return 0; +} + +static int kafka_sasl_plain_recv(const char *buf, size_t len, void *conf, void *q) +{ + return 0; +} + +static int kafka_sasl_plain_client_new(void *p, kafka_sasl_t *sasl) +{ + kafka_config_t *conf = (kafka_config_t *)p; + size_t ulen = strlen(conf->username); + size_t plen = strlen(conf->password); + size_t blen = ulen + plen + 2; + size_t off = 0; + char *buf = (char *)malloc(blen); + + if (!buf) + return -1; + + buf[off++] = '\0'; + + memcpy(buf + off, conf->username, ulen); + off += ulen; + buf[off++] = '\0'; + + memcpy(buf + off, conf->password, plen); + + free(sasl->buf); + sasl->buf = buf; + sasl->bsize = blen; + + return 0; +} + +static int scram_get_attr(const struct iovec *inbuf, char attr, + struct iovec *outbuf) +{ + const char *td; + size_t len; + size_t of = 0; + void *ptr; + char ochar, nchar; + + for (of = 0; of < inbuf->iov_len;) + { + ptr = (char *)inbuf->iov_base + of; + td = memchr(ptr, ',', inbuf->iov_len - of); + if (td) + len = (size_t)((char *)td - (char *)inbuf->iov_base - of); + else + len = inbuf->iov_len - of; + + ochar = *((char *)inbuf->iov_base + of); + nchar = *((char *)inbuf->iov_base + of + 1); + if (ochar == attr && inbuf->iov_len > of + 1 && nchar == '=') + { + outbuf->iov_base = (char *)ptr + 2; + outbuf->iov_len = len - 2; + return 0; + } + + of += len + 1; + } + + return -1; +} + +static char *scram_base64_encode(const struct iovec *in) +{ + char *ret; + size_t ret_len, max_len; + + if (in->iov_len > INT_MAX) + return NULL; + + max_len = (((in->iov_len + 2) / 3) * 4) + 1; + ret = malloc(max_len); + if (!ret) + return NULL; + + ret_len = EVP_EncodeBlock((uint8_t *)ret, (uint8_t *)in->iov_base, + (int)in->iov_len); + if (ret_len >= max_len) + { + free(ret); + return NULL; + } + ret[ret_len] = 0; + + return ret; +} + +static int scram_base64_decode(const struct iovec *in, struct iovec *out) +{ + size_t ret_len; + + if (in->iov_len % 4 != 0 || in->iov_len > INT_MAX) + return -1; + + ret_len = ((in->iov_len / 4) * 3); + out->iov_base = malloc(ret_len + 1); + if (!out->iov_base) + return -1; + + if (EVP_DecodeBlock((uint8_t*)out->iov_base, (uint8_t*)in->iov_base, + (int)in->iov_len) == -1) + { + free(out->iov_base); + out->iov_base = NULL; + return -1; + } + + if (in->iov_len > 1 && ((char *)(in->iov_base))[in->iov_len - 1] == '=') + { + if (in->iov_len > 2 && ((char *)(in->iov_base))[in->iov_len - 2] == '=') + ret_len -= 2; + else + ret_len -= 1; + } + + ((char *)(out->iov_base))[ret_len] = '\0'; + out->iov_len = ret_len; + + return 0; +} + +static int scram_hi(const EVP_MD *evp, int itcnt, const struct iovec *in, + const struct iovec *salt, struct iovec *out) +{ + unsigned int ressize = 0; + unsigned char tempres[EVP_MAX_MD_SIZE]; + unsigned char tempdest[EVP_MAX_MD_SIZE]; + unsigned char *saltplus; + int i, j; + + saltplus = alloca(salt->iov_len + 4); + if (!saltplus) + return -1; + + memcpy(saltplus, salt->iov_base, salt->iov_len); + saltplus[salt->iov_len] = '\0'; + saltplus[salt->iov_len + 1] = '\0'; + saltplus[salt->iov_len + 2] = '\0'; + saltplus[salt->iov_len + 3] = '\1'; + + if (!HMAC(evp, (const unsigned char *)in->iov_base, (int)in->iov_len, + saltplus, salt->iov_len + 4, tempres, &ressize)) + { + return -1; + } + + memcpy(out->iov_base, tempres, ressize); + + for (i = 1; i < itcnt; i++) + { + if (!HMAC(evp, (const unsigned char *)in->iov_base, (int)in->iov_len, + tempres, ressize, tempdest, NULL)) + { + return -1; + } + + for (j = 0; j < (int)ressize; j++) + { + ((char *)(out->iov_base))[j] ^= tempdest[j]; + tempres[j] = tempdest[j]; + } + } + + out->iov_len = ressize; + return 0; +} + +static int scram_hmac(const EVP_MD *evp, const struct iovec *key, + const struct iovec *str, struct iovec *out) +{ + unsigned int outsize; + + if (!HMAC(evp, (const unsigned char *)key->iov_base, (int)key->iov_len, + (const unsigned char *)str->iov_base, (int)str->iov_len, + (unsigned char *)out->iov_base, &outsize)) + { + return -1; + } + + out->iov_len = outsize; + + return 0; +} + +static void scram_h(kafka_scram_t *scram, const struct iovec *str, + struct iovec *out) +{ + scram->scram_h((const unsigned char *)str->iov_base, str->iov_len, + (unsigned char *)out->iov_base); + out->iov_len = scram->scram_h_size; +} + +static void scram_build_client_final_message_wo_proof( + kafka_scram_t *scram, const struct iovec *snonce, struct iovec *out) +{ + const char *attr_c = "biws"; + + out->iov_len = 9 + scram->cnonce.iov_len + snonce->iov_len; + out->iov_base = malloc(out->iov_len + 1); + if (out->iov_base) + { + snprintf((char *)out->iov_base, out->iov_len + 1, "c=%s,r=%.*s%.*s", + attr_c, (int)scram->cnonce.iov_len, + (char *)scram->cnonce.iov_base, (int)snonce->iov_len, + (char *)snonce->iov_base); + } +} + +static int scram_build_client_final_message(kafka_scram_t *scram, int itcnt, + const struct iovec *salt, + const struct iovec *server_first_msg, + const struct iovec *server_nonce, + struct iovec *out, + const kafka_config_t *conf) +{ + char salted_pwd[EVP_MAX_MD_SIZE]; + char client_key[EVP_MAX_MD_SIZE]; + char server_key[EVP_MAX_MD_SIZE]; + char stored_key[EVP_MAX_MD_SIZE]; + char client_sign[EVP_MAX_MD_SIZE]; + char server_sign[EVP_MAX_MD_SIZE]; + char client_proof[EVP_MAX_MD_SIZE]; + struct iovec password_iov = {conf->password, strlen(conf->password)}; + struct iovec salted_pwd_iov = {salted_pwd, EVP_MAX_MD_SIZE}; + struct iovec client_key_verbatim_iov = {"Client Key", 10}; + struct iovec server_key_verbatim_iov = {"Server Key", 10}; + struct iovec client_key_iov = {client_key, EVP_MAX_MD_SIZE}; + struct iovec server_key_iov = {server_key, EVP_MAX_MD_SIZE}; + struct iovec stored_key_iov = {stored_key, EVP_MAX_MD_SIZE}; + struct iovec server_sign_iov = {server_sign, EVP_MAX_MD_SIZE}; + struct iovec client_sign_iov = {client_sign, EVP_MAX_MD_SIZE}; + struct iovec client_proof_iov = {client_proof, EVP_MAX_MD_SIZE}; + struct iovec client_final_msg_wo_proof_iov; + struct iovec auth_message_iov; + char *server_sign_b64, *client_proof_b64 = NULL; + int i; + + if (scram_hi((const EVP_MD *)scram->evp, itcnt, &password_iov, salt, + &salted_pwd_iov) == -1) + return -1; + + if (scram_hmac((const EVP_MD *)scram->evp, &salted_pwd_iov, + &client_key_verbatim_iov, &client_key_iov) == -1) + return -1; + + scram_h(scram, &client_key_iov, &stored_key_iov); + + scram_build_client_final_message_wo_proof(scram, server_nonce, + &client_final_msg_wo_proof_iov); + + auth_message_iov.iov_len = scram->first_msg.iov_len + 1 + + server_first_msg->iov_len + 1 + client_final_msg_wo_proof_iov.iov_len; + auth_message_iov.iov_base = alloca(auth_message_iov.iov_len + 1); + if (auth_message_iov.iov_base) + { + snprintf(auth_message_iov.iov_base, auth_message_iov.iov_len + 1, + "%.*s,%.*s,%.*s", + (int)scram->first_msg.iov_len, + (char *)scram->first_msg.iov_base, + (int)server_first_msg->iov_len, + (char *)server_first_msg->iov_base, + (int)client_final_msg_wo_proof_iov.iov_len, + (char *)client_final_msg_wo_proof_iov.iov_base); + + if (scram_hmac((const EVP_MD *)scram->evp, &salted_pwd_iov, + &server_key_verbatim_iov, &server_key_iov) == 0 && + scram_hmac((const EVP_MD *)scram->evp, &server_key_iov, + &auth_message_iov, &server_sign_iov) == 0) + { + server_sign_b64 = scram_base64_encode(&server_sign_iov); + if (server_sign_b64 && + scram_hmac((const EVP_MD *)scram->evp, &stored_key_iov, + &auth_message_iov, &client_sign_iov) ==0 && + client_key_iov.iov_len == client_sign_iov.iov_len) + { + scram->server_signature_b64.iov_base = server_sign_b64; + scram->server_signature_b64.iov_len = strlen(server_sign_b64); + for (i = 0 ; i < (int)client_key_iov.iov_len; i++) + ((char *)(client_proof_iov.iov_base))[i] = + ((char *)(client_key_iov.iov_base))[i] ^ + ((char *)(client_sign_iov.iov_base))[i]; + client_proof_iov.iov_len = client_key_iov.iov_len; + + client_proof_b64 = scram_base64_encode(&client_proof_iov); + if (client_proof_b64) + { + out->iov_len = client_final_msg_wo_proof_iov.iov_len + 3 + + strlen(client_proof_b64); + out->iov_base = malloc(out->iov_len + 1); + + snprintf((char *)out->iov_base, out->iov_len + 1, "%.*s,p=%s", + (int)client_final_msg_wo_proof_iov.iov_len, + (char *)client_final_msg_wo_proof_iov.iov_base, + client_proof_b64); + } + } + } + } + + free(client_proof_b64); + free(client_final_msg_wo_proof_iov.iov_base); + return 0; +} + +static int scram_handle_server_first_message(const char *buf, size_t len, + kafka_config_t *conf, + kafka_sasl_t *sasl) +{ + int itcnt; + int ret = -1; + const char *endptr; + struct iovec out, salt, server_nonce; + const struct iovec in = {(void *)buf, len}; + + if (scram_get_attr(&in, 'm', &out) == 0) + return -1; + + if (scram_get_attr(&in, 'r', &server_nonce) != 0) + return -1; + + if (server_nonce.iov_len <= sasl->scram.cnonce.iov_len || + memcmp(server_nonce.iov_base, sasl->scram.cnonce.iov_base, + sasl->scram.cnonce.iov_len) != 0) + { + return -1; + } + + if (scram_get_attr(&in, 's', &out) != 0) + return -1; + + if (scram_base64_decode(&out, &salt) != 0) + return -1; + + if (scram_get_attr(&in, 'i', &out) == 0) + { + itcnt = (int)strtoul((const char *)out.iov_base, (char **)&endptr, 10); + if ((const char *)out.iov_base != endptr && *endptr == '\0' && + itcnt <= 1000000) + { + ret = scram_build_client_final_message(&sasl->scram, itcnt, &salt, + &in, &server_nonce, &out, + conf); + if (ret == 0) + { + free(sasl->buf); + sasl->buf = out.iov_base; + sasl->bsize = out.iov_len; + } + } + } + + free(salt.iov_base); + return ret; +} + +static int scram_handle_server_final_message(const char *buf, size_t len, + kafka_config_t *conf, + kafka_sasl_t *sasl) +{ + struct iovec attr_v, attr_e; + const struct iovec in = {(void *)buf, len}; + + if (scram_get_attr(&in, 'm', &attr_e) == 0) + return -1; + + if (scram_get_attr(&in, 'v', &attr_v) == 0) + { + if (sasl->scram.server_signature_b64.iov_len == attr_v.iov_len && + strncmp((const char *)sasl->scram.server_signature_b64.iov_base, + (const char *)attr_v.iov_base, attr_v.iov_len) != 0) + { + return -1; + } + } + + return 0; +} + +static int kafka_sasl_scram_recv(const char *buf, size_t len, void *p, void *q) +{ + kafka_config_t *conf = (kafka_config_t *)p; + kafka_sasl_t *sasl = (kafka_sasl_t *)q; + int ret = -1; + + switch(sasl->scram.state) + { + case KAFKA_SASL_SCRAM_STATE_SERVER_FIRST_MESSAGE: + ret = scram_handle_server_first_message(buf, len, conf, sasl); + sasl->scram.state = KAFKA_SASL_SCRAM_STATE_CLIENT_FINAL_MESSAGE; + break; + + case KAFKA_SASL_SCRAM_STATE_CLIENT_FINAL_MESSAGE: + ret = scram_handle_server_final_message(buf, len, conf, sasl); + sasl->scram.state = KAFKA_SASL_SCRAM_STATE_CLIENT_FINISHED; + break; + + default: + break; + } + + return ret; +} + +static int jitter(int low, int high) +{ + return (low + (rand() % ((high - low) + 1))); +} + +static int scram_generate_nonce(struct iovec *iov) +{ + int i; + char *ptr = (char *)malloc(33); + + if (!ptr) + return -1; + + for (i = 0; i < 32; i++) + ptr[i] = jitter(0x2d, 0x7e); + ptr[32] = '\0'; + + iov->iov_base = ptr; + iov->iov_len = 32; + return 0; +} + +static int kafka_sasl_scram_client_new(void *p, kafka_sasl_t *sasl) +{ + kafka_config_t *conf = (kafka_config_t *)p; + size_t ulen = strlen(conf->username); + size_t tlen = strlen("n,,n=,r="); + size_t olen = ulen + tlen + 32; + void *ptr; + + if (sasl->scram.state != KAFKA_SASL_SCRAM_STATE_CLIENT_FIRST_MESSAGE) + return -1; + + if (scram_generate_nonce(&sasl->scram.cnonce) != 0) + return -1; + + ptr = malloc(olen + 1); + if (!ptr) + return -1; + + snprintf(ptr, olen + 1, "n,,n=%s,r=%.*s", conf->username, + (int)sasl->scram.cnonce.iov_len, + (char *)sasl->scram.cnonce.iov_base); + sasl->buf = ptr; + sasl->bsize = olen; + + sasl->scram.first_msg.iov_base = (char *)ptr + 3; + sasl->scram.first_msg.iov_len = olen - 3; + sasl->scram.state = KAFKA_SASL_SCRAM_STATE_SERVER_FIRST_MESSAGE; + return 0; +} + +int kafka_sasl_set_mechanisms(kafka_config_t *conf) +{ + if (strcasecmp(conf->mechanisms, "plain") == 0) + { + conf->recv = kafka_sasl_plain_recv; + conf->client_new = kafka_sasl_plain_client_new; + + return 0; + } + else if (strncasecmp(conf->mechanisms, "SCRAM", 5) == 0) + { + conf->recv = kafka_sasl_scram_recv; + conf->client_new = kafka_sasl_scram_client_new; + } + + return -1; +} + +void kafka_sasl_init(kafka_sasl_t *sasl) +{ + sasl->scram.evp = NULL; + sasl->scram.scram_h = NULL; + sasl->scram.scram_h_size = 0; + sasl->scram.state = KAFKA_SASL_SCRAM_STATE_CLIENT_FIRST_MESSAGE; + sasl->scram.cnonce.iov_base = NULL; + sasl->scram.cnonce.iov_len = 0; + sasl->scram.first_msg.iov_base = NULL; + sasl->scram.first_msg.iov_len = 0; + sasl->scram.server_signature_b64.iov_base = NULL; + sasl->scram.server_signature_b64.iov_len = 0; + sasl->buf = NULL; + sasl->bsize = 0; + sasl->status = 0; +} + +void kafka_sasl_deinit(kafka_sasl_t *sasl) +{ + free(sasl->scram.cnonce.iov_base); + free(sasl->scram.server_signature_b64.iov_base); + free(sasl->buf); +} + +int kafka_sasl_set_username(const char *username, kafka_config_t *conf) +{ + char *t = strdup(username); + + if (!t) + return -1; + + free(conf->username); + conf->username = t; + return 0; +} + +int kafka_sasl_set_password(const char *password, kafka_config_t *conf) +{ + char *t = strdup(password); + + if (!t) + return -1; + + free(conf->password); + conf->password = t; + return 0; +} + diff --git a/src/protocol/kafka_parser.h b/src/protocol/kafka_parser.h new file mode 100644 index 0000000..df613d4 --- /dev/null +++ b/src/protocol/kafka_parser.h @@ -0,0 +1,503 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhulei (wangzhulei@sogou-inc.com) +*/ + +#ifndef _KAFKA_PARSER_H_ +#define _KAFKA_PARSER_H_ + +#include +#include +#include +#include "list.h" + +enum +{ + KAFKA_UNKNOWN_SERVER_ERROR = -1, + KAFKA_OFFSET_OUT_OF_RANGE = 1, + KAFKA_CORRUPT_MESSAGE = 2, + KAFKA_UNKNOWN_TOPIC_OR_PARTITION = 3, + KAFKA_INVALID_FETCH_SIZE = 4, + KAFKA_LEADER_NOT_AVAILABLE = 5, + KAFKA_NOT_LEADER_FOR_PARTITION = 6, + KAFKA_REQUEST_TIMED_OUT = 7, + KAFKA_BROKER_NOT_AVAILABLE = 8, + KAFKA_REPLICA_NOT_AVAILABLE = 9, + KAFKA_MESSAGE_TOO_LARGE = 10, + KAFKA_STALE_CONTROLLER_EPOCH = 11, + KAFKA_OFFSET_METADATA_TOO_LARGE = 12, + KAFKA_NETWORK_EXCEPTION = 13, + KAFKA_COORDINATOR_LOAD_IN_PROGRESS = 14, + KAFKA_COORDINATOR_NOT_AVAILABLE = 15, + KAFKA_NOT_COORDINATOR = 16, + KAFKA_INVALID_TOPIC_EXCEPTION = 17, + KAFKA_RECORD_LIST_TOO_LARGE = 18, + KAFKA_NOT_ENOUGH_REPLICAS = 19, + KAFKA_NOT_ENOUGH_REPLICAS_AFTER_APPEND = 20, + KAFKA_INVALID_REQUIRED_ACKS = 21, + KAFKA_ILLEGAL_GENERATION = 22, + KAFKA_INCONSISTENT_GROUP_PROTOCOL = 23, + KAFKA_INVALID_GROUP_ID = 24, + KAFKA_UNKNOWN_MEMBER_ID = 25, + KAFKA_INVALID_SESSION_TIMEOUT = 26, + KAFKA_REBALANCE_IN_PROGRESS = 27, + KAFKA_INVALID_COMMIT_OFFSET_SIZE = 28, + KAFKA_TOPIC_AUTHORIZATION_FAILED = 29, + KAFKA_GROUP_AUTHORIZATION_FAILED = 30, + KAFKA_CLUSTER_AUTHORIZATION_FAILED = 31, + KAFKA_INVALID_TIMESTAMP = 32, + KAFKA_UNSUPPORTED_SASL_MECHANISM = 33, + KAFKA_ILLEGAL_SASL_STATE = 34, + KAFKA_UNSUPPORTED_VERSION = 35, + KAFKA_TOPIC_ALREADY_EXISTS = 36, + KAFKA_INVALID_PARTITIONS = 37, + KAFKA_INVALID_REPLICATION_FACTOR = 38, + KAFKA_INVALID_REPLICA_ASSIGNMENT = 39, + KAFKA_INVALID_CONFIG = 40, + KAFKA_NOT_CONTROLLER = 41, + KAFKA_INVALID_REQUEST = 42, + KAFKA_UNSUPPORTED_FOR_MESSAGE_FORMAT = 43, + KAFKA_POLICY_VIOLATION = 44, + KAFKA_OUT_OF_ORDER_SEQUENCE_NUMBER = 45, + KAFKA_DUPLICATE_SEQUENCE_NUMBER = 46, + KAFKA_INVALID_PRODUCER_EPOCH = 47, + KAFKA_INVALID_TXN_STATE = 48, + KAFKA_INVALID_PRODUCER_ID_MAPPING = 49, + KAFKA_INVALID_TRANSACTION_TIMEOUT = 50, + KAFKA_CONCURRENT_TRANSACTIONS = 51, + KAFKA_TRANSACTION_COORDINATOR_FENCED = 52, + KAFKA_TRANSACTIONAL_ID_AUTHORIZATION_FAILED = 53, + KAFKA_SECURITY_DISABLED = 54, + KAFKA_OPERATION_NOT_ATTEMPTED = 55, + KAFKA_KAFKA_STORAGE_ERROR = 56, + KAFKA_LOG_DIR_NOT_FOUND = 57, + KAFKA_SASL_AUTHENTICATION_FAILED = 58, + KAFKA_UNKNOWN_PRODUCER_ID = 59, + KAFKA_REASSIGNMENT_IN_PROGRESS = 60, + KAFKA_DELEGATION_TOKEN_AUTH_DISABLED = 61, + KAFKA_DELEGATION_TOKEN_NOT_FOUND = 62, + KAFKA_DELEGATION_TOKEN_OWNER_MISMATCH = 63, + KAFKA_DELEGATION_TOKEN_REQUEST_NOT_ALLOWED = 64, + KAFKA_DELEGATION_TOKEN_AUTHORIZATION_FAILED = 65, + KAFKA_DELEGATION_TOKEN_EXPIRED = 66, + KAFKA_INVALID_PRINCIPAL_TYPE = 67, + KAFKA_NON_EMPTY_GROUP = 68, + KAFKA_GROUP_ID_NOT_FOUND = 69, + KAFKA_FETCH_SESSION_ID_NOT_FOUND = 70, + KAFKA_INVALID_FETCH_SESSION_EPOCH = 71, + KAFKA_LISTENER_NOT_FOUND = 72, + KAFKA_TOPIC_DELETION_DISABLED = 73, + KAFKA_FENCED_LEADER_EPOCH = 74, + KAFKA_UNKNOWN_LEADER_EPOCH = 75, + KAFKA_UNSUPPORTED_COMPRESSION_TYPE = 76, + KAFKA_STALE_BROKER_EPOCH = 77, + KAFKA_OFFSET_NOT_AVAILABLE = 78, + KAFKA_MEMBER_ID_REQUIRED = 79, + KAFKA_PREFERRED_LEADER_NOT_AVAILABLE = 80, + KAFKA_GROUP_MAX_SIZE_REACHED = 81, + KAFKA_FENCED_INSTANCE_ID = 82, +}; + +enum +{ + Kafka_Unknown = -1, + Kafka_Produce = 0, + Kafka_Fetch = 1, + Kafka_ListOffsets = 2, + Kafka_Metadata = 3, + Kafka_LeaderAndIsr = 4, + Kafka_StopReplica = 5, + Kafka_UpdateMetadata = 6, + Kafka_ControlledShutdown = 7, + Kafka_OffsetCommit = 8, + Kafka_OffsetFetch = 9, + Kafka_FindCoordinator = 10, + Kafka_JoinGroup = 11, + Kafka_Heartbeat = 12, + Kafka_LeaveGroup = 13, + Kafka_SyncGroup = 14, + Kafka_DescribeGroups = 15, + Kafka_ListGroups = 16, + Kafka_SaslHandshake = 17, + Kafka_ApiVersions = 18, + Kafka_CreateTopics = 19, + Kafka_DeleteTopics = 20, + Kafka_DeleteRecords = 21, + Kafka_InitProducerId = 22, + Kafka_OffsetForLeaderEpoch = 23, + Kafka_AddPartitionsToTxn = 24, + Kafka_AddOffsetsToTxn = 25, + Kafka_EndTxn = 26, + Kafka_WriteTxnMarkers = 27, + Kafka_TxnOffsetCommit = 28, + Kafka_DescribeAcls = 29, + Kafka_CreateAcls = 30, + Kafka_DeleteAcls = 31, + Kafka_DescribeConfigs = 32, + Kafka_AlterConfigs = 33, + Kafka_AlterReplicaLogDirs = 34, + Kafka_DescribeLogDirs = 35, + Kafka_SaslAuthenticate = 36, + Kafka_CreatePartitions = 37, + Kafka_CreateDelegationToken = 38, + Kafka_RenewDelegationToken = 39, + Kafka_ExpireDelegationToken = 40, + Kafka_DescribeDelegationToken = 41, + Kafka_DeleteGroups = 42, + Kafka_ElectPreferredLeaders = 43, + Kafka_IncrementalAlterConfigs = 44, + Kafka_ApiNums, +}; + +enum +{ + Kafka_NoCompress, + Kafka_Gzip, + Kafka_Snappy, + Kafka_Lz4, + Kafka_Zstd, +}; + +enum +{ + KAFKA_FEATURE_APIVERSION = 1<<0, + KAFKA_FEATURE_BROKER_BALANCED_CONSUMER = 1<<1, + KAFKA_FEATURE_THROTTLETIME = 1<<2, + KAFKA_FEATURE_BROKER_GROUP_COORD = 1<<3, + KAFKA_FEATURE_LZ4 = 1<<4, + KAFKA_FEATURE_OFFSET_TIME = 1<<5, + KAFKA_FEATURE_MSGVER2 = 1<<6, + KAFKA_FEATURE_MSGVER1 = 1<<7, + KAFKA_FEATURE_ZSTD = 1<<8, + KAFKA_FEATURE_SASL_GSSAPI = 1<<9, + KAFKA_FEATURE_SASL_HANDSHAKE = 1<<10, + KAFKA_FEATURE_SASL_AUTH_REQ = 1<<11, +}; + +enum +{ + KAFKA_OFFSET_AUTO, + KAFKA_OFFSET_ASSIGN, +}; + +enum +{ + KAFKA_BROKER_UNINIT, + KAFKA_BROKER_DOING, + KAFKA_BROKER_INITED, +}; + +enum +{ + KAFKA_TIMESTAMP_EARLIEST = -2, + KAFKA_TIMESTAMP_LATEST = -1, + KAFKA_TIMESTAMP_UNINIT = 0, +}; + +enum +{ + KAFKA_OFFSET_UNINIT = -2, + KAFKA_OFFSET_OVERFLOW = -1, +}; + +typedef struct __kafka_api_version +{ + short api_key; + short min_ver; + short max_ver; +} kafka_api_version_t; + +typedef struct __kafka_api_t +{ + unsigned features; + kafka_api_version_t *api; + int elements; +} kafka_api_t; + +typedef struct __kafka_parser +{ + int complete; + size_t message_size; + void *msgbuf; + size_t cur_size; + char headbuf[4]; + size_t hsize; +} kafka_parser_t; + +enum __kafka_scram_state +{ + KAFKA_SASL_SCRAM_STATE_CLIENT_FIRST_MESSAGE, + KAFKA_SASL_SCRAM_STATE_SERVER_FIRST_MESSAGE, + KAFKA_SASL_SCRAM_STATE_CLIENT_FINAL_MESSAGE, + KAFKA_SASL_SCRAM_STATE_CLIENT_FINISHED, +}; + +typedef struct __kafka_scram +{ + const void *evp; + unsigned char *(*scram_h)(const unsigned char *d, size_t n, + unsigned char *md); + size_t scram_h_size; + enum __kafka_scram_state state; + struct iovec cnonce; + struct iovec first_msg; + struct iovec server_signature_b64; +} kafka_scram_t; + +typedef struct __kafka_sasl +{ + kafka_scram_t scram; + char *buf; + size_t bsize; + int status; +} kafka_sasl_t; + +typedef struct __kafka_config +{ + int produce_timeout; + int produce_msg_max_bytes; + int produce_msgset_cnt; + int produce_msgset_max_bytes; + int fetch_timeout; + int fetch_min_bytes; + int fetch_max_bytes; + int fetch_msg_max_bytes; + long long offset_timestamp; + long long commit_timestamp; + int session_timeout; + int rebalance_timeout; + long long retention_time_period; + int produce_acks; + int allow_auto_topic_creation; + int api_version_request; + int api_version_timeout; + char *broker_version; + int compress_type; + int compress_level; + char *client_id; + int check_crcs; + int offset_store; + char *rack_id; + + char *mechanisms; + char *username; + char *password; + int (*client_new)(void *conf, kafka_sasl_t *sasl); + int (*recv)(const char *buf, size_t len, void *conf, void *sasl); +} kafka_config_t; + +typedef struct __kafka_broker +{ + int node_id; + int port; + char *host; + char *rack; + short error; + int status; +} kafka_broker_t; + +typedef struct __kafka_partition +{ + short error; + int partition_index; + kafka_broker_t leader; + int *replica_nodes; + int replica_node_elements; + int *isr_nodes; + int isr_node_elements; +} kafka_partition_t; + +typedef struct __kafka_meta +{ + short error; + char *topic_name; + char *error_message; + signed char is_internal; + kafka_partition_t **partitions; + int partition_elements; +} kafka_meta_t; + +typedef struct __kafka_topic_partition +{ + short error; + char *topic_name; + int partition; + int preferred_read_replica; + long long offset; + long long high_watermark; + long long low_watermark; + long long last_stable_offset; + long long log_start_offset; + long long offset_timestamp; + char *committed_metadata; + struct list_head record_list; +} kafka_topic_partition_t; + +typedef struct __kafka_record_header +{ + struct list_head list; + void *key; + size_t key_len; + int key_is_moved; + void *value; + size_t value_len; + int value_is_moved; +} kafka_record_header_t; + +typedef struct __kafka_record +{ + void *key; + size_t key_len; + int key_is_moved; + void *value; + size_t value_len; + int value_is_moved; + long long timestamp; + long long offset; + struct list_head header_list; + short status; + kafka_topic_partition_t *toppar; +} kafka_record_t; + +typedef struct __kafka_memeber +{ + char *member_id; + char *client_id; + char *client_host; + void *member_metadata; + size_t member_metadata_len; + struct list_head toppar_list; + struct list_head assigned_toppar_list; +} kafka_member_t; + +typedef int (*kafka_assignor_t)(kafka_member_t **members, int member_elements, + void *meta_topic); + +typedef struct __kafka_group_protocol +{ + struct list_head list; + char *protocol_name; + kafka_assignor_t assignor; +} kafka_group_protocol_t; + +typedef struct __kafka_cgroup +{ + struct list_head assigned_toppar_list; + short error; + char *error_msg; + kafka_broker_t coordinator; + char *leader_id; + char *member_id; + kafka_member_t **members; + int member_elements; + int generation_id; + char *group_name; + char *protocol_type; + char *protocol_name; + struct list_head group_protocol_list; +} kafka_cgroup_t; + +typedef struct __kafka_block +{ + void *buf; + size_t len; + int is_moved; +} kafka_block_t; + + +#ifdef __cplusplus +extern "C" +{ +#endif + +int kafka_parser_append_message(const void *buf, size_t *size, + kafka_parser_t *parser); + +void kafka_parser_init(kafka_parser_t *parser); +void kafka_parser_deinit(kafka_parser_t *parser); + +void kafka_topic_partition_init(kafka_topic_partition_t *toppar); +void kafka_topic_partition_deinit(kafka_topic_partition_t *toppar); + +void kafka_cgroup_init(kafka_cgroup_t *cgroup); +void kafka_cgroup_deinit(kafka_cgroup_t *cgroup); + +void kafka_block_init(kafka_block_t *block); +void kafka_block_deinit(kafka_block_t *block); + +void kafka_broker_init(kafka_broker_t *brock); +void kafka_broker_deinit(kafka_broker_t *broker); + +void kafka_config_init(kafka_config_t *config); +void kafka_config_deinit(kafka_config_t *config); + +void kafka_meta_init(kafka_meta_t *meta); +void kafka_meta_deinit(kafka_meta_t *meta); + +void kafka_partition_init(kafka_partition_t *partition); +void kafka_partition_deinit(kafka_partition_t *partition); + +void kafka_member_init(kafka_member_t *member); +void kafka_member_deinit(kafka_member_t *member); + +void kafka_record_init(kafka_record_t *record); +void kafka_record_deinit(kafka_record_t *record); + +void kafka_record_header_init(kafka_record_header_t *header); +void kafka_record_header_deinit(kafka_record_header_t *header); + +void kafka_api_init(kafka_api_t *api); +void kafka_api_deinit(kafka_api_t *api); + +void kafka_sasl_init(kafka_sasl_t *sasl); +void kafka_sasl_deinit(kafka_sasl_t *sasl); + +int kafka_topic_partition_set_tp(const char *topic_name, int partition, + kafka_topic_partition_t *toppar); + +int kafka_record_set_key(const void *key, size_t key_len, + kafka_record_t *record); + +int kafka_record_set_value(const void *val, size_t val_len, + kafka_record_t *record); + +int kafka_record_header_set_kv(const void *key, size_t key_len, + const void *val, size_t val_len, + kafka_record_header_t *header); + +int kafka_meta_set_topic(const char *topic_name, kafka_meta_t *meta); + +int kafka_cgroup_set_group(const char *group_name, kafka_cgroup_t *cgroup); + +int kafka_broker_get_api_version(const kafka_api_t *broker, int api_key, + int min_ver, int max_ver); + +unsigned kafka_get_features(kafka_api_version_t *api, size_t api_cnt); + +int kafka_api_version_is_queryable(const char *broker_version, + kafka_api_version_t **api, + size_t *api_cnt); + +int kafka_sasl_set_mechanisms(kafka_config_t *conf); +int kafka_sasl_set_username(const char *username, kafka_config_t *conf); +int kafka_sasl_set_password(const char *passwd, kafka_config_t *conf); + +#ifdef __cplusplus +} +#endif + +#endif + diff --git a/src/protocol/mysql_byteorder.c b/src/protocol/mysql_byteorder.c new file mode 100644 index 0000000..c093860 --- /dev/null +++ b/src/protocol/mysql_byteorder.c @@ -0,0 +1,91 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) +*/ + +#include "mysql_byteorder.h" + +int decode_length_safe(unsigned long long *res, const unsigned char **pos, + const unsigned char *end) +{ + const unsigned char *p = *pos; + + if (p >= end) + return 0; + + switch (*p) + { + default: + *res = *p; + *pos = p + 1; + break; + + case 251: + *res = (~0ULL); + *pos = p + 1; + break; + + case 252: + if (p + 2 > end) + return 0; + + *res = uint2korr(p + 1); + *pos = p + 3; + break; + + case 253: + if (p + 3 > end) + return 0; + + *res = uint3korr(p + 1); + *pos = p + 4; + break; + + case 254: + if (p + 8 > end) + return 0; + + *res = uint8korr(p + 1); + *pos = p + 9; + break; + + case 255: + return -1; + } + + return 1; +} + +int decode_string(const unsigned char **str, unsigned long long *len, + const unsigned char **pos, const unsigned char *end) +{ + unsigned long long length; + + if (decode_length_safe(&length, pos, end) <= 0) + return 0; + + if (length == (~0ULL)) + length = 0; + + if (*pos + length > end) + return 0; + + *len = length; + *str = *pos; + *pos = *pos + length; + return 1; +} + diff --git a/src/protocol/mysql_byteorder.h b/src/protocol/mysql_byteorder.h new file mode 100644 index 0000000..79c2f5d --- /dev/null +++ b/src/protocol/mysql_byteorder.h @@ -0,0 +1,438 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) +*/ + +#ifndef _MYSQL_BYTEORDER_H_ +#define _MYSQL_BYTEORDER_H_ + +#include +#include +#include + +#if __BYTE_ORDER == __LITTLE_ENDIAN + +static inline int16_t sint2korr(const unsigned char *A) +{ + int16_t ret; + memcpy(&ret, A, sizeof(ret)); + return ret; +} + +static inline int32_t sint4korr(const unsigned char *A) +{ + int32_t ret; + memcpy(&ret, A, sizeof(ret)); + return ret; +} + +static inline uint16_t uint2korr(const unsigned char *A) +{ + uint16_t ret; + memcpy(&ret, A, sizeof(ret)); + return ret; +} + +static inline uint32_t uint4korr(const unsigned char *A) +{ + uint32_t ret; + memcpy(&ret, A, sizeof(ret)); + return ret; +} + +static inline uint64_t uint8korr(const unsigned char *A) +{ + uint64_t ret; + memcpy(&ret, A, sizeof(ret)); + return ret; +} + +static inline int64_t sint8korr(const unsigned char *A) +{ + int64_t ret; + memcpy(&ret, A, sizeof(ret)); + return ret; +} + +static inline void int2store(unsigned char *T, uint16_t A) +{ + memcpy(T, &A, sizeof(A)); +} + +static inline void int4store(unsigned char *T, uint32_t A) +{ + memcpy(T, &A, sizeof(A)); +} + +static inline void int7store(unsigned char *T, uint64_t A) +{ + memcpy(T, &A, 7); +} + +static inline void int8store(unsigned char *T, uint64_t A) +{ + memcpy(T, &A, sizeof(A)); +} + +static inline void float4get(float *V, const unsigned char *M) +{ + memcpy(V, (M), sizeof(float)); +} + +static inline void float4store(unsigned char *V, float M) +{ + memcpy(V, (&M), sizeof(float)); +} + +static inline void float8get(double *V, const unsigned char *M) +{ + memcpy(V, M, sizeof(double)); +} + +static inline void float8store(unsigned char *V, double M) +{ + memcpy(V, &M, sizeof(double)); +} + +static inline void floatget(float *V, const unsigned char *M) +{ + float4get(V, M); +} + +static inline void floatstore(unsigned char *V, float M) +{ + float4store(V, M); +} + +static inline void doublestore(unsigned char *T, double V) +{ + memcpy(T, &V, sizeof(double)); +} +static inline void doubleget(double *V, const unsigned char *M) +{ + memcpy(V, M, sizeof(double)); +} + +static inline void ushortget(uint16_t *V, const unsigned char *pM) +{ + *V = uint2korr(pM); +} + +static inline void shortget(int16_t *V, const unsigned char *pM) +{ + *V = sint2korr(pM); +} + +static inline void longget(int32_t *V, const unsigned char *pM) +{ + *V = sint4korr(pM); +} + +static inline void ulongget(uint32_t *V, const unsigned char *pM) +{ + *V = uint4korr(pM); +} + +static inline void shortstore(unsigned char *T, int16_t V) +{ + int2store(T, V); +} + +static inline void longstore(unsigned char *T, int32_t V) +{ + int4store(T, V); +} + +static inline void longlongget(int64_t *V, const unsigned char *M) +{ + memcpy(V, (M), sizeof(uint64_t)); +} + +static inline void longlongstore(unsigned char *T, int64_t V) +{ + memcpy((T), &V, sizeof(uint64_t)); +} + +static inline int32_t sint3korr(const unsigned char *A) +{ + int32_t ret = 0; + memcpy(&ret, A, 3); + return ret; +} + +static inline uint32_t uint3korr(const unsigned char *A) +{ + uint32_t ret = 0; + memcpy(&ret, A, 3); + return ret; +} + +static inline void int3store(unsigned char *T, uint32_t A) +{ + memcpy(T, &A, 3); +} + +#elif __BYTE_ORDER == __BIG_ENDIAN + +static inline int16_t sint2korr(const unsigned char *A) +{ + return (int16_t)(((int16_t)(A[0])) + ((int16_t)(A[1]) << 8)); +} + +static inline int32_t sint4korr(const unsigned char *A) +{ + return (int32_t)(((int32_t)(A[0])) + (((int32_t)(A[1]) << 8)) + + (((int32_t)(A[2]) << 16)) + (((int32_t)(A[3]) << 24))); +} + +static inline uint16_t uint2korr(const unsigned char *A) +{ + return (uint16_t)(((uint16_t)(A[0])) + ((uint16_t)(A[1]) << 8)); +} + +static inline uint32_t uint4korr(const unsigned char *A) +{ + return (uint32_t)(((uint32_t)(A[0])) + (((uint32_t)(A[1])) << 8) + + (((uint32_t)(A[2])) << 16) + (((uint32_t)(A[3])) << 24)); +} + +static inline uint64_t uint8korr(const unsigned char *A) +{ + return ((uint64_t)(((uint32_t)(A[0])) + (((uint32_t)(A[1])) << 8) + + (((uint32_t)(A[2])) << 16) + (((uint32_t)(A[3])) << 24)) + + (((uint64_t)(((uint32_t)(A[4])) + (((uint32_t)(A[5])) << 8) + + (((uint32_t)(A[6])) << 16) + (((uint32_t)(A[7])) << 24))) << 32)); +} + +static inline int64_t sint8korr(const unsigned char *A) +{ + return (int64_t)uint8korr(A); +} + +static inline void int2store(unsigned char *T, uint16_t A) +{ + uint def_temp = A; + *(T) = (unsigned char)(def_temp); + *(T + 1) = (unsigned char)(def_temp >> 8); +} + +static inline void int4store(unsigned char *T, uint32_t A) +{ + *(T) = (unsigned char)(A); + *(T + 1) = (unsigned char)(A >> 8); + *(T + 2) = (unsigned char)(A >> 16); + *(T + 3) = (unsigned char)(A >> 24); +} + +static inline void int7store(unsigned char *T, uint64_t A) +{ + *(T) = (unsigned char)(A); + *(T + 1) = (unsigned char)(A >> 8); + *(T + 2) = (unsigned char)(A >> 16); + *(T + 3) = (unsigned char)(A >> 24); + *(T + 4) = (unsigned char)(A >> 32); + *(T + 5) = (unsigned char)(A >> 40); + *(T + 6) = (unsigned char)(A >> 48); +} + +static inline void int8store(unsigned char *T, uint64_t A) +{ + uint def_temp = (uint)A, def_temp2 = (uint)(A >> 32); + int4store(T, def_temp); + int4store(T + 4, def_temp2); +} + +static inline void float4store(unsigned char *T, float A) +{ + *(T) = ((unsigned char *)&A)[3]; + *((T) + 1) = (unsigned char)((unsigned char *)&A)[2]; + *((T) + 2) = (unsigned char)((unsigned char *)&A)[1]; + *((T) + 3) = (unsigned char)((unsigned char *)&A)[0]; +} + +static inline void float4get(float *V, const unsigned char *M) +{ + float def_temp; + ((unsigned char *)&def_temp)[0] = (M)[3]; + ((unsigned char *)&def_temp)[1] = (M)[2]; + ((unsigned char *)&def_temp)[2] = (M)[1]; + ((unsigned char *)&def_temp)[3] = (M)[0]; + (*V) = def_temp; +} + +static inline void float8store(unsigned char *T, double V) { + *(T) = ((unsigned char *)&V)[7]; + *((T) + 1) = (unsigned char)((unsigned char *)&V)[6]; + *((T) + 2) = (unsigned char)((unsigned char *)&V)[5]; + *((T) + 3) = (unsigned char)((unsigned char *)&V)[4]; + *((T) + 4) = (unsigned char)((unsigned char *)&V)[3]; + *((T) + 5) = (unsigned char)((unsigned char *)&V)[2]; + *((T) + 6) = (unsigned char)((unsigned char *)&V)[1]; + *((T) + 7) = (unsigned char)((unsigned char *)&V)[0]; +} + +static inline void float8get(double *V, const unsigned char *M) +{ + double def_temp; + ((unsigned char *)&def_temp)[0] = (M)[7]; + ((unsigned char *)&def_temp)[1] = (M)[6]; + ((unsigned char *)&def_temp)[2] = (M)[5]; + ((unsigned char *)&def_temp)[3] = (M)[4]; + ((unsigned char *)&def_temp)[4] = (M)[3]; + ((unsigned char *)&def_temp)[5] = (M)[2]; + ((unsigned char *)&def_temp)[6] = (M)[1]; + ((unsigned char *)&def_temp)[7] = (M)[0]; + (*V) = def_temp; +} + +static inline void ushortget(uint16_t *V, const unsigned char *pM) +{ + *V = (uint16_t)(((uint16_t)((unsigned char)(pM)[1])) + + ((uint16_t)((uint16_t)(pM)[0]) << 8)); +} + +static inline void shortget(int16_t *V, const unsigned char *pM) +{ + *V = (short)(((short)((unsigned char)(pM)[1])) + + ((short)((short)(pM)[0]) << 8)); +} + +static inline void longget(int32_t *V, const unsigned char *pM) +{ + int32_t def_temp; + ((unsigned char *)&def_temp)[0] = (pM)[0]; + ((unsigned char *)&def_temp)[1] = (pM)[1]; + ((unsigned char *)&def_temp)[2] = (pM)[2]; + ((unsigned char *)&def_temp)[3] = (pM)[3]; + (*V) = def_temp; +} + +static inline void ulongget(uint32_t *V, const unsigned char *pM) +{ + uint32_t def_temp; + ((unsigned char *)&def_temp)[0] = (pM)[0]; + ((unsigned char *)&def_temp)[1] = (pM)[1]; + ((unsigned char *)&def_temp)[2] = (pM)[2]; + ((unsigned char *)&def_temp)[3] = (pM)[3]; + (*V) = def_temp; +} + +static inline void shortstore(unsigned char *T, int16_t A) +{ + uint def_temp = (uint)(A); + *(((unsigned char *)T) + 1) = (unsigned char)(def_temp); + *(((unsigned char *)T) + 0) = (unsigned char)(def_temp >> 8); +} + +static inline void longstore(unsigned char *T, int32_t A) +{ + *(((unsigned char *)T) + 3) = ((A)); + *(((unsigned char *)T) + 2) = (((A) >> 8)); + *(((unsigned char *)T) + 1) = (((A) >> 16)); + *(((unsigned char *)T) + 0) = (((A) >> 24)); +} + +static inline void floatget(float *V, const unsigned char *M) +{ + memcpy(V, (M), sizeof(float)); +} + +static inline void floatstore(unsigned char *T, float V) +{ + memcpy((T), (&V), sizeof(float)); +} + +static inline void doubleget(double *V, const unsigned char *M) +{ + memcpy(V, (M), sizeof(double)); +} + +static inline void doublestore(unsigned char *T, double V) +{ + memcpy((T), &V, sizeof(double)); +} + +static inline void longlongget(int64_t *V, const unsigned char *M) +{ + memcpy(V, (M), sizeof(uint64_t)); +} + +static inline void longlongstore(unsigned char *T, int64_t V) +{ + memcpy((T), &V, sizeof(uint64_t)); +} + +static inline int32_t sint3korr(const unsigned char *p) +{ + return ((int32_t)(((p[2]) & 128) + ? (((uint32_t)255L << 24) | (((uint32_t)p[2]) << 16) + | (((uint32_t)p[1]) << 8) | ((uint32_t)p[0])) + : (((uint32_t)p[2]) << 16) | (((uint32_t)p[1]) << 8) + | ((uint32_t)p[0]))); +} + +static inline uint32_t uint3korr(const unsigned char *p) +{ + return (uint32_t)(((uint32_t)(p[0])) + + (((uint32_t)(p[1])) << 8) + + (((uint32_t)(p[2])) << 16)); +} + +static inline void int3store(unsigned char *p, uint32_t x) +{ + *(p) = (unsigned char)(x); + *(p + 1) = (unsigned char)(x >> 8); + *(p + 2) = (unsigned char)(x >> 16); +} + +#else +# error "unknown byte order" +#endif + +// length of buffer needed to store this number [1, 3, 4, 9]. +static inline unsigned int get_length_size(unsigned long long num) +{ + if (num < (unsigned long long)252LL) + return 1; + + if (num < (unsigned long long)65536LL) + return 3; + + if (num < (unsigned long long)16777216LL) + return 4; + + return 9; +} + +#ifdef __cplusplus +extern "C" +{ +#endif + +// decode encoded length integer within *end, move pos forward +int decode_length_safe(unsigned long long *res, const unsigned char **pos, + const unsigned char *end); + +// decode encoded length string within *end, move pos forward +int decode_string(const unsigned char **str, unsigned long long *len, + const unsigned char **pos, const unsigned char *end); + +#ifdef __cplusplus +} +#endif + +#endif + diff --git a/src/protocol/mysql_parser.c b/src/protocol/mysql_parser.c new file mode 100644 index 0000000..babe03b --- /dev/null +++ b/src/protocol/mysql_parser.c @@ -0,0 +1,481 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) +*/ + +#include +#include "mysql_types.h" +#include "mysql_byteorder.h" +#include "mysql_parser.h" + +static int parse_base_packet(const void *buf, size_t len, mysql_parser_t *parser); + +static int parse_error_packet(const void *buf, size_t len, mysql_parser_t *parser); + +static int parse_ok_packet(const void *buf, size_t len, mysql_parser_t *parser); + +static int parse_eof_packet(const void *buf, size_t len, mysql_parser_t *parser); + +static int parse_field_eof_packet(const void *buf, size_t len, mysql_parser_t *parser); + +static int parse_field_count(const void *buf, size_t len, mysql_parser_t *parser); + +static int parse_column_def_packet(const void *buf, size_t len, mysql_parser_t *parser); + +static int parse_local_inline(const void *buf, size_t len, mysql_parser_t *parser); + +static int parse_row_packet(const void *buf, size_t len, mysql_parser_t *parser); + +void mysql_parser_init(mysql_parser_t *parser) +{ + parser->offset = 0; + parser->cmd = MYSQL_COM_QUERY; + parser->packet_type = MYSQL_PACKET_OTHER; + parser->parse = parse_base_packet; + parser->result_set_count = 0; + INIT_LIST_HEAD(&parser->result_set_list); +} + +void mysql_parser_deinit(mysql_parser_t *parser) +{ + struct __mysql_result_set *result_set; + struct list_head *pos, *tmp; + int i; + + list_for_each_safe(pos, tmp, &parser->result_set_list) + { + result_set = list_entry(pos, struct __mysql_result_set, list); + list_del(pos); + + if (result_set->field_count) + { + for (i = 0; i < result_set->field_count; i++) + free(result_set->fields[i]); + + free(result_set->fields); + } + + free(result_set); + } +} + +int mysql_parser_parse(const void *buf, size_t len, mysql_parser_t *parser) +{ +// const char *end = (const char *)buf + len; + int ret; + + do { + ret = parser->parse(buf, len, parser); + if (ret < 0) + return ret; + + if (ret > 0 && parser->offset != len) + return -2; + + } while (parser->offset < len); + + return ret; +} + +void mysql_parser_get_net_state(const char **net_state_str, + size_t *net_state_len, + mysql_parser_t *parser) +{ + *net_state_str = (const char *)parser->buf + parser->net_state_offset; + *net_state_len = MYSQL_STATE_LENGTH; +} + +void mysql_parser_get_err_msg(const char **err_msg_str, + size_t *err_msg_len, + mysql_parser_t *parser) +{ + if (parser->err_msg_offset == (size_t)-1 && parser->err_msg_len == 0) + { + *err_msg_str = MYSQL_STATE_DEFAULT; + *err_msg_len = MYSQL_STATE_LENGTH; + } else { + *err_msg_str = (const char *)parser->buf + parser->err_msg_offset; + *err_msg_len = parser->err_msg_len; + } +} + +static int parse_base_packet(const void *buf, size_t len, mysql_parser_t *parser) +{ + const unsigned char *p = (const unsigned char *)buf + parser->offset; + + switch (*p) + { + // OK PACKET + case MYSQL_PACKET_HEADER_OK: + parser->parse = parse_ok_packet; + break; + // ERR PACKET + case MYSQL_PACKET_HEADER_ERROR: + parser->parse = parse_error_packet; + break; + // EOF PACKET + case MYSQL_PACKET_HEADER_EOF: + parser->parse = parse_eof_packet; + break; + // LOCAL INFILE PACKET + case MYSQL_PACKET_HEADER_NULL: + // if (field_count == -1) + parser->parse = parse_local_inline; + break; + default: + parser->parse = parse_field_count; + break; + } + + return 0; +} + +// 1:0xFF|2:err_no|1:#|5:server_state|0-512:err_msg +static int parse_error_packet(const void *buf, size_t len, mysql_parser_t *parser) +{ + const unsigned char *p = (const unsigned char *)buf + parser->offset; + const unsigned char *buf_end = (const unsigned char *)buf + len; + + if (p + 9 > buf_end) + return -2; + + parser->error = uint2korr(p + 1); + p += 3; + + if (*p == '#') + { + p += 1; + parser->net_state_offset = p - (const unsigned char *)buf; + p += MYSQL_STATE_LENGTH; + + size_t msg_len = len - parser->offset - 9; + parser->err_msg_offset = p - (const unsigned char *)buf; + parser->err_msg_len = msg_len; + } else { + parser->err_msg_offset = (size_t)-1; + parser->err_msg_len = 0; + } + + parser->offset = len; + parser->packet_type = MYSQL_PACKET_ERROR; + parser->buf = buf; + return 1; +} + +// 1:0x00|1-9:affect_row|1-9:insert_id|2:server_status|2:warning_count|0-n:server_msg +static int parse_ok_packet(const void *buf, size_t len, mysql_parser_t *parser) +{ + const unsigned char *p = (const unsigned char *)buf + parser->offset; + const unsigned char *buf_end = (const unsigned char *)buf + len; + + unsigned long long affected_rows, insert_id, info_len; + const unsigned char *str; + struct __mysql_result_set *result_set; + unsigned int warning_count; + int server_status; + + p += 1;// 0x00 + if (decode_length_safe(&affected_rows, &p, buf_end) <= 0) + return -2; + + if (decode_length_safe(&insert_id, &p, buf_end) <= 0) + return -2; + + if (p + 4 > buf_end) + return -2; + + server_status = uint2korr(p); + p += 2; + warning_count = uint2korr(p); + p += 2; + + if (p != buf_end) + { + if (decode_string(&str, &info_len, &p, buf_end) == 0) + return -2; + + if (p != buf_end) + { + if (server_status & MYSQL_SERVER_SESSION_STATE_CHANGED) + { + const unsigned char *tmp_str; + unsigned long long tmp_len; + if (decode_string(&tmp_str, &tmp_len, &p, buf_end) == 0) + return -2; + } else + return -2; + } + } else { + str = p; + info_len = 0; + } + + result_set = (struct __mysql_result_set *)malloc(sizeof(struct __mysql_result_set)); + if (result_set == NULL) + return -1; + + result_set->info_offset = str - (const unsigned char *)buf; + result_set->info_len = info_len; + result_set->affected_rows = (affected_rows == ~0ULL) ? 0 : affected_rows; + result_set->insert_id = (insert_id == ~0ULL) ? 0 : insert_id; + result_set->server_status = server_status; + result_set->warning_count = warning_count; + result_set->type = MYSQL_PACKET_OK; + result_set->field_count = 0; + + list_add_tail(&result_set->list, &parser->result_set_list); + parser->current_result_set = result_set; + parser->result_set_count++; + parser->packet_type = MYSQL_PACKET_OK; + + parser->buf = buf; + parser->offset = p - (const unsigned char *)buf; + + if (server_status & MYSQL_SERVER_MORE_RESULTS_EXIST) + { + parser->parse = parse_base_packet; + return 0; + } + + return 1; +} + +// 1:0xfe|2:warnings|2:status_flag +static int parse_eof_packet(const void *buf, size_t len, mysql_parser_t *parser) +{ + const unsigned char *p = (const unsigned char *)buf + parser->offset; + const unsigned char *buf_end = (const unsigned char *)buf + len; + + if (p + 5 > buf_end) + return -2; + + parser->offset += 5; + parser->packet_type = MYSQL_PACKET_EOF; + parser->buf = buf; + + int status_flag = uint2korr(p + 3); + if (status_flag & MYSQL_SERVER_MORE_RESULTS_EXIST) + { + parser->parse = parse_base_packet; + return 0; + } + + return 1; +} + +static int parse_field_eof_packet(const void *buf, size_t len, mysql_parser_t *parser) +{ + const unsigned char *p = (const unsigned char *)buf + parser->offset; + const unsigned char *buf_end = (const unsigned char *)buf + len; + + if (p + 5 > buf_end) + return -2; + + parser->offset += 5; + parser->current_result_set->rows_begin_offset = parser->offset; + parser->parse = parse_row_packet; + return 0; +} + +//raw file data +static int parse_local_inline(const void *buf, size_t len, mysql_parser_t *parser) +{ + parser->local_inline_offset = parser->offset; + parser->local_inline_length = len - parser->offset; + parser->offset = len; + parser->packet_type = MYSQL_PACKET_LOCAL_INLINE; + parser->buf = buf; + return 1; +} + +// for each field: +// NULL as 0xfb, or a length-encoded-string +static int parse_row_packet(const void *buf, size_t len, mysql_parser_t *parser) +{ + const unsigned char *p = (const unsigned char *)buf + parser->offset; + const unsigned char *buf_end = (const unsigned char *)buf + len; + + unsigned long long cell_len; + const unsigned char *cell_data; + + size_t i; + + if (*p == MYSQL_PACKET_HEADER_ERROR) + { + parser->parse = parse_error_packet; + return 0; + } + + if (*p == MYSQL_PACKET_HEADER_EOF) + { + parser->parse = parse_eof_packet; + parser->current_result_set->rows_end_offset = parser->offset; + + return 0; + } + + for (i = 0; i < parser->current_result_set->field_count; i++) + { + if (*p == MYSQL_PACKET_HEADER_NULL) + { + p++; + } else { + if (decode_string(&cell_data, &cell_len, &p, buf_end) == 0) + break; + } + } + + if (i != parser->current_result_set->field_count) + return -2; + + parser->current_result_set->row_count++; + parser->offset = p - (const unsigned char *)buf; + return 0; +} + +static int parse_field_count(const void *buf, size_t len, mysql_parser_t *parser) +{ + const unsigned char *p = (const unsigned char *)buf + parser->offset; + const unsigned char *buf_end = (const unsigned char *)buf + len; + + unsigned long long field_count; + struct __mysql_result_set *result_set; + + if (decode_length_safe(&field_count, &p, buf_end) <= 0) + return -2; + + field_count = (field_count == ~0ULL) ? 0 : field_count; + + if (field_count) + { + result_set = (struct __mysql_result_set *)malloc(sizeof (struct __mysql_result_set)); + if (result_set == NULL) + return -1; + + result_set->fields = (mysql_field_t **)calloc(field_count, sizeof (mysql_field_t *)); + if (result_set->fields == NULL) + { + free(result_set); + return -1; + } + + result_set->field_count = field_count; + result_set->row_count = 0; + result_set->type = MYSQL_PACKET_GET_RESULT; + + list_add_tail(&result_set->list, &parser->result_set_list); + parser->current_result_set = result_set; + parser->current_field_count = 0; + parser->result_set_count++; + parser->packet_type = MYSQL_PACKET_GET_RESULT; + + parser->parse = parse_column_def_packet; + parser->offset = p - (const unsigned char *)buf; + } else { + parser->parse = parse_ok_packet; + } + return 0; +} + +// COLUMN DEFINATION PACKET. for one field: (after protocol 41) +// str:catalog|str:db|str:table|str:org_table|str:name|str:org_name| +// 2:charsetnr|4:length|1:type|2:flags|1:decimals|1:0x00|1:0x00|n:str(if COM_FIELD_LIST) +static int parse_column_def_packet(const void *buf, size_t len, mysql_parser_t *parser) +{ + const unsigned char *p = (const unsigned char *)buf + parser->offset; + const unsigned char *buf_end = (const unsigned char *)buf + len; + + int flag = 0; + const unsigned char *str; + unsigned long long str_len; + mysql_field_t *field = (mysql_field_t *)malloc(sizeof(mysql_field_t)); + + if (!field) + return -1; + + do { + if (decode_string(&str, &str_len, &p, buf_end) == 0) + break; + field->catalog_offset = str - (const unsigned char *)buf; + field->catalog_length = str_len; + + if (decode_string(&str, &str_len, &p, buf_end) == 0) + break; + field->db_offset = str - (const unsigned char *)buf; + field->db_length = str_len; + + if (decode_string(&str, &str_len, &p, buf_end) == 0) + break; + field->table_offset = str - (const unsigned char *)buf; + field->table_length = str_len; + + if (decode_string(&str, &str_len, &p, buf_end) == 0) + break; + field->org_table_offset = str - (const unsigned char *)buf; + field->org_table_length = str_len; + + if (decode_string(&str, &str_len, &p, buf_end) == 0) + break; + field->name_offset = str - (const unsigned char *)buf; + field->name_length = str_len; + + if (decode_string(&str, &str_len, &p, buf_end) == 0) + break; + field->org_name_offset = str - (const unsigned char *)buf; + field->org_name_length = str_len; + + // the rest needs at least 13 + if (p + 13 > buf_end) + break; + + p++; // length of the following fields (always 0x0c) + field->charsetnr = uint2korr(p); + field->length = uint4korr(p + 2); + field->data_type = *(p + 6); + field->flags = uint2korr(p + 7); + field->decimals = (int)p[9]; + p += 12; + // if is COM_FIELD_LIST, the rest is a string + // 0x03 for COM_QUERY + if (parser->cmd == MYSQL_COM_FIELD_LIST) + { + if (decode_string(&str, &str_len, &p, buf_end) == 0) + break; + field->def_offset = str - (const unsigned char *)buf; + field->def_length = str_len; + } else { + field->def_offset = (size_t)-1; + field->def_length = 0; + } + flag = 1; + } while (0); + + if (flag == 0) + { + free(field); + return -2; + } + + //parser->fields.emplace_back(std::move(field)); + parser->current_result_set->fields[parser->current_field_count] = field; + + parser->offset = p - (const unsigned char *)buf; + if (++parser->current_field_count == parser->current_result_set->field_count) + parser->parse = parse_field_eof_packet; + + return 0; +} + diff --git a/src/protocol/mysql_parser.h b/src/protocol/mysql_parser.h new file mode 100644 index 0000000..faf39e2 --- /dev/null +++ b/src/protocol/mysql_parser.h @@ -0,0 +1,178 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) +*/ + +#ifndef _MYSQL_PARSER_H_ +#define _MYSQL_PARSER_H_ + +#include +#include "list.h" + +// the first byte in response message +// from 1 to 0xfa means result_set field or data_row +// NULL is sent as 0xfb, will be treated as LOCAL_INLINE +// MySQL MESSAGE STATUS +enum +{ + MYSQL_PACKET_HEADER_OK = 0, + MYSQL_PACKET_HEADER_NULL = 251, //0xfb + MYSQL_PACKET_HEADER_EOF = 254, //0xfe + MYSQL_PACKET_HEADER_ERROR = 255, //0xff +}; + +typedef struct __mysql_field +{ + size_t name_offset; /* Name of column */ + size_t org_name_offset; /* Original column name, if an alias */ + size_t table_offset; /* Table of column if column was a field */ + size_t org_table_offset; /* Org table name, if table was an alias */ + size_t db_offset; /* Database for table */ + size_t catalog_offset; /* Catalog for table */ + size_t def_offset; /* Default value (set by mysql_list_fields) */ + int length; /* Width of column (create length) */ + int name_length; + int org_name_length; + int table_length; + int org_table_length; + int db_length; + int catalog_length; + int def_length; + int flags; /* Div flags */ + int decimals; /* Number of decimals in field */ + int charsetnr; /* Character set */ + int data_type; /* Type of field. See mysql_types.h for types */ +// void *extension; +} mysql_field_t; + +struct __mysql_result_set +{ + struct list_head list; + int type; + int server_status; + + int field_count; + int row_count; + size_t rows_begin_offset; + size_t rows_end_offset; + mysql_field_t **fields; + + unsigned long long affected_rows; + unsigned long long insert_id; + int warning_count; + size_t info_offset; + int info_len; +}; + +typedef struct __mysql_result_set_cursor +{ + const struct list_head *head; + const struct list_head *current; +} mysql_result_set_cursor_t; + +typedef struct __mysql_parser +{ + size_t offset; + int cmd; + int packet_type; + int (*parse)(const void *, size_t, struct __mysql_parser *); + + size_t net_state_offset; // err packet server_state + size_t err_msg_offset; // -1 for default + int err_msg_len; // -1 for default + + size_t local_inline_offset; // local inline file name + int local_inline_length; + const void *buf; + int error; + + int result_set_count; + struct list_head result_set_list; + struct __mysql_result_set *current_result_set; + int current_field_count; +} mysql_parser_t; + +#ifdef __cplusplus +extern "C" +{ +#endif + +void mysql_parser_init(mysql_parser_t *parser); +void mysql_parser_deinit(mysql_parser_t *parser); +void mysql_parser_get_info(const char **info_str, + size_t *info_len, + mysql_parser_t *parser); +void mysql_parser_get_net_state(const char **net_state_str, + size_t *net_state_len, + mysql_parser_t *parser); +void mysql_parser_get_err_msg(const char **err_msg_str, + size_t *err_msg_len, + mysql_parser_t *parser); +// if append check get 0, don`t need to parse() +// if append check get 1, parse and tell them if this is all the package +// +// ret: 1: this ResultSet is received finished +// 0: this ResultSet is not recieved finished +// -1: system error +// -2: bad message error +int mysql_parser_parse(const void *buf, size_t len, mysql_parser_t *parser); + +#ifdef __cplusplus +} +#endif + +static inline void mysql_parser_set_command(int cmd, mysql_parser_t *parser) +{ + parser->cmd = cmd; +} + +static inline void mysql_parser_get_local_inline(const char **local_inline_name, + size_t *local_inline_len, + mysql_parser_t *parser) +{ + *local_inline_name = (const char *)parser->buf + parser->local_inline_offset; + *local_inline_len = parser->local_inline_length; +} + +static inline void mysql_result_set_cursor_init(mysql_result_set_cursor_t *cursor, + mysql_parser_t *parser) +{ + cursor->head = &parser->result_set_list; + cursor->current = cursor->head; +} + +static inline void mysql_result_set_cursor_rewind(mysql_result_set_cursor_t *cursor) +{ + cursor->current = cursor->head; +} + +static inline void mysql_result_set_cursor_deinit(mysql_result_set_cursor_t *cursor) +{ +} + +static inline int mysql_result_set_cursor_next(struct __mysql_result_set **result_set, + mysql_result_set_cursor_t *cursor) +{ + if (cursor->current->next != cursor->head) + { + cursor->current = cursor->current->next; + *result_set = list_entry(cursor->current, struct __mysql_result_set, list); + return 0; + } + return 1; +} + +#endif diff --git a/src/protocol/mysql_stream.c b/src/protocol/mysql_stream.c new file mode 100644 index 0000000..b659582 --- /dev/null +++ b/src/protocol/mysql_stream.c @@ -0,0 +1,98 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include "mysql_stream.h" + +#define MAX(x, y) ((x) >= (y) ? (x) : (y)) + +static int __mysql_stream_write_payload(const void *buf, size_t *n, + mysql_stream_t *stream); + +static int __mysql_stream_write_head(const void *buf, size_t *n, + mysql_stream_t *stream) +{ + void *p = &stream->head[4 - stream->head_left]; + + if (*n < stream->head_left) + { + memcpy(p, buf, *n); + stream->head_left -= *n; + return 0; + } + + memcpy(p, buf, stream->head_left); + stream->payload_length = (stream->head[2] << 16) + + (stream->head[1] << 8) + + stream->head[0]; + stream->payload_left = stream->payload_length; + stream->sequence_id = stream->head[3]; + if (stream->bufsize < stream->length + stream->payload_left) + { + size_t new_size = MAX(2048, 2 * stream->bufsize); + void *new_base; + + while (new_size < stream->length + stream->payload_left) + new_size *= 2; + + new_base = realloc(stream->buf, new_size); + if (!new_base) + return -1; + + stream->buf = new_base; + stream->bufsize = new_size; + } + + *n = stream->head_left; + stream->write = __mysql_stream_write_payload; + return 0; +} + +static int __mysql_stream_write_payload(const void *buf, size_t *n, + mysql_stream_t *stream) +{ + char *p = (char *)stream->buf + stream->length; + + if (*n < stream->payload_left) + { + memcpy(p, buf, *n); + stream->length += *n; + stream->payload_left -= *n; + return 0; + } + + memcpy(p, buf, stream->payload_left); + stream->length += stream->payload_left; + + *n = stream->payload_left; + stream->head_left = 4; + stream->write = __mysql_stream_write_head; + return stream->payload_length != (1 << 24) - 1; +} + +void mysql_stream_init(mysql_stream_t *stream) +{ + stream->head_left = 4; + stream->sequence_id = 0; + stream->buf = NULL; + stream->length = 0; + stream->bufsize = 0; + stream->write = __mysql_stream_write_head; +} + diff --git a/src/protocol/mysql_stream.h b/src/protocol/mysql_stream.h new file mode 100644 index 0000000..b4685bb --- /dev/null +++ b/src/protocol/mysql_stream.h @@ -0,0 +1,72 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _MYSQL_STREAM_H_ +#define _MYSQL_STREAM_H_ + +#include + +typedef struct __mysql_stream +{ + unsigned char head[4]; + unsigned char head_left; + unsigned char sequence_id; + int payload_length; + int payload_left; + void *buf; + size_t length; + size_t bufsize; + int (*write)(const void *, size_t *, struct __mysql_stream *); +} mysql_stream_t; + +#ifdef __cplusplus +extern "C" +{ +#endif + +void mysql_stream_init(mysql_stream_t *stream); + +#ifdef __cplusplus +} +#endif + +static inline int mysql_stream_write(const void *buf, size_t *n, + mysql_stream_t *stream) +{ + return stream->write(buf, n, stream); +} + +static inline int mysql_stream_get_seq(mysql_stream_t *stream) +{ + return stream->sequence_id; +} + +static inline void mysql_stream_get_buf(const void **buf, size_t *length, + mysql_stream_t *stream) +{ + *buf = stream->buf; + *length = stream->length; +} + +static inline void mysql_stream_deinit(mysql_stream_t *stream) +{ + free(stream->buf); +} + +#endif + diff --git a/src/protocol/mysql_types.h b/src/protocol/mysql_types.h new file mode 100644 index 0000000..238ec29 --- /dev/null +++ b/src/protocol/mysql_types.h @@ -0,0 +1,218 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Li Yingxin (liyingxin@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _MYSQL_TYPES_H_ +#define _MYSQL_TYPES_H_ + +#define MYSQL_STATE_LENGTH 5 +#define MYSQL_STATE_DEFAULT "HY000" + +#define MYSQL_SERVER_MORE_RESULTS_EXIST 0x0008 +#define MYSQL_SERVER_SESSION_STATE_CHANGED 0x4000 + +enum +{ + MYSQL_COM_SLEEP, + MYSQL_COM_QUIT, + MYSQL_COM_INIT_DB, + MYSQL_COM_QUERY, + MYSQL_COM_FIELD_LIST, + MYSQL_COM_CREATE_DB, + MYSQL_COM_DROP_DB, + MYSQL_COM_REFRESH, + MYSQL_COM_DEPRECATED_1, + MYSQL_COM_STATISTICS, + MYSQL_COM_PROCESS_INFO, + MYSQL_COM_CONNECT, + MYSQL_COM_PROCESS_KILL, + MYSQL_COM_DEBUG, + MYSQL_COM_PING, + MYSQL_COM_TIME, + MYSQL_COM_DELAYED_INSERT, + MYSQL_COM_CHANGE_USER, + MYSQL_COM_BINLOG_DUMP, + MYSQL_COM_TABLE_DUMP, + MYSQL_COM_CONNECT_OUT, + MYSQL_COM_REGISTER_SLAVE, + MYSQL_COM_STMT_PREPARE, + MYSQL_COM_STMT_EXECUTE, + MYSQL_COM_STMT_SEND_LONG_DATA, + MYSQL_COM_STMT_CLOSE, + MYSQL_COM_STMT_RESET, + MYSQL_COM_SET_OPTION, + MYSQL_COM_STMT_FETCH, + MYSQL_COM_DAEMON, + MYSQL_COM_BINLOG_DUMP_GTID, + MYSQL_COM_RESET_CONNECTION, + MYSQL_COM_CLONE, + MYSQL_COM_END +}; + +// MySQL packet type +enum +{ + MYSQL_PACKET_OTHER = 0, + MYSQL_PACKET_OK, + MYSQL_PACKET_NULL, + MYSQL_PACKET_EOF, + MYSQL_PACKET_ERROR, + MYSQL_PACKET_GET_RESULT, + MYSQL_PACKET_LOCAL_INLINE, +}; + +// MySQL cursor status +enum +{ + MYSQL_STATUS_NOT_INIT = 0, + MYSQL_STATUS_OK, + MYSQL_STATUS_GET_RESULT, + MYSQL_STATUS_ERROR, + MYSQL_STATUS_END, +}; + +// Column types for MySQL +enum +{ + MYSQL_TYPE_DECIMAL = 0, + MYSQL_TYPE_TINY, + MYSQL_TYPE_SHORT, + MYSQL_TYPE_LONG, + MYSQL_TYPE_FLOAT, + MYSQL_TYPE_DOUBLE, + MYSQL_TYPE_NULL, + MYSQL_TYPE_TIMESTAMP, + MYSQL_TYPE_LONGLONG, + MYSQL_TYPE_INT24, + MYSQL_TYPE_DATE, + MYSQL_TYPE_TIME, + MYSQL_TYPE_DATETIME, + MYSQL_TYPE_YEAR, + MYSQL_TYPE_NEWDATE, // Internal to MySQL. Not used in protocol + MYSQL_TYPE_VARCHAR, + MYSQL_TYPE_BIT, + MYSQL_TYPE_TIMESTAMP2, + MYSQL_TYPE_DATETIME2, // Internal to MySQL. Not used in protocol + MYSQL_TYPE_TIME2, // Internal to MySQL. Not used in protocol + MYSQL_TYPE_TYPED_ARRAY = 244, // Used for replication only + MYSQL_TYPE_JSON = 245, + MYSQL_TYPE_NEWDECIMAL = 246, + MYSQL_TYPE_ENUM = 247, + MYSQL_TYPE_SET = 248, + MYSQL_TYPE_TINY_BLOB = 249, + MYSQL_TYPE_MEDIUM_BLOB = 250, + MYSQL_TYPE_LONG_BLOB = 251, + MYSQL_TYPE_BLOB = 252, + MYSQL_TYPE_VAR_STRING = 253, + MYSQL_TYPE_STRING = 254, + MYSQL_TYPE_GEOMETRY = 255 +}; + +static inline const char *datatype2str(int data_type) +{ + switch (data_type) + { + case MYSQL_TYPE_BIT: + return "BIT"; + + case MYSQL_TYPE_BLOB: + return "BLOB"; + + case MYSQL_TYPE_DATE: + return "DATE"; + + case MYSQL_TYPE_DATETIME: + return "DATETIME"; + + case MYSQL_TYPE_NEWDECIMAL: + return "NEWDECIMAL"; + + case MYSQL_TYPE_DECIMAL: + return "DECIMAL"; + + case MYSQL_TYPE_DOUBLE: + return "DOUBLE"; + + case MYSQL_TYPE_ENUM: + return "ENUM"; + + case MYSQL_TYPE_FLOAT: + return "FLOAT"; + + case MYSQL_TYPE_GEOMETRY: + return "GEOMETRY"; + + case MYSQL_TYPE_INT24: + return "INT24"; + + case MYSQL_TYPE_JSON: + return "JSON"; + + case MYSQL_TYPE_LONG: + return "LONG"; + + case MYSQL_TYPE_LONGLONG: + return "LONGLONG"; + + case MYSQL_TYPE_LONG_BLOB: + return "LONG_BLOB"; + + case MYSQL_TYPE_MEDIUM_BLOB: + return "MEDIUM_BLOB"; + + case MYSQL_TYPE_NEWDATE: + return "NEWDATE"; + + case MYSQL_TYPE_NULL: + return "NULL"; + + case MYSQL_TYPE_SET: + return "SET"; + + case MYSQL_TYPE_SHORT: + return "SHORT"; + + case MYSQL_TYPE_STRING: + return "STRING"; + + case MYSQL_TYPE_TIME: + return "TIME"; + + case MYSQL_TYPE_TIMESTAMP: + return "TIMESTAMP"; + + case MYSQL_TYPE_TINY: + return "TINY"; + + case MYSQL_TYPE_TINY_BLOB: + return "TINY_BLOB"; + + case MYSQL_TYPE_VAR_STRING: + return "VAR_STRING"; + + case MYSQL_TYPE_YEAR: + return "YEAR"; + + default: + return "?-unknown-?"; + + } +} + +#endif + diff --git a/src/protocol/redis_parser.c b/src/protocol/redis_parser.c new file mode 100644 index 0000000..e18a0e7 --- /dev/null +++ b/src/protocol/redis_parser.c @@ -0,0 +1,505 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) + Liu Kai (liukaidx@sogou-inc.com) +*/ + +#include +#include +#include +#include "list.h" +#include "redis_parser.h" + +#define MIN(x, y) ((x) <= (y) ? (x) : (y)) +#define MAX(x, y) ((x) >= (y) ? (x) : (y)) + +#define REDIS_MSGBUF_INIT_SIZE 8 +#define REDIS_REPLY_DEPTH_LIMIT 64 +#define REDIS_ARRAY_SIZE_LIMIT (4 * 1024 * 1024) + +enum +{ + //REDIS_PARSE_INIT = 0, + REDIS_GET_CMD = 1, + REDIS_GET_CR, + REDIS_GET_LF, + REDIS_UNTIL_CRLF, + REDIS_GET_NCHAR, + REDIS_PARSE_END +}; + +struct __redis_read_record +{ + struct list_head list; + redis_reply_t *reply; +}; + +void redis_reply_deinit(redis_reply_t *reply) +{ + size_t i; + + for (i = 0; i < reply->elements; i++) + { + redis_reply_deinit(reply->element[i]); + free(reply->element[i]); + } + + free(reply->element); +} + +static redis_reply_t **__redis_create_array(size_t size, redis_reply_t *reply) +{ + size_t elements = 0; + redis_reply_t **element = (redis_reply_t **)malloc(size * sizeof (void *)); + + if (element) + { + size_t i; + + for (i = 0; i < size; i++) + { + element[i] = (redis_reply_t *)malloc(sizeof (redis_reply_t)); + if (element[i]) + { + redis_reply_init(element[i]); + elements++; + continue; + } + + break; + } + + if (elements == size) + return element; + + while (elements > 0) + free(element[--elements]); + + free(element); + } + + return NULL; +} + +int redis_reply_set_array(size_t size, redis_reply_t *reply) +{ + redis_reply_t **element = __redis_create_array(size, reply); + + if (element == NULL) + return -1; + + redis_reply_deinit(reply); + reply->element = element; + reply->elements = size; + reply->type = REDIS_REPLY_TYPE_ARRAY; + return 0; +} + +static int __redis_parse_cmd(const char ch, redis_parser_t *parser) +{ + switch (ch) + { + case '+': + case '-': + case ':': + case '$': + case '*': + parser->cmd = ch; + parser->status = REDIS_UNTIL_CRLF; + parser->findidx = parser->msgidx; + return 0; + } + + return -2; +} + +static int __redis_parse_cr(const char ch, redis_parser_t *parser) +{ + if (ch != '\r') + return -2; + + parser->status = REDIS_GET_LF; + return 0; +} + +static int __redis_parse_lf(const char ch, redis_parser_t *parser) +{ + if (ch != '\n') + return -2; + + return 1; +} + +static int __redis_parse_line(redis_parser_t *parser) +{ + char *str = parser->msgbuf + parser->msgidx; + size_t slen = parser->findidx - parser->msgidx; + char data[32]; + int i, n; + const char *offset = (const char *)parser->msgidx; + struct __redis_read_record *node; + + parser->msgidx = parser->findidx + 2; + switch (parser->cmd) + { + case '+': + redis_reply_set_status(offset, slen, parser->cur); + return 1; + + case '-': + redis_reply_set_error(offset, slen, parser->cur); + return 1; + + case ':': + if (slen == 0 || slen > 30) + return -2; + + memcpy(data, str, slen); + data[slen] = '\0'; + redis_reply_set_integer(atoll(data), parser->cur); + return 1; + + case '$': + n = atoi(str); + if (n < 0) + { + redis_reply_set_null(parser->cur); + return 1; + } + else if (n == 0) + { + /* "-0" not acceptable. */ + if (!isdigit(*str)) + return -2; + + redis_reply_set_string(offset, 0, parser->cur); + parser->status = REDIS_GET_CR; + return 0; + } + + parser->nchar = n; + parser->status = REDIS_GET_NCHAR; + return 0; + + case '*': + n = atoi(str); + if (n < 0) + { + redis_reply_set_null(parser->cur); + return 1; + } + + if (n == 0 && !isdigit(*str)) + return -2; + + if (n > REDIS_ARRAY_SIZE_LIMIT) + return -2; + + parser->nleft += n; + if (redis_reply_set_array(n, parser->cur) < 0) + return -1; + + if (n == 0) + return 1; + + parser->nleft--; + for (i = 0; i < n - 1; i++) + { + node = (struct __redis_read_record *)malloc(sizeof *node); + if (!node) + return -1; + + node->reply = parser->cur->element[n - 1 - i]; + list_add(&node->list, &parser->read_list); + } + + parser->cur = parser->cur->element[0]; + parser->status = REDIS_GET_CMD; + return 0; + + } + + return -2; +} + +static int __redis_parse_crlf(redis_parser_t *parser) +{ + char *buf = parser->msgbuf; + + for (; parser->findidx + 1 < parser->msgsize; parser->findidx++) + { + if (buf[parser->findidx] == '\r' && buf[parser->findidx + 1] == '\n') + return __redis_parse_line(parser); + } + + return 2; +} + +static int __redis_parse_nchar(redis_parser_t *parser) +{ + if (parser->nchar <= parser->msgsize - parser->msgidx) + { + redis_reply_set_string((const char *)parser->msgidx, parser->nchar, + parser->cur); + + parser->msgidx += parser->nchar; + parser->status = REDIS_GET_CR; + return 0; + } + + return 2; +} + +//-1 error | 0 continue | 1 finish-one | 2 not-enough +static int __redis_parser_forward(redis_parser_t *parser) +{ + char *buf = parser->msgbuf; + + if (parser->msgidx >= parser->msgsize) + return 2; + + switch (parser->status) + { + case REDIS_GET_CMD: + return __redis_parse_cmd(buf[parser->msgidx++], parser); + + case REDIS_GET_CR: + return __redis_parse_cr(buf[parser->msgidx++], parser); + + case REDIS_GET_LF: + return __redis_parse_lf(buf[parser->msgidx++], parser); + + case REDIS_UNTIL_CRLF: + return __redis_parse_crlf(parser); + + case REDIS_GET_NCHAR: + return __redis_parse_nchar(parser); + } + + return -2; +} + +void redis_parser_init(redis_parser_t *parser) +{ + redis_reply_init(&parser->reply); + parser->parse_succ = 0; + parser->msgbuf = NULL; + parser->msgsize = 0; + parser->bufsize = 0; + //parser->status = REDIS_PARSE_INIT; + //parser->nleft = 0; + parser->status = REDIS_GET_CMD; + parser->nleft = 1; + parser->cur = &parser->reply; + INIT_LIST_HEAD(&parser->read_list); + parser->msgidx = 0; + parser->cmd = '\0'; + parser->nchar = 0; + parser->findidx = 0; +} + +void redis_parser_deinit(redis_parser_t *parser) +{ + struct list_head *pos, *tmp; + struct __redis_read_record *next; + + list_for_each_safe(pos, tmp, &parser->read_list) + { + next = list_entry(pos, struct __redis_read_record, list); + list_del(pos); + free(next); + } + + redis_reply_deinit(&parser->reply); + free(parser->msgbuf); +} + +static int __redis_parse_done(redis_reply_t *reply, char *buf, int depth) +{ + size_t i; + + if (depth == REDIS_REPLY_DEPTH_LIMIT) + return -2; + + switch (reply->type) + { + case REDIS_REPLY_TYPE_INTEGER: + break; + + case REDIS_REPLY_TYPE_ARRAY: + for (i = 0; i < reply->elements; i++) + { + if (__redis_parse_done(reply->element[i], buf, depth + 1) < 0) + return -2; + } + + break; + + case REDIS_REPLY_TYPE_STATUS: + case REDIS_REPLY_TYPE_ERROR: + case REDIS_REPLY_TYPE_STRING: + reply->str = buf + (size_t)reply->str; + break; + } + + return 1; +} + +static int __redis_split_inline_command(redis_parser_t *parser) +{ + char *msg = parser->msgbuf; + char *end = msg + parser->msgsize; + size_t arr_size = 0; + redis_reply_t **ele; + char *cur; + int ret; + + while (msg != end) + { + while (msg != end && isspace(*msg)) + msg++; + + if (msg == end) + break; + + arr_size++; + + while (msg != end && !isspace(*msg)) + msg++; + } + + if (arr_size == 0) + { + parser->msgsize = 0; + parser->msgidx = 0; + return 0; + } + + ret = redis_reply_set_array(arr_size, &parser->reply); + if (ret < 0) + return ret; + + ele = parser->reply.element; + msg = parser->msgbuf; + + while (msg != end) + { + while (msg != end && isspace(*msg)) + msg++; + + if (msg == end) + break; + + cur = msg; + while (cur != end && !isspace(*cur)) + cur++; + + redis_reply_set_string(msg, cur - msg, *ele); + + msg = cur; + ele++; + } + + parser->status = REDIS_PARSE_END; + return 1; +} + +int redis_parser_append_message(const void *buf, size_t *size, + redis_parser_t *parser) +{ + size_t msgsize_bak = parser->msgsize; + + if (parser->status == REDIS_PARSE_END) + { + *size = 0; + return 1; + } + + if (parser->msgsize + *size > parser->bufsize) + { + size_t new_size = MAX(REDIS_MSGBUF_INIT_SIZE, 2 * parser->bufsize); + void *new_base; + + while (new_size < parser->msgsize + *size) + new_size *= 2; + + new_base = realloc(parser->msgbuf, new_size); + if (!new_base) + return -1; + + parser->msgbuf = (char *)new_base; + parser->bufsize = new_size; + } + + memcpy(parser->msgbuf + parser->msgsize, buf, *size); + parser->msgsize += *size; + if (parser->msgsize > 0 && (isalpha(*parser->msgbuf) || + isspace(*parser->msgbuf))) + { + while (parser->msgidx < parser->msgsize && + *(parser->msgbuf + parser->msgidx) != '\n') + { + parser->msgidx++; + } + + if (parser->msgidx == parser->msgsize) + return 0; + + parser->msgidx++; + parser->msgsize = parser->msgidx; + *size = parser->msgsize - msgsize_bak; + + return __redis_split_inline_command(parser); + } + + do + { + int ret = __redis_parser_forward(parser); + + if (ret < 0) + return ret; + + if (ret == 1) + { + struct list_head *lnext = parser->read_list.next; + struct __redis_read_record *next; + + parser->nleft--; + if (lnext && lnext != &parser->read_list) + { + next = list_entry(lnext, struct __redis_read_record, list); + parser->cur = next->reply; + list_del(lnext); + free(next); + } + + if (parser->nleft > 0) + parser->status = REDIS_GET_CMD; + else + { + parser->parse_succ = 1; + parser->status = REDIS_PARSE_END; + } + } + else if (ret == 2) + return 0; + + } while (parser->status != REDIS_PARSE_END); + + *size = parser->msgidx - msgsize_bak; + return __redis_parse_done(&parser->reply, parser->msgbuf, 0); +} + diff --git a/src/protocol/redis_parser.h b/src/protocol/redis_parser.h new file mode 100644 index 0000000..29e6188 --- /dev/null +++ b/src/protocol/redis_parser.h @@ -0,0 +1,125 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _REDIS_PARSER_H_ +#define _REDIS_PARSER_H_ + +#include +#include "list.h" + +// redis_parser_t is absolutely same as hiredis-redisReply in memory +// If you include hiredis.h, redisReply* can cast to redis_reply_t* safely + +#define REDIS_REPLY_TYPE_STRING 1 +#define REDIS_REPLY_TYPE_ARRAY 2 +#define REDIS_REPLY_TYPE_INTEGER 3 +#define REDIS_REPLY_TYPE_NIL 4 +#define REDIS_REPLY_TYPE_STATUS 5 +#define REDIS_REPLY_TYPE_ERROR 6 + +typedef struct __redis_reply { + int type; /* REDIS_REPLY_TYPE_* */ + long long integer; /* The integer when type is REDIS_REPLY_TYPE_INTEGER */ + size_t len; /* Length of string */ + char *str; /* Used for both REDIS_REPLY_TYPE_ERROR and REDIS_REPLY_TYPE_STRING */ + size_t elements; /* number of elements, for REDIS_REPLY_TYPE_ARRAY */ + struct __redis_reply **element; /* elements vector for REDIS_REPLY_TYPE_ARRAY */ +} redis_reply_t; + +typedef struct __redis_parser +{ + int parse_succ;//check first + int status; + char *msgbuf; + size_t msgsize; + size_t bufsize; + redis_reply_t *cur; + struct list_head read_list; + size_t msgidx; + size_t findidx; + int nleft; + int nchar; + char cmd; + redis_reply_t reply; +} redis_parser_t; + +#ifdef __cplusplus +extern "C" +{ +#endif + +void redis_parser_init(redis_parser_t *parser); +void redis_parser_deinit(redis_parser_t *parser); +int redis_parser_append_message(const void *buf, size_t *size, + redis_parser_t *parser); + +void redis_reply_deinit(redis_reply_t *reply); + +int redis_reply_set_array(size_t size, redis_reply_t *reply); + +#ifdef __cplusplus +} +#endif + +static inline void redis_reply_init(redis_reply_t *reply) +{ + reply->type = REDIS_REPLY_TYPE_NIL; + reply->integer = 0; + reply->len = 0; + reply->str = NULL; + reply->elements = 0; + reply->element = NULL; +} + +static inline void redis_reply_set_string(const char *str, size_t len, + redis_reply_t *reply) +{ + reply->type = REDIS_REPLY_TYPE_STRING; + reply->len = len; + reply->str = (char *)str; +} + +static inline void redis_reply_set_integer(long long intv, redis_reply_t *reply) +{ + reply->type = REDIS_REPLY_TYPE_INTEGER; + reply->integer = intv; +} + +static inline void redis_reply_set_null(redis_reply_t *reply) +{ + reply->type = REDIS_REPLY_TYPE_NIL; +} + +static inline void redis_reply_set_error(const char *err, size_t len, + redis_reply_t *reply) +{ + reply->type = REDIS_REPLY_TYPE_ERROR; + reply->len = len; + reply->str = (char *)err; +} + +static inline void redis_reply_set_status(const char *str, size_t len, + redis_reply_t *reply) +{ + reply->type = REDIS_REPLY_TYPE_STATUS; + reply->len = len; + reply->str = (char *)str; +} + +#endif + diff --git a/src/protocol/xmake.lua b/src/protocol/xmake.lua new file mode 100644 index 0000000..422659b --- /dev/null +++ b/src/protocol/xmake.lua @@ -0,0 +1,59 @@ +target("basic_protocol") + set_kind("object") + add_files("PackageWrapper.cc", + "SSLWrapper.cc", + "dns_parser.c", + "DnsMessage.cc", + "DnsUtil.cc", + "http_parser.c", + "HttpMessage.cc", + "HttpUtil.cc") + +target("mysql_protocol") + if has_config("mysql") then + add_files("mysql_stream.c", + "mysql_parser.c", + "mysql_byteorder.c", + "MySQLMessage.cc", + "MySQLResult.cc", + "MySQLUtil.cc") + set_kind("object") + add_deps("basic_protocol") + else + set_kind("phony") + end + +target("redis_protocol") + if has_config("redis") then + add_files("redis_parser.c", "RedisMessage.cc") + set_kind("object") + add_deps("basic_protocol") + else + set_kind("phony") + end + +target("protocol") + set_kind("object") + add_deps("basic_protocol", "mysql_protocol", "redis_protocol") + +target("kafka_message") + if has_config("kafka") then + add_files("KafkaMessage.cc") + set_kind("object") + add_cxxflags("-fno-rtti") + add_packages("lz4", "zstd", "zlib", "snappy") + else + set_kind("phony") + end + +target("kafka_protocol") + if has_config("kafka") then + set_kind("object") + add_files("kafka_parser.c", + "KafkaDataTypes.cc", + "KafkaResult.cc") + add_deps("kafka_message", "protocol") + add_packages("zlib", "snappy", "zstd", "lz4") + else + set_kind("phony") + end diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt new file mode 100644 index 0000000..db987f2 --- /dev/null +++ b/src/server/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.6) +project(server) + +set(SRC + WFServer.cc +) + +if (NOT MYSQL STREQUAL "n") + set(SRC + ${SRC} + WFMySQLServer.cc + ) +endif () + +add_library(${PROJECT_NAME} OBJECT ${SRC}) diff --git a/src/server/WFDnsServer.h b/src/server/WFDnsServer.h new file mode 100644 index 0000000..94123c2 --- /dev/null +++ b/src/server/WFDnsServer.h @@ -0,0 +1,62 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#ifndef _WFDNSSERVER_H_ +#define _WFDNSSERVER_H_ + +#include "DnsMessage.h" +#include "WFServer.h" +#include "WFTaskFactory.h" + +using dns_process_t = std::function; +using WFDnsServer = WFServer; + +static constexpr struct WFServerParams DNS_SERVER_PARAMS_DEFAULT = +{ + .transport_type = TT_UDP, + .max_connections = 2000, + .peer_response_timeout = 10 * 1000, + .receive_timeout = -1, + .keep_alive_timeout = 300 * 1000, + .request_size_limit = (size_t)-1, + .ssl_accept_timeout = 5000, +}; + +template<> inline +WFDnsServer::WFServer(dns_process_t proc) : + WFServerBase(&DNS_SERVER_PARAMS_DEFAULT), + process(std::move(proc)) +{ +} + +template<> inline +CommSession *WFDnsServer::new_session(long long seq, CommConnection *conn) +{ + WFDnsTask *task; + + task = WFServerTaskFactory::create_dns_task(this, this->process); + task->set_keep_alive(this->params.keep_alive_timeout); + task->set_receive_timeout(this->params.receive_timeout); + task->get_req()->set_size_limit(this->params.request_size_limit); + + return task; +} + +#endif + diff --git a/src/server/WFHttpServer.h b/src/server/WFHttpServer.h new file mode 100644 index 0000000..c820aa4 --- /dev/null +++ b/src/server/WFHttpServer.h @@ -0,0 +1,63 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFHTTPSERVER_H_ +#define _WFHTTPSERVER_H_ + +#include +#include "HttpMessage.h" +#include "WFServer.h" +#include "WFTaskFactory.h" + +using http_process_t = std::function; +using WFHttpServer = WFServer; + +static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT = +{ + .transport_type = TT_TCP, + .max_connections = 2000, + .peer_response_timeout = 10 * 1000, + .receive_timeout = -1, + .keep_alive_timeout = 60 * 1000, + .request_size_limit = (size_t)-1, + .ssl_accept_timeout = 10 * 1000, +}; + +template<> inline +WFHttpServer::WFServer(http_process_t proc) : + WFServerBase(&HTTP_SERVER_PARAMS_DEFAULT), + process(std::move(proc)) +{ +} + +template<> inline +CommSession *WFHttpServer::new_session(long long seq, CommConnection *conn) +{ + WFHttpTask *task; + + task = WFServerTaskFactory::create_http_task(this, this->process); + task->set_keep_alive(this->params.keep_alive_timeout); + task->set_receive_timeout(this->params.receive_timeout); + task->get_req()->set_size_limit(this->params.request_size_limit); + + return task; +} + +#endif + diff --git a/src/server/WFMySQLServer.cc b/src/server/WFMySQLServer.cc new file mode 100644 index 0000000..d3448f1 --- /dev/null +++ b/src/server/WFMySQLServer.cc @@ -0,0 +1,60 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#include +#include "WFMySQLServer.h" + +WFConnection *WFMySQLServer::new_connection(int accept_fd) +{ + WFConnection *conn = this->WFServer::new_connection(accept_fd); + + if (conn) + { + protocol::MySQLHandshakeResponse resp; + struct iovec vec[8]; + int count; + + resp.server_set(0x0a, "5.5", 1, (const uint8_t *)"12345678901234567890", + 0, 33, 0); + count = resp.encode(vec, 8); + if (count >= 0) + { + if (writev(accept_fd, vec, count) >= 0) + return conn; + } + + this->delete_connection(conn); + } + + return NULL; +} + +CommSession *WFMySQLServer::new_session(long long seq, CommConnection *conn) +{ + static mysql_process_t empty = [](WFMySQLTask *){ }; + WFMySQLTask *task; + + task = WFServerTaskFactory::create_mysql_task(this, seq ? this->process : + empty); + task->set_keep_alive(this->params.keep_alive_timeout); + task->set_receive_timeout(this->params.receive_timeout); + task->get_req()->set_size_limit(this->params.request_size_limit); + + return task; +} + diff --git a/src/server/WFMySQLServer.h b/src/server/WFMySQLServer.h new file mode 100644 index 0000000..1d6a46f --- /dev/null +++ b/src/server/WFMySQLServer.h @@ -0,0 +1,57 @@ +/* + Copyright (c) 2020 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _WFMYSQLSERVER_H_ +#define _WFMYSQLSERVER_H_ + +#include +#include "MySQLMessage.h" +#include "WFServer.h" +#include "WFTaskFactory.h" +#include "WFConnection.h" + +using mysql_process_t = std::function; +class MySQLServer; + +static constexpr struct WFServerParams MYSQL_SERVER_PARAMS_DEFAULT = +{ + .transport_type = TT_TCP, + .max_connections = 2000, + .peer_response_timeout = 10 * 1000, + .receive_timeout = -1, + .keep_alive_timeout = 28800 * 1000, + .request_size_limit = (size_t)-1, + .ssl_accept_timeout = 10 * 1000, +}; + +class WFMySQLServer : public WFServer +{ +public: + WFMySQLServer(mysql_process_t proc): + WFServer(&MYSQL_SERVER_PARAMS_DEFAULT, std::move(proc)) + { + } + +protected: + virtual WFConnection *new_connection(int accept_fd); + virtual CommSession *new_session(long long seq, CommConnection *conn); +}; + +#endif + diff --git a/src/server/WFRedisServer.h b/src/server/WFRedisServer.h new file mode 100644 index 0000000..a735cbb --- /dev/null +++ b/src/server/WFRedisServer.h @@ -0,0 +1,49 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFREDISSERVER_H_ +#define _WFREDISSERVER_H_ + +#include "RedisMessage.h" +#include "WFServer.h" +#include "WFTaskFactory.h" + +using redis_process_t = std::function; +using WFRedisServer = WFServer; + +static constexpr struct WFServerParams REDIS_SERVER_PARAMS_DEFAULT = +{ + .transport_type = TT_TCP, + .max_connections = 2000, + .peer_response_timeout = 10 * 1000, + .receive_timeout = -1, + .keep_alive_timeout = 300 * 1000, + .request_size_limit = (size_t)-1, + .ssl_accept_timeout = 5000, +}; + +template<> inline +WFRedisServer::WFServer(redis_process_t proc) : + WFServerBase(&REDIS_SERVER_PARAMS_DEFAULT), + process(std::move(proc)) +{ +} + +#endif + diff --git a/src/server/WFServer.cc b/src/server/WFServer.cc new file mode 100644 index 0000000..7dd6724 --- /dev/null +++ b/src/server/WFServer.cc @@ -0,0 +1,288 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "CommScheduler.h" +#include "EndpointParams.h" +#include "WFConnection.h" +#include "WFGlobal.h" +#include "WFServer.h" + +#define PORT_STR_MAX 5 + +class WFServerConnection : public WFConnection +{ +public: + WFServerConnection(std::atomic *conn_count) + { + this->conn_count = conn_count; + } + + virtual ~WFServerConnection() + { + (*this->conn_count)--; + } + +private: + std::atomic *conn_count; +}; + +int WFServerBase::ssl_ctx_callback(SSL *ssl, int *al, void *arg) +{ + WFServerBase *server = (WFServerBase *)arg; + const char *servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + SSL_CTX *ssl_ctx = server->get_server_ssl_ctx(servername); + + if (!ssl_ctx) + return SSL_TLSEXT_ERR_NOACK; + + if (ssl_ctx != server->get_ssl_ctx()) + SSL_set_SSL_CTX(ssl, ssl_ctx); + + return SSL_TLSEXT_ERR_OK; +} + +SSL_CTX *WFServerBase::new_ssl_ctx(const char *cert_file, const char *key_file) +{ + SSL_CTX *ssl_ctx = WFGlobal::new_ssl_server_ctx(); + + if (!ssl_ctx) + return NULL; + + if (SSL_CTX_use_certificate_chain_file(ssl_ctx, cert_file) > 0 && + SSL_CTX_use_PrivateKey_file(ssl_ctx, key_file, SSL_FILETYPE_PEM) > 0 && + SSL_CTX_check_private_key(ssl_ctx) > 0 && + SSL_CTX_set_tlsext_servername_callback(ssl_ctx, ssl_ctx_callback) > 0 && + SSL_CTX_set_tlsext_servername_arg(ssl_ctx, this) > 0) + { + return ssl_ctx; + } + + SSL_CTX_free(ssl_ctx); + return NULL; +} + +int WFServerBase::init(const struct sockaddr *bind_addr, socklen_t addrlen, + const char *cert_file, const char *key_file) +{ + int timeout = this->params.peer_response_timeout; + + if (this->params.receive_timeout >= 0) + { + if ((unsigned int)timeout > (unsigned int)this->params.receive_timeout) + timeout = this->params.receive_timeout; + } + + if (this->params.transport_type == TT_TCP_SSL || + this->params.transport_type == TT_SCTP_SSL) + { + if (!cert_file || !key_file) + { + errno = EINVAL; + return -1; + } + } + + if (this->CommService::init(bind_addr, addrlen, -1, timeout) < 0) + return -1; + + if (cert_file && key_file && this->params.transport_type != TT_UDP) + { + SSL_CTX *ssl_ctx = this->new_ssl_ctx(cert_file, key_file); + + if (!ssl_ctx) + { + this->deinit(); + return -1; + } + + this->set_ssl(ssl_ctx, this->params.ssl_accept_timeout); + } + + this->scheduler = WFGlobal::get_scheduler(); + return 0; +} + +int WFServerBase::create_listen_fd() +{ + if (this->listen_fd < 0) + { + const struct sockaddr *bind_addr; + socklen_t addrlen; + int type, protocol; + int reuse = 1; + + switch (this->params.transport_type) + { + case TT_TCP: + case TT_TCP_SSL: + type = SOCK_STREAM; + protocol = 0; + break; + case TT_UDP: + type = SOCK_DGRAM; + protocol = 0; + break; +#ifdef IPPROTO_SCTP + case TT_SCTP: + case TT_SCTP_SSL: + type = SOCK_STREAM; + protocol = IPPROTO_SCTP; + break; +#endif + default: + errno = EPROTONOSUPPORT; + return -1; + } + + this->get_addr(&bind_addr, &addrlen); + this->listen_fd = socket(bind_addr->sa_family, type, protocol); + if (this->listen_fd >= 0) + { + setsockopt(this->listen_fd, SOL_SOCKET, SO_REUSEADDR, + &reuse, sizeof (int)); + } + } + else + this->listen_fd = dup(this->listen_fd); + + return this->listen_fd; +} + +WFConnection *WFServerBase::new_connection(int accept_fd) +{ + if (++this->conn_count <= this->params.max_connections || + this->drain(1) == 1) + { + int reuse = 1; + setsockopt(accept_fd, SOL_SOCKET, SO_REUSEADDR, + &reuse, sizeof (int)); + return new WFServerConnection(&this->conn_count); + } + + this->conn_count--; + errno = EMFILE; + return NULL; +} + +void WFServerBase::delete_connection(WFConnection *conn) +{ + delete (WFServerConnection *)conn; +} + +void WFServerBase::handle_unbound() +{ + this->mutex.lock(); + this->unbind_finish = true; + this->cond.notify_one(); + this->mutex.unlock(); +} + +int WFServerBase::start(const struct sockaddr *bind_addr, socklen_t addrlen, + const char *cert_file, const char *key_file) +{ + SSL_CTX *ssl_ctx; + + if (this->init(bind_addr, addrlen, cert_file, key_file) >= 0) + { + if (this->scheduler->bind(this) >= 0) + return 0; + + ssl_ctx = this->get_ssl_ctx(); + this->deinit(); + if (ssl_ctx) + SSL_CTX_free(ssl_ctx); + } + + this->listen_fd = -1; + return -1; +} + +int WFServerBase::start(int family, const char *host, unsigned short port, + const char *cert_file, const char *key_file) +{ + struct addrinfo hints = { + .ai_flags = AI_PASSIVE, + .ai_family = family, + .ai_socktype = SOCK_STREAM, + }; + struct addrinfo *addrinfo; + char port_str[PORT_STR_MAX + 1]; + int ret; + + snprintf(port_str, PORT_STR_MAX + 1, "%d", port); + ret = getaddrinfo(host, port_str, &hints, &addrinfo); + if (ret == 0) + { + ret = start(addrinfo->ai_addr, (socklen_t)addrinfo->ai_addrlen, + cert_file, key_file); + freeaddrinfo(addrinfo); + } + else + { + if (ret != EAI_SYSTEM) + errno = EINVAL; + ret = -1; + } + + return ret; +} + +int WFServerBase::serve(int listen_fd, + const char *cert_file, const char *key_file) +{ + struct sockaddr_storage ss; + socklen_t len = sizeof ss; + + if (getsockname(listen_fd, (struct sockaddr *)&ss, &len) < 0) + return -1; + + this->listen_fd = listen_fd; + return start((struct sockaddr *)&ss, len, cert_file, key_file); +} + +void WFServerBase::shutdown() +{ + this->listen_fd = -1; + this->scheduler->unbind(this); +} + +void WFServerBase::wait_finish() +{ + SSL_CTX *ssl_ctx = this->get_ssl_ctx(); + std::unique_lock lock(this->mutex); + + while (!this->unbind_finish) + this->cond.wait(lock); + + this->deinit(); + this->unbind_finish = false; + lock.unlock(); + if (ssl_ctx) + SSL_CTX_free(ssl_ctx); +} + diff --git a/src/server/WFServer.h b/src/server/WFServer.h new file mode 100644 index 0000000..ed6e70e --- /dev/null +++ b/src/server/WFServer.h @@ -0,0 +1,244 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _WFSERVER_H_ +#define _WFSERVER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "EndpointParams.h" +#include "WFTaskFactory.h" + +struct WFServerParams +{ + enum TransportType transport_type; + size_t max_connections; + int peer_response_timeout; /* timeout of each read or write operation */ + int receive_timeout; /* timeout of receiving the whole message */ + int keep_alive_timeout; + size_t request_size_limit; + int ssl_accept_timeout; /* if not ssl, this will be ignored */ +}; + +static constexpr struct WFServerParams SERVER_PARAMS_DEFAULT = +{ + .transport_type = TT_TCP, + .max_connections = 2000, + .peer_response_timeout = 10 * 1000, + .receive_timeout = -1, + .keep_alive_timeout = 60 * 1000, + .request_size_limit = (size_t)-1, + .ssl_accept_timeout = 10 * 1000, +}; + +class WFServerBase : protected CommService +{ +public: + WFServerBase(const struct WFServerParams *params) : + conn_count(0) + { + this->params = *params; + this->unbind_finish = false; + this->listen_fd = -1; + } + +public: + /* To start a TCP server */ + + /* Start on port with IPv4. */ + int start(unsigned short port) + { + return start(AF_INET, NULL, port, NULL, NULL); + } + + /* Start with family. AF_INET or AF_INET6. */ + int start(int family, unsigned short port) + { + return start(family, NULL, port, NULL, NULL); + } + + /* Start with hostname and port. */ + int start(const char *host, unsigned short port) + { + return start(AF_INET, host, port, NULL, NULL); + } + + /* Start with family, hostname and port. */ + int start(int family, const char *host, unsigned short port) + { + return start(family, host, port, NULL, NULL); + } + + /* Start with binding address. */ + int start(const struct sockaddr *bind_addr, socklen_t addrlen) + { + return start(bind_addr, addrlen, NULL, NULL); + } + + /* To start an SSL server. */ + + int start(unsigned short port, const char *cert_file, const char *key_file) + { + return start(AF_INET, NULL, port, cert_file, key_file); + } + + int start(int family, unsigned short port, + const char *cert_file, const char *key_file) + { + return start(family, NULL, port, cert_file, key_file); + } + + int start(const char *host, unsigned short port, + const char *cert_file, const char *key_file) + { + return start(AF_INET, host, port, cert_file, key_file); + } + + int start(int family, const char *host, unsigned short port, + const char *cert_file, const char *key_file); + + /* This is the only necessary start function. */ + int start(const struct sockaddr *bind_addr, socklen_t addrlen, + const char *cert_file, const char *key_file); + + /* To start with a specified fd. For graceful restart or SCTP server. */ + int serve(int listen_fd) + { + return serve(listen_fd, NULL, NULL); + } + + int serve(int listen_fd, const char *cert_file, const char *key_file); + + /* stop() is a blocking operation. */ + void stop() + { + this->shutdown(); + this->wait_finish(); + } + + /* Nonblocking terminating the server. For stopping multiple servers. + * Typically, call shutdown() and then wait_finish(). + * But indeed wait_finish() can be called before shutdown(), even before + * start() in another thread. */ + void shutdown(); + void wait_finish(); + +public: + size_t get_conn_count() const { return this->conn_count; } + + /* Get the listening address. This is often used after starting + * server on a random port (start() with port == 0). */ + int get_listen_addr(struct sockaddr *addr, socklen_t *addrlen) const + { + if (this->listen_fd >= 0) + return getsockname(this->listen_fd, addr, addrlen); + + errno = ENOTCONN; + return -1; + } + + const struct WFServerParams *get_params() const { return &this->params; } + +protected: + /* Override this function to create the initial SSL CTX of the server */ + virtual SSL_CTX *new_ssl_ctx(const char *cert_file, const char *key_file); + + /* Override this function to implement server that supports TLS SNI. + * "servername" will be NULL if client does not set a host name. + * Returning NULL to indicate that servername is not supported. */ + virtual SSL_CTX *get_server_ssl_ctx(const char *servername) + { + return this->get_ssl_ctx(); + } + + /* This can be used by the implementation of 'new_ssl_ctx'. */ + static int ssl_ctx_callback(SSL *ssl, int *al, void *arg); + +protected: + WFServerParams params; + +protected: + virtual int create_listen_fd(); + virtual WFConnection *new_connection(int accept_fd); + void delete_connection(WFConnection *conn); + +private: + int init(const struct sockaddr *bind_addr, socklen_t addrlen, + const char *cert_file, const char *key_file); + virtual void handle_unbound(); + +protected: + std::atomic conn_count; + +private: + int listen_fd; + bool unbind_finish; + + std::mutex mutex; + std::condition_variable cond; + + class CommScheduler *scheduler; +}; + +template +class WFServer : public WFServerBase +{ +public: + WFServer(const struct WFServerParams *params, + std::function *)> proc) : + WFServerBase(params), + process(std::move(proc)) + { + } + + WFServer(std::function *)> proc) : + WFServerBase(&SERVER_PARAMS_DEFAULT), + process(std::move(proc)) + { + } + +protected: + virtual CommSession *new_session(long long seq, CommConnection *conn); + +protected: + std::function *)> process; +}; + +template +CommSession *WFServer::new_session(long long seq, CommConnection *conn) +{ + using factory = WFNetworkTaskFactory; + WFNetworkTask *task; + + task = factory::create_server_task(this, this->process); + task->set_keep_alive(this->params.keep_alive_timeout); + task->set_receive_timeout(this->params.receive_timeout); + task->get_req()->set_size_limit(this->params.request_size_limit); + + return task; +} + +#endif + diff --git a/src/server/xmake.lua b/src/server/xmake.lua new file mode 100644 index 0000000..4d9086c --- /dev/null +++ b/src/server/xmake.lua @@ -0,0 +1,6 @@ +target("server") + set_kind("object") + add_files("*.cc") + if not has_config("mysql") then + remove_files("WFMySQLServer.cc") + end diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt new file mode 100644 index 0000000..74d4f19 --- /dev/null +++ b/src/util/CMakeLists.txt @@ -0,0 +1,18 @@ +cmake_minimum_required(VERSION 3.6) +project(util) + +set(SRC + json_parser.c + EncodeStream.cc + StringUtil.cc + URIParser.cc +) + +add_library(${PROJECT_NAME} OBJECT ${SRC}) + +if (KAFKA STREQUAL "y") + set(SRC + crc32c.c + ) + add_library("util_kafka" OBJECT ${SRC}) +endif () diff --git a/src/util/EncodeStream.cc b/src/util/EncodeStream.cc new file mode 100644 index 0000000..31db698 --- /dev/null +++ b/src/util/EncodeStream.cc @@ -0,0 +1,137 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include "list.h" +#include "EncodeStream.h" + +#define ALIGN(x,a) (((x)+(a)-1)&~((a)-1)) +#define ENCODE_BUF_SIZE 1024 + +struct EncodeBuf +{ + struct list_head list; + char *pos; + char data[ENCODE_BUF_SIZE]; +}; + +void EncodeStream::clear_buf_data() +{ + struct list_head *pos, *tmp; + struct EncodeBuf *entry; + + list_for_each_safe(pos, tmp, &buf_list_) + { + entry = list_entry(pos, struct EncodeBuf, list); + list_del(pos); + delete [](char *)entry; + } +} + +void EncodeStream::merge() +{ + size_t len = bytes_ - merged_bytes_; + struct EncodeBuf *buf; + size_t n; + char *p; + int i; + + if (len > ENCODE_BUF_SIZE) + n = offsetof(struct EncodeBuf, data) + ALIGN(len, 8); + else + n = sizeof (struct EncodeBuf); + + buf = (struct EncodeBuf *)new char[n]; + p = buf->data; + for (i = merged_size_; i < size_; i++) + { + memcpy(p, vec_[i].iov_base, vec_[i].iov_len); + p += vec_[i].iov_len; + } + + buf->pos = buf->data + ALIGN(len, 8); + list_add(&buf->list, &buf_list_); + + vec_[merged_size_].iov_base = buf->data; + vec_[merged_size_].iov_len = len; + merged_size_++; + merged_bytes_ = bytes_; + size_ = merged_size_; +} + +void EncodeStream::append_nocopy(const char *data, size_t len) +{ + if (size_ >= max_) + { + if (merged_size_ + 1 < max_) + merge(); + else + { + size_ = max_ + 1; /* Overflow */ + return; + } + } + + vec_[size_].iov_base = (char *)data; + vec_[size_].iov_len = len; + size_++; + bytes_ += len; +} + +void EncodeStream::append_copy(const char *data, size_t len) +{ + if (size_ >= max_) + { + if (merged_size_ + 1 < max_) + merge(); + else + { + size_ = max_ + 1; /* Overflow */ + return; + } + } + + struct EncodeBuf *buf = list_entry(buf_list_.prev, struct EncodeBuf, list); + + if (list_empty(&buf_list_) || buf->pos + len > buf->data + ENCODE_BUF_SIZE) + { + size_t n; + + if (len > ENCODE_BUF_SIZE) + n = offsetof(struct EncodeBuf, data) + ALIGN(len, 8); + else + n = sizeof (struct EncodeBuf); + + buf = (struct EncodeBuf *)new char[n]; + buf->pos = buf->data; + list_add_tail(&buf->list, &buf_list_); + } + + memcpy(buf->pos, data, len); + vec_[size_].iov_base = buf->pos; + vec_[size_].iov_len = len; + size_++; + bytes_ += len; + + buf->pos += ALIGN(len, 8); + if (buf->pos >= buf->data + ENCODE_BUF_SIZE) + list_move(&buf->list, &buf_list_); +} + diff --git a/src/util/EncodeStream.h b/src/util/EncodeStream.h new file mode 100644 index 0000000..0046399 --- /dev/null +++ b/src/util/EncodeStream.h @@ -0,0 +1,138 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) + Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _ENCODESTREAM_H_ +#define _ENCODESTREAM_H_ + +#include +#include +#include +#include +#include "list.h" + +/** + * @file EncodeStream.h + * @brief Encoder toolbox for protocol message encode + */ + +// make sure max > 0 +class EncodeStream +{ +public: + EncodeStream() + { + init_vec(NULL, 0); + INIT_LIST_HEAD(&buf_list_); + } + + EncodeStream(struct iovec *vectors, int max) + { + init_vec(vectors, max); + INIT_LIST_HEAD(&buf_list_); + } + + ~EncodeStream() { clear_buf_data(); } + + void reset(struct iovec *vectors, int max) + { + clear_buf_data(); + init_vec(vectors, max); + } + + int size() const { return size_; } + size_t bytes() const { return bytes_; } + + void append_nocopy(const char *data, size_t len); + + void append_nocopy(const char *data) + { + append_nocopy(data, strlen(data)); + } + + void append_nocopy(const std::string& data) + { + append_nocopy(data.c_str(), data.size()); + } + + void append_copy(const char *data, size_t len); + + void append_copy(const char *data) + { + append_copy(data, strlen(data)); + } + + void append_copy(const std::string& data) + { + append_copy(data.c_str(), data.size()); + } + +private: + void init_vec(struct iovec *vectors, int max) + { + vec_ = vectors; + max_ = max; + bytes_ = 0; + size_ = 0; + merged_bytes_ = 0; + merged_size_ = 0; + } + + void merge(); + void clear_buf_data(); + +private: + struct iovec *vec_; + int max_; + int size_; + size_t bytes_; + int merged_size_; + size_t merged_bytes_; + struct list_head buf_list_; +}; + +static inline EncodeStream& operator << (EncodeStream& stream, + const char *data) +{ + stream.append_nocopy(data, strlen(data)); + return stream; +} + +static inline EncodeStream& operator << (EncodeStream& stream, + const std::string& data) +{ + stream.append_nocopy(data.c_str(), data.size()); + return stream; +} + +static inline EncodeStream& operator << (EncodeStream& stream, + const std::pair& data) +{ + stream.append_nocopy(data.first, data.second); + return stream; +} + +static inline EncodeStream& operator << (EncodeStream& stream, + int64_t intv) +{ + stream.append_copy(std::to_string(intv)); + return stream; +} + +#endif + diff --git a/src/util/LRUCache.h b/src/util/LRUCache.h new file mode 100644 index 0000000..730f31e --- /dev/null +++ b/src/util/LRUCache.h @@ -0,0 +1,258 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _LRUCACHE_H_ +#define _LRUCACHE_H_ + +#include +#include "list.h" +#include "rbtree.h" + +/** + * @file LRUCache.h + * @brief Template LRU Cache + */ + +// RAII: NO. Release ref by LRUCache::release +// Thread safety: NO. +// DONOT change value by handler, use Cache::put instead +template +class LRUHandle +{ +public: + VALUE value; + +private: + LRUHandle(const KEY& k, const VALUE& v) : + value(v), key(k) + { + } + + KEY key; + struct list_head list; + struct rb_node rb; + bool in_cache; + int ref; + + template friend class LRUCache; +}; + +// RAII: NO. Release ref by LRUCache::release +// Define ValueDeleter(VALUE& v) for value deleter +// Thread safety: NO +// Make sure KEY operator< usable +template +class LRUCache +{ +protected: + typedef LRUHandle Handle; + +public: + LRUCache() + { + INIT_LIST_HEAD(&this->not_use); + INIT_LIST_HEAD(&this->in_use); + this->cache_map.rb_node = NULL; + this->max_size = 0; + this->size = 0; + } + + ~LRUCache() + { + struct list_head *pos, *tmp; + Handle *e; + + // Error if caller has an unreleased handle + assert(list_empty(&this->in_use)); + list_for_each_safe(pos, tmp, &this->not_use) + { + e = list_entry(pos, Handle, list); + assert(e->in_cache); + e->in_cache = false; + assert(e->ref == 1);// Invariant for not_use_ list. + this->unref(e); + } + } + + // default max_size=0 means no-limit cache + // max_size means max cache number of key-value pairs + void set_max_size(size_t max_size) + { + this->max_size = max_size; + } + + // Remove all cache that are not actively in use. + void prune() + { + struct list_head *pos, *tmp; + Handle *e; + + list_for_each_safe(pos, tmp, &this->not_use) + { + e = list_entry(pos, Handle, list); + assert(e->ref == 1); + rb_erase(&e->rb, &this->cache_map); + this->erase_node(e); + } + } + + // release handle by get/put + void release(const Handle *handle) + { + this->unref(const_cast(handle)); + } + + // get handler + // Need call release when handle no longer needed + const Handle *get(const KEY& key) + { + struct rb_node *p = this->cache_map.rb_node; + Handle *bound = NULL; + Handle *e; + + while (p) + { + e = rb_entry(p, Handle, rb); + if (!(e->key < key)) + { + bound = e; + p = p->rb_left; + } + else + p = p->rb_right; + } + + if (bound && !(key < bound->key)) + { + this->ref(bound); + return bound; + } + + return NULL; + } + + // put copy + // Need call release when handle no longer needed + const Handle *put(const KEY& key, VALUE value) + { + struct rb_node **p = &this->cache_map.rb_node; + struct rb_node *parent = NULL; + Handle *bound = NULL; + Handle *e; + + while (*p) + { + parent = *p; + e = rb_entry(*p, Handle, rb); + if (!(e->key < key)) + { + bound = e; + p = &(*p)->rb_left; + } + else + p = &(*p)->rb_right; + } + + e = new Handle(key, value); + e->in_cache = true; + e->ref = 2; + list_add_tail(&e->list, &this->in_use); + this->size++; + + if (bound && !(key < bound->key)) + { + rb_replace_node(&bound->rb, &e->rb, &this->cache_map); + this->erase_node(bound); + } + else + { + rb_link_node(&e->rb, parent, p); + rb_insert_color(&e->rb, &this->cache_map); + } + + if (this->max_size > 0) + { + while (this->size > this->max_size && !list_empty(&this->not_use)) + { + Handle *tmp = list_entry(this->not_use.next, Handle, list); + assert(tmp->ref == 1); + rb_erase(&tmp->rb, &this->cache_map); + this->erase_node(tmp); + } + } + + return e; + } + + // delete from cache, deleter delay called when all inuse-handle release. + void del(const KEY& key) + { + Handle *e = const_cast(this->get(key)); + + if (e) + { + this->unref(e); + rb_erase(&e->rb, &this->cache_map); + this->erase_node(e); + } + } + +private: + void ref(Handle *e) + { + if (e->in_cache && e->ref == 1) + list_move_tail(&e->list, &this->in_use); + + e->ref++; + } + + void unref(Handle *e) + { + assert(e->ref > 0); + if (--e->ref == 0) + { + assert(!e->in_cache); + this->value_deleter(e->value); + delete e; + } + else if (e->in_cache && e->ref == 1) + list_move_tail(&e->list, &this->not_use); + } + + void erase_node(Handle *e) + { + assert(e->in_cache); + list_del(&e->list); + e->in_cache = false; + this->size--; + this->unref(e); + } + + size_t max_size; + size_t size; + + struct list_head not_use; + struct list_head in_use; + struct rb_root cache_map; + + ValueDeleter value_deleter; +}; + +#endif + diff --git a/src/util/StringUtil.cc b/src/util/StringUtil.cc new file mode 100644 index 0000000..fe69bce --- /dev/null +++ b/src/util/StringUtil.cc @@ -0,0 +1,230 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include "StringUtil.h" + +static int __htoi(unsigned char *s) +{ + int value; + int c; + + c = s[0]; + if (isupper(c)) + c = tolower(c); + + value = (c >= '0' && c <= '9' ? c - '0' : c - 'a' + 10) * 16; + + c = s[1]; + if (isupper(c)) + c = tolower(c); + + value += (c >= '0' && c <= '9' ? c - '0' : c - 'a' + 10); + return value; +} + +static inline char __itoh(int n) +{ + if (n > 9) + return n - 10 + 'A'; + + return n + '0'; +} + +static size_t __url_decode(char *str) +{ + char *dest = str; + char *data = str; + + while (*data) + { + if (*data == '%' && isxdigit(data[1]) && isxdigit(data[2])) + { + *dest = __htoi((unsigned char *)data + 1); + data += 2; + } + else if (*data == '+') + *dest = ' '; + else + *dest = *data; + + data++; + dest++; + } + + *dest = '\0'; + return dest - str; +} + +void StringUtil::url_decode(std::string& str) +{ + if (str.empty()) + return; + + size_t sz = __url_decode(const_cast(str.c_str())); + + str.resize(sz); +} + +std::string StringUtil::url_encode(const std::string& str) +{ + const char *cur = str.c_str(); + const char *end = cur + str.size(); + std::string res; + + while (cur < end) + { + if (*cur == ' ') + res += '+'; + else if (isalnum(*cur) || *cur == '-' || *cur == '_' || *cur == '.' || + *cur == '!' || *cur == '~' || *cur == '*' || *cur == '\'' || + *cur == '(' || *cur == ')' || *cur == ':' || *cur == '/' || + *cur == '@' || *cur == '?' || *cur == '#' || *cur == '&') + res += *cur; + else + { + res += '%'; + res += __itoh(((const unsigned char)(*cur)) >> 4); + res += __itoh(((const unsigned char)(*cur)) % 16); + } + + cur++; + } + + return res; +} + +std::string StringUtil::url_encode_component(const std::string& str) +{ + const char *cur = str.c_str(); + const char *end = cur + str.size(); + std::string res; + + while (cur < end) + { + if (*cur == ' ') + res += '+'; + else if (isalnum(*cur) || *cur == '-' || *cur == '_' || *cur == '.' || + *cur == '!' || *cur == '~' || *cur == '*' || *cur == '\'' || + *cur == '(' || *cur == ')') + res += *cur; + else + { + res += '%'; + res += __itoh(((const unsigned char)(*cur)) >> 4); + res += __itoh(((const unsigned char)(*cur)) % 16); + } + + cur++; + } + + return res; +} + +std::vector StringUtil::split(const std::string& str, char sep) +{ + std::string::const_iterator cur = str.begin(); + std::string::const_iterator end = str.end(); + std::string::const_iterator next = find(cur, end, sep); + std::vector res; + + while (next != end) + { + res.emplace_back(cur, next); + cur = next + 1; + next = std::find(cur, end, sep); + } + + res.emplace_back(cur, next); + return res; +} + +std::vector StringUtil::split_filter_empty(const std::string& str, + char sep) +{ + std::vector res; + std::string::const_iterator cur = str.begin(); + std::string::const_iterator end = str.end(); + std::string::const_iterator next = find(cur, end, sep); + + while (next != end) + { + if (cur < next) + res.emplace_back(cur, next); + + cur = next + 1; + next = find(cur, end, sep); + } + + if (cur < next) + res.emplace_back(cur, next); + + return res; +} + +std::string StringUtil::strip(const std::string& str) +{ + std::string res; + + if (!str.empty()) + { + const char *cur = str.c_str(); + const char *end = cur + str.size(); + + while (cur < end) + { + if (!isspace(*cur)) + break; + + cur++; + } + + while (end > cur) + { + if (!isspace(*(end - 1))) + break; + + end--; + } + + if (end > cur) + res.assign(cur, end - cur); + } + + return res; +} + +bool StringUtil::start_with(const std::string& str, const std::string& prefix) +{ + size_t prefix_len = prefix.size(); + + if (str.size() < prefix_len) + return false; + + for (size_t i = 0; i < prefix_len; i++) + { + if (str[i] != prefix[i]) + return false; + } + + return true; +} + diff --git a/src/util/StringUtil.h b/src/util/StringUtil.h new file mode 100644 index 0000000..e40e87a --- /dev/null +++ b/src/util/StringUtil.h @@ -0,0 +1,46 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) +*/ + +#ifndef _STRINGUTIL_H_ +#define _STRINGUTIL_H_ + +#include +#include + +/** + * @file StringUtil.h + * @brief String toolbox + */ + +// static class +class StringUtil +{ +public: + static void url_decode(std::string& str); + static std::string url_encode(const std::string& str); + static std::string url_encode_component(const std::string& str); + static std::vector split(const std::string& str, char sep); + static std::string strip(const std::string& str); + static bool start_with(const std::string& str, const std::string& prefix); + + //this will filter any empty result, so the result vector has no empty string + static std::vector split_filter_empty(const std::string& str, char sep); +}; + +#endif + diff --git a/src/util/URIParser.cc b/src/util/URIParser.cc new file mode 100644 index 0000000..cd1ed0b --- /dev/null +++ b/src/util/URIParser.cc @@ -0,0 +1,512 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) + Wang Zhulei (wangzhulei@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include "StringUtil.h" +#include "URIParser.h" + +enum +{ + URI_SCHEME, + URI_USERINFO, + URI_HOST, + URI_PORT, + URI_QUERY, + URI_FRAGMENT, + URI_PATH, + URI_PART_ELEMENTS, +}; + +//scheme://[userinfo@]host[:port][/path][?query][#fragment] +//0-6 (scheme, userinfo, host, port, path, query, fragment) +static constexpr unsigned char valid_char[4][256] = { + { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }, + { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }, + { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }, + { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }, +}; + +static unsigned char authority_map[256] = { + URI_PART_ELEMENTS, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, URI_FRAGMENT, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, URI_PATH, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, URI_HOST, 0, 0, 0, 0, URI_QUERY, + URI_USERINFO, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; + +ParsedURI::ParsedURI(ParsedURI&& uri) +{ + scheme = uri.scheme; + userinfo = uri.userinfo; + host = uri.host; + port = uri.port; + path = uri.path; + query = uri.query; + fragment = uri.fragment; + state = uri.state; + error = uri.error; + uri.init(); +} + +ParsedURI& ParsedURI::operator= (ParsedURI&& uri) +{ + if (this != &uri) + { + deinit(); + scheme = uri.scheme; + userinfo = uri.userinfo; + host = uri.host; + port = uri.port; + path = uri.path; + query = uri.query; + fragment = uri.fragment; + state = uri.state; + error = uri.error; + uri.init(); + } + + return *this; +} + +void ParsedURI::copy(const ParsedURI& uri) +{ + init(); + state = uri.state; + error = uri.error; + if (state == URI_STATE_SUCCESS) + { + bool succ = false; + + do + { + if (uri.scheme) + { + scheme = strdup(uri.scheme); + if (!scheme) + break; + } + + if (uri.userinfo) + { + userinfo = strdup(uri.userinfo); + if (!userinfo) + break; + } + + if (uri.host) + { + host = strdup(uri.host); + if (!host) + break; + } + + if (uri.port) + { + port = strdup(uri.port); + if (!port) + break; + } + + if (uri.path) + { + path = strdup(uri.path); + if (!path) + break; + } + + if (uri.query) + { + query = strdup(uri.query); + if (!query) + break; + } + + if (uri.fragment) + { + fragment = strdup(uri.fragment); + if (!fragment) + break; + } + + succ = true; + } while (0); + + if (!succ) + { + deinit(); + init(); + state = URI_STATE_ERROR; + error = errno; + } + } +} + +int URIParser::parse(const char *str, ParsedURI& uri) +{ + uri.state = URI_STATE_INVALID; + + int start_idx[URI_PART_ELEMENTS] = {0}; + int end_idx[URI_PART_ELEMENTS] = {0}; + int pre_state = URI_SCHEME; + bool in_ipv6 = false; + int i; + + for (i = 0; str[i]; i++) + { + if (str[i] == ':') + { + end_idx[URI_SCHEME] = i++; + break; + } + } + + if (end_idx[URI_SCHEME] == 0) + return -1; + + if (str[i] == '/' && str[i + 1] == '/') + { + pre_state = URI_HOST; + i += 2; + if (str[i] == '[') + in_ipv6 = true; + else + start_idx[URI_USERINFO] = i; + + start_idx[URI_HOST] = i; + } + else + { + pre_state = URI_PATH; + start_idx[URI_PATH] = i; + } + + bool skip_path = false; + if (start_idx[URI_PATH] == 0) + { + for (; ; i++) + { + switch (authority_map[(unsigned char)str[i]]) + { + case 0: + continue; + + case URI_USERINFO: + if (str[i + 1] == '[') + in_ipv6 = true; + + end_idx[URI_USERINFO] = i; + start_idx[URI_HOST] = i + 1; + pre_state = URI_HOST; + continue; + + case URI_HOST: + if (str[i - 1] == ']') + in_ipv6 = false; + + if (!in_ipv6) + { + end_idx[URI_HOST] = i; + start_idx[URI_PORT] = i + 1; + pre_state = URI_PORT; + } + continue; + + case URI_QUERY: + end_idx[pre_state] = i; + start_idx[URI_QUERY] = i + 1; + pre_state = URI_QUERY; + skip_path = true; + continue; + + case URI_FRAGMENT: + end_idx[pre_state] = i; + start_idx[URI_FRAGMENT] = i + 1; + end_idx[URI_FRAGMENT] = i + strlen(str + i); + pre_state = URI_PART_ELEMENTS; + skip_path = true; + break; + + case URI_PATH: + if (skip_path) + continue; + + start_idx[URI_PATH] = i; + break; + + case URI_PART_ELEMENTS: + skip_path = true; + break; + } + + break; + } + } + + if (pre_state != URI_PART_ELEMENTS) + end_idx[pre_state] = i; + + if (!skip_path) + { + pre_state = URI_PATH; + for (; str[i]; i++) + { + if (str[i] == '?') + { + end_idx[URI_PATH] = i; + start_idx[URI_QUERY] = i + 1; + pre_state = URI_QUERY; + while (str[i + 1]) + { + if (str[++i] == '#') + break; + } + } + + if (str[i] == '#') + { + end_idx[pre_state] = i; + start_idx[URI_FRAGMENT] = i + 1; + pre_state = URI_FRAGMENT; + break; + } + } + + end_idx[pre_state] = i + strlen(str + i); + } + + for (int i = 0; i < URI_QUERY; i++) + { + for (int j = start_idx[i]; j < end_idx[i]; j++) + { + if (!valid_char[i][(unsigned char)str[j]]) + return -1;//invalid char + } + } + + char **dst[URI_PART_ELEMENTS] = {&uri.scheme, &uri.userinfo, &uri.host, &uri.port, + &uri.query, &uri.fragment, &uri.path}; + + for (int i = 0; i < URI_PART_ELEMENTS; i++) + { + if (end_idx[i] > start_idx[i]) + { + size_t len = end_idx[i] - start_idx[i]; + + *dst[i] = (char *)realloc(*dst[i], len + 1); + if (*dst[i] == NULL) + { + uri.state = URI_STATE_ERROR; + uri.error = errno; + return -1; + } + + if (i == URI_HOST && str[start_idx[i]] == '[' && + str[end_idx[i] - 1] == ']') + { + len -= 2; + memcpy(*dst[i], str + start_idx[i] + 1, len); + } + else + memcpy(*dst[i], str + start_idx[i], len); + + (*dst[i])[len] = '\0'; + } + else + { + free(*dst[i]); + *dst[i] = NULL; + } + } + + uri.state = URI_STATE_SUCCESS; + return 0; +} + +std::map> +URIParser::split_query_strict(const std::string &query) +{ + std::map> res; + + if (query.empty()) + return res; + + std::vector arr = StringUtil::split(query, '&'); + + if (arr.empty()) + return res; + + for (const auto& ele : arr) + { + if (ele.empty()) + continue; + + std::vector kv = StringUtil::split(ele, '='); + size_t kv_size = kv.size(); + std::string& key = kv[0]; + + if (key.empty()) + continue; + + if (kv_size == 1) + { + res[key].emplace_back(); + continue; + } + + std::string& val = kv[1]; + + if (val.empty()) + res[key].emplace_back(); + else + res[key].emplace_back(std::move(val)); + } + + return res; +} + +std::map URIParser::split_query(const std::string &query) +{ + std::map res; + + if (query.empty()) + return res; + + std::vector arr = StringUtil::split(query, '&'); + + if (arr.empty()) + return res; + + for (const auto& ele : arr) + { + if (ele.empty()) + continue; + + std::vector kv = StringUtil::split(ele, '='); + size_t kv_size = kv.size(); + std::string& key = kv[0]; + + if (key.empty() || res.count(key) > 0) + continue; + + if (kv_size == 1) + { + res.emplace(std::move(key), ""); + continue; + } + + std::string& val = kv[1]; + + if (val.empty()) + res.emplace(std::move(key), ""); + else + res.emplace(std::move(key), std::move(val)); + } + + return res; +} + +std::vector URIParser::split_path(const std::string &path) +{ + return StringUtil::split_filter_empty(path, '/'); +} + diff --git a/src/util/URIParser.h b/src/util/URIParser.h new file mode 100644 index 0000000..74f2529 --- /dev/null +++ b/src/util/URIParser.h @@ -0,0 +1,123 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) + Wang Zhulei (wangzhulei@sogou-inc.com) +*/ + +#ifndef _URIPARSER_H_ +#define _URIPARSER_H_ + +#include +#include +#include +#include + +#define URI_STATE_INIT 0 +#define URI_STATE_SUCCESS 1 +#define URI_STATE_INVALID 2 +#define URI_STATE_ERROR 3 + +/** + * @file URIParser.h + * @brief URI parser + */ + +// RAII: YES +class ParsedURI +{ +public: + char *scheme; + char *userinfo; + char *host; + char *port; + char *path; + char *query; + char *fragment; + int state; + int error; + + ParsedURI() { init(); } + virtual ~ParsedURI() { deinit(); } + + //copy constructor + ParsedURI(const ParsedURI& uri) { copy(uri); } + //copy operator + ParsedURI& operator= (const ParsedURI& uri) + { + if (this != &uri) + { + deinit(); + copy(uri); + } + + return *this; + } + + //move constructor + ParsedURI(ParsedURI&& uri); + //move operator + ParsedURI& operator= (ParsedURI&& uri); + +private: + void init() + { + scheme = NULL; + userinfo = NULL; + host = NULL; + port = NULL; + path = NULL; + query = NULL; + fragment = NULL; + state = URI_STATE_INIT; + error = 0; + } + + void deinit() + { + free(scheme); + free(userinfo); + free(host); + free(port); + free(path); + free(query); + free(fragment); + } + + void copy(const ParsedURI& uri); +}; + +// static class +class URIParser +{ +public: + // return 0 mean succ, -1 mean fail + static int parse(const char *str, ParsedURI& uri); + static int parse(const std::string& str, ParsedURI& uri) + { + return parse(str.c_str(), uri); + } + + static std::map> + split_query_strict(const std::string &query); + + static std::map + split_query(const std::string &query); + + static std::vector split_path(const std::string &path); +}; + +#endif + diff --git a/src/util/crc32c.c b/src/util/crc32c.c new file mode 100644 index 0000000..e0494c7 --- /dev/null +++ b/src/util/crc32c.c @@ -0,0 +1,519 @@ +/* Copied from http://stackoverflow.com/a/17646775/1821055 + * with the following modifications: + * * remove test code + * * global hw/sw initialization to be called once per process + * * HW support is determined by configure's WITH_CRC32C_HW + * * Windows porting (no hardware support on Windows yet) + * + * FIXME: + * * Hardware support on Windows (MSVC assembler) + * * Hardware support on ARM + */ + +/* crc32c.c -- compute CRC-32C using the Intel crc32 instruction + * Copyright (C) 2013 Mark Adler + * Version 1.1 1 Aug 2013 Mark Adler + */ + +/* + This software is provided 'as-is', without any express or implied + warranty. In no event will the author be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + Mark Adler + madler@alumni.caltech.edu + */ + +/* Use hardware CRC instruction on Intel SSE 4.2 processors. This computes a + CRC-32C, *not* the CRC-32 used by Ethernet and zip, gzip, etc. A software + version is provided as a fall-back, as well as for speed comparisons. */ + +/* Version history: + 1.0 10 Feb 2013 First version + 1.1 1 Aug 2013 Correct comments on why three crc instructions in parallel + */ + + +#include +#include +#include +#include + +/** + * Provides portable endian-swapping macros/functions. + * + * be64toh() + * htobe64() + * be32toh() + * htobe32() + * be16toh() + * htobe16() + * le64toh() + */ + +#ifdef __FreeBSD__ + #include +#elif defined __GLIBC__ + #include + #ifndef be64toh + /* Support older glibc (<2.9) which lack be64toh */ + #include + #if __BYTE_ORDER == __BIG_ENDIAN + #define be16toh(x) (x) + #define be32toh(x) (x) + #define be64toh(x) (x) + #define le64toh(x) __bswap_64 (x) + #define le32toh(x) __bswap_32 (x) + #else + #define be16toh(x) __bswap_16 (x) + #define be32toh(x) __bswap_32 (x) + #define be64toh(x) __bswap_64 (x) + #define le64toh(x) (x) + #define le32toh(x) (x) + #endif + #endif + +#elif defined __CYGWIN__ + #include +#elif defined __BSD__ + #include +#elif defined __sun + #include + #include +#define __LITTLE_ENDIAN 1234 +#define __BIG_ENDIAN 4321 +#ifdef _BIG_ENDIAN +#define __BYTE_ORDER __BIG_ENDIAN +#define be64toh(x) (x) +#define be32toh(x) (x) +#define be16toh(x) (x) +#define le16toh(x) ((uint16_t)BSWAP_16(x)) +#define le32toh(x) BSWAP_32(x) +#define le64toh(x) BSWAP_64(x) +# else +#define __BYTE_ORDER __LITTLE_ENDIAN +#define be64toh(x) BSWAP_64(x) +#define be32toh(x) ntohl(x) +#define be16toh(x) ntohs(x) +#define le16toh(x) (x) +#define le32toh(x) (x) +#define le64toh(x) (x) +#define htole16(x) (x) +#define htole64(x) (x) +#endif /* __sun */ + +#elif defined __APPLE__ + #include + #include +#if __DARWIN_BYTE_ORDER == __DARWIN_BIG_ENDIAN +#define be64toh(x) (x) +#define be32toh(x) (x) +#define be16toh(x) (x) +#define le16toh(x) OSSwapInt16(x) +#define le32toh(x) OSSwapInt32(x) +#define le64toh(x) OSSwapInt64(x) +#else +#define be64toh(x) OSSwapInt64(x) +#define be32toh(x) OSSwapInt32(x) +#define be16toh(x) OSSwapInt16(x) +#define le16toh(x) (x) +#define le32toh(x) (x) +#define le64toh(x) (x) +#endif + +#elif defined(_WIN32) +#include + +#define be64toh(x) _byteswap_uint64(x) +#define be32toh(x) _byteswap_ulong(x) +#define be16toh(x) _byteswap_ushort(x) +#define le16toh(x) (x) +#define le32toh(x) (x) +#define le64toh(x) (x) + +#elif defined _AIX /* AIX is always big endian */ +#define be64toh(x) (x) +#define be32toh(x) (x) +#define be16toh(x) (x) +#define le32toh(x) \ + ((((x) & 0xff) << 24) | \ + (((x) & 0xff00) << 8) | \ + (((x) & 0xff0000) >> 8) | \ + (((x) & 0xff000000) >> 24)) +#define le64toh(x) \ + ((((x) & 0x00000000000000ffL) << 56) | \ + (((x) & 0x000000000000ff00L) << 40) | \ + (((x) & 0x0000000000ff0000L) << 24) | \ + (((x) & 0x00000000ff000000L) << 8) | \ + (((x) & 0x000000ff00000000L) >> 8) | \ + (((x) & 0x0000ff0000000000L) >> 24) | \ + (((x) & 0x00ff000000000000L) >> 40) | \ + (((x) & 0xff00000000000000L) >> 56)) +#else + #include +#endif + + + +/* + * On Solaris, be64toh is a function, not a macro, so there's no need to error + * if it's not defined. + */ +#if !defined(__sun) && !defined(be64toh) +#error Missing definition for be64toh +#endif + +#ifndef be32toh +#define be32toh(x) ntohl(x) +#endif + +#ifndef be16toh +#define be16toh(x) ntohs(x) +#endif + +#ifndef htobe64 +#define htobe64(x) be64toh(x) +#endif +#ifndef htobe32 +#define htobe32(x) be32toh(x) +#endif +#ifndef htobe16 +#define htobe16(x) be16toh(x) +#endif + +#ifndef htole32 +#define htole32(x) le32toh(x) +#endif + + +/* CRC-32C (iSCSI) polynomial in reversed bit order. */ +#define POLY 0x82f63b78 + +/* Table for a quadword-at-a-time software crc. */ +static uint32_t crc32c_table[8][256]; + +/* Construct table for software CRC-32C calculation. */ +static void crc32c_init_sw(void) +{ + uint32_t n, crc, k; + + for (n = 0; n < 256; n++) { + crc = n; + crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; + crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; + crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; + crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; + crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; + crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; + crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; + crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; + crc32c_table[0][n] = crc; + } + for (n = 0; n < 256; n++) { + crc = crc32c_table[0][n]; + for (k = 1; k < 8; k++) { + crc = crc32c_table[0][crc & 0xff] ^ (crc >> 8); + crc32c_table[k][n] = crc; + } + } +} + +/* Table-driven software version as a fall-back. This is about 15 times slower + than using the hardware instructions. This assumes little-endian integers, + as is the case on Intel processors that the assembler code here is for. */ +static uint32_t crc32c_sw(uint32_t crci, const void *buf, size_t len) +{ + const unsigned char *next = buf; + uint64_t crc; + + crc = crci ^ 0xffffffff; + while (len && ((uintptr_t)next & 7) != 0) { + crc = crc32c_table[0][(crc ^ *next++) & 0xff] ^ (crc >> 8); + len--; + } + while (len >= 8) { + /* Alignment-safe */ + uint64_t ncopy; + memcpy(&ncopy, next, sizeof(ncopy)); + crc ^= le64toh(ncopy); + crc = crc32c_table[7][crc & 0xff] ^ + crc32c_table[6][(crc >> 8) & 0xff] ^ + crc32c_table[5][(crc >> 16) & 0xff] ^ + crc32c_table[4][(crc >> 24) & 0xff] ^ + crc32c_table[3][(crc >> 32) & 0xff] ^ + crc32c_table[2][(crc >> 40) & 0xff] ^ + crc32c_table[1][(crc >> 48) & 0xff] ^ + crc32c_table[0][crc >> 56]; + next += 8; + len -= 8; + } + while (len) { + crc = crc32c_table[0][(crc ^ *next++) & 0xff] ^ (crc >> 8); + len--; + } + return (uint32_t)crc ^ 0xffffffff; +} + + +#if WITH_CRC32C_HW +static int sse42; /* Cached SSE42 support */ + +/* Multiply a matrix times a vector over the Galois field of two elements, + GF(2). Each element is a bit in an unsigned integer. mat must have at + least as many entries as the power of two for most significant one bit in + vec. */ +static RD_INLINE uint32_t gf2_matrix_times(uint32_t *mat, uint32_t vec) +{ + uint32_t sum; + + sum = 0; + while (vec) { + if (vec & 1) + sum ^= *mat; + vec >>= 1; + mat++; + } + return sum; +} + +/* Multiply a matrix by itself over GF(2). Both mat and square must have 32 + rows. */ +static RD_INLINE void gf2_matrix_square(uint32_t *square, uint32_t *mat) +{ + int n; + + for (n = 0; n < 32; n++) + square[n] = gf2_matrix_times(mat, mat[n]); +} + +/* Construct an operator to apply len zeros to a crc. len must be a power of + two. If len is not a power of two, then the result is the same as for the + largest power of two less than len. The result for len == 0 is the same as + for len == 1. A version of this routine could be easily written for any + len, but that is not needed for this application. */ +static void crc32c_zeros_op(uint32_t *even, size_t len) +{ + int n; + uint32_t row; + uint32_t odd[32]; /* odd-power-of-two zeros operator */ + + /* put operator for one zero bit in odd */ + odd[0] = POLY; /* CRC-32C polynomial */ + row = 1; + for (n = 1; n < 32; n++) { + odd[n] = row; + row <<= 1; + } + + /* put operator for two zero bits in even */ + gf2_matrix_square(even, odd); + + /* put operator for four zero bits in odd */ + gf2_matrix_square(odd, even); + + /* first square will put the operator for one zero byte (eight zero bits), + in even -- next square puts operator for two zero bytes in odd, and so + on, until len has been rotated down to zero */ + do { + gf2_matrix_square(even, odd); + len >>= 1; + if (len == 0) + return; + gf2_matrix_square(odd, even); + len >>= 1; + } while (len); + + /* answer ended up in odd -- copy to even */ + for (n = 0; n < 32; n++) + even[n] = odd[n]; +} + +/* Take a length and build four lookup tables for applying the zeros operator + for that length, byte-by-byte on the operand. */ +static void crc32c_zeros(uint32_t zeros[][256], size_t len) +{ + uint32_t n; + uint32_t op[32]; + + crc32c_zeros_op(op, len); + for (n = 0; n < 256; n++) { + zeros[0][n] = gf2_matrix_times(op, n); + zeros[1][n] = gf2_matrix_times(op, n << 8); + zeros[2][n] = gf2_matrix_times(op, n << 16); + zeros[3][n] = gf2_matrix_times(op, n << 24); + } +} + +/* Apply the zeros operator table to crc. */ +static RD_INLINE uint32_t crc32c_shift(uint32_t zeros[][256], uint32_t crc) +{ + return zeros[0][crc & 0xff] ^ zeros[1][(crc >> 8) & 0xff] ^ + zeros[2][(crc >> 16) & 0xff] ^ zeros[3][crc >> 24]; +} + +/* Block sizes for three-way parallel crc computation. LONG and SHORT must + both be powers of two. The associated string constants must be set + accordingly, for use in constructing the assembler instructions. */ +#define LONG 8192 +#define LONGx1 "8192" +#define LONGx2 "16384" +#define SHORT 256 +#define SHORTx1 "256" +#define SHORTx2 "512" + +/* Tables for hardware crc that shift a crc by LONG and SHORT zeros. */ +static uint32_t crc32c_long[4][256]; +static uint32_t crc32c_short[4][256]; + +/* Initialize tables for shifting crcs. */ +static void crc32c_init_hw(void) +{ + crc32c_zeros(crc32c_long, LONG); + crc32c_zeros(crc32c_short, SHORT); +} + +/* Compute CRC-32C using the Intel hardware instruction. */ +static uint32_t crc32c_hw(uint32_t crc, const void *buf, size_t len) +{ + const unsigned char *next = buf; + const unsigned char *end; + uint64_t crc0, crc1, crc2; /* need to be 64 bits for crc32q */ + + /* pre-process the crc */ + crc0 = crc ^ 0xffffffff; + + /* compute the crc for up to seven leading bytes to bring the data pointer + to an eight-byte boundary */ + while (len && ((uintptr_t)next & 7) != 0) { + __asm__("crc32b\t" "(%1), %0" + : "=r"(crc0) + : "r"(next), "0"(crc0)); + next++; + len--; + } + + /* compute the crc on sets of LONG*3 bytes, executing three independent crc + instructions, each on LONG bytes -- this is optimized for the Nehalem, + Westmere, Sandy Bridge, and Ivy Bridge architectures, which have a + throughput of one crc per cycle, but a latency of three cycles */ + while (len >= LONG*3) { + crc1 = 0; + crc2 = 0; + end = next + LONG; + do { + __asm__("crc32q\t" "(%3), %0\n\t" + "crc32q\t" LONGx1 "(%3), %1\n\t" + "crc32q\t" LONGx2 "(%3), %2" + : "=r"(crc0), "=r"(crc1), "=r"(crc2) + : "r"(next), "0"(crc0), "1"(crc1), "2"(crc2)); + next += 8; + } while (next < end); + crc0 = crc32c_shift(crc32c_long, crc0) ^ crc1; + crc0 = crc32c_shift(crc32c_long, crc0) ^ crc2; + next += LONG*2; + len -= LONG*3; + } + + /* do the same thing, but now on SHORT*3 blocks for the remaining data less + than a LONG*3 block */ + while (len >= SHORT*3) { + crc1 = 0; + crc2 = 0; + end = next + SHORT; + do { + __asm__("crc32q\t" "(%3), %0\n\t" + "crc32q\t" SHORTx1 "(%3), %1\n\t" + "crc32q\t" SHORTx2 "(%3), %2" + : "=r"(crc0), "=r"(crc1), "=r"(crc2) + : "r"(next), "0"(crc0), "1"(crc1), "2"(crc2)); + next += 8; + } while (next < end); + crc0 = crc32c_shift(crc32c_short, crc0) ^ crc1; + crc0 = crc32c_shift(crc32c_short, crc0) ^ crc2; + next += SHORT*2; + len -= SHORT*3; + } + + /* compute the crc on the remaining eight-byte units less than a SHORT*3 + block */ + end = next + (len - (len & 7)); + while (next < end) { + __asm__("crc32q\t" "(%1), %0" + : "=r"(crc0) + : "r"(next), "0"(crc0)); + next += 8; + } + len &= 7; + + /* compute the crc for up to seven trailing bytes */ + while (len) { + __asm__("crc32b\t" "(%1), %0" + : "=r"(crc0) + : "r"(next), "0"(crc0)); + next++; + len--; + } + + /* return a post-processed crc */ + return (uint32_t)crc0 ^ 0xffffffff; +} + +/* Check for SSE 4.2. SSE 4.2 was first supported in Nehalem processors + introduced in November, 2008. This does not check for the existence of the + cpuid instruction itself, which was introduced on the 486SL in 1992, so this + will fail on earlier x86 processors. cpuid works on all Pentium and later + processors. */ +#define SSE42(have) \ + do { \ + uint32_t eax, ecx; \ + eax = 1; \ + __asm__("cpuid" \ + : "=c"(ecx) \ + : "a"(eax) \ + : "%ebx", "%edx"); \ + (have) = (ecx >> 20) & 1; \ + } while (0) + +#endif /* WITH_CRC32C_HW */ + +/* Compute a CRC-32C. If the crc32 instruction is available, use the hardware + version. Otherwise, use the software version. */ +uint32_t crc32c(uint32_t crc, const void *buf, size_t len) +{ +#if WITH_CRC32C_HW + if (sse42) + return crc32c_hw(crc, buf, len); + else +#endif + return crc32c_sw(crc, buf, len); +} + + + + + + +/** + * @brief Populate shift tables once + */ +void crc32c_global_init (void) { +#if WITH_CRC32C_HW + SSE42(sse42); + if (sse42) + crc32c_init_hw(); + else +#endif + crc32c_init_sw(); +} diff --git a/src/util/crc32c.h b/src/util/crc32c.h new file mode 100644 index 0000000..fd45c76 --- /dev/null +++ b/src/util/crc32c.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2017 Magnus Edenhill + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _CRC32C_H_ +#define _CRC32C_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" +{ +#endif + +uint32_t crc32c(uint32_t crc, const void *buf, size_t len); + +void crc32c_global_init (void); + +#ifdef __cplusplus +} +#endif + +#endif /* _CRC32C_H_ */ diff --git a/src/util/json_parser.c b/src/util/json_parser.c new file mode 100644 index 0000000..647b3ec --- /dev/null +++ b/src/util/json_parser.c @@ -0,0 +1,1399 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include +#include "list.h" +#include "rbtree.h" +#include "json_parser.h" + +#define JSON_DEPTH_LIMIT 1024 + +struct __json_object +{ + struct list_head head; + struct rb_root root; + int size; +}; + +struct __json_array +{ + struct list_head head; + int size; +}; + +struct __json_value +{ + union + { + char *string; + double number; + json_object_t object; + json_array_t array; + } value; + int type; +}; + +struct __json_member +{ + struct list_head list; + struct rb_node rb; + json_value_t value; + char name[1]; +}; + +struct __json_element +{ + struct list_head list; + json_value_t value; +}; + +typedef struct __json_member json_member_t; +typedef struct __json_element json_element_t; + +static int __json_string_length(const char *cursor) +{ + int len = 0; + + while (*cursor != '\"') + { + if (*cursor == '\\') + { + cursor++; + if (*cursor == '\0') + return -2; + } + else if ((unsigned char)*cursor < 0x20) + return -2; + + cursor++; + len++; + } + + return len; +} + +static int __parse_json_hex4(const char *cursor, const char **end, + unsigned int *code) +{ + int hex; + int i; + + *code = 0; + for (i = 0; i < 4; i++) + { + hex = *cursor; + if (hex >= '0' && hex <= '9') + hex = hex - '0'; + else if (hex >= 'A' && hex <= 'F') + hex = hex - 'A' + 10; + else if (hex >= 'a' && hex <= 'f') + hex = hex - 'a' + 10; + else + return -2; + + *code = (*code << 4) + hex; + cursor++; + } + + *end = cursor; + return 0; +} + +static int __parse_json_unicode(const char *cursor, const char **end, + char *utf8) +{ + unsigned int code; + unsigned int next; + int ret; + + ret = __parse_json_hex4(cursor, end, &code); + if (ret < 0) + return ret; + + if (code >= 0xdc00 && code <= 0xdfff) + return -2; + + if (code >= 0xd800 && code <= 0xdbff) + { + cursor = *end; + if (*cursor != '\\') + return -2; + + cursor++; + if (*cursor != 'u') + return -2; + + cursor++; + ret = __parse_json_hex4(cursor, end, &next); + if (ret < 0) + return ret; + + if (next < 0xdc00 || next > 0xdfff) + return -2; + + code = (((code & 0x3ff) << 10) | (next & 0x3ff)) + 0x10000; + } + + if (code <= 0x7f) + { + utf8[0] = code; + return 1; + } + else if (code <= 0x7ff) + { + utf8[0] = 0xc0 | (code >> 6); + utf8[1] = 0x80 | (code & 0x3f); + return 2; + } + else if (code <= 0xffff) + { + utf8[0] = 0xe0 | (code >> 12); + utf8[1] = 0x80 | ((code >> 6) & 0x3f); + utf8[2] = 0x80 | (code & 0x3f); + return 3; + } + else + { + utf8[0] = 0xf0 | (code >> 18); + utf8[1] = 0x80 | ((code >> 12) & 0x3f); + utf8[2] = 0x80 | ((code >> 6) & 0x3f); + utf8[3] = 0x80 | (code & 0x3f); + return 4; + } +} + +static int __parse_json_string(const char *cursor, const char **end, + char *str) +{ + int ret; + + while (*cursor != '\"') + { + if (*cursor == '\\') + { + cursor++; + switch (*cursor) + { + case '\"': + *str = '\"'; + break; + case '\\': + *str = '\\'; + break; + case '/': + *str = '/'; + break; + case 'b': + *str = '\b'; + break; + case 'f': + *str = '\f'; + break; + case 'n': + *str = '\n'; + break; + case 'r': + *str = '\r'; + break; + case 't': + *str = '\t'; + break; + case 'u': + cursor++; + ret = __parse_json_unicode(cursor, &cursor, str); + if (ret < 0) + return ret; + + str += ret; + continue; + + default: + return -2; + } + } + else + *str = *cursor; + + cursor++; + str++; + } + + *str = '\0'; + *end = cursor + 1; + return 0; +} + +static const double __power_of_10[309] = { + 1e0, 1e1, 1e2, 1e3, 1e4, + 1e5, 1e6, 1e7, 1e8, 1e9, + 1e10, 1e11, 1e12, 1e13, 1e14, + 1e15, 1e16, 1e17, 1e18, 1e19, + 1e20, 1e21, 1e22, 1e23, 1e24, + 1e25, 1e26, 1e27, 1e28, 1e29, + 1e30, 1e31, 1e32, 1e33, 1e34, + 1e35, 1e36, 1e37, 1e38, 1e39, + 1e40, 1e41, 1e42, 1e43, 1e44, + 1e45, 1e46, 1e47, 1e48, 1e49, + 1e50, 1e51, 1e52, 1e53, 1e54, + 1e55, 1e56, 1e57, 1e58, 1e59, + 1e60, 1e61, 1e62, 1e63, 1e64, + 1e65, 1e66, 1e67, 1e68, 1e69, + 1e70, 1e71, 1e72, 1e73, 1e74, + 1e75, 1e76, 1e77, 1e78, 1e79, + 1e80, 1e81, 1e82, 1e83, 1e84, + 1e85, 1e86, 1e87, 1e88, 1e89, + 1e90, 1e91, 1e92, 1e93, 1e94, + 1e95, 1e96, 1e97, 1e98, 1e99, + 1e100, 1e101, 1e102, 1e103, 1e104, + 1e105, 1e106, 1e107, 1e108, 1e109, + 1e110, 1e111, 1e112, 1e113, 1e114, + 1e115, 1e116, 1e117, 1e118, 1e119, + 1e120, 1e121, 1e122, 1e123, 1e124, + 1e125, 1e126, 1e127, 1e128, 1e129, + 1e130, 1e131, 1e132, 1e133, 1e134, + 1e135, 1e136, 1e137, 1e138, 1e139, + 1e140, 1e141, 1e142, 1e143, 1e144, + 1e145, 1e146, 1e147, 1e148, 1e149, + 1e150, 1e151, 1e152, 1e153, 1e154, + 1e155, 1e156, 1e157, 1e158, 1e159, + 1e160, 1e161, 1e162, 1e163, 1e164, + 1e165, 1e166, 1e167, 1e168, 1e169, + 1e170, 1e171, 1e172, 1e173, 1e174, + 1e175, 1e176, 1e177, 1e178, 1e179, + 1e180, 1e181, 1e182, 1e183, 1e184, + 1e185, 1e186, 1e187, 1e188, 1e189, + 1e190, 1e191, 1e192, 1e193, 1e194, + 1e195, 1e196, 1e197, 1e198, 1e199, + 1e200, 1e201, 1e202, 1e203, 1e204, + 1e205, 1e206, 1e207, 1e208, 1e209, + 1e210, 1e211, 1e212, 1e213, 1e214, + 1e215, 1e216, 1e217, 1e218, 1e219, + 1e220, 1e221, 1e222, 1e223, 1e224, + 1e225, 1e226, 1e227, 1e228, 1e229, + 1e230, 1e231, 1e232, 1e233, 1e234, + 1e235, 1e236, 1e237, 1e238, 1e239, + 1e240, 1e241, 1e242, 1e243, 1e244, + 1e245, 1e246, 1e247, 1e248, 1e249, + 1e250, 1e251, 1e252, 1e253, 1e254, + 1e255, 1e256, 1e257, 1e258, 1e259, + 1e260, 1e261, 1e262, 1e263, 1e264, + 1e265, 1e266, 1e267, 1e268, 1e269, + 1e270, 1e271, 1e272, 1e273, 1e274, + 1e275, 1e276, 1e277, 1e278, 1e279, + 1e280, 1e281, 1e282, 1e283, 1e284, + 1e285, 1e286, 1e287, 1e288, 1e289, + 1e290, 1e291, 1e292, 1e293, 1e294, + 1e295, 1e296, 1e297, 1e298, 1e299, + 1e300, 1e301, 1e302, 1e303, 1e304, + 1e305, 1e306, 1e307, 1e308 +}; + +static double __evaluate_json_number(const char *integer, + const char *fraction, + int exp) +{ + long long mant = 0; + int figures = 0; + double num; + int sign; + + sign = (*integer == '-'); + if (sign) + integer++; + + if (*integer != '0') + { + mant = *integer - '0'; + integer++; + figures++; + while (isdigit(*integer) && figures < 18) + { + mant *= 10; + mant += *integer - '0'; + integer++; + figures++; + } + + while (isdigit(*integer)) + { + exp++; + integer++; + } + } + else + { + while (*fraction == '0') + { + exp--; + fraction++; + } + } + + while (isdigit(*fraction) && figures < 18) + { + mant *= 10; + mant += *fraction - '0'; + exp--; + fraction++; + figures++; + } + + if (exp != 0 && figures != 0) + { + while (exp > 0 && figures < 18) + { + mant *= 10; + exp--; + figures++; + } + + while (exp < 0 && mant % 10 == 0) + { + mant /= 10; + exp++; + figures--; + } + } + + num = mant; + if (exp != 0 && figures != 0) + { + if (exp > 291) + num = INFINITY; + else if (exp > 0) + num *= __power_of_10[exp]; + else if (exp > -309) + num /= __power_of_10[-exp]; + else if (exp > -324 - figures) + { + num /= __power_of_10[-exp - 308]; + num /= __power_of_10[308]; + } + else + num = 0.0; + } + + return sign ? -num : num; +} + +static int __parse_json_number(const char *cursor, const char **end, + double *num) +{ + const char *integer = cursor; + const char *fraction = ""; + int exp = 0; + int sign; + + if (*cursor == '-') + cursor++; + + if (!isdigit(*cursor)) + return -2; + + if (*cursor == '0' && isdigit(cursor[1])) + return -2; + + cursor++; + while (isdigit(*cursor)) + cursor++; + + if (*cursor == '.') + { + cursor++; + fraction = cursor; + if (!isdigit(*cursor)) + return -2; + + cursor++; + while (isdigit(*cursor)) + cursor++; + } + + if (*cursor == 'E' || *cursor == 'e') + { + cursor++; + sign = (*cursor == '-'); + if (sign || *cursor == '+') + cursor++; + + if (!isdigit(*cursor)) + return -2; + + exp = *cursor - '0'; + cursor++; + while (isdigit(*cursor) && exp < 2000000) + { + exp *= 10; + exp += *cursor - '0'; + cursor++; + } + + while (isdigit(*cursor)) + cursor++; + + if (sign) + exp = -exp; + } + + if (cursor - integer > 1000000) + return -2; + + *num = __evaluate_json_number(integer, fraction, exp); + *end = cursor; + return 0; +} + +static void __insert_json_member(json_member_t *memb, struct list_head *pos, + json_object_t *obj) +{ + struct rb_node **p = &obj->root.rb_node; + struct rb_node *parent = NULL; + json_member_t *entry; + + while (*p) + { + parent = *p; + entry = rb_entry(*p, json_member_t, rb); + if (strcmp(memb->name, entry->name) < 0) + p = &(*p)->rb_left; + else + p = &(*p)->rb_right; + } + + rb_link_node(&memb->rb, parent, p); + rb_insert_color(&memb->rb, &obj->root); + list_add(&memb->list, pos); +} + +static int __parse_json_value(const char *cursor, const char **end, + int depth, json_value_t *val); + +static void __destroy_json_value(json_value_t *val); + +static int __parse_json_member(const char *cursor, const char **end, + int depth, json_member_t *memb) +{ + int ret; + + ret = __parse_json_string(cursor, &cursor, memb->name); + if (ret < 0) + return ret; + + while (isspace(*cursor)) + cursor++; + + if (*cursor != ':') + return -2; + + cursor++; + while (isspace(*cursor)) + cursor++; + + ret = __parse_json_value(cursor, &cursor, depth, &memb->value); + if (ret < 0) + return ret; + + *end = cursor; + return 0; +} + +static int __parse_json_members(const char *cursor, const char **end, + int depth, json_object_t *obj) +{ + json_member_t *memb; + int cnt = 0; + int ret; + + while (isspace(*cursor)) + cursor++; + + if (*cursor == '}') + { + *end = cursor + 1; + return 0; + } + + while (1) + { + if (*cursor != '\"') + return -2; + + cursor++; + ret = __json_string_length(cursor); + if (ret < 0) + return ret; + + memb = (json_member_t *)malloc(offsetof(json_member_t, name) + ret + 1); + if (!memb) + return -1; + + ret = __parse_json_member(cursor, &cursor, depth, memb); + if (ret < 0) + { + free(memb); + return ret; + } + + __insert_json_member(memb, obj->head.prev, obj); + cnt++; + + while (isspace(*cursor)) + cursor++; + + if (*cursor == ',') + { + cursor++; + while (isspace(*cursor)) + cursor++; + } + else if (*cursor == '}') + break; + else + return -2; + } + + *end = cursor + 1; + return cnt; +} + +static void __destroy_json_members(json_object_t *obj) +{ + struct list_head *pos, *tmp; + json_member_t *memb; + + list_for_each_safe(pos, tmp, &obj->head) + { + memb = list_entry(pos, json_member_t, list); + __destroy_json_value(&memb->value); + free(memb); + } +} + +static int __parse_json_object(const char *cursor, const char **end, + int depth, json_object_t *obj) +{ + int ret; + + if (depth == JSON_DEPTH_LIMIT) + return -3; + + INIT_LIST_HEAD(&obj->head); + obj->root.rb_node = NULL; + ret = __parse_json_members(cursor, end, depth + 1, obj); + if (ret < 0) + { + __destroy_json_members(obj); + return ret; + } + + obj->size = ret; + return 0; +} + +static int __parse_json_elements(const char *cursor, const char **end, + int depth, json_array_t *arr) +{ + json_element_t *elem; + int cnt = 0; + int ret; + + while (isspace(*cursor)) + cursor++; + + if (*cursor == ']') + { + *end = cursor + 1; + return 0; + } + + while (1) + { + elem = (json_element_t *)malloc(sizeof (json_element_t)); + if (!elem) + return -1; + + ret = __parse_json_value(cursor, &cursor, depth, &elem->value); + if (ret < 0) + { + free(elem); + return ret; + } + + list_add_tail(&elem->list, &arr->head); + cnt++; + + while (isspace(*cursor)) + cursor++; + + if (*cursor == ',') + { + cursor++; + while (isspace(*cursor)) + cursor++; + } + else if (*cursor == ']') + break; + else + return -2; + } + + *end = cursor + 1; + return cnt; +} + +static void __destroy_json_elements(json_array_t *arr) +{ + struct list_head *pos, *tmp; + json_element_t *elem; + + list_for_each_safe(pos, tmp, &arr->head) + { + elem = list_entry(pos, json_element_t, list); + __destroy_json_value(&elem->value); + free(elem); + } +} + +static int __parse_json_array(const char *cursor, const char **end, + int depth, json_array_t *arr) +{ + int ret; + + if (depth == JSON_DEPTH_LIMIT) + return -3; + + INIT_LIST_HEAD(&arr->head); + ret = __parse_json_elements(cursor, end, depth + 1, arr); + if (ret < 0) + { + __destroy_json_elements(arr); + return ret; + } + + arr->size = ret; + return 0; +} + +static int __parse_json_value(const char *cursor, const char **end, + int depth, json_value_t *val) +{ + int ret; + + switch (*cursor) + { + case '\"': + cursor++; + ret = __json_string_length(cursor); + if (ret < 0) + return ret; + + val->value.string = (char *)malloc(ret + 1); + if (!val->value.string) + return -1; + + ret = __parse_json_string(cursor, end, val->value.string); + if (ret < 0) + { + free(val->value.string); + return ret; + } + + val->type = JSON_VALUE_STRING; + break; + + case '-': + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + ret = __parse_json_number(cursor, end, &val->value.number); + if (ret < 0) + return ret; + + val->type = JSON_VALUE_NUMBER; + break; + + case '{': + cursor++; + ret = __parse_json_object(cursor, end, depth, &val->value.object); + if (ret < 0) + return ret; + + val->type = JSON_VALUE_OBJECT; + break; + + case '[': + cursor++; + ret = __parse_json_array(cursor, end, depth, &val->value.array); + if (ret < 0) + return ret; + + val->type = JSON_VALUE_ARRAY; + break; + + case 't': + if (strncmp(cursor, "true", 4) != 0) + return -2; + + *end = cursor + 4; + val->type = JSON_VALUE_TRUE; + break; + + case 'f': + if (strncmp(cursor, "false", 5) != 0) + return -2; + + *end = cursor + 5; + val->type = JSON_VALUE_FALSE; + break; + + case 'n': + if (strncmp(cursor, "null", 4) != 0) + return -2; + + *end = cursor + 4; + val->type = JSON_VALUE_NULL; + break; + + default: + return -2; + } + + return 0; +} + +static void __destroy_json_value(json_value_t *val) +{ + switch (val->type) + { + case JSON_VALUE_STRING: + free(val->value.string); + break; + + case JSON_VALUE_OBJECT: + __destroy_json_members(&val->value.object); + break; + + case JSON_VALUE_ARRAY: + __destroy_json_elements(&val->value.array); + break; + } +} + +json_value_t *json_value_parse(const char *cursor) +{ + json_value_t *val; + + val = (json_value_t *)malloc(sizeof (json_value_t)); + if (!val) + return NULL; + + while (isspace(*cursor)) + cursor++; + + if (__parse_json_value(cursor, &cursor, 0, val) >= 0) + { + while (isspace(*cursor)) + cursor++; + + if (*cursor == '\0') + return val; + + __destroy_json_value(val); + } + + free(val); + return NULL; +} + +static void __move_json_value(json_value_t *src, json_value_t *dest) +{ + switch (src->type) + { + case JSON_VALUE_STRING: + dest->value.string = src->value.string; + break; + + case JSON_VALUE_NUMBER: + dest->value.number = src->value.number; + break; + + case JSON_VALUE_OBJECT: + INIT_LIST_HEAD(&dest->value.object.head); + list_splice(&src->value.object.head, &dest->value.object.head); + dest->value.object.root.rb_node = src->value.object.root.rb_node; + dest->value.object.size = src->value.object.size; + break; + + case JSON_VALUE_ARRAY: + INIT_LIST_HEAD(&dest->value.array.head); + list_splice(&src->value.array.head, &dest->value.array.head); + dest->value.array.size = src->value.array.size; + break; + } + + dest->type = src->type; +} + +static int __set_json_value(int type, va_list ap, json_value_t *val) +{ + json_value_t *src; + const char *str; + int len; + + switch (type) + { + case 0: + src = va_arg(ap, json_value_t *); + __move_json_value(src, val); + free(src); + return 0; + + case JSON_VALUE_STRING: + str = va_arg(ap, const char *); + len = strlen(str); + val->value.string = (char *)malloc(len + 1); + if (!val->value.string) + return -1; + + memcpy(val->value.string, str, len + 1); + break; + + case JSON_VALUE_NUMBER: + val->value.number = va_arg(ap, double); + break; + + case JSON_VALUE_OBJECT: + INIT_LIST_HEAD(&val->value.object.head); + val->value.object.root.rb_node = NULL; + val->value.object.size = 0; + break; + + case JSON_VALUE_ARRAY: + INIT_LIST_HEAD(&val->value.array.head); + val->value.array.size = 0; + break; + } + + val->type = type; + return 0; +} + +json_value_t *json_value_create(int type, ...) +{ + json_value_t *val; + va_list ap; + int ret; + + val = (json_value_t *)malloc(sizeof (json_value_t)); + if (!val) + return NULL; + + va_start(ap, type); + ret = __set_json_value(type, ap, val); + va_end(ap); + if (ret < 0) + { + free(val); + return NULL; + } + + return val; +} + +static int __copy_json_value(const json_value_t *src, json_value_t *dest); + +static int __copy_json_members(const json_object_t *src, json_object_t *dest) +{ + struct list_head *pos; + json_member_t *entry; + json_member_t *memb; + int len; + + list_for_each(pos, &src->head) + { + entry = list_entry(pos, json_member_t, list); + len = strlen(entry->name); + memb = (json_member_t *)malloc(offsetof(json_member_t, name) + len + 1); + if (!memb) + return -1; + + if (__copy_json_value(&entry->value, &memb->value) < 0) + { + free(memb); + return -1; + } + + memcpy(memb->name, entry->name, len + 1); + __insert_json_member(memb, dest->head.prev, dest); + } + + return src->size; +} + +static int __copy_json_elements(const json_array_t *src, json_array_t *dest) +{ + struct list_head *pos; + json_element_t *entry; + json_element_t *elem; + + list_for_each(pos, &src->head) + { + elem = (json_element_t *)malloc(sizeof (json_element_t)); + if (!elem) + return -1; + + entry = list_entry(pos, json_element_t, list); + if (__copy_json_value(&entry->value, &elem->value) < 0) + { + free(elem); + return -1; + } + + list_add_tail(&elem->list, &dest->head); + } + + return src->size; +} + +static int __copy_json_value(const json_value_t *src, json_value_t *dest) +{ + int len; + + switch (src->type) + { + case JSON_VALUE_STRING: + len = strlen(src->value.string); + dest->value.string = (char *)malloc(len + 1); + if (!dest->value.string) + return -1; + + memcpy(dest->value.string, src->value.string, len + 1); + break; + + case JSON_VALUE_NUMBER: + dest->value.number = src->value.number; + break; + + case JSON_VALUE_OBJECT: + INIT_LIST_HEAD(&dest->value.object.head); + dest->value.object.root.rb_node = NULL; + if (__copy_json_members(&src->value.object, &dest->value.object) < 0) + { + __destroy_json_members(&dest->value.object); + return -1; + } + + dest->value.object.size = src->value.object.size; + break; + + case JSON_VALUE_ARRAY: + INIT_LIST_HEAD(&dest->value.array.head); + if (__copy_json_elements(&src->value.array, &dest->value.array) < 0) + { + __destroy_json_elements(&dest->value.array); + return -1; + } + + dest->value.array.size = src->value.array.size; + break; + } + + dest->type = src->type; + return 0; +} + +json_value_t *json_value_copy(const json_value_t *val) +{ + json_value_t *copy; + + copy = (json_value_t *)malloc(sizeof (json_value_t)); + if (!copy) + return NULL; + + if (__copy_json_value(val, copy) < 0) + { + free(copy); + return NULL; + } + + return copy; +} + +void json_value_destroy(json_value_t *val) +{ + __destroy_json_value(val); + free(val); +} + +int json_value_type(const json_value_t *val) +{ + return val->type; +} + +const char *json_value_string(const json_value_t *val) +{ + if (val->type != JSON_VALUE_STRING) + return NULL; + + return val->value.string; +} + +double json_value_number(const json_value_t *val) +{ + if (val->type != JSON_VALUE_NUMBER) + return NAN; + + return val->value.number; +} + +json_object_t *json_value_object(const json_value_t *val) +{ + if (val->type != JSON_VALUE_OBJECT) + return NULL; + + return (json_object_t *)&val->value.object; +} + +json_array_t *json_value_array(const json_value_t *val) +{ + if (val->type != JSON_VALUE_ARRAY) + return NULL; + + return (json_array_t *)&val->value.array; +} + +const json_value_t *json_object_find(const char *name, + const json_object_t *obj) +{ + struct rb_node *p = obj->root.rb_node; + json_member_t *memb; + int n; + + while (p) + { + memb = rb_entry(p, json_member_t, rb); + n = strcmp(name, memb->name); + if (n < 0) + p = p->rb_left; + else if (n > 0) + p = p->rb_right; + else + return &memb->value; + } + + return NULL; +} + +int json_object_size(const json_object_t *obj) +{ + return obj->size; +} + +const char *json_object_next_name(const char *name, + const json_object_t *obj) +{ + const struct list_head *pos; + + if (name) + pos = &list_entry(name, json_member_t, name)->list; + else + pos = &obj->head; + + if (pos->next == &obj->head) + return NULL; + + return list_entry(pos->next, json_member_t, list)->name; +} + +const json_value_t *json_object_next_value(const json_value_t *val, + const json_object_t *obj) +{ + const struct list_head *pos; + + if (val) + pos = &list_entry(val, json_member_t, value)->list; + else + pos = &obj->head; + + if (pos->next == &obj->head) + return NULL; + + return &list_entry(pos->next, json_member_t, list)->value; +} + +const char *json_object_prev_name(const char *name, + const json_object_t *obj) +{ + const struct list_head *pos; + + if (name) + pos = &list_entry(name, json_member_t, name)->list; + else + pos = &obj->head; + + if (pos->prev == &obj->head) + return NULL; + + return list_entry(pos->prev, json_member_t, list)->name; +} + +const json_value_t *json_object_prev_value(const json_value_t *val, + const json_object_t *obj) +{ + const struct list_head *pos; + + if (val) + pos = &list_entry(val, json_member_t, value)->list; + else + pos = &obj->head; + + if (pos->prev == &obj->head) + return NULL; + + return &list_entry(pos->prev, json_member_t, list)->value; +} + +static const json_value_t *__json_object_insert(const char *name, + int type, va_list ap, + struct list_head *pos, + json_object_t *obj) +{ + json_member_t *memb; + int len; + + len = strlen(name); + memb = (json_member_t *)malloc(offsetof(json_member_t, name) + len + 1); + if (!memb) + return NULL; + + memcpy(memb->name, name, len + 1); + if (__set_json_value(type, ap, &memb->value) < 0) + { + free(memb); + return NULL; + } + + __insert_json_member(memb, pos, obj); + obj->size++; + return &memb->value; +} + +const json_value_t *json_object_append(json_object_t *obj, + const char *name, + int type, ...) +{ + const json_value_t *val; + va_list ap; + + va_start(ap, type); + val = __json_object_insert(name, type, ap, obj->head.prev, obj); + va_end(ap); + return val; +} + +const json_value_t *json_object_insert_after(const json_value_t *val, + json_object_t *obj, + const char *name, + int type, ...) +{ + struct list_head *pos; + va_list ap; + + if (val) + pos = &list_entry(val, json_member_t, value)->list; + else + pos = &obj->head; + + va_start(ap, type); + val = __json_object_insert(name, type, ap, pos, obj); + va_end(ap); + return val; +} + +const json_value_t *json_object_insert_before(const json_value_t *val, + json_object_t *obj, + const char *name, + int type, ...) +{ + struct list_head *pos; + va_list ap; + + if (val) + pos = &list_entry(val, json_member_t, value)->list; + else + pos = &obj->head; + + va_start(ap, type); + val = __json_object_insert(name, type, ap, pos->prev, obj); + va_end(ap); + return val; +} + +json_value_t *json_object_remove(const json_value_t *val, + json_object_t *obj) +{ + json_member_t *memb = list_entry(val, json_member_t, value); + + val = (json_value_t *)malloc(sizeof (json_value_t)); + if (!val) + return NULL; + + list_del(&memb->list); + rb_erase(&memb->rb, &obj->root); + obj->size--; + + __move_json_value(&memb->value, (json_value_t *)val); + free(memb); + return (json_value_t *)val; +} + +int json_array_size(const json_array_t *arr) +{ + return arr->size; +} + +const json_value_t *json_array_next_value(const json_value_t *val, + const json_array_t *arr) +{ + const struct list_head *pos; + + if (val) + pos = &list_entry(val, json_element_t, value)->list; + else + pos = &arr->head; + + if (pos->next == &arr->head) + return NULL; + + return &list_entry(pos->next, json_element_t, list)->value; +} + +const json_value_t *json_array_prev_value(const json_value_t *val, + const json_array_t *arr) +{ + const struct list_head *pos; + + if (val) + pos = &list_entry(val, json_element_t, value)->list; + else + pos = &arr->head; + + if (pos->prev == &arr->head) + return NULL; + + return &list_entry(pos->prev, json_element_t, list)->value; +} + +static const json_value_t *__json_array_insert(int type, va_list ap, + struct list_head *pos, + json_array_t *arr) +{ + json_element_t *elem; + + elem = (json_element_t *)malloc(sizeof (json_element_t)); + if (!elem) + return NULL; + + if (__set_json_value(type, ap, &elem->value) < 0) + { + free(elem); + return NULL; + } + + list_add(&elem->list, pos); + arr->size++; + return &elem->value; +} + +const json_value_t *json_array_append(json_array_t *arr, + int type, ...) +{ + const json_value_t *val; + va_list ap; + + va_start(ap, type); + val = __json_array_insert(type, ap, arr->head.prev, arr); + va_end(ap); + return val; +} + +const json_value_t *json_array_insert_after(const json_value_t *val, + json_array_t *arr, + int type, ...) +{ + struct list_head *pos; + va_list ap; + + if (val) + pos = &list_entry(val, json_element_t, value)->list; + else + pos = &arr->head; + + va_start(ap, type); + val = __json_array_insert(type, ap, pos, arr); + va_end(ap); + return val; +} + +const json_value_t *json_array_insert_before(const json_value_t *val, + json_array_t *arr, + int type, ...) +{ + struct list_head *pos; + va_list ap; + + if (val) + pos = &list_entry(val, json_element_t, value)->list; + else + pos = &arr->head; + + va_start(ap, type); + val = __json_array_insert(type, ap, pos->prev, arr); + va_end(ap); + return val; +} + +json_value_t *json_array_remove(const json_value_t *val, + json_array_t *arr) +{ + json_element_t *elem = list_entry(val, json_element_t, value); + + val = (json_value_t *)malloc(sizeof (json_value_t)); + if (!val) + return NULL; + + list_del(&elem->list); + arr->size--; + + __move_json_value(&elem->value, (json_value_t *)val); + free(elem); + return (json_value_t *)val; +} + diff --git a/src/util/json_parser.h b/src/util/json_parser.h new file mode 100644 index 0000000..4f21ebd --- /dev/null +++ b/src/util/json_parser.h @@ -0,0 +1,114 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _JSON_PARSER_H_ +#define _JSON_PARSER_H_ + +#include + +#define JSON_VALUE_STRING 1 +#define JSON_VALUE_NUMBER 2 +#define JSON_VALUE_OBJECT 3 +#define JSON_VALUE_ARRAY 4 +#define JSON_VALUE_TRUE 5 +#define JSON_VALUE_FALSE 6 +#define JSON_VALUE_NULL 7 + +typedef struct __json_value json_value_t; +typedef struct __json_object json_object_t; +typedef struct __json_array json_array_t; + +#ifdef __cplusplus +extern "C" +{ +#endif + +json_value_t *json_value_parse(const char *text); +json_value_t *json_value_create(int type, ...); +json_value_t *json_value_copy(const json_value_t *val); +void json_value_destroy(json_value_t *val); + +int json_value_type(const json_value_t *val); +const char *json_value_string(const json_value_t *val); +double json_value_number(const json_value_t *val); +json_object_t *json_value_object(const json_value_t *val); +json_array_t *json_value_array(const json_value_t *val); + +const json_value_t *json_object_find(const char *name, + const json_object_t *obj); +int json_object_size(const json_object_t *obj); +const char *json_object_next_name(const char *name, + const json_object_t *obj); +const json_value_t *json_object_next_value(const json_value_t *val, + const json_object_t *obj); +const char *json_object_prev_name(const char *name, + const json_object_t *obj); +const json_value_t *json_object_prev_value(const json_value_t *val, + const json_object_t *obj); +const json_value_t *json_object_append(json_object_t *obj, + const char *name, + int type, ...); +const json_value_t *json_object_insert_after(const json_value_t *val, + json_object_t *obj, + const char *name, + int type, ...); +const json_value_t *json_object_insert_before(const json_value_t *val, + json_object_t *obj, + const char *name, + int type, ...); +json_value_t *json_object_remove(const json_value_t *val, + json_object_t *obj); + +int json_array_size(const json_array_t *arr); +const json_value_t *json_array_next_value(const json_value_t *val, + const json_array_t *arr); +const json_value_t *json_array_prev_value(const json_value_t *val, + const json_array_t *arr); +const json_value_t *json_array_append(json_array_t *arry, + int type, ...); +const json_value_t *json_array_insert_after(const json_value_t *val, + json_array_t *arr, + int type, ...); +const json_value_t *json_array_insert_before(const json_value_t *val, + json_array_t *arr, + int type, ...); +json_value_t *json_array_remove(const json_value_t *val, + json_array_t *arr); + +#ifdef __cplusplus +} +#endif + +#define json_object_for_each(name, val, obj) \ + for (name = NULL, val = NULL; \ + name = json_object_next_name(name, obj), \ + val = json_object_next_value(val, obj), val; ) + +#define json_object_for_each_prev(name, val, obj) \ + for (name = NULL, val = NULL; \ + name = json_object_prev_name(name, obj), \ + val = json_object_prev_value(val, obj), val; ) + +#define json_array_for_each(val, arr) \ + for (val = NULL; val = json_array_next_value(val, arr), val; ) + +#define json_array_for_each_prev(val, arr) \ + for (val = NULL; val = json_array_prev_value(val, arr), val; ) + +#endif + diff --git a/src/util/xmake.lua b/src/util/xmake.lua new file mode 100644 index 0000000..6559045 --- /dev/null +++ b/src/util/xmake.lua @@ -0,0 +1,13 @@ +target("util") + set_kind("object") + add_files("*.c") + add_files("*.cc") + remove_files("crc32c.c") + +target("kafka_util") + if has_config("kafka") then + set_kind("object") + add_files("crc32c.c") + else + set_kind("phony") + end diff --git a/src/xmake.lua b/src/xmake.lua new file mode 100644 index 0000000..7f8d103 --- /dev/null +++ b/src/xmake.lua @@ -0,0 +1,62 @@ +includes("**/xmake.lua") + +after_build(function (target) + local lib_dir = get_config("workflow_lib") + if (not os.isdir(lib_dir)) then + os.mkdir(lib_dir) + end + shared_suffix = "*.so" + if is_plat("macosx") then + shared_suffix = "*.dylib" + end + if target:is_static() then + os.mv(path.join("$(projectdir)", target:targetdir(), "*.a"), lib_dir) + else + os.mv(path.join("$(projectdir)", target:targetdir(), shared_suffix), lib_dir) + end +end) + +target("workflow") + set_kind("$(kind)") + add_deps("client", "factory", "kernel", "manager", + "nameservice", "protocol", "server", "util") + + on_load(function (package) + local include_path = path.join(get_config("workflow_inc"), "workflow") + if (not os.isdir(include_path)) then + os.mkdir(include_path) + end + + os.cp(path.join("$(projectdir)", "src/include/**.h"), include_path) + os.cp(path.join("$(projectdir)", "src/include/**.inl"), include_path) + end) + + after_clean(function (target) + os.rm(get_config("workflow_inc")) + os.rm(get_config("workflow_lib")) + os.rm("$(buildir)") + end) + + on_install(function (target) + os.mkdir(path.join(target:installdir(), "include/workflow")) + os.mkdir(path.join(target:installdir(), "lib")) + os.cp(path.join(get_config("workflow_inc"), "workflow"), path.join(target:installdir(), "include")) + shared_suffix = "*.so" + if is_plat("macosx") then + shared_suffix = "*.dylib" + end + if target:is_static() then + os.cp(path.join(get_config("workflow_lib"), "*.a"), path.join(target:installdir(), "lib")) + else + os.cp(path.join(get_config("workflow_lib"), shared_suffix), path.join(target:installdir(), "lib")) + end + end) + +target("wfkafka") + if has_config("kafka") then + set_kind("$(kind)") + add_deps("kafka_client", "kafka_factory", "kafka_protocol", "kafka_util", "workflow") + else + set_kind("phony") + end +