yangsiyi #1

Open
p8keuqptb wants to merge 13 commits from yangsiyi into main

File diff suppressed because it is too large Load Diff

@ -0,0 +1,308 @@
/*
Copyright (c) 2021 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Liu Kai (liukaidx@sogou-inc.com)
*/
#include <string>
#include <vector>
#include <atomic>
#include "URIParser.h"
#include "StringUtil.h"
#include "WFDnsClient.h"
using namespace protocol;
using DnsCtx = std::function<void (WFDnsTask *)>;
using ComplexTask = WFComplexClientTask<DnsRequest, DnsResponse, DnsCtx>;
// 定义一个DnsParams类用于存储DNS查询参数
class DnsParams
{
public:
// 定义一个内部结构体dns_params用于存储具体的DNS参数
struct dns_params
{
std::vector<ParsedURI> uris; // 存储解析后的URI
std::vector<std::string> search_list; // 搜索列表
int ndots; // 域名中允许的最小点数
int attempts; // 最大尝试次数
bool rotate; // 是否轮询
};
public:
// DnsParams类的构造函数
DnsParams()
{
this->ref = new std::atomic<size_t>(1); // 初始化引用计数为1
this->params = new dns_params(); // 初始化参数结构体
}
// DnsParams类的拷贝构造函数
DnsParams(const DnsParams& p)
{
this->ref = p.ref; // 拷贝引用计数指针
this->params = p.params; // 拷贝参数结构体指针
this->incref(); // 增加引用计数
}
// DnsParams类的赋值运算符
DnsParams& operator=(const DnsParams& p)
{
if (this != &p) // 如果不是自赋值
{
this->decref(); // 减少当前对象的引用计数
this->ref = p.ref; // 拷贝引用计数指针
this->params = p.params; // 拷贝参数结构体指针
this->incref(); // 增加引用计数
}
return *this; // 返回当前对象的引用
}
// DnsParams类的析构函数
~DnsParams() { this->decref(); } // 减少引用计数
// 获取const类型的参数指针
const dns_params *get_params() const { return this->params; }
// 获取非const类型的参数指针
dns_params *get_params() { return this->params; }
private:
// 增加引用计数
void incref() { (*this->ref)++; }
// 减少引用计数并在计数为0时释放资源
void decref()
{
if (--*this->ref == 0)
{
delete this->params; // 删除参数结构体
delete this->ref; // 删除引用计数
}
}
private:
dns_params *params; // 指向参数结构体的指针
std::atomic<size_t> *ref; // 指向引用计数的指针
};
// 定义DNS状态枚举
enum
{
DNS_STATUS_TRY_ORIGIN_DONE = 0,
DNS_STATUS_TRY_ORIGIN_FIRST = 1,
DNS_STATUS_TRY_ORIGIN_LAST = 2
};
// 定义DnsStatus结构体用于存储DNS查询状态
struct DnsStatus
{
std::string origin_name; // 原始域名
std::string current_name; // 当前域名
size_t next_server; // 下一个要尝试的服务器
size_t last_server; // 上一个尝试的服务器
size_t next_domain; // 下一个要尝试的搜索域
int attempts_left; // 剩余尝试次数
int try_origin_state; // 尝试原始域名的状态
};
// 计算字符串中点的数量
static int __get_ndots(const std::string& s)
{
int ndots = 0;
for (size_t i = 0; i < s.size(); i++)
ndots += s[i] == '.'; // 统计点的数量
return ndots;
}
// 检查是否有下一个域名要尝试
static bool __has_next_name(const DnsParams::dns_params *p,
struct DnsStatus *s)
{
if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_FIRST)
{
s->current_name = s->origin_name; // 设置当前域名为原始域名
s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE; // 更新状态
return true;
}
if (s->next_domain < p->search_list.size()) // 如果还有搜索域要尝试
{
s->current_name = s->origin_name; // 设置当前域名为原始域名
s->current_name.push_back('.'); // 添加点
s->current_name.append(p->search_list[s->next_domain]); // 添加搜索域
s->next_domain++; // 移动到下一个搜索域
return true;
}
if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_LAST)
{
s->current_name = s->origin_name; // 设置当前域名为原始域名
s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE; // 更新状态
return true;
}
return false; // 没有下一个域名要尝试
}
// DNS查询回调内部函数
static void __callback_internal(WFDnsTask *task, const DnsParams& params,
struct DnsStatus& s)
{
ComplexTask *ctask = static_cast<ComplexTask *>(task); // 转换任务类型
int state = task->get_state(); // 获取任务状态
DnsRequest *req = task->get_req(); // 获取DNS请求
DnsResponse *resp = task->get_resp(); // 获取DNS响应
const auto *p = params.get_params(); // 获取DNS参数
int rcode = resp->get_rcode(); // 获取响应码
bool try_next_server = state != WFT_STATE_SUCCESS || // 如果状态不是成功
rcode == DNS_RCODE_SERVER_FAILURE || // 或者响应码是服务器失败
rcode == DNS_RCODE_NOT_IMPLEMENTED || // 或者响应码是未实现
rcode == DNS_RCODE_REFUSED; // 或者响应码是拒绝
bool try_next_name = rcode == DNS_RCODE_FORMAT_ERROR || // 如果响应码是格式错误
rcode == DNS_RCODE_NAME_ERROR || // 或者响应码是名字错误
resp->get_ancount() == 0; // 或者响应中没有答案
if (try_next_server)
{
if (s.last_server == s.next_server) // 如果已经是最后一个服务器
s.attempts_left--; // 减少尝试次数
if (s.attempts_left <= 0) // 如果尝试次数用完
return; // 返回
s.next_server = (s.next_server + 1) % p->uris.size(); // 计算下一个服务器
ctask->set_redirect(p->uris[s.next_server]); // 设置重定向
return; // 返回
}
if (try_next_name && __has_next_name(p, &s)) // 如果需要尝试下一个名字
{
req->set_question_name(s.current_name.c_str()); // 设置查询名字
ctask->set_redirect(p->uris[s.next_server]); // 设置重定向
return; // 返回
}
}
// WFDnsClient类的初始化函数带一个URL参数
int WFDnsClient::init(const std::string& url)
{
return this->init(url, "", 1, 2, false); // 调用另一个初始化函数
}
// WFDnsClient类的初始化函数用于配置DNS客户端
int WFDnsClient::init(const std::string& url, const std::string& search_list,
int ndots, int attempts, bool rotate)
{
std::vector<std::string> hosts; // 用于存储分割后的主机名
std::vector<ParsedURI> uris; // 用于存储解析后的URI
std::string host; // 单个主机名字符串
ParsedURI uri; // 用于存储解析后的单个URI
this->id = 0; // 初始化客户端ID为0
hosts = StringUtil::split_filter_empty(url, ','); // 根据逗号分割URL字符串获取主机名列表
// 遍历主机名列表,对每个主机名进行处理
for (size_t i = 0; i < hosts.size(); i++)
{
host = hosts[i]; // 获取当前主机名
// 检查主机名是否以"dns://"或"dnss://"开头,如果不是,则添加"dns://"前缀
if (strncasecmp(host.c_str(), "dns://", 6) != 0 &&
strncasecmp(host.c_str(), "dnss://", 7) != 0)
{
host = "dns://" + host;
}
// 使用URIParser解析当前主机名如果解析失败则返回错误码-1
if (URIParser::parse(host, uri) != 0)
return -1;
// 将解析后的URI添加到uris列表中
uris.emplace_back(std::move(uri));
}
// 如果uris列表为空或者ndots小于0或者attempts小于1则设置errno为EINVAL并返回错误码-1
if (uris.empty() || ndots < 0 || attempts < 1)
{
errno = EINVAL;
return -1;
}
// 创建一个新的DnsParams对象来存储DNS参数
this->params = new DnsParams;
DnsParams::dns_params *q = ((DnsParams *)this->params)->get_params(); // 获取DNS参数的指针
q->uris = std::move(uris); // 将解析后的URI列表移动到DNS参数中
q->search_list = StringUtil::split_filter_empty(search_list, ','); // 根据逗号分割search_list字符串获取搜索域列表
// 设置ndots的值如果大于15则设置为15
q->ndots = ndots > 15 ? 15 : ndots;
// 设置attempts的值如果大于5则设置为5
q->attempts = attempts > 5 ? 5 : attempts;
q->rotate = rotate; // 设置是否轮询
return 0; // 初始化成功返回0
}
// WFDnsClient类的析构函数用于释放资源
void WFDnsClient::deinit()
{
delete (DnsParams *)this->params; // 删除分配的DnsParams对象
this->params = NULL; // 将params指针设置为NULL
}
// WFDnsClient类的方法用于创建一个新的DNS任务
WFDnsTask *WFDnsClient::create_dns_task(const std::string& name,
dns_callback_t callback)
{
DnsParams::dns_params *p = ((DnsParams *)this->params)->get_params(); // 获取DNS参数
struct DnsStatus status; // 创建DNS状态结构体
size_t next_server; // 下一个要尝试的服务器索引
WFDnsTask *task; // 指向新创建的DNS任务的指针
DnsRequest *req; // 指向DNS请求的指针
// 如果启用轮询,则计算下一个服务器索引,否则使用第一个服务器
next_server = p->rotate ? this->id++ % p->uris.size() : 0;
status.origin_name = name; // 设置原始域名
status.next_domain = 0; // 设置下一个要尝试的搜索域索引
status.attempts_left = p->attempts; // 设置剩余尝试次数
status.try_origin_state = DNS_STATUS_TRY_ORIGIN_FIRST; // 设置尝试原始域名的状态
// 如果域名以点结尾,则跳过搜索域
if (!name.empty() && name.back() == '.')
status.next_domain = p->search_list.size();
// 如果域名中的点数少于ndots则设置尝试原始域名的状态为最后尝试
else if (__get_ndots(name) < p->ndots)
status.try_origin_state = DNS_STATUS_TRY_ORIGIN_LAST;
// 检查是否有下一个域名要尝试,并更新状态
__has_next_name(p, &status);
// 创建一个新的DNS任务使用下一个服务器的URI和提供的回调函数
task = WFTaskFactory::create_dns_task(p->uris[next_server], 0, std::move(callback));
status.next_server = next_server; // 设置当前服务器索引
status.last_server = (next_server + p->uris.size() - 1) % p->uris.size(); // 设置最后一个服务器索引
req = task->get_req(); // 获取DNS请求对象
req->set_question(status.current_name.c_str(), DNS_TYPE_A, DNS_CLASS_IN); // 设置DNS请求的问题部分
req->set_rd(1); // 设置DNS请求的递归标志
ComplexTask *ctask = static_cast<ComplexTask *>(task); // 将任务转换为ComplexTask类型
// 设置任务的上下文回调当DNS任务完成时将调用__callback_internal函数
*ctask->get_mutable_ctx() = std::bind(__callback_internal,
std::placeholders::_1,
*(DnsParams *)params, status);
return task; // 返回新创建的DNS任务
}

