From 8a8c117accd0a5238243fad4e60e828bbfcd3ff5 Mon Sep 17 00:00:00 2001 From: zhaojinze <1799174451@qq.com> Date: Mon, 16 Dec 2024 23:41:18 +0800 Subject: [PATCH] zhaojinze --- src/protocol/ConsulDataTypes.h.txt | 302 ++++++++++++++++++++ src/protocol/mysql_parser.c | 433 ++++++----------------------- src/protocol/mysql_stream.c | 91 ++---- 3 files changed, 411 insertions(+), 415 deletions(-) create mode 100644 src/protocol/ConsulDataTypes.h.txt diff --git a/src/protocol/ConsulDataTypes.h.txt b/src/protocol/ConsulDataTypes.h.txt new file mode 100644 index 0000000..ec14c6d --- /dev/null +++ b/src/protocol/ConsulDataTypes.h.txt @@ -0,0 +1,302 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhenpeng (wangzhenpeng@sogou-inc.com) +*/ + +#ifndef _CONSULDATATYPES_H_ +#define _CONSULDATATYPES_H_ + +#include +#include +#include +#include +#include + +namespace protocol +{ + +/** + * @class ConsulConfig + * @brief 配置类,用于管理 Consul 的各种设置。 + */ +class ConsulConfig +{ +public: + // === Common Config === + // 设置和获取 token 配置,用于鉴权 + void set_token(const std::string& token) { this->ptr->token = token; } + std::string get_token() const { return this->ptr->token; } + + // === Discover Config === + // 设置和获取数据中心名称 + void set_datacenter(const std::string& data_center) { this->ptr->dc = data_center; } + std::string get_datacenter() const { return this->ptr->dc; } + + // 设置和获取指定的邻近节点 + void set_near_node(const std::string& near_node) { this->ptr->near = near_node; } + std::string get_near_node() const { return this->ptr->near; } + + // 设置和获取过滤表达式 + void set_filter_expr(const std::string& filter_expr) { this->ptr->filter = filter_expr; } + std::string get_filter_expr() const { return this->ptr->filter; } + + // 设置和获取阻塞查询的等待时间(单位:毫秒) + void set_wait_ttl(int wait_ttl) { this->ptr->wait_ttl = wait_ttl; } + int get_wait_ttl() const { return this->ptr->wait_ttl; } + + // 启用或禁用阻塞查询 + void set_blocking_query(bool enable_flag) { this->ptr->blocking_query = enable_flag; } + bool blocking_query() const { return this->ptr->blocking_query; } + + // 设置和获取仅获取健康通过状态的服务实例 + void set_passing(bool passing) { this->ptr->passing = passing; } + bool get_passing() const { return this->ptr->passing; } + + // === Register Config === + // 设置和获取是否替换现有检查 + void set_replace_checks(bool replace_checks) { this->ptr->replace_checks = replace_checks; } + bool get_replace_checks() const { return this->ptr->replace_checks; } + + // 设置和获取健康检查的名称 + void set_check_name(const std::string& check_name) { this->ptr->check_cfg.check_name = check_name; } + std::string get_check_name() const { return this->ptr->check_cfg.check_name; } + + // 设置和获取 HTTP 检查 URL + void set_check_http_url(const std::string& http_url) { this->ptr->check_cfg.http_url = http_url; } + std::string get_check_http_url() const { return this->ptr->check_cfg.http_url; } + + // 设置和获取 HTTP 检查方法 + void set_check_http_method(const std::string& method) { this->ptr->check_cfg.http_method = method; } + std::string get_check_http_method() const { return this->ptr->check_cfg.http_method; } + + // 添加 HTTP 请求头 + void add_http_header(const std::string& key, const std::vector& values) + { + this->ptr->check_cfg.headers.emplace(key, values); + } + const std::map>* get_http_headers() const + { + return &this->ptr->check_cfg.headers; + } + + // 设置和获取 HTTP 检查请求体 + void set_http_body(const std::string& body) { this->ptr->check_cfg.http_body = body; } + std::string get_http_body() const { return this->ptr->check_cfg.http_body; } + + // 设置和获取检查的时间间隔 + void set_check_interval(int interval) { this->ptr->check_cfg.interval = interval; } + int get_check_interval() const { return this->ptr->check_cfg.interval; } + + // 设置和获取检查超时时间 + void set_check_timeout(int timeout) { this->ptr->check_cfg.timeout = timeout; } + int get_check_timeout() const { return this->ptr->check_cfg.timeout; } + + // 设置和获取健康检查的备注 + void set_check_notes(const std::string& notes) { this->ptr->check_cfg.notes = notes; } + std::string get_check_notes() const { return this->ptr->check_cfg.notes; } + + // 设置和获取 TCP 地址 + void set_check_tcp(const std::string& tcp_address) { this->ptr->check_cfg.tcp_address = tcp_address; } + std::string get_check_tcp() const { return this->ptr->check_cfg.tcp_address; } + + // 设置和获取初始状态 + void set_initial_status(const std::string& initial_status) { this->ptr->check_cfg.initial_status = initial_status; } + std::string get_initial_status() const { return this->ptr->check_cfg.initial_status; } + + // 设置和获取自动注销时间 + void set_auto_deregister_time(int milliseconds) + { + this->ptr->check_cfg.auto_deregister_time = milliseconds; + } + int get_auto_deregister_time() const { return this->ptr->check_cfg.auto_deregister_time; } + + // 设置和获取通过前的成功次数 + void set_success_times(int times) { this->ptr->check_cfg.success_times = times; } + int get_success_times() const { return this->ptr->check_cfg.success_times; } + + // 设置和获取进入临界状态前的失败次数 + void set_failure_times(int times) { this->ptr->check_cfg.failure_times = times; } + int get_failure_times() const { return this->ptr->check_cfg.failure_times; } + + // 启用或禁用健康检查 + void set_health_check(bool enable_flag) { this->ptr->check_cfg.health_check = enable_flag; } + bool get_health_check() const { return this->ptr->check_cfg.health_check; } + +public: + // 构造函数,初始化默认配置 + ConsulConfig() + { + this->ptr = new Config; + this->ptr->blocking_query = false; + this->ptr->passing = false; + this->ptr->replace_checks = false; + this->ptr->wait_ttl = 300 * 1000; + this->ptr->check_cfg.interval = 5000; + this->ptr->check_cfg.timeout = 10000; + this->ptr->check_cfg.http_method = "GET"; + this->ptr->check_cfg.initial_status = "critical"; + this->ptr->check_cfg.auto_deregister_time = 10 * 60 * 1000; + this->ptr->check_cfg.success_times = 0; + this->ptr->check_cfg.failure_times = 0; + this->ptr->check_cfg.health_check = false; + + this->ref = new std::atomic(1); + } + + // 析构函数,释放内存 + virtual ~ConsulConfig() + { + if (--*this->ref == 0) + { + delete this->ptr; + delete this->ref; + } + } + + // 移动构造函数 + ConsulConfig(ConsulConfig&& move) + { + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new Config; + move.ref = new std::atomic(1); + } + + // 拷贝构造函数 + ConsulConfig(const ConsulConfig& copy) + { + this->ptr = copy.ptr; + this->ref = copy.ref; + ++(*this->ref); + } + + // 移动赋值运算符 + ConsulConfig& operator= (ConsulConfig&& move) + { + if (this != &move) + { + this->~ConsulConfig(); + this->ptr = move.ptr; + this->ref = move.ref; + move.ptr = new Config; + move.ref = new std::atomic(1); + } + return *this; + } + + // 拷贝赋值运算符 + ConsulConfig& operator= (const ConsulConfig& copy) + { + if (this != ©) + { + this->~ConsulConfig(); + this->ptr = copy.ptr; + this->ref = copy.ref; + ++(*this->ref); + } + return *this; + } + +private: + // 健康检查配置结构体 + struct HealthCheckConfig + { + std::string check_name; // 检查名称 + std::string notes; // 备注 + std::string http_url; // HTTP URL + std::string http_method; // HTTP 方法 + std::string http_body; // HTTP 请求体 + std::string tcp_address; // TCP 地址 + std::string initial_status; // 初始状态 + std::map> headers; // 请求头 + int auto_deregister_time; // 自动注销时间 + int interval; // 时间间隔 + int timeout; // 超时时间 + int success_times; // 成功次数 + int failure_times; // 失败次数 + bool health_check; // 是否启用健康检查 + }; + + // 配置结构体 + struct Config + { + // 通用配置 + std::string token; + + // 服务发现配置 + std::string dc; // 数据中心 + std::string near; // 邻近节点 + std::string filter; // 过滤表达式 + int wait_ttl; // 等待时间 + bool blocking_query; // 阻塞查询 + bool passing; // 健康状态 + + // 注册配置 + bool replace_checks; // 替换现有检查 + HealthCheckConfig check_cfg; // 健康检查配置 + }; + +private: + struct Config *ptr; // 配置指针 + std::atomic *ref; // 引用计数 +}; + +// 地址类型,包含地址和端口 +using ConsulAddress = std::pair; + +/** + * @struct ConsulService + * @brief 表示一个服务的基本信息。 + */ +struct ConsulService +{ + std::string service_name; // 服务名称 + std::string service_namespace; // 服务命名空间 + std::string service_id; // 服务 ID + std::vector tags; // 服务标签 + ConsulAddress service_address; // 服务地址 + ConsulAddress lan; // 局域网地址 + ConsulAddress lan_ipv4; // 局域网 IPv4 地址 + ConsulAddress lan_ipv6; // 局域网 IPv6 地址 + ConsulAddress virtual_address; // 虚拟地址 + ConsulAddress wan; // 广域网地址 + ConsulAddress wan_ipv4; // 广域网 IPv4 地址 + ConsulAddress wan_ipv6; // 广域网 IPv6 地址 + std::map meta; // 元信息 + bool tag_override; // 标签覆盖标志 +}; + +/** + * @struct ConsulServiceInstance + * @brief 表示服务实例,包括节点和健康检查信息。 + */ +struct ConsulServiceInstance +{ + // 节点信息 + std::string node_id; // 节点 ID + std::string node_name; // 节点名称 + std::string node_address; // 节点地址 + std::string dc; // 数据中心 + std::map node_meta; // 节点元信息 + long long create_index; // 创建索引 + long long modify_index; // 修改索引 + + // 服务信息 + struct ConsulService service; + + // 健康检查信息 + std::string check_name; // 检查名称 + std diff --git a/src/protocol/mysql_parser.c b/src/protocol/mysql_parser.c index babe03b..efe6d2a 100644 --- a/src/protocol/mysql_parser.c +++ b/src/protocol/mysql_parser.c @@ -21,40 +21,51 @@ #include "mysql_byteorder.h" #include "mysql_parser.h" +// === 静态函数声明 === +// 用于处理不同类型的 MySQL 数据包,分别解析基础数据包、错误包、OK 包、EOF 包等。 static int parse_base_packet(const void *buf, size_t len, mysql_parser_t *parser); - static int parse_error_packet(const void *buf, size_t len, mysql_parser_t *parser); - static int parse_ok_packet(const void *buf, size_t len, mysql_parser_t *parser); - static int parse_eof_packet(const void *buf, size_t len, mysql_parser_t *parser); - static int parse_field_eof_packet(const void *buf, size_t len, mysql_parser_t *parser); - static int parse_field_count(const void *buf, size_t len, mysql_parser_t *parser); - static int parse_column_def_packet(const void *buf, size_t len, mysql_parser_t *parser); - static int parse_local_inline(const void *buf, size_t len, mysql_parser_t *parser); - static int parse_row_packet(const void *buf, size_t len, mysql_parser_t *parser); +// === MySQL Parser 初始化与销毁 === + +/** + * @brief 初始化 MySQL 数据包解析器。 + * + * 设置解析器的初始状态,例如偏移量、命令类型和数据包类型。 + * + * @param parser 指向待初始化的解析器对象。 + */ void mysql_parser_init(mysql_parser_t *parser) { - parser->offset = 0; - parser->cmd = MYSQL_COM_QUERY; - parser->packet_type = MYSQL_PACKET_OTHER; - parser->parse = parse_base_packet; - parser->result_set_count = 0; - INIT_LIST_HEAD(&parser->result_set_list); + parser->offset = 0; // 初始化数据偏移量 + parser->cmd = MYSQL_COM_QUERY; // 默认命令类型为查询 + parser->packet_type = MYSQL_PACKET_OTHER; // 数据包类型默认为其他 + parser->parse = parse_base_packet; // 设置初始解析函数 + parser->result_set_count = 0; // 初始化结果集计数 + INIT_LIST_HEAD(&parser->result_set_list); // 初始化结果集链表 } +/** + * @brief 销毁 MySQL 数据包解析器。 + * + * 清理解析器中分配的内存,释放所有结果集的资源。 + * + * @param parser 指向待销毁的解析器对象。 + */ void mysql_parser_deinit(mysql_parser_t *parser) { struct __mysql_result_set *result_set; struct list_head *pos, *tmp; int i; + // 遍历结果集链表,释放每个结果集的内存 list_for_each_safe(pos, tmp, &parser->result_set_list) { result_set = list_entry(pos, struct __mysql_result_set, list); @@ -72,17 +83,28 @@ void mysql_parser_deinit(mysql_parser_t *parser) } } +// === 解析器核心逻辑 === + +/** + * @brief 解析 MySQL 数据包。 + * + * 该函数根据当前解析器的状态调用对应的解析函数。 + * + * @param buf 待解析的数据缓冲区。 + * @param len 缓冲区的长度。 + * @param parser 解析器对象。 + * @return 成功返回 0,解析完成返回 1,失败返回负值。 + */ int mysql_parser_parse(const void *buf, size_t len, mysql_parser_t *parser) { -// const char *end = (const char *)buf + len; int ret; do { - ret = parser->parse(buf, len, parser); + ret = parser->parse(buf, len, parser); // 调用当前的解析函数 if (ret < 0) return ret; - if (ret > 0 && parser->offset != len) + if (ret > 0 && parser->offset != len) // 检查是否完全解析 return -2; } while (parser->offset < len); @@ -90,6 +112,15 @@ int mysql_parser_parse(const void *buf, size_t len, mysql_parser_t *parser) return ret; } +/** + * @brief 获取网络状态信息。 + * + * 从解析器中提取当前的网络状态字符串及其长度。 + * + * @param net_state_str 返回的状态字符串指针。 + * @param net_state_len 返回的状态字符串长度。 + * @param parser 解析器对象。 + */ void mysql_parser_get_net_state(const char **net_state_str, size_t *net_state_len, mysql_parser_t *parser) @@ -98,6 +129,15 @@ void mysql_parser_get_net_state(const char **net_state_str, *net_state_len = MYSQL_STATE_LENGTH; } +/** + * @brief 获取错误信息。 + * + * 从解析器中提取当前的错误消息及其长度。 + * + * @param err_msg_str 返回的错误消息字符串指针。 + * @param err_msg_len 返回的错误消息字符串长度。 + * @param parser 解析器对象。 + */ void mysql_parser_get_err_msg(const char **err_msg_str, size_t *err_msg_len, mysql_parser_t *parser) @@ -112,30 +152,37 @@ void mysql_parser_get_err_msg(const char **err_msg_str, } } +// === 数据包解析函数 === + +/** + * @brief 解析基础数据包。 + * + * 根据数据包头部字节判断数据包类型,并切换解析函数。 + * + * @param buf 数据缓冲区。 + * @param len 缓冲区长度。 + * @param parser 解析器对象。 + * @return 成功返回 0。 + */ static int parse_base_packet(const void *buf, size_t len, mysql_parser_t *parser) { const unsigned char *p = (const unsigned char *)buf + parser->offset; switch (*p) { - // OK PACKET - case MYSQL_PACKET_HEADER_OK: + case MYSQL_PACKET_HEADER_OK: // OK 包 parser->parse = parse_ok_packet; break; - // ERR PACKET - case MYSQL_PACKET_HEADER_ERROR: + case MYSQL_PACKET_HEADER_ERROR: // 错误包 parser->parse = parse_error_packet; break; - // EOF PACKET - case MYSQL_PACKET_HEADER_EOF: + case MYSQL_PACKET_HEADER_EOF: // EOF 包 parser->parse = parse_eof_packet; break; - // LOCAL INFILE PACKET - case MYSQL_PACKET_HEADER_NULL: - // if (field_count == -1) + case MYSQL_PACKET_HEADER_NULL: // LOCAL INFILE 包 parser->parse = parse_local_inline; break; - default: + default: // 字段计数包 parser->parse = parse_field_count; break; } @@ -143,7 +190,18 @@ static int parse_base_packet(const void *buf, size_t len, mysql_parser_t *parser return 0; } -// 1:0xFF|2:err_no|1:#|5:server_state|0-512:err_msg +// === 示例: 解析错误包 === + +/** + * @brief 解析错误包。 + * + * 提取错误码、服务器状态和错误消息。 + * + * @param buf 数据缓冲区。 + * @param len 缓冲区长度。 + * @param parser 解析器对象。 + * @return 成功返回 1,数据不足返回 -2。 + */ static int parse_error_packet(const void *buf, size_t len, mysql_parser_t *parser) { const unsigned char *p = (const unsigned char *)buf + parser->offset; @@ -152,10 +210,10 @@ static int parse_error_packet(const void *buf, size_t len, mysql_parser_t *parse if (p + 9 > buf_end) return -2; - parser->error = uint2korr(p + 1); + parser->error = uint2korr(p + 1); // 提取错误码 p += 3; - if (*p == '#') + if (*p == '#') // 提取服务器状态 { p += 1; parser->net_state_offset = p - (const unsigned char *)buf; @@ -165,317 +223,4 @@ static int parse_error_packet(const void *buf, size_t len, mysql_parser_t *parse parser->err_msg_offset = p - (const unsigned char *)buf; parser->err_msg_len = msg_len; } else { - parser->err_msg_offset = (size_t)-1; - parser->err_msg_len = 0; - } - - parser->offset = len; - parser->packet_type = MYSQL_PACKET_ERROR; - parser->buf = buf; - return 1; -} - -// 1:0x00|1-9:affect_row|1-9:insert_id|2:server_status|2:warning_count|0-n:server_msg -static int parse_ok_packet(const void *buf, size_t len, mysql_parser_t *parser) -{ - const unsigned char *p = (const unsigned char *)buf + parser->offset; - const unsigned char *buf_end = (const unsigned char *)buf + len; - - unsigned long long affected_rows, insert_id, info_len; - const unsigned char *str; - struct __mysql_result_set *result_set; - unsigned int warning_count; - int server_status; - - p += 1;// 0x00 - if (decode_length_safe(&affected_rows, &p, buf_end) <= 0) - return -2; - - if (decode_length_safe(&insert_id, &p, buf_end) <= 0) - return -2; - - if (p + 4 > buf_end) - return -2; - - server_status = uint2korr(p); - p += 2; - warning_count = uint2korr(p); - p += 2; - - if (p != buf_end) - { - if (decode_string(&str, &info_len, &p, buf_end) == 0) - return -2; - - if (p != buf_end) - { - if (server_status & MYSQL_SERVER_SESSION_STATE_CHANGED) - { - const unsigned char *tmp_str; - unsigned long long tmp_len; - if (decode_string(&tmp_str, &tmp_len, &p, buf_end) == 0) - return -2; - } else - return -2; - } - } else { - str = p; - info_len = 0; - } - - result_set = (struct __mysql_result_set *)malloc(sizeof(struct __mysql_result_set)); - if (result_set == NULL) - return -1; - - result_set->info_offset = str - (const unsigned char *)buf; - result_set->info_len = info_len; - result_set->affected_rows = (affected_rows == ~0ULL) ? 0 : affected_rows; - result_set->insert_id = (insert_id == ~0ULL) ? 0 : insert_id; - result_set->server_status = server_status; - result_set->warning_count = warning_count; - result_set->type = MYSQL_PACKET_OK; - result_set->field_count = 0; - - list_add_tail(&result_set->list, &parser->result_set_list); - parser->current_result_set = result_set; - parser->result_set_count++; - parser->packet_type = MYSQL_PACKET_OK; - - parser->buf = buf; - parser->offset = p - (const unsigned char *)buf; - - if (server_status & MYSQL_SERVER_MORE_RESULTS_EXIST) - { - parser->parse = parse_base_packet; - return 0; - } - - return 1; -} - -// 1:0xfe|2:warnings|2:status_flag -static int parse_eof_packet(const void *buf, size_t len, mysql_parser_t *parser) -{ - const unsigned char *p = (const unsigned char *)buf + parser->offset; - const unsigned char *buf_end = (const unsigned char *)buf + len; - - if (p + 5 > buf_end) - return -2; - - parser->offset += 5; - parser->packet_type = MYSQL_PACKET_EOF; - parser->buf = buf; - - int status_flag = uint2korr(p + 3); - if (status_flag & MYSQL_SERVER_MORE_RESULTS_EXIST) - { - parser->parse = parse_base_packet; - return 0; - } - - return 1; -} - -static int parse_field_eof_packet(const void *buf, size_t len, mysql_parser_t *parser) -{ - const unsigned char *p = (const unsigned char *)buf + parser->offset; - const unsigned char *buf_end = (const unsigned char *)buf + len; - - if (p + 5 > buf_end) - return -2; - - parser->offset += 5; - parser->current_result_set->rows_begin_offset = parser->offset; - parser->parse = parse_row_packet; - return 0; -} - -//raw file data -static int parse_local_inline(const void *buf, size_t len, mysql_parser_t *parser) -{ - parser->local_inline_offset = parser->offset; - parser->local_inline_length = len - parser->offset; - parser->offset = len; - parser->packet_type = MYSQL_PACKET_LOCAL_INLINE; - parser->buf = buf; - return 1; -} - -// for each field: -// NULL as 0xfb, or a length-encoded-string -static int parse_row_packet(const void *buf, size_t len, mysql_parser_t *parser) -{ - const unsigned char *p = (const unsigned char *)buf + parser->offset; - const unsigned char *buf_end = (const unsigned char *)buf + len; - - unsigned long long cell_len; - const unsigned char *cell_data; - - size_t i; - - if (*p == MYSQL_PACKET_HEADER_ERROR) - { - parser->parse = parse_error_packet; - return 0; - } - - if (*p == MYSQL_PACKET_HEADER_EOF) - { - parser->parse = parse_eof_packet; - parser->current_result_set->rows_end_offset = parser->offset; - - return 0; - } - - for (i = 0; i < parser->current_result_set->field_count; i++) - { - if (*p == MYSQL_PACKET_HEADER_NULL) - { - p++; - } else { - if (decode_string(&cell_data, &cell_len, &p, buf_end) == 0) - break; - } - } - - if (i != parser->current_result_set->field_count) - return -2; - - parser->current_result_set->row_count++; - parser->offset = p - (const unsigned char *)buf; - return 0; -} - -static int parse_field_count(const void *buf, size_t len, mysql_parser_t *parser) -{ - const unsigned char *p = (const unsigned char *)buf + parser->offset; - const unsigned char *buf_end = (const unsigned char *)buf + len; - - unsigned long long field_count; - struct __mysql_result_set *result_set; - - if (decode_length_safe(&field_count, &p, buf_end) <= 0) - return -2; - - field_count = (field_count == ~0ULL) ? 0 : field_count; - - if (field_count) - { - result_set = (struct __mysql_result_set *)malloc(sizeof (struct __mysql_result_set)); - if (result_set == NULL) - return -1; - - result_set->fields = (mysql_field_t **)calloc(field_count, sizeof (mysql_field_t *)); - if (result_set->fields == NULL) - { - free(result_set); - return -1; - } - - result_set->field_count = field_count; - result_set->row_count = 0; - result_set->type = MYSQL_PACKET_GET_RESULT; - - list_add_tail(&result_set->list, &parser->result_set_list); - parser->current_result_set = result_set; - parser->current_field_count = 0; - parser->result_set_count++; - parser->packet_type = MYSQL_PACKET_GET_RESULT; - - parser->parse = parse_column_def_packet; - parser->offset = p - (const unsigned char *)buf; - } else { - parser->parse = parse_ok_packet; - } - return 0; -} - -// COLUMN DEFINATION PACKET. for one field: (after protocol 41) -// str:catalog|str:db|str:table|str:org_table|str:name|str:org_name| -// 2:charsetnr|4:length|1:type|2:flags|1:decimals|1:0x00|1:0x00|n:str(if COM_FIELD_LIST) -static int parse_column_def_packet(const void *buf, size_t len, mysql_parser_t *parser) -{ - const unsigned char *p = (const unsigned char *)buf + parser->offset; - const unsigned char *buf_end = (const unsigned char *)buf + len; - - int flag = 0; - const unsigned char *str; - unsigned long long str_len; - mysql_field_t *field = (mysql_field_t *)malloc(sizeof(mysql_field_t)); - - if (!field) - return -1; - - do { - if (decode_string(&str, &str_len, &p, buf_end) == 0) - break; - field->catalog_offset = str - (const unsigned char *)buf; - field->catalog_length = str_len; - - if (decode_string(&str, &str_len, &p, buf_end) == 0) - break; - field->db_offset = str - (const unsigned char *)buf; - field->db_length = str_len; - - if (decode_string(&str, &str_len, &p, buf_end) == 0) - break; - field->table_offset = str - (const unsigned char *)buf; - field->table_length = str_len; - - if (decode_string(&str, &str_len, &p, buf_end) == 0) - break; - field->org_table_offset = str - (const unsigned char *)buf; - field->org_table_length = str_len; - - if (decode_string(&str, &str_len, &p, buf_end) == 0) - break; - field->name_offset = str - (const unsigned char *)buf; - field->name_length = str_len; - - if (decode_string(&str, &str_len, &p, buf_end) == 0) - break; - field->org_name_offset = str - (const unsigned char *)buf; - field->org_name_length = str_len; - - // the rest needs at least 13 - if (p + 13 > buf_end) - break; - - p++; // length of the following fields (always 0x0c) - field->charsetnr = uint2korr(p); - field->length = uint4korr(p + 2); - field->data_type = *(p + 6); - field->flags = uint2korr(p + 7); - field->decimals = (int)p[9]; - p += 12; - // if is COM_FIELD_LIST, the rest is a string - // 0x03 for COM_QUERY - if (parser->cmd == MYSQL_COM_FIELD_LIST) - { - if (decode_string(&str, &str_len, &p, buf_end) == 0) - break; - field->def_offset = str - (const unsigned char *)buf; - field->def_length = str_len; - } else { - field->def_offset = (size_t)-1; - field->def_length = 0; - } - flag = 1; - } while (0); - - if (flag == 0) - { - free(field); - return -2; - } - - //parser->fields.emplace_back(std::move(field)); - parser->current_result_set->fields[parser->current_field_count] = field; - - parser->offset = p - (const unsigned char *)buf; - if (++parser->current_field_count == parser->current_result_set->field_count) - parser->parse = parse_field_eof_packet; - - return 0; -} - + parser->err diff --git a/src/protocol/mysql_stream.c b/src/protocol/mysql_stream.c index b659582..740807c 100644 --- a/src/protocol/mysql_stream.c +++ b/src/protocol/mysql_stream.c @@ -22,77 +22,26 @@ #define MAX(x, y) ((x) >= (y) ? (x) : (y)) +/** + * @brief 写入 MySQL 数据包的负载部分。 + * + * 将数据缓冲区的内容写入到流中的负载部分,支持分块写入。 + * + * @param buf 待写入的数据缓冲区 + * @param n 指向待写入数据长度的指针,在返回时更新为实际写入长度 + * @param stream MySQL 流对象 + * @return int 成功返回 0,失败返回 -1 + */ static int __mysql_stream_write_payload(const void *buf, size_t *n, mysql_stream_t *stream); -static int __mysql_stream_write_head(const void *buf, size_t *n, - mysql_stream_t *stream) -{ - void *p = &stream->head[4 - stream->head_left]; - - if (*n < stream->head_left) - { - memcpy(p, buf, *n); - stream->head_left -= *n; - return 0; - } - - memcpy(p, buf, stream->head_left); - stream->payload_length = (stream->head[2] << 16) + - (stream->head[1] << 8) + - stream->head[0]; - stream->payload_left = stream->payload_length; - stream->sequence_id = stream->head[3]; - if (stream->bufsize < stream->length + stream->payload_left) - { - size_t new_size = MAX(2048, 2 * stream->bufsize); - void *new_base; - - while (new_size < stream->length + stream->payload_left) - new_size *= 2; - - new_base = realloc(stream->buf, new_size); - if (!new_base) - return -1; - - stream->buf = new_base; - stream->bufsize = new_size; - } - - *n = stream->head_left; - stream->write = __mysql_stream_write_payload; - return 0; -} - -static int __mysql_stream_write_payload(const void *buf, size_t *n, - mysql_stream_t *stream) -{ - char *p = (char *)stream->buf + stream->length; - - if (*n < stream->payload_left) - { - memcpy(p, buf, *n); - stream->length += *n; - stream->payload_left -= *n; - return 0; - } - - memcpy(p, buf, stream->payload_left); - stream->length += stream->payload_left; - - *n = stream->payload_left; - stream->head_left = 4; - stream->write = __mysql_stream_write_head; - return stream->payload_length != (1 << 24) - 1; -} - -void mysql_stream_init(mysql_stream_t *stream) -{ - stream->head_left = 4; - stream->sequence_id = 0; - stream->buf = NULL; - stream->length = 0; - stream->bufsize = 0; - stream->write = __mysql_stream_write_head; -} - +/** + * @brief 写入 MySQL 数据包的头部。 + * + * 将数据包的头部写入到流中,并初始化负载部分的写入逻辑。 + * 如果头部完整写入,则切换到负载写入逻辑。 + * + * @param buf 待写入的数据缓冲区 + * @param n 指向待写入数据长度的指针,在返回时更新为实际写入长度 + * @param stream MySQL 流对象 + * @return int 成功返回 0,失败返回 -