@ -0,0 +1 @@
[1217/001811.230:ERROR:registration_protocol_win.cc(108)] CreateFile: 系统找不到指定的文件。 (0x2)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,18 @@
cmake_minimum_required(VERSION 3.6) # CMake 3.6
project(server) # "server"
#
set(SRC
WFServer.cc # WFServer.cc
)
# MySQL "n"
if (NOT MYSQL STREQUAL "n")
set(SRC
${SRC} #
WFMySQLServer.cc # WFMySQLServer.cc
)
endif ()
#
add_library(${PROJECT_NAME} OBJECT ${SRC}) # "server"

@ -0,0 +1,70 @@
/*
Copyright (c) 2021 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Liu Kai (liukaidx@sogou-inc.com)
*/
#ifndef _WFDNSSERVER_H_
#define _WFDNSSERVER_H_
#include "DnsMessage.h" // 引入 DNS 消息相关的头文件
#include "WFServer.h" // 引入服务器功能的头文件
#include "WFTaskFactory.h" // 引入任务工厂以创建 DNS 任务
// 定义一个 dns_process_t 类型,它是一个接受 WFDnsTask 指针并返回 void 的函数对象
using dns_process_t = std::function<void (WFDnsTask *)>;
// 定义 WFDnsServer 类型,基于 WFServer 模板类,使用 DnsRequest 和 DnsResponse 作为协议
using WFDnsServer = WFServer<protocol::DnsRequest, protocol::DnsResponse>;
// 设置 DNS 服务器的默认参数
static constexpr struct WFServerParams DNS_SERVER_PARAMS_DEFAULT =
{
.transport_type = TT_UDP, // 传输类型为 UDP
.max_connections = 2000, // 最大连接数
.peer_response_timeout = 10 * 1000, // 对端响应超时时间
.receive_timeout = -1, // 接收超时时间,-1 表示不超时
.keep_alive_timeout = 300 * 1000, // 保持连接的超时时间
.request_size_limit = (size_t)-1, // 请求大小限制,-1 表示无限制
.ssl_accept_timeout = 5000, // SSL 接受超时时间
};
// WFDnsServer 构造函数的特化模板
template<> inline
WFDnsServer::WFServer(dns_process_t proc) :
WFServerBase(&DNS_SERVER_PARAMS_DEFAULT), // 调用基类构造函数,传入默认参数
process(std::move(proc)) // 移动构造函数以提高性能
{
}
// 创建新的会话,重写基类方法
template<> inline
CommSession *WFDnsServer::new_session(long long seq, CommConnection *conn)
{
WFDnsTask *task;
// 使用任务工厂创建一个新的 DNS 任务
task = WFServerTaskFactory::create_dns_task(this, this->process);
// 设置保持连接的超时时间
task->set_keep_alive(this->params.keep_alive_timeout);
// 设置接收超时时间
task->set_receive_timeout(this->params.receive_timeout);
// 设置请求大小限制
task->get_req()->set_size_limit(this->params.request_size_limit);
return task; // 返回创建的任务
}
#endif // _WFDNSSERVER_H_

@ -0,0 +1,10 @@
-- 定义一个 xmake 目标,名为 server
target("server")
set_kind("object") -- 设置目标类型为对象文件
add_files("*.cc") -- 添加当前目录下所有的 .cc 文件到目标中
-- 检查是否未定义 MySQL 配置
if not has_config("mysql") then
remove_files("WFMySQLServer.cc") -- 如果未定义 MySQL 配置,则移除 MySQL 服务器相关的源文件
end

Binary file not shown.

@ -661,101 +661,103 @@ static int __dns_parser_parse_record(int idx, dns_parser_t *parser)
*/ */
static int __dns_parser_parse_question(dns_parser_t *parser) static int __dns_parser_parse_question(dns_parser_t *parser)
{ {
uint16_t qtype; uint16_t qtype;
uint16_t qclass; uint16_t qclass;
const char *msgend; const char *msgend;
const char **cur; const char **cur;
int ret; int ret;
char host[DNS_NAMES_MAX + 2]; char host[DNS_NAMES_MAX + 2];
msgend = (const char *)parser->msgbuf + parser->msgsize; msgend = (const char *)parser->msgbuf + parser->msgsize;
cur = &(parser->cur); cur = &(parser->cur);
// question count != 1 is an error // question count != 1 is an error
if (parser->header.qdcount != 1) if (parser->header.qdcount != 1)
return -2; return -2;
// parse qname // parse qname
ret = __dns_parser_parse_host(host, parser); ret = __dns_parser_parse_host(host, parser);
if (ret < 0) if (ret < 0)
return ret; return ret;
// parse qtype and qclass // parse qtype and qclass
if (*cur + 4 > msgend) if (*cur + 4 > msgend)
return -2; return -2;
qtype = __dns_parser_uint16(*cur); qtype = __dns_parser_uint16(*cur);
qclass = __dns_parser_uint16(*cur + 2); qclass = __dns_parser_uint16(*cur + 2);
*cur += 4; *cur += 4;
if (parser->question.qname) if (parser->question.qname)
free(parser->question.qname); free(parser->question.qname);
parser->question.qname = strdup(host); parser->question.qname = strdup(host);
if (parser->question.qname == NULL) if (parser->question.qname == NULL)
return -1; return -1;
parser->question.qtype = qtype; parser->question.qtype = qtype;
parser->question.qclass = qclass; parser->question.qclass = qclass;
return 0; return 0;
} }
void dns_parser_init(dns_parser_t *parser) void dns_parser_init(dns_parser_t *parser)
{ {
parser->msgbuf = NULL; // 初始化解析器结构体,包括分配内存、设置指针、初始化计数器和列表头
parser->msgbase = NULL; parser->msgbuf = NULL;
parser->cur = NULL; parser->msgbase = NULL;
parser->msgsize = 0; parser->cur = NULL;
parser->bufsize = 0; parser->msgsize = 0;
parser->complete = 0; parser->bufsize = 0;
parser->single_packet = 0; parser->complete = 0;
memset(&parser->header, 0, sizeof (struct dns_header)); parser->single_packet = 0;
memset(&parser->question, 0, sizeof (struct dns_question)); memset(&parser->header, 0, sizeof (struct dns_header));
INIT_LIST_HEAD(&parser->answer_list); memset(&parser->question, 0, sizeof (struct dns_question));
INIT_LIST_HEAD(&parser->authority_list); INIT_LIST_HEAD(&parser->answer_list);
INIT_LIST_HEAD(&parser->additional_list); INIT_LIST_HEAD(&parser->authority_list);
INIT_LIST_HEAD(&parser->additional_list);
} }
int dns_parser_set_question(const char *name, int dns_parser_set_question(const char *name,
uint16_t qtype, uint16_t qtype,
uint16_t qclass, uint16_t qclass,
dns_parser_t *parser) dns_parser_t *parser)
{ {
int ret; int ret;
ret = dns_parser_set_question_name(name, parser); // 设置DNS查询的问题部分包括域名、查询类型和查询类
if (ret < 0) ret = dns_parser_set_question_name(name, parser);
return ret; if (ret < 0)
return ret;
parser->question.qtype = qtype; parser->question.qtype = qtype;
parser->question.qclass = qclass; parser->question.qclass = qclass;
parser->header.qdcount = 1; parser->header.qdcount = 1;
return 0; return 0;
} }
int dns_parser_set_question_name(const char *name, dns_parser_t *parser) int dns_parser_set_question_name(const char *name, dns_parser_t *parser)
{ {
char *newname; char *newname;
size_t len; size_t len;
len = strlen(name); len = strlen(name);
newname = (char *)malloc(len + 1); newname = (char *)malloc(len + 1);
if (!newname) if (!newname)
return -1; return -1;
memcpy(newname, name, len + 1); memcpy(newname, name, len + 1);
// Remove trailing dot, except name is "." // Remove trailing dot, except name is "."
if (len > 1 && newname[len - 1] == '.') if (len > 1 && newname[len - 1] == '.')
newname[len - 1] = '\0'; newname[len - 1] = '\0';
if (parser->question.qname) if (parser->question.qname)
free(parser->question.qname); free(parser->question.qname);
parser->question.qname = newname; parser->question.qname = newname;
return 0; return 0;
} }
void dns_parser_set_id(uint16_t id, dns_parser_t *parser) void dns_parser_set_id(uint16_t id, dns_parser_t *parser)

@ -1,15 +1,18 @@
cmake_minimum_required(VERSION 3.6) cmake_minimum_required(VERSION 3.6) # CMake 3.6
project(server) project(server) # "server"
#
set(SRC set(SRC
WFServer.cc WFServer.cc # WFServer.cc
) )
# MySQL "n"
if (NOT MYSQL STREQUAL "n") if (NOT MYSQL STREQUAL "n")
set(SRC set(SRC
${SRC} ${SRC} #
WFMySQLServer.cc WFMySQLServer.cc # WFMySQLServer.cc
) )
endif () endif ()
add_library(${PROJECT_NAME} OBJECT ${SRC}) #
add_library(${PROJECT_NAME} OBJECT ${SRC}) # "server"

@ -19,44 +19,52 @@
#ifndef _WFDNSSERVER_H_ #ifndef _WFDNSSERVER_H_
#define _WFDNSSERVER_H_ #define _WFDNSSERVER_H_
#include "DnsMessage.h" #include "DnsMessage.h" // 引入 DNS 消息相关的头文件
#include "WFServer.h" #include "WFServer.h" // 引入服务器功能的头文件
#include "WFTaskFactory.h" #include "WFTaskFactory.h" // 引入任务工厂以创建 DNS 任务
// 定义一个 dns_process_t 类型,它是一个接受 WFDnsTask 指针并返回 void 的函数对象
using dns_process_t = std::function<void (WFDnsTask *)>; using dns_process_t = std::function<void (WFDnsTask *)>;
using WFDnsServer = WFServer<protocol::DnsRequest,
protocol::DnsResponse>;
// 定义 WFDnsServer 类型,基于 WFServer 模板类,使用 DnsRequest 和 DnsResponse 作为协议
using WFDnsServer = WFServer<protocol::DnsRequest, protocol::DnsResponse>;
// 设置 DNS 服务器的默认参数
static constexpr struct WFServerParams DNS_SERVER_PARAMS_DEFAULT = static constexpr struct WFServerParams DNS_SERVER_PARAMS_DEFAULT =
{ {
.transport_type = TT_UDP, .transport_type = TT_UDP, // 传输类型为 UDP
.max_connections = 2000, .max_connections = 2000, // 最大连接数
.peer_response_timeout = 10 * 1000, .peer_response_timeout = 10 * 1000, // 对端响应超时时间
.receive_timeout = -1, .receive_timeout = -1, // 接收超时时间,-1 表示不超时
.keep_alive_timeout = 300 * 1000, .keep_alive_timeout = 300 * 1000, // 保持连接的超时时间
.request_size_limit = (size_t)-1, .request_size_limit = (size_t)-1, // 请求大小限制,-1 表示无限制
.ssl_accept_timeout = 5000, .ssl_accept_timeout = 5000, // SSL 接受超时时间
}; };
// WFDnsServer 构造函数的特化模板
template<> inline template<> inline
WFDnsServer::WFServer(dns_process_t proc) : WFDnsServer::WFServer(dns_process_t proc) :
WFServerBase(&DNS_SERVER_PARAMS_DEFAULT), WFServerBase(&DNS_SERVER_PARAMS_DEFAULT), // 调用基类构造函数,传入默认参数
process(std::move(proc)) process(std::move(proc)) // 移动构造函数以提高性能
{ {
} }
// 创建新的会话,重写基类方法
template<> inline template<> inline
CommSession *WFDnsServer::new_session(long long seq, CommConnection *conn) CommSession *WFDnsServer::new_session(long long seq, CommConnection *conn)
{ {
WFDnsTask *task; WFDnsTask *task;
// 使用任务工厂创建一个新的 DNS 任务
task = WFServerTaskFactory::create_dns_task(this, this->process); task = WFServerTaskFactory::create_dns_task(this, this->process);
// 设置保持连接的超时时间
task->set_keep_alive(this->params.keep_alive_timeout); task->set_keep_alive(this->params.keep_alive_timeout);
// 设置接收超时时间
task->set_receive_timeout(this->params.receive_timeout); task->set_receive_timeout(this->params.receive_timeout);
// 设置请求大小限制
task->get_req()->set_size_limit(this->params.request_size_limit); task->get_req()->set_size_limit(this->params.request_size_limit);
return task; return task; // 返回创建的任务
} }
#endif #endif // _WFDNSSERVER_H_

@ -19,45 +19,53 @@
#ifndef _WFHTTPSERVER_H_ #ifndef _WFHTTPSERVER_H_
#define _WFHTTPSERVER_H_ #define _WFHTTPSERVER_H_
#include <utility> #include <utility> // 引入标准库中的工具函数
#include "HttpMessage.h" #include "HttpMessage.h" // 引入 HTTP 消息相关的头文件
#include "WFServer.h" #include "WFServer.h" // 引入服务器功能的头文件
#include "WFTaskFactory.h" #include "WFTaskFactory.h" // 引入任务工厂以创建 HTTP 任务
// 定义一个 http_process_t 类型,它是一个接受 WFHttpTask 指针并返回 void 的函数对象
using http_process_t = std::function<void (WFHttpTask *)>; using http_process_t = std::function<void (WFHttpTask *)>;
using WFHttpServer = WFServer<protocol::HttpRequest,
protocol::HttpResponse>;
// 定义 WFHttpServer 类型,基于 WFServer 模板类,使用 HttpRequest 和 HttpResponse 作为协议
using WFHttpServer = WFServer<protocol::HttpRequest, protocol::HttpResponse>;
// 设置 HTTP 服务器的默认参数
static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT = static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT =
{ {
.transport_type = TT_TCP, .transport_type = TT_TCP, // 传输类型为 TCP
.max_connections = 2000, .max_connections = 2000, // 最大连接数
.peer_response_timeout = 10 * 1000, .peer_response_timeout = 10 * 1000, // 对端响应超时时间
.receive_timeout = -1, .receive_timeout = -1, // 接收超时时间,-1 表示不超时
.keep_alive_timeout = 60 * 1000, .keep_alive_timeout = 60 * 1000, // 保持连接的超时时间
.request_size_limit = (size_t)-1, .request_size_limit = (size_t)-1, // 请求大小限制,-1 表示无限制
.ssl_accept_timeout = 10 * 1000, .ssl_accept_timeout = 10 * 1000, // SSL 接受超时时间
}; };
// WFHttpServer 构造函数的特化模板
template<> inline template<> inline
WFHttpServer::WFServer(http_process_t proc) : WFHttpServer::WFServer(http_process_t proc) :
WFServerBase(&HTTP_SERVER_PARAMS_DEFAULT), WFServerBase(&HTTP_SERVER_PARAMS_DEFAULT), // 调用基类构造函数,传入默认参数
process(std::move(proc)) process(std::move(proc)) // 移动构造函数以提高性能
{ {
} }
// 创建新的会话,重写基类的方法
template<> inline template<> inline
CommSession *WFHttpServer::new_session(long long seq, CommConnection *conn) CommSession *WFHttpServer::new_session(long long seq, CommConnection *conn)
{ {
WFHttpTask *task; WFHttpTask *task; // 声明一个 HTTP 任务指针
// 使用任务工厂创建一个新的 HTTP 任务
task = WFServerTaskFactory::create_http_task(this, this->process); task = WFServerTaskFactory::create_http_task(this, this->process);
// 设置保持连接的超时时间
task->set_keep_alive(this->params.keep_alive_timeout); task->set_keep_alive(this->params.keep_alive_timeout);
// 设置接收超时时间
task->set_receive_timeout(this->params.receive_timeout); task->set_receive_timeout(this->params.receive_timeout);
// 设置请求大小限制
task->get_req()->set_size_limit(this->params.request_size_limit); task->get_req()->set_size_limit(this->params.request_size_limit);
return task; return task; // 返回创建的任务
} }
#endif #endif // _WFHTTPSERVER_H_

@ -16,45 +16,54 @@
Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Authors: Wu Jiaxu (wujiaxu@sogou-inc.com)
*/ */
#include <sys/uio.h> #include <sys/uio.h> // 引入系统I/O头文件用于向客户端发送数据
#include "WFMySQLServer.h" #include "WFMySQLServer.h" // 引入 MySQL 服务器的头文件
// 创建新的连接
WFConnection *WFMySQLServer::new_connection(int accept_fd) WFConnection *WFMySQLServer::new_connection(int accept_fd)
{ {
// 调用基类的 new_connection 方法创建连接
WFConnection *conn = this->WFServer::new_connection(accept_fd); WFConnection *conn = this->WFServer::new_connection(accept_fd);
if (conn) if (conn) // 如果连接成功
{ {
protocol::MySQLHandshakeResponse resp; protocol::MySQLHandshakeResponse resp; // 创建 MySQL 握手响应对象
struct iovec vec[8]; struct iovec vec[8]; // 定义 I/O 向量以便于发送数据
int count; int count; // 记录向量中有效数据的数量
// 设置服务器的握手响应参数
resp.server_set(0x0a, "5.5", 1, (const uint8_t *)"12345678901234567890", resp.server_set(0x0a, "5.5", 1, (const uint8_t *)"12345678901234567890",
0, 33, 0); 0, 33, 0);
// 编码握手响应信息到 I/O 向量中
count = resp.encode(vec, 8); count = resp.encode(vec, 8);
if (count >= 0) if (count >= 0) // 如果编码成功
{ {
// 使用 writev 函数发送握手响应数据到客户端
if (writev(accept_fd, vec, count) >= 0) if (writev(accept_fd, vec, count) >= 0)
return conn; return conn; // 返回创建的连接
} }
// 如果发送失败,删除该连接
this->delete_connection(conn); this->delete_connection(conn);
} }
return NULL; return NULL; // 如果连接失败,则返回 NULL
} }
// 创建新的会话,重写基类的方法
CommSession *WFMySQLServer::new_session(long long seq, CommConnection *conn) CommSession *WFMySQLServer::new_session(long long seq, CommConnection *conn)
{ {
static mysql_process_t empty = [](WFMySQLTask *){ }; static mysql_process_t empty = [](WFMySQLTask *){ }; // 定义一个空的 MySQL 处理函数
WFMySQLTask *task; WFMySQLTask *task; // 声明一个 MySQL 任务指针
task = WFServerTaskFactory::create_mysql_task(this, seq ? this->process : // 使用任务工厂创建一个新的 MySQL 任务
empty); task = WFServerTaskFactory::create_mysql_task(this, seq ? this->process : empty);
// 设置保持连接的超时时间
task->set_keep_alive(this->params.keep_alive_timeout); task->set_keep_alive(this->params.keep_alive_timeout);
// 设置接收超时时间
task->set_receive_timeout(this->params.receive_timeout); task->set_receive_timeout(this->params.receive_timeout);
// 设置请求大小限制
task->get_req()->set_size_limit(this->params.request_size_limit); task->get_req()->set_size_limit(this->params.request_size_limit);
return task; return task; // 返回创建的任务
} }

@ -19,39 +19,43 @@
#ifndef _WFMYSQLSERVER_H_ #ifndef _WFMYSQLSERVER_H_
#define _WFMYSQLSERVER_H_ #define _WFMYSQLSERVER_H_
#include <utility> #include <utility> // 引入标准库的实用工具
#include "MySQLMessage.h" #include "MySQLMessage.h" // 引入 MySQL 消息相关的头文件
#include "WFServer.h" #include "WFServer.h" // 引入服务器功能的头文件
#include "WFTaskFactory.h" #include "WFTaskFactory.h" // 引入任务工厂以创建 MySQL 任务
#include "WFConnection.h" #include "WFConnection.h" // 引入连接类的头文件
// 定义一个 mysql_process_t 类型,这是一个接受 WFMySQLTask 指针并返回 void 的函数对象类型
using mysql_process_t = std::function<void (WFMySQLTask *)>; using mysql_process_t = std::function<void (WFMySQLTask *)>;
class MySQLServer;
static constexpr struct WFServerParams MYSQL_SERVER_PARAMS_DEFAULT = // 定义 WFMySQLServer 类,继承自 WFServer 模板类
{ class WFMySQLServer : public WFServer<protocol::MySQLRequest, protocol::MySQLResponse>
.transport_type = TT_TCP,
.max_connections = 2000,
.peer_response_timeout = 10 * 1000,
.receive_timeout = -1,
.keep_alive_timeout = 28800 * 1000,
.request_size_limit = (size_t)-1,
.ssl_accept_timeout = 10 * 1000,
};
class WFMySQLServer : public WFServer<protocol::MySQLRequest,
protocol::MySQLResponse>
{ {
public: public:
WFMySQLServer(mysql_process_t proc): // 构造函数,接收一个 MySQL 处理函数
WFServer(&MYSQL_SERVER_PARAMS_DEFAULT, std::move(proc)) WFMySQLServer(mysql_process_t proc) :
{ WFServer(&MYSQL_SERVER_PARAMS_DEFAULT, std::move(proc)) // 调用基类构造函数,使用默认参数
} {
}
protected: protected:
virtual WFConnection *new_connection(int accept_fd); // 重写 new_connection 方法,创建新的连接
virtual CommSession *new_session(long long seq, CommConnection *conn); virtual WFConnection *new_connection(int accept_fd);
// 重写 new_session 方法,创建新的会话
virtual CommSession *new_session(long long seq, CommConnection *conn);
}; };
#endif // 设置 MySQL 服务器的默认参数
static constexpr struct WFServerParams MYSQL_SERVER_PARAMS_DEFAULT =
{
.transport_type = TT_TCP, // 设置传输类型为 TCP
.max_connections = 2000, // 设置最大连接数
.peer_response_timeout = 10 * 1000, // 设置对端响应超时时间
.receive_timeout = -1, // 设置接收超时时间,-1 表示不超时
.keep_alive_timeout = 28800 * 1000, // 设置保持连接的超时时间
.request_size_limit = (size_t)-1, // 设置请求大小限制,-1 表示无限制
.ssl_accept_timeout = 10 * 1000, // 设置 SSL 接受超时时间
};
#endif // _WFMYSQLSERVER_H_

@ -19,31 +19,35 @@
#ifndef _WFREDISSERVER_H_ #ifndef _WFREDISSERVER_H_
#define _WFREDISSERVER_H_ #define _WFREDISSERVER_H_
#include "RedisMessage.h" #include "RedisMessage.h" // 引入 Redis 消息相关的头文件
#include "WFServer.h" #include "WFServer.h" // 引入服务器功能的头文件
#include "WFTaskFactory.h" #include "WFTaskFactory.h" // 引入任务工厂以创建 Redis 任务
// 定义一个 redis_process_t 类型,这是一个接受 WFRedisTask 指针并返回 void 的函数对象类型
using redis_process_t = std::function<void (WFRedisTask *)>; using redis_process_t = std::function<void (WFRedisTask *)>;
using WFRedisServer = WFServer<protocol::RedisRequest,
protocol::RedisResponse>;
// 定义 WFRedisServer 类型,基于 WFServer 模板类,使用 RedisRequest 和 RedisResponse 作为协议
using WFRedisServer = WFServer<protocol::RedisRequest, protocol::RedisResponse>;
// 设置 Redis 服务器的默认参数
static constexpr struct WFServerParams REDIS_SERVER_PARAMS_DEFAULT = static constexpr struct WFServerParams REDIS_SERVER_PARAMS_DEFAULT =
{ {
.transport_type = TT_TCP, .transport_type = TT_TCP, // 传输类型设置为 TCP
.max_connections = 2000, .max_connections = 2000, // 设置最大连接数为 2000
.peer_response_timeout = 10 * 1000, .peer_response_timeout = 10 * 1000, // 对端响应超时时间为 10 秒
.receive_timeout = -1, .receive_timeout = -1, // 接收超时时间,-1 表示不超时
.keep_alive_timeout = 300 * 1000, .keep_alive_timeout = 300 * 1000, // 保持连接的超时时间为 300 秒
.request_size_limit = (size_t)-1, .request_size_limit = (size_t)-1, // 请求大小限制,-1 表示无限制
.ssl_accept_timeout = 5000, .ssl_accept_timeout = 5000, // SSL 接受超时时间为 5 秒
}; };
// WFRedisServer 构造函数的特化模板
template<> inline template<> inline
WFRedisServer::WFServer(redis_process_t proc) : WFRedisServer::WFServer(redis_process_t proc) :
WFServerBase(&REDIS_SERVER_PARAMS_DEFAULT), WFServerBase(&REDIS_SERVER_PARAMS_DEFAULT), // 调用基类构造函数,传入默认参数
process(std::move(proc)) process(std::move(proc)) // 采用移动语义以避免不必要的拷贝
{ {
} }
#endif // 生成新会话的方法,重写基类的方法
#endif // _WFREDISSERVER_H_

@ -17,272 +17,300 @@
Wu Jiaxu (wujiaxu@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com)
*/ */
#include <sys/types.h> #include <sys/types.h> // 引入数据类型定义
#include <sys/socket.h> #include <sys/socket.h> // 引入套接字相关的函数
#include <errno.h> #include <errno.h> // 引入错误码定义
#include <unistd.h> #include <unistd.h> // 引入 UNIX 标准函数
#include <stdio.h> #include <stdio.h> // 引入标准输入输出库
#include <atomic> #include <atomic> // 引入原子操作库
#include <mutex> #include <mutex> // 引入互斥锁库
#include <condition_variable> #include <condition_variable> // 引入条件变量库
#include <openssl/ssl.h> #include <openssl/ssl.h> // 引入 OpenSSL 库以支持 SSL
#include "CommScheduler.h" #include "CommScheduler.h" // 引入调度器相关的头文件
#include "EndpointParams.h" #include "EndpointParams.h" // 引入端点参数定义
#include "WFConnection.h" #include "WFConnection.h" // 引入连接类的定义
#include "WFGlobal.h" #include "WFGlobal.h" // 引入全局配置相关的定义
#include "WFServer.h" #include "WFServer.h" // 引入服务器功能的头文件
#define PORT_STR_MAX 5 #define PORT_STR_MAX 5 // 定义最大端口字符串长度
// 定义 WFServerConnection 类,继承自 WFConnection
class WFServerConnection : public WFConnection class WFServerConnection : public WFConnection
{ {
public: public:
// 构造函数,接受连接计数的原子指针
WFServerConnection(std::atomic<size_t> *conn_count) WFServerConnection(std::atomic<size_t> *conn_count)
{ {
this->conn_count = conn_count; this->conn_count = conn_count; // 初始化连接计数指针
} }
// 析构函数
virtual ~WFServerConnection() virtual ~WFServerConnection()
{ {
(*this->conn_count)--; (*this->conn_count)--; // 每当一个连接被销毁,减少连接计数
} }
private: private:
std::atomic<size_t> *conn_count; std::atomic<size_t> *conn_count; // 存储连接计数的原子变量指针
}; };
// SSL 上下文回调函数,用于处理 SSL 连接
int WFServerBase::ssl_ctx_callback(SSL *ssl, int *al, void *arg) int WFServerBase::ssl_ctx_callback(SSL *ssl, int *al, void *arg)
{ {
WFServerBase *server = (WFServerBase *)arg; WFServerBase *server = (WFServerBase *)arg; // 将参数转换为 WFServerBase 类型
const char *servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); const char *servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); // 获取服务器名称
SSL_CTX *ssl_ctx = server->get_server_ssl_ctx(servername); SSL_CTX *ssl_ctx = server->get_server_ssl_ctx(servername); // 根据服务器名称获取 SSL 上下文
if (!ssl_ctx) if (!ssl_ctx) // 如果未找到 SSL 上下文
return SSL_TLSEXT_ERR_NOACK; return SSL_TLSEXT_ERR_NOACK; // 返回错误
// 如果找到的 SSL 上下文与当前不一致,则设置为找到的上下文
if (ssl_ctx != server->get_ssl_ctx()) if (ssl_ctx != server->get_ssl_ctx())
SSL_set_SSL_CTX(ssl, ssl_ctx); SSL_set_SSL_CTX(ssl, ssl_ctx);
return SSL_TLSEXT_ERR_OK; return SSL_TLSEXT_ERR_OK; // 返回成功
} }
// 创建新的 SSL 上下文
SSL_CTX *WFServerBase::new_ssl_ctx(const char *cert_file, const char *key_file) SSL_CTX *WFServerBase::new_ssl_ctx(const char *cert_file, const char *key_file)
{ {
SSL_CTX *ssl_ctx = WFGlobal::new_ssl_server_ctx(); SSL_CTX *ssl_ctx = WFGlobal::new_ssl_server_ctx(); // 创建新的 SSL 服务器上下文
if (!ssl_ctx) if (!ssl_ctx) // 如果创建失败
return NULL; return NULL; // 返回 NULL
// 加载证书链、私钥,并检查私钥有效性
if (SSL_CTX_use_certificate_chain_file(ssl_ctx, cert_file) > 0 && if (SSL_CTX_use_certificate_chain_file(ssl_ctx, cert_file) > 0 &&
SSL_CTX_use_PrivateKey_file(ssl_ctx, key_file, SSL_FILETYPE_PEM) > 0 && SSL_CTX_use_PrivateKey_file(ssl_ctx, key_file, SSL_FILETYPE_PEM) > 0 &&
SSL_CTX_check_private_key(ssl_ctx) > 0 && SSL_CTX_check_private_key(ssl_ctx) > 0 &&
SSL_CTX_set_tlsext_servername_callback(ssl_ctx, ssl_ctx_callback) > 0 && SSL_CTX_set_tlsext_servername_callback(ssl_ctx, ssl_ctx_callback) > 0 &&
SSL_CTX_set_tlsext_servername_arg(ssl_ctx, this) > 0) SSL_CTX_set_tlsext_servername_arg(ssl_ctx, this) > 0)
{ {
return ssl_ctx; return ssl_ctx; // 返回创建的 SSL 上下文
} }
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx); // 释放 SSL 上下文
return NULL; return NULL; // 返回 NULL
} }
// 初始化服务器
int WFServerBase::init(const struct sockaddr *bind_addr, socklen_t addrlen, int WFServerBase::init(const struct sockaddr *bind_addr, socklen_t addrlen,
const char *cert_file, const char *key_file) const char *cert_file, const char *key_file)
{ {
int timeout = this->params.peer_response_timeout; int timeout = this->params.peer_response_timeout; // 将超时时间初始化为对端响应超时时间
if (this->params.receive_timeout >= 0) if (this->params.receive_timeout >= 0) // 如果接收超时时间有效
{ {
if ((unsigned int)timeout > (unsigned int)this->params.receive_timeout) if ((unsigned int)timeout > (unsigned int)this->params.receive_timeout)
timeout = this->params.receive_timeout; timeout = this->params.receive_timeout; // 将超时时间更新为接收超时时间
} }
// 检查是否需要 SSL/TLS 连接
if (this->params.transport_type == TT_TCP_SSL || if (this->params.transport_type == TT_TCP_SSL ||
this->params.transport_type == TT_SCTP_SSL) this->params.transport_type == TT_SCTP_SSL)
{ {
if (!cert_file || !key_file) if (!cert_file || !key_file) // 如果没有提供证书或密钥文件
{ {
errno = EINVAL; errno = EINVAL; // 设置错误号为无效参数
return -1; return -1; // 返回错误
} }
} }
// 初始化基本通信服务
if (this->CommService::init(bind_addr, addrlen, -1, timeout) < 0) if (this->CommService::init(bind_addr, addrlen, -1, timeout) < 0)
return -1; return -1; // 初始化失败,返回错误
// 如果提供了证书和密钥,设置 SSL 上下文
if (cert_file && key_file && this->params.transport_type != TT_UDP) if (cert_file && key_file && this->params.transport_type != TT_UDP)
{ {
SSL_CTX *ssl_ctx = this->new_ssl_ctx(cert_file, key_file); SSL_CTX *ssl_ctx = this->new_ssl_ctx(cert_file, key_file); // 创建新的 SSL 上下文
if (!ssl_ctx) if (!ssl_ctx) // 如果创建失败
{ {
this->deinit(); this->deinit(); // 反初始化
return -1; return -1; // 返回错误
} }
this->set_ssl(ssl_ctx, this->params.ssl_accept_timeout); this->set_ssl(ssl_ctx, this->params.ssl_accept_timeout); // 设置 SSL 上下文
} }
this->scheduler = WFGlobal::get_scheduler(); this->scheduler = WFGlobal::get_scheduler(); // 获取全局调度器
return 0; return 0; // 返回成功
} }
// 创建监听文件描述符
int WFServerBase::create_listen_fd() int WFServerBase::create_listen_fd()
{ {
if (this->listen_fd < 0) if (this->listen_fd < 0) // 如果尚未创建监听文件描述符
{ {
const struct sockaddr *bind_addr; const struct sockaddr *bind_addr; // 定义绑定地址
socklen_t addrlen; socklen_t addrlen; // 定义地址长度
int type, protocol; int type, protocol; // 定义套接字类型和协议
int reuse = 1; int reuse = 1; // 允许地址重用
// 根据传输类型设置套接字类型和协议
switch (this->params.transport_type) switch (this->params.transport_type)
{ {
case TT_TCP: case TT_TCP:
case TT_TCP_SSL: case TT_TCP_SSL:
type = SOCK_STREAM; type = SOCK_STREAM; // TCP 使用流式套接字
protocol = 0; protocol = 0; // 默认协议
break; break;
case TT_UDP: case TT_UDP:
type = SOCK_DGRAM; type = SOCK_DGRAM; // UDP 使用数据报套接字
protocol = 0; protocol = 0; // 默认协议
break; break;
#ifdef IPPROTO_SCTP #ifdef IPPROTO_SCTP
case TT_SCTP: case TT_SCTP:
case TT_SCTP_SSL: case TT_SCTP_SSL:
type = SOCK_STREAM; type = SOCK_STREAM; // SCTP 使用流式套接字
protocol = IPPROTO_SCTP; protocol = IPPROTO_SCTP; // SCTP 协议
break; break;
#endif #endif
default: default:
errno = EPROTONOSUPPORT; errno = EPROTONOSUPPORT; // 不支持的协议类型
return -1; return -1; // 返回错误
} }
this->get_addr(&bind_addr, &addrlen); this->get_addr(&bind_addr, &addrlen); // 获取绑定地址和长度
this->listen_fd = socket(bind_addr->sa_family, type, protocol); this->listen_fd = socket(bind_addr->sa_family, type, protocol); // 创建套接字
if (this->listen_fd >= 0) if (this->listen_fd >= 0) // 如果套接字创建成功
{ {
setsockopt(this->listen_fd, SOL_SOCKET, SO_REUSEADDR, setsockopt(this->listen_fd, SOL_SOCKET, SO_REUSEADDR, // 设置选项以重用地址
&reuse, sizeof (int)); &reuse, sizeof (int));
} }
} }
else else
this->listen_fd = dup(this->listen_fd); this->listen_fd = dup(this->listen_fd); // 复制已存在的监听文件描述符
return this->listen_fd; return this->listen_fd; // 返回监听文件描述符
} }
// 创建新的连接
WFConnection *WFServerBase::new_connection(int accept_fd) WFConnection *WFServerBase::new_connection(int accept_fd)
{ {
// 检查当前连接数是否在最大连接数限制内
if (++this->conn_count <= this->params.max_connections || if (++this->conn_count <= this->params.max_connections ||
this->drain(1) == 1) this->drain(1) == 1) // 如果连接数未超过最大值
{ {
int reuse = 1; int reuse = 1; // 允许地址重用
setsockopt(accept_fd, SOL_SOCKET, SO_REUSEADDR, setsockopt(accept_fd, SOL_SOCKET, SO_REUSEADDR, // 设置选项以重用地址
&reuse, sizeof (int)); &reuse, sizeof (int));
return new WFServerConnection(&this->conn_count); return new WFServerConnection(&this->conn_count); // 创建新的连接对象
} }
this->conn_count--; this->conn_count--; // 如果连接数超过最大值,减少连接计数
errno = EMFILE; errno = EMFILE; // 设置错误号为文件描述符过多
return NULL; return NULL; // 返回 NULL
} }
// 删除连接
void WFServerBase::delete_connection(WFConnection *conn) void WFServerBase::delete_connection(WFConnection *conn)
{ {
delete (WFServerConnection *)conn; delete (WFServerConnection *)conn; // 删除连接对象
} }
// 处理未绑定的情况
void WFServerBase::handle_unbound() void WFServerBase::handle_unbound()
{ {
this->mutex.lock(); this->mutex.lock(); // 加锁
this->unbind_finish = true; this->unbind_finish = true; // 设置解除绑定状态
this->cond.notify_one(); this->cond.notify_one(); // 通知等待的线程
this->mutex.unlock(); this->mutex.unlock(); // 解锁
} }
// 启动服务器
int WFServerBase::start(const struct sockaddr *bind_addr, socklen_t addrlen, int WFServerBase::start(const struct sockaddr *bind_addr, socklen_t addrlen,
const char *cert_file, const char *key_file) const char *cert_file, const char *key_file)
{ {
SSL_CTX *ssl_ctx; SSL_CTX *ssl_ctx; // 声明 SSL 上下文指针
// 初始化服务器
if (this->init(bind_addr, addrlen, cert_file, key_file) >= 0) if (this->init(bind_addr, addrlen, cert_file, key_file) >= 0)
{ {
// 如果调度器绑定成功,返回 0
if (this->scheduler->bind(this) >= 0) if (this->scheduler->bind(this) >= 0)
return 0; return 0;
ssl_ctx = this->get_ssl_ctx(); ssl_ctx = this->get_ssl_ctx(); // 获取当前 SSL 上下文
this->deinit(); this->deinit(); // 反初始化
if (ssl_ctx) if (ssl_ctx) // 如果 SSL 上下文存在
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx); // 释放 SSL 上下文
} }
this->listen_fd = -1; this->listen_fd = -1; // 设置监听文件描述符为无效
return -1; return -1; // 返回错误
} }
// 启动服务器,使用主机和端口号
int WFServerBase::start(int family, const char *host, unsigned short port, int WFServerBase::start(int family, const char *host, unsigned short port,
const char *cert_file, const char *key_file) const char *cert_file, const char *key_file)
{ {
struct addrinfo hints = { struct addrinfo hints = { // 设置地址信息结构
.ai_flags = AI_PASSIVE, .ai_flags = AI_PASSIVE, // 使套接字可用于绑定
.ai_family = family, .ai_family = family, // 设置地址族
.ai_socktype = SOCK_STREAM, .ai_socktype = SOCK_STREAM, // 设置为流套接字
}; };
struct addrinfo *addrinfo; struct addrinfo *addrinfo; // 用于存储返回的地址信息
char port_str[PORT_STR_MAX + 1]; char port_str[PORT_STR_MAX + 1]; // 存储端口字符串
int ret; int ret; // 存储返回值
// 将端口号转换为字符串
snprintf(port_str, PORT_STR_MAX + 1, "%d", port); snprintf(port_str, PORT_STR_MAX + 1, "%d", port);
ret = getaddrinfo(host, port_str, &hints, &addrinfo); ret = getaddrinfo(host, port_str, &hints, &addrinfo); // 获取地址信息
if (ret == 0) if (ret == 0) // 如果成功获取地址信息
{ {
// 启动服务器
ret = start(addrinfo->ai_addr, (socklen_t)addrinfo->ai_addrlen, ret = start(addrinfo->ai_addr, (socklen_t)addrinfo->ai_addrlen,
cert_file, key_file); cert_file, key_file);
freeaddrinfo(addrinfo); freeaddrinfo(addrinfo); // 释放地址信息
} }
else else // 获取地址信息失败
{ {
if (ret != EAI_SYSTEM) if (ret != EAI_SYSTEM) // 如果不是系统错误
errno = EINVAL; errno = EINVAL; // 设置错误号为无效参数
ret = -1; ret = -1; // 返回错误
} }
return ret; return ret; // 返回结果
} }
// 启动服务器,使用已存在的监听文件描述符
int WFServerBase::serve(int listen_fd, int WFServerBase::serve(int listen_fd,
const char *cert_file, const char *key_file) const char *cert_file, const char *key_file)
{ {
struct sockaddr_storage ss; struct sockaddr_storage ss; // 存储 socket 地址
socklen_t len = sizeof ss; socklen_t len = sizeof ss; // 定义地址长度
// 获取当前 socket 名称
if (getsockname(listen_fd, (struct sockaddr *)&ss, &len) < 0) if (getsockname(listen_fd, (struct sockaddr *)&ss, &len) < 0)
return -1; return -1; // 返回错误
this->listen_fd = listen_fd; this->listen_fd = listen_fd; // 设置监听文件描述符
return start((struct sockaddr *)&ss, len, cert_file, key_file); return start((struct sockaddr *)&ss, len, cert_file, key_file); // 启动服务器
} }
// 关闭服务器
void WFServerBase::shutdown() void WFServerBase::shutdown()
{ {
this->listen_fd = -1; this->listen_fd = -1; // 设置监听文件描述符为无效
this->scheduler->unbind(this); this->scheduler->unbind(this); // 解除调度器的绑定
} }
// 等待服务器完成
void WFServerBase::wait_finish() void WFServerBase::wait_finish()
{ {
SSL_CTX *ssl_ctx = this->get_ssl_ctx(); SSL_CTX *ssl_ctx = this->get_ssl_ctx(); // 获取当前 SSL 上下文
std::unique_lock<std::mutex> lock(this->mutex); std::unique_lock<std::mutex> lock(this->mutex); // 使用唯一锁管理 mutex
// 等待解除绑定完成
while (!this->unbind_finish) while (!this->unbind_finish)
this->cond.wait(lock); this->cond.wait(lock); // 条件等待
this->deinit(); this->deinit(); // 反初始化
this->unbind_finish = false; this->unbind_finish = false; // 重置解除绑定状态
lock.unlock(); lock.unlock(); // 解锁
if (ssl_ctx)
SSL_CTX_free(ssl_ctx);
}
if (ssl_ctx) // 如果 SSL 上下文存在
SSL_CTX_free(ssl_ctx); // 释放 SSL 上下文
}

@ -17,113 +17,117 @@
Wu Jiaxu (wujiaxu@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com)
*/ */
#ifndef _WFSERVER_H_ #ifndef _WFSERVER_H_ // 防止头文件被多次包含
#define _WFSERVER_H_ #define _WFSERVER_H_
#include <sys/types.h> #include <sys/types.h> // 引入基本数据类型定义
#include <sys/socket.h> #include <sys/socket.h> // 引入 socket 相关的函数
#include <errno.h> #include <errno.h> // 引入错误码定义
#include <functional> #include <functional> // 引入函数对象
#include <atomic> #include <atomic> // 引入原子操作的支持
#include <mutex> #include <mutex> // 引入互斥量支持
#include <condition_variable> #include <condition_variable> // 引入条件变量支持
#include <openssl/ssl.h> #include <openssl/ssl.h> // 引入 OpenSSL 以支持 SSL/TLS
#include "EndpointParams.h" #include "EndpointParams.h" // 引入端点参数定义
#include "WFTaskFactory.h" #include "WFTaskFactory.h" // 引入任务工厂以创建任务
// 定义服务器参数结构体
struct WFServerParams struct WFServerParams
{ {
enum TransportType transport_type; enum TransportType transport_type; // 传输类型TCP, UDP, SSL等
size_t max_connections; size_t max_connections; // 最大并发连接数
int peer_response_timeout; /* timeout of each read or write operation */ int peer_response_timeout; // 每次读取或写入操作的超时
int receive_timeout; /* timeout of receiving the whole message */ int receive_timeout; // 接收整个消息的超时
int keep_alive_timeout; int keep_alive_timeout; // 连接保持存活的超时
size_t request_size_limit; size_t request_size_limit; // 请求大小限制
int ssl_accept_timeout; /* if not ssl, this will be ignored */ int ssl_accept_timeout; // SSL 接受超时时间
}; };
// 设置服务器参数的默认值
static constexpr struct WFServerParams SERVER_PARAMS_DEFAULT = static constexpr struct WFServerParams SERVER_PARAMS_DEFAULT =
{ {
.transport_type = TT_TCP, .transport_type = TT_TCP, // 默认使用 TCP
.max_connections = 2000, .max_connections = 2000, // 默认最多 2000 个连接
.peer_response_timeout = 10 * 1000, .peer_response_timeout = 10 * 1000, // 默认对端响应超时 10 秒
.receive_timeout = -1, .receive_timeout = -1, // 默认不超时
.keep_alive_timeout = 60 * 1000, .keep_alive_timeout = 60 * 1000, // 默认连接保持存活 60 秒
.request_size_limit = (size_t)-1, .request_size_limit = (size_t)-1, // 请求大小限制为无限制
.ssl_accept_timeout = 10 * 1000, .ssl_accept_timeout = 10 * 1000, // SSL 接受超时设置为 10 秒
}; };
// 定义 WFServerBase 基类,用于建立服务器
class WFServerBase : protected CommService class WFServerBase : protected CommService
{ {
public: public:
// 构造函数,接收服务器参数并初始化
WFServerBase(const struct WFServerParams *params) : WFServerBase(const struct WFServerParams *params) :
conn_count(0) conn_count(0) // 初始化连接计数为 0
{ {
this->params = *params; this->params = *params; // 保存服务器参数
this->unbind_finish = false; this->unbind_finish = false; // 初始化解除绑定状态
this->listen_fd = -1; this->listen_fd = -1; // 初始化监听文件描述符为无效值
} }
public: public:
/* To start a TCP server */ // 启动一个 TCP 服务器函数
/* Start on port with IPv4. */ /* 在指定端口上启动服务器IPv4 */
int start(unsigned short port) int start(unsigned short port)
{ {
return start(AF_INET, NULL, port, NULL, NULL); return start(AF_INET, NULL, port, NULL, NULL);
} }
/* Start with family. AF_INET or AF_INET6. */ /* 根据家庭类型启动服务器IPv4 或 IPv6 */
int start(int family, unsigned short port) int start(int family, unsigned short port)
{ {
return start(family, NULL, port, NULL, NULL); return start(family, NULL, port, NULL, NULL);
} }
/* Start with hostname and port. */ /* 使用主机名和端口启动服务器 */
int start(const char *host, unsigned short port) int start(const char *host, unsigned short port)
{ {
return start(AF_INET, host, port, NULL, NULL); return start(AF_INET, host, port, NULL, NULL);
} }
/* Start with family, hostname and port. */ /* 使用家庭类型、主机名和端口启动服务器 */
int start(int family, const char *host, unsigned short port) int start(int family, const char *host, unsigned short port)
{ {
return start(family, host, port, NULL, NULL); return start(family, host, port, NULL, NULL);
} }
/* Start with binding address. */ /* 使用指定的地址绑定启动服务器 */
int start(const struct sockaddr *bind_addr, socklen_t addrlen) int start(const struct sockaddr *bind_addr, socklen_t addrlen)
{ {
return start(bind_addr, addrlen, NULL, NULL); return start(bind_addr, addrlen, NULL, NULL);
} }
/* To start an SSL server. */ // 启动一个 SSL 服务器
/* 在指定端口启动 SSL 服务器 */
int start(unsigned short port, const char *cert_file, const char *key_file) int start(unsigned short port, const char *cert_file, const char *key_file)
{ {
return start(AF_INET, NULL, port, cert_file, key_file); return start(AF_INET, NULL, port, cert_file, key_file);
} }
/* 使用家庭类型和端口启动 SSL 服务器 */
int start(int family, unsigned short port, int start(int family, unsigned short port,
const char *cert_file, const char *key_file) const char *cert_file, const char *key_file)
{ {
return start(family, NULL, port, cert_file, key_file); return start(family, NULL, port, cert_file, key_file);
} }
/* 使用主机名和端口启动 SSL 服务器 */
int start(const char *host, unsigned short port, int start(const char *host, unsigned short port,
const char *cert_file, const char *key_file) const char *cert_file, const char *key_file)
{ {
return start(AF_INET, host, port, cert_file, key_file); return start(AF_INET, host, port, cert_file, key_file);
} }
/* 使用家庭类型、主机名和端口启动 SSL 服务器 */
int start(int family, const char *host, unsigned short port, int start(int family, const char *host, unsigned short port,
const char *cert_file, const char *key_file); const char *cert_file, const char *key_file);
/* This is the only necessary start function. */ /* 使用指定文件描述符启动服务器,用于优雅重启或 SCTP 服务器 */
int start(const struct sockaddr *bind_addr, socklen_t addrlen,
const char *cert_file, const char *key_file);
/* To start with a specified fd. For graceful restart or SCTP server. */
int serve(int listen_fd) int serve(int listen_fd)
{ {
return serve(listen_fd, NULL, NULL); return serve(listen_fd, NULL, NULL);
@ -131,114 +135,119 @@ public:
int serve(int listen_fd, const char *cert_file, const char *key_file); int serve(int listen_fd, const char *cert_file, const char *key_file);
/* stop() is a blocking operation. */ /* 停止服务器是一个阻塞操作 */
void stop() void stop()
{ {
this->shutdown(); this->shutdown(); // 关闭服务器
this->wait_finish(); this->wait_finish(); // 等待完成
} }
/* Nonblocking terminating the server. For stopping multiple servers. /* 非阻塞地终止服务器。用于停止多个服务器。
* Typically, call shutdown() and then wait_finish(). * shutdown() wait_finish()
* But indeed wait_finish() can be called before shutdown(), even before * 线 shutdown() start() wait_finish() */
* start() in another thread. */ void shutdown(); // 停止服务器
void shutdown(); void wait_finish(); // 等待所有连接和任务完成
void wait_finish();
public: public:
// 获取当前连接数
size_t get_conn_count() const { return this->conn_count; } size_t get_conn_count() const { return this->conn_count; }
/* Get the listening address. This is often used after starting /* 获取监听地址。这通常在服务器在随机端口上启动后使用start() 端口为 0。 */
* server on a random port (start() with port == 0). */
int get_listen_addr(struct sockaddr *addr, socklen_t *addrlen) const int get_listen_addr(struct sockaddr *addr, socklen_t *addrlen) const
{ {
if (this->listen_fd >= 0) if (this->listen_fd >= 0)
return getsockname(this->listen_fd, addr, addrlen); return getsockname(this->listen_fd, addr, addrlen); // 获取当前的 socket 地址
errno = ENOTCONN; errno = ENOTCONN; // 如果没有连接,设置错误号为没有连接
return -1; return -1; // 返回错误
} }
// 获取服务器参数
const struct WFServerParams *get_params() const { return &this->params; } const struct WFServerParams *get_params() const { return &this->params; }
protected: protected:
/* Override this function to create the initial SSL CTX of the server */ /* 重写此函数以创建服务器的初始 SSL CTX */
virtual SSL_CTX *new_ssl_ctx(const char *cert_file, const char *key_file); virtual SSL_CTX *new_ssl_ctx(const char *cert_file, const char *key_file);
/* Override this function to implement server that supports TLS SNI. /* 重写此函数以实现支持 TLS SNI 的服务器。
* "servername" will be NULL if client does not set a host name. * "servername" NULL
* Returning NULL to indicate that servername is not supported. */ * NULL */
virtual SSL_CTX *get_server_ssl_ctx(const char *servername) virtual SSL_CTX *get_server_ssl_ctx(const char *servername)
{ {
return this->get_ssl_ctx(); return this->get_ssl_ctx(); // 返回当前的 SSL 上下文
} }
/* This can be used by the implementation of 'new_ssl_ctx'. */ /* 这个函数可以在 'new_ssl_ctx' 的实现中使用。 */
static int ssl_ctx_callback(SSL *ssl, int *al, void *arg); static int ssl_ctx_callback(SSL *ssl, int *al, void *arg);
protected: protected:
WFServerParams params; WFServerParams params; // 服务器参数结构体
protected: protected:
// 创建新连接的方法
virtual int create_listen_fd(); virtual int create_listen_fd();
virtual WFConnection *new_connection(int accept_fd); virtual WFConnection *new_connection(int accept_fd); // 创建新连接
void delete_connection(WFConnection *conn); void delete_connection(WFConnection *conn); // 删除连接
private: private:
// 初始化服务器的方法
int init(const struct sockaddr *bind_addr, socklen_t addrlen, int init(const struct sockaddr *bind_addr, socklen_t addrlen,
const char *cert_file, const char *key_file); const char *cert_file, const char *key_file);
virtual void handle_unbound(); virtual void handle_unbound(); // 处理解除绑定
protected: protected:
std::atomic<size_t> conn_count; std::atomic<size_t> conn_count; // 当前连接计数
private: private:
int listen_fd; int listen_fd; // 监听文件描述符
bool unbind_finish; bool unbind_finish; // 标记解除绑定是否完成
std::mutex mutex; std::mutex mutex; // 互斥锁,用于线程安全
std::condition_variable cond; std::condition_variable cond; // 条件变量,用于线程间通知
class CommScheduler *scheduler; class CommScheduler *scheduler; // 调度器
}; };
// 模板类,用于定义特定协议的服务器
template<class REQ, class RESP> template<class REQ, class RESP>
class WFServer : public WFServerBase class WFServer : public WFServerBase
{ {
public: public:
// 构造函数,接收参数和处理函数
WFServer(const struct WFServerParams *params, WFServer(const struct WFServerParams *params,
std::function<void (WFNetworkTask<REQ, RESP> *)> proc) : std::function<void (WFNetworkTask<REQ, RESP> *)> proc) :
WFServerBase(params), WFServerBase(params), // 调用基类构造函数
process(std::move(proc)) process(std::move(proc)) // 移动处理函数,避免不必要的拷贝
{ {
} }
// 构造函数,使用默认参数
WFServer(std::function<void (WFNetworkTask<REQ, RESP> *)> proc) : WFServer(std::function<void (WFNetworkTask<REQ, RESP> *)> proc) :
WFServerBase(&SERVER_PARAMS_DEFAULT), WFServerBase(&SERVER_PARAMS_DEFAULT), // 调用基类构造函数,传入默认参数
process(std::move(proc)) process(std::move(proc)) // 移动处理函数
{ {
} }
protected: protected:
virtual CommSession *new_session(long long seq, CommConnection *conn); virtual CommSession *new_session(long long seq, CommConnection *conn); // 创建新会话
protected: protected:
std::function<void (WFNetworkTask<REQ, RESP> *)> process; std::function<void (WFNetworkTask<REQ, RESP> *)> process; // 处理函数
}; };
// 创建新会话的方法,实现具体的会话处理
template<class REQ, class RESP> template<class REQ, class RESP>
CommSession *WFServer<REQ, RESP>::new_session(long long seq, CommConnection *conn) CommSession *WFServer<REQ, RESP>::new_session(long long seq, CommConnection *conn)
{ {
using factory = WFNetworkTaskFactory<REQ, RESP>; using factory = WFNetworkTaskFactory<REQ, RESP>; // 使用工厂来创建任务
WFNetworkTask<REQ, RESP> *task; WFNetworkTask<REQ, RESP> *task; // 声明任务指针
task = factory::create_server_task(this, this->process); task = factory::create_server_task(this, this->process); // 使用工厂创建任务
task->set_keep_alive(this->params.keep_alive_timeout); task->set_keep_alive(this->params.keep_alive_timeout); // 设置保持连接的超时
task->set_receive_timeout(this->params.receive_timeout); task->set_receive_timeout(this->params.receive_timeout); // 设置接收超时
task->get_req()->set_size_limit(this->params.request_size_limit); task->get_req()->set_size_limit(this->params.request_size_limit); // 设置请求大小限制
return task; return task; // 返回创建的会话任务
} }
#endif #endif // _WFSERVER_H_

@ -1,6 +1,10 @@
-- 定义一个 xmake 目标,名为 server
target("server") target("server")
set_kind("object") set_kind("object") -- 设置目标类型为对象文件
add_files("*.cc")
add_files("*.cc") -- 添加当前目录下所有的 .cc 文件到目标中
-- 检查是否未定义 MySQL 配置
if not has_config("mysql") then if not has_config("mysql") then
remove_files("WFMySQLServer.cc") remove_files("WFMySQLServer.cc") -- 如果未定义 MySQL 配置,则移除 MySQL 服务器相关的源文件
end end

@ -0,0 +1,71 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Xie Han (xiehan@sogou-inc.com)
*/
#ifndef _WFHTTPSERVER_H_
#define _WFHTTPSERVER_H_
#include <utility> // 引入标准库中的工具函数
#include "HttpMessage.h" // 引入 HTTP 消息相关的头文件
#include "WFServer.h" // 引入服务器功能的头文件
#include "WFTaskFactory.h" // 引入任务工厂以创建 HTTP 任务
// 定义一个 http_process_t 类型,它是一个接受 WFHttpTask 指针并返回 void 的函数对象
using http_process_t = std::function<void (WFHttpTask *)>;
// 定义 WFHttpServer 类型,基于 WFServer 模板类,使用 HttpRequest 和 HttpResponse 作为协议
using WFHttpServer = WFServer<protocol::HttpRequest, protocol::HttpResponse>;
// 设置 HTTP 服务器的默认参数
static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT =
{
.transport_type = TT_TCP, // 传输类型为 TCP
.max_connections = 2000, // 最大连接数
.peer_response_timeout = 10 * 1000, // 对端响应超时时间
.receive_timeout = -1, // 接收超时时间,-1 表示不超时
.keep_alive_timeout = 60 * 1000, // 保持连接的超时时间
.request_size_limit = (size_t)-1, // 请求大小限制,-1 表示无限制
.ssl_accept_timeout = 10 * 1000, // SSL 接受超时时间
};
// WFHttpServer 构造函数的特化模板
template<> inline
WFHttpServer::WFServer(http_process_t proc) :
WFServerBase(&HTTP_SERVER_PARAMS_DEFAULT), // 调用基类构造函数,传入默认参数
process(std::move(proc)) // 移动构造函数以提高性能
{
}
// 创建新的会话,重写基类的方法
template<> inline
CommSession *WFHttpServer::new_session(long long seq, CommConnection *conn)
{
WFHttpTask *task; // 声明一个 HTTP 任务指针
// 使用任务工厂创建一个新的 HTTP 任务
task = WFServerTaskFactory::create_http_task(this, this->process);
// 设置保持连接的超时时间
task->set_keep_alive(this->params.keep_alive_timeout);
// 设置接收超时时间
task->set_receive_timeout(this->params.receive_timeout);
// 设置请求大小限制
task->get_req()->set_size_limit(this->params.request_size_limit);
return task; // 返回创建的任务
}
#endif // _WFHTTPSERVER_H_
Loading…
Cancel
Save