Compare commits

..

3 Commits

Binary file not shown.

Binary file not shown.

@ -28,281 +28,250 @@ using namespace protocol;
using DnsCtx = std::function<void (WFDnsTask *)>; using DnsCtx = std::function<void (WFDnsTask *)>;
using ComplexTask = WFComplexClientTask<DnsRequest, DnsResponse, DnsCtx>; using ComplexTask = WFComplexClientTask<DnsRequest, DnsResponse, DnsCtx>;
// 定义一个DnsParams类用于存储DNS查询参数
class DnsParams class DnsParams
{ {
public: public:
// 定义一个内部结构体dns_params用于存储具体的DNS参数 struct dns_params
struct dns_params {
{ std::vector<ParsedURI> uris;
std::vector<ParsedURI> uris; // 存储解析后的URI std::vector<std::string> search_list;
std::vector<std::string> search_list; // 搜索列表 int ndots;
int ndots; // 域名中允许的最小点数 int attempts;
int attempts; // 最大尝试次数 bool rotate;
bool rotate; // 是否轮询 };
};
public: public:
// DnsParams类的构造函数 DnsParams()
DnsParams() {
{ this->ref = new std::atomic<size_t>(1);
this->ref = new std::atomic<size_t>(1); // 初始化引用计数为1 this->params = new dns_params();
this->params = new dns_params(); // 初始化参数结构体 }
}
DnsParams(const DnsParams& p)
// DnsParams类的拷贝构造函数 {
DnsParams(const DnsParams& p) this->ref = p.ref;
{ this->params = p.params;
this->ref = p.ref; // 拷贝引用计数指针 this->incref();
this->params = p.params; // 拷贝参数结构体指针 }
this->incref(); // 增加引用计数
} DnsParams& operator=(const DnsParams& p)
{
// DnsParams类的赋值运算符 if (this != &p)
DnsParams& operator=(const DnsParams& p) {
{ this->decref();
if (this != &p) // 如果不是自赋值 this->ref = p.ref;
{ this->params = p.params;
this->decref(); // 减少当前对象的引用计数 this->incref();
this->ref = p.ref; // 拷贝引用计数指针 }
this->params = p.params; // 拷贝参数结构体指针 return *this;
this->incref(); // 增加引用计数 }
}
return *this; // 返回当前对象的引用 ~DnsParams() { this->decref(); }
}
const dns_params *get_params() const { return this->params; }
// DnsParams类的析构函数 dns_params *get_params() { return this->params; }
~DnsParams() { this->decref(); } // 减少引用计数
// 获取const类型的参数指针
const dns_params *get_params() const { return this->params; }
// 获取非const类型的参数指针
dns_params *get_params() { return this->params; }
private: private:
// 增加引用计数 void incref() { (*this->ref)++; }
void incref() { (*this->ref)++; } void decref()
// 减少引用计数并在计数为0时释放资源 {
void decref() if (--*this->ref == 0)
{ {
if (--*this->ref == 0) delete this->params;
{ delete this->ref;
delete this->params; // 删除参数结构体 }
delete this->ref; // 删除引用计数 }
}
}
private: private:
dns_params *params; // 指向参数结构体的指针 dns_params *params;
std::atomic<size_t> *ref; // 指向引用计数的指针 std::atomic<size_t> *ref;
}; };
// 定义DNS状态枚举
enum enum
{ {
DNS_STATUS_TRY_ORIGIN_DONE = 0, DNS_STATUS_TRY_ORIGIN_DONE = 0,
DNS_STATUS_TRY_ORIGIN_FIRST = 1, DNS_STATUS_TRY_ORIGIN_FIRST = 1,
DNS_STATUS_TRY_ORIGIN_LAST = 2 DNS_STATUS_TRY_ORIGIN_LAST = 2
}; };
// 定义DnsStatus结构体用于存储DNS查询状态
struct DnsStatus struct DnsStatus
{ {
std::string origin_name; // 原始域名 std::string origin_name;
std::string current_name; // 当前域名 std::string current_name;
size_t next_server; // 下一个要尝试的服务器 size_t next_server; // next server to try
size_t last_server; // 上一个尝试的服务器 size_t last_server; // last server to try
size_t next_domain; // 下一个要尝试的搜索域 size_t next_domain; // next search domain to try
int attempts_left; // 剩余尝试次数 int attempts_left;
int try_origin_state; // 尝试原始域名的状态 int try_origin_state;
}; };
// 计算字符串中点的数量
static int __get_ndots(const std::string& s) static int __get_ndots(const std::string& s)
{ {
int ndots = 0; int ndots = 0;
for (size_t i = 0; i < s.size(); i++) for (size_t i = 0; i < s.size(); i++)
ndots += s[i] == '.'; // 统计点的数量 ndots += s[i] == '.';
return ndots; return ndots;
} }
// 检查是否有下一个域名要尝试
static bool __has_next_name(const DnsParams::dns_params *p, static bool __has_next_name(const DnsParams::dns_params *p,
struct DnsStatus *s) struct DnsStatus *s)
{ {
if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_FIRST) if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_FIRST)
{ {
s->current_name = s->origin_name; // 设置当前域名为原始域名 s->current_name = s->origin_name;
s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE; // 更新状态 s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE;
return true; return true;
} }
if (s->next_domain < p->search_list.size()) // 如果还有搜索域要尝试 if (s->next_domain < p->search_list.size())
{ {
s->current_name = s->origin_name; // 设置当前域名为原始域名 s->current_name = s->origin_name;
s->current_name.push_back('.'); // 添加点 s->current_name.push_back('.');
s->current_name.append(p->search_list[s->next_domain]); // 添加搜索域 s->current_name.append(p->search_list[s->next_domain]);
s->next_domain++; // 移动到下一个搜索域 s->next_domain++;
return true; return true;
} }
if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_LAST) if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_LAST)
{ {
s->current_name = s->origin_name; // 设置当前域名为原始域名 s->current_name = s->origin_name;
s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE; // 更新状态 s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE;
return true; return true;
} }
return false; // 没有下一个域名要尝试 return false;
} }
// DNS查询回调内部函数
static void __callback_internal(WFDnsTask *task, const DnsParams& params, static void __callback_internal(WFDnsTask *task, const DnsParams& params,
struct DnsStatus& s) struct DnsStatus& s)
{ {
ComplexTask *ctask = static_cast<ComplexTask *>(task); // 转换任务类型 ComplexTask *ctask = static_cast<ComplexTask *>(task);
int state = task->get_state(); // 获取任务状态 int state = task->get_state();
DnsRequest *req = task->get_req(); // 获取DNS请求 DnsRequest *req = task->get_req();
DnsResponse *resp = task->get_resp(); // 获取DNS响应 DnsResponse *resp = task->get_resp();
const auto *p = params.get_params(); // 获取DNS参数 const auto *p = params.get_params();
int rcode = resp->get_rcode(); // 获取响应码 int rcode = resp->get_rcode();
bool try_next_server = state != WFT_STATE_SUCCESS || // 如果状态不是成功 bool try_next_server = state != WFT_STATE_SUCCESS ||
rcode == DNS_RCODE_SERVER_FAILURE || // 或者响应码是服务器失败 rcode == DNS_RCODE_SERVER_FAILURE ||
rcode == DNS_RCODE_NOT_IMPLEMENTED || // 或者响应码是未实现 rcode == DNS_RCODE_NOT_IMPLEMENTED ||
rcode == DNS_RCODE_REFUSED; // 或者响应码是拒绝 rcode == DNS_RCODE_REFUSED;
bool try_next_name = rcode == DNS_RCODE_FORMAT_ERROR || // 如果响应码是格式错误 bool try_next_name = rcode == DNS_RCODE_FORMAT_ERROR ||
rcode == DNS_RCODE_NAME_ERROR || // 或者响应码是名字错误 rcode == DNS_RCODE_NAME_ERROR ||
resp->get_ancount() == 0; // 或者响应中没有答案 resp->get_ancount() == 0;
if (try_next_server) if (try_next_server)
{ {
if (s.last_server == s.next_server) // 如果已经是最后一个服务器 if (s.last_server == s.next_server)
s.attempts_left--; // 减少尝试次数 s.attempts_left--;
if (s.attempts_left <= 0) // 如果尝试次数用完 if (s.attempts_left <= 0)
return; // 返回 return;
s.next_server = (s.next_server + 1) % p->uris.size(); // 计算下一个服务器 s.next_server = (s.next_server + 1) % p->uris.size();
ctask->set_redirect(p->uris[s.next_server]); // 设置重定向 ctask->set_redirect(p->uris[s.next_server]);
return; // 返回 return;
} }
if (try_next_name && __has_next_name(p, &s)) // 如果需要尝试下一个名字 if (try_next_name && __has_next_name(p, &s))
{ {
req->set_question_name(s.current_name.c_str()); // 设置查询名字 req->set_question_name(s.current_name.c_str());
ctask->set_redirect(p->uris[s.next_server]); // 设置重定向 ctask->set_redirect(p->uris[s.next_server]);
return; // 返回 return;
} }
} }
// WFDnsClient类的初始化函数带一个URL参数
int WFDnsClient::init(const std::string& url) int WFDnsClient::init(const std::string& url)
{ {
return this->init(url, "", 1, 2, false); // 调用另一个初始化函数 return this->init(url, "", 1, 2, false);
} }
// WFDnsClient类的初始化函数用于配置DNS客户端
int WFDnsClient::init(const std::string& url, const std::string& search_list, int WFDnsClient::init(const std::string& url, const std::string& search_list,
int ndots, int attempts, bool rotate) int ndots, int attempts, bool rotate)
{ {
std::vector<std::string> hosts; // 用于存储分割后的主机名 std::vector<std::string> hosts;
std::vector<ParsedURI> uris; // 用于存储解析后的URI std::vector<ParsedURI> uris;
std::string host; // 单个主机名字符串 std::string host;
ParsedURI uri; // 用于存储解析后的单个URI ParsedURI uri;
this->id = 0; // 初始化客户端ID为0 this->id = 0;
hosts = StringUtil::split_filter_empty(url, ','); // 根据逗号分割URL字符串获取主机名列表 hosts = StringUtil::split_filter_empty(url, ',');
// 遍历主机名列表,对每个主机名进行处理 for (size_t i = 0; i < hosts.size(); i++)
for (size_t i = 0; i < hosts.size(); i++) {
{ host = hosts[i];
host = hosts[i]; // 获取当前主机名 if (strncasecmp(host.c_str(), "dns://", 6) != 0 &&
// 检查主机名是否以"dns://"或"dnss://"开头,如果不是,则添加"dns://"前缀 strncasecmp(host.c_str(), "dnss://", 7) != 0)
if (strncasecmp(host.c_str(), "dns://", 6) != 0 && {
strncasecmp(host.c_str(), "dnss://", 7) != 0) host = "dns://" + host;
{ }
host = "dns://" + host;
} if (URIParser::parse(host, uri) != 0)
return -1;
// 使用URIParser解析当前主机名如果解析失败则返回错误码-1
if (URIParser::parse(host, uri) != 0) uris.emplace_back(std::move(uri));
return -1; }
// 将解析后的URI添加到uris列表中 if (uris.empty() || ndots < 0 || attempts < 1)
uris.emplace_back(std::move(uri)); {
} errno = EINVAL;
return -1;
// 如果uris列表为空或者ndots小于0或者attempts小于1则设置errno为EINVAL并返回错误码-1 }
if (uris.empty() || ndots < 0 || attempts < 1)
{ this->params = new DnsParams;
errno = EINVAL; DnsParams::dns_params *q = ((DnsParams *)this->params)->get_params();
return -1; q->uris = std::move(uris);
} q->search_list = StringUtil::split_filter_empty(search_list, ',');
q->ndots = ndots > 15 ? 15 : ndots;
// 创建一个新的DnsParams对象来存储DNS参数 q->attempts = attempts > 5 ? 5 : attempts;
this->params = new DnsParams; q->rotate = rotate;
DnsParams::dns_params *q = ((DnsParams *)this->params)->get_params(); // 获取DNS参数的指针
q->uris = std::move(uris); // 将解析后的URI列表移动到DNS参数中 return 0;
q->search_list = StringUtil::split_filter_empty(search_list, ','); // 根据逗号分割search_list字符串获取搜索域列表
// 设置ndots的值如果大于15则设置为15
q->ndots = ndots > 15 ? 15 : ndots;
// 设置attempts的值如果大于5则设置为5
q->attempts = attempts > 5 ? 5 : attempts;
q->rotate = rotate; // 设置是否轮询
return 0; // 初始化成功返回0
} }
// WFDnsClient类的析构函数用于释放资源
void WFDnsClient::deinit() void WFDnsClient::deinit()
{ {
delete (DnsParams *)this->params; // 删除分配的DnsParams对象 delete (DnsParams *)this->params;
this->params = NULL; // 将params指针设置为NULL this->params = NULL;
} }
// WFDnsClient类的方法用于创建一个新的DNS任务
WFDnsTask *WFDnsClient::create_dns_task(const std::string& name, WFDnsTask *WFDnsClient::create_dns_task(const std::string& name,
dns_callback_t callback) dns_callback_t callback)
{ {
DnsParams::dns_params *p = ((DnsParams *)this->params)->get_params(); // 获取DNS参数 DnsParams::dns_params *p = ((DnsParams *)this->params)->get_params();
struct DnsStatus status; // 创建DNS状态结构体 struct DnsStatus status;
size_t next_server; // 下一个要尝试的服务器索引 size_t next_server;
WFDnsTask *task; // 指向新创建的DNS任务的指针 WFDnsTask *task;
DnsRequest *req; // 指向DNS请求的指针 DnsRequest *req;
// 如果启用轮询,则计算下一个服务器索引,否则使用第一个服务器 next_server = p->rotate ? this->id++ % p->uris.size() : 0;
next_server = p->rotate ? this->id++ % p->uris.size() : 0;
status.origin_name = name;
status.origin_name = name; // 设置原始域名 status.next_domain = 0;
status.next_domain = 0; // 设置下一个要尝试的搜索域索引 status.attempts_left = p->attempts;
status.attempts_left = p->attempts; // 设置剩余尝试次数 status.try_origin_state = DNS_STATUS_TRY_ORIGIN_FIRST;
status.try_origin_state = DNS_STATUS_TRY_ORIGIN_FIRST; // 设置尝试原始域名的状态
if (!name.empty() && name.back() == '.')
// 如果域名以点结尾,则跳过搜索域 status.next_domain = p->search_list.size();
if (!name.empty() && name.back() == '.') else if (__get_ndots(name) < p->ndots)
status.next_domain = p->search_list.size(); status.try_origin_state = DNS_STATUS_TRY_ORIGIN_LAST;
// 如果域名中的点数少于ndots则设置尝试原始域名的状态为最后尝试
else if (__get_ndots(name) < p->ndots) __has_next_name(p, &status);
status.try_origin_state = DNS_STATUS_TRY_ORIGIN_LAST;
task = WFTaskFactory::create_dns_task(p->uris[next_server], 0,
// 检查是否有下一个域名要尝试,并更新状态 std::move(callback));
__has_next_name(p, &status); status.next_server = next_server;
status.last_server = (next_server + p->uris.size() - 1) % p->uris.size();
// 创建一个新的DNS任务使用下一个服务器的URI和提供的回调函数
task = WFTaskFactory::create_dns_task(p->uris[next_server], 0, std::move(callback)); req = task->get_req();
status.next_server = next_server; // 设置当前服务器索引 req->set_question(status.current_name.c_str(), DNS_TYPE_A, DNS_CLASS_IN);
status.last_server = (next_server + p->uris.size() - 1) % p->uris.size(); // 设置最后一个服务器索引 req->set_rd(1);
req = task->get_req(); // 获取DNS请求对象 ComplexTask *ctask = static_cast<ComplexTask *>(task);
req->set_question(status.current_name.c_str(), DNS_TYPE_A, DNS_CLASS_IN); // 设置DNS请求的问题部分 *ctask->get_mutable_ctx() = std::bind(__callback_internal,
req->set_rd(1); // 设置DNS请求的递归标志 std::placeholders::_1,
*(DnsParams *)params, status);
ComplexTask *ctask = static_cast<ComplexTask *>(task); // 将任务转换为ComplexTask类型
// 设置任务的上下文回调当DNS任务完成时将调用__callback_internal函数 return task;
*ctask->get_mutable_ctx() = std::bind(__callback_internal,
std::placeholders::_1,
*(DnsParams *)params, status);
return task; // 返回新创建的DNS任务
} }

File diff suppressed because it is too large Load Diff

@ -26,81 +26,64 @@
#include "URIParser.h" #include "URIParser.h"
#include "WFTaskFactory.h" #include "WFTaskFactory.h"
// 定义WFMySQLConnection类用于管理MySQL数据库连接
class WFMySQLConnection class WFMySQLConnection
{ {
public: public:
// 初始化函数接受一个包含连接信息的URL字符串 /* example: mysql://username:passwd@127.0.0.1/dbname?character_set=utf8
// 例如mysql://username:passwd@127.0.0.1/dbname?character_set=utf8 * IP string is recommmended in url. When using a domain name, the first
// 推荐使用IP地址而不是域名因为域名会被解析为第一个地址并且不建议使用上游名称作为主机名 * address resovled will be used. Don't use upstream name as a host. */
int init(const std::string& url) int init(const std::string& url)
{ {
// 调用另一个init函数传入URL和一个NULL的SSL上下文
return this->init(url, NULL); return this->init(url, NULL);
} }
// 重载的init函数接受URL和一个可选的SSL上下文
int init(const std::string& url, SSL_CTX *ssl_ctx); int init(const std::string& url, SSL_CTX *ssl_ctx);
// 反初始化函数,当前为空实现
void deinit() { } void deinit() { }
public: public:
// 创建一个查询任务接受一个SQL查询字符串和一个回调函数
WFMySQLTask *create_query_task(const std::string& query, WFMySQLTask *create_query_task(const std::string& query,
mysql_callback_t callback) mysql_callback_t callback)
{ {
// 使用WFTaskFactory创建一个新的MySQL任务传入连接URI、超时时间和回调函数
WFMySQLTask *task = WFTaskFactory::create_mysql_task(this->uri, 0, WFMySQLTask *task = WFTaskFactory::create_mysql_task(this->uri, 0,
std::move(callback)); std::move(callback));
// 设置任务的SSL上下文
this->set_ssl_ctx(task); this->set_ssl_ctx(task);
// 设置任务的查询字符串
task->get_req()->set_query(query); task->get_req()->set_query(query);
// 返回创建的任务
return task; return task;
} }
// 创建一个断开连接的任务,接受一个回调函数 /* If you don't disconnect manually, the TCP connection will be
// 如果不手动断开连接TCP连接将在对象删除后保持活动并可能被具有相同ID和URL的另一个WFMySQLConnection对象重用 * kept alive after this object is deleted, and maybe reused by
* another WFMySQLConnection object with same id and url. */
WFMySQLTask *create_disconnect_task(mysql_callback_t callback) WFMySQLTask *create_disconnect_task(mysql_callback_t callback)
{ {
// 创建一个空查询的任务,以此作为断开连接的任务
WFMySQLTask *task = this->create_query_task("", std::move(callback)); WFMySQLTask *task = this->create_query_task("", std::move(callback));
// 设置任务的SSL上下文
this->set_ssl_ctx(task); this->set_ssl_ctx(task);
// 设置任务不保持活动状态,暗示这是一个断开连接的任务
task->set_keep_alive(0); task->set_keep_alive(0);
// 返回创建的任务
return task; return task;
} }
protected: protected:
// 设置任务的SSL上下文
void set_ssl_ctx(WFMySQLTask *task) const void set_ssl_ctx(WFMySQLTask *task) const
{ {
// 使用别名简化类型名称
using MySQLRequest = protocol::MySQLRequest; using MySQLRequest = protocol::MySQLRequest;
using MySQLResponse = protocol::MySQLResponse; using MySQLResponse = protocol::MySQLResponse;
// 将任务强制转换为WFComplexClientTask类型
auto *t = (WFComplexClientTask<MySQLRequest, MySQLResponse> *)task; auto *t = (WFComplexClientTask<MySQLRequest, MySQLResponse> *)task;
// 设置任务的SSL上下文如果为NULL则使用默认 /* 'ssl_ctx' can be NULL and will use default. */
t->set_ssl_ctx(this->ssl_ctx); t->set_ssl_ctx(this->ssl_ctx);
} }
protected: protected:
// 解析后的URI用于存储连接信息
ParsedURI uri; ParsedURI uri;
// SSL上下文用于加密通信
SSL_CTX *ssl_ctx; SSL_CTX *ssl_ctx;
// 连接ID确保并发连接具有不同的ID
int id; int id;
public: public:
// 确保并发连接的ID不同当连接对象被删除时ID可以被重用 /* Make sure that concurrent connections have different id.
* When a connection object is deleted, id can be reused. */
WFMySQLConnection(int id) { this->id = id; } WFMySQLConnection(int id) { this->id = id; }
// 析构函数,当前为空实现
virtual ~WFMySQLConnection() { } virtual ~WFMySQLConnection() { }
}; };
#endif #endif

@ -30,188 +30,211 @@
#include "WFTask.h" #include "WFTask.h"
#include "WFTaskFactory.h" #include "WFTaskFactory.h"
// 定义WFRedisSubscribeTask类它继承自WFGenericTask
class WFRedisSubscribeTask : public WFGenericTask class WFRedisSubscribeTask : public WFGenericTask
{ {
public: public:
// 获取与任务关联的Redis响应对象 /* Note: Call 'get_resp()' only in the 'extract' function or
// 注意:只能在'extract'函数中或在任务开始之前调用'get_resp()'来设置响应大小限制。 before the task is started to set response size limit. */
protocol::RedisResponse *get_resp() protocol::RedisResponse *get_resp()
{ {
return this->task->get_resp(); return this->task->get_resp();
} }
public: public:
// 用户需要恰好调用一次'release()'函数,可以在任何地方调用。 /* User needs to call 'release()' exactly once, anywhere. */
void release() void release()
{ {
// 使用原子操作来检查并设置标志如果标志已经是true则不执行删除操作 if (this->flag.exchange(true))
// 这样可以防止多次释放同一个对象 delete this;
if (this->flag.exchange(true)) }
delete this;
}
// ... (省略了其他Redis命令的函数如subscribe, unsubscribe等)
// 这些函数通过发送Redis命令来执行订阅、取消订阅等操作
public: public:
// 设置与消息接收相关的超时时间 /* Note: After 'release()' is called, all the requesting functions
// 这些函数只能在任务开始之前或在'extract'中调用 should not be called except in 'extract', because the task
point may have been deleted because 'callback' finished. */
// 设置等待每个消息的超时时间。非常有用。如果不设置,则最大等待时间将是全局的'response_timeout'
void set_watch_timeout(int timeout) int subscribe(const std::vector<std::string>& channels)
{ {
this->task->set_watch_timeout(timeout); return this->sync_send("SUBSCRIBE", channels);
} }
// 设置接收完整消息的超时时间。 int unsubscribe(const std::vector<std::string>& channels)
void set_recv_timeout(int timeout) {
{ return this->sync_send("UNSUBSCRIBE", channels);
this->task->set_receive_timeout(timeout); }
}
int unsubscribe()
// 设置发送第一个订阅请求的超时时间。 {
void set_send_timeout(int timeout) return this->sync_send("UNSUBSCRIBE", { });
{ }
this->task->set_send_timeout(timeout);
} int psubscribe(const std::vector<std::string>& patterns)
{
// 设置保持连接活跃的超时时间。默认值为0。如果你想保持连接活跃确保在所有频道/模式取消订阅后不再发送任何请求。 return this->sync_send("PSUBSCRIBE", patterns);
void set_keep_alive(int timeout) }
{
this->task->set_keep_alive(timeout); int punsubscribe(const std::vector<std::string>& patterns)
} {
return this->sync_send("PUNSUBSCRIBE", patterns);
}
int punsubscribe()
{
return this->sync_send("PUNSUBSCRIBE", { });
}
int ping(const std::string& message)
{
return this->sync_send("PING", { message });
}
int ping()
{
return this->sync_send("PING", { });
}
int quit()
{
return this->sync_send("QUIT", { });
}
public: public:
// 设置提取函数或回调函数,这些函数只能在任务开始之前或在'extract'中调用 /* All 'timeout' proxy functions can only be called only before
the task is started or in 'extract'. */
// 设置提取函数,当任务完成时,这个函数将被调用以处理结果 /* Timeout of waiting for each message. Very useful. If not set,
void set_extract(std::function<void (WFRedisSubscribeTask *)> ex) the max waiting time will be the global 'response_timeout'*/
{ void set_watch_timeout(int timeout)
this->extract = std::move(ex); {
} this->task->set_watch_timeout(timeout);
}
// 设置回调函数,当任务完成时(可能是在提取函数之后),这个函数将被调用 /* Timeout of receiving a complete message. */
void set_callback(std::function<void (WFRedisSubscribeTask *)> cb) void set_recv_timeout(int timeout)
{ {
this->callback = std::move(cb); this->task->set_receive_timeout(timeout);
} }
protected: /* Timeout of sending the first subscribe request. */
// 虚拟的dispatch函数用于将任务添加到执行序列中 void set_send_timeout(int timeout)
virtual void dispatch() {
{ this->task->set_send_timeout(timeout);
series_of(this)->push_front(this->task); }
this->subtask_done();
} /* The default keep alive timeout is 0. If you want to keep
the connection alive, make sure not to send any request
// 虚拟的done函数用于从执行序列中弹出任务并返回 after all channels/patterns were unsubscribed. */
virtual SubTask *done() void set_keep_alive(int timeout)
{ {
return series_of(this)->pop(); this->task->set_keep_alive(timeout);
} }
public:
/* Call 'set_extract' or 'set_callback' only before the task
is started, or in 'extract'. */
void set_extract(std::function<void (WFRedisSubscribeTask *)> ex)
{
this->extract = std::move(ex);
}
void set_callback(std::function<void (WFRedisSubscribeTask *)> cb)
{
this->callback = std::move(cb);
}
protected: protected:
// 同步发送Redis命令的函数它封装了与Redis服务器的通信细节 virtual void dispatch()
int sync_send(const std::string& command, {
const std::vector<std::string>& params); series_of(this)->push_front(this->task);
this->subtask_done();
}
virtual SubTask *done()
{
return series_of(this)->pop();
}
// 静态的提取函数和回调函数它们被设计为与Redis任务交互的接口 protected:
static void task_extract(WFRedisTask *task); int sync_send(const std::string& command,
static void task_callback(WFRedisTask *task); const std::vector<std::string>& params);
static void task_extract(WFRedisTask *task);
static void task_callback(WFRedisTask *task);
protected: protected:
// 指向底层Redis任务的指针 WFRedisTask *task;
WFRedisTask *task; std::mutex mutex;
// 用于同步访问的互斥锁 std::atomic<bool> flag;
std::mutex mutex; std::function<void (WFRedisSubscribeTask *)> extract;
// 原子标志,用于跟踪对象是否已被释放 std::function<void (WFRedisSubscribeTask *)> callback;
std::atomic<bool> flag;
// 提取函数和回调函数,当任务完成时它们将被调用
std::function<void (WFRedisSubscribeTask *)> extract;
std::function<void (WFRedisSubscribeTask *)> callback;
protected: protected:
// 构造函数接受一个Redis任务、提取函数和回调函数 WFRedisSubscribeTask(WFRedisTask *task,
WFRedisSubscribeTask(WFRedisTask *task, std::function<void (WFRedisSubscribeTask *)>&& ex,
std::function<void (WFRedisSubscribeTask *)>&& ex, std::function<void (WFRedisSubscribeTask *)>&& cb) :
std::function<void (WFRedisSubscribeTask *)>&& cb) : flag(false),
flag(false), extract(std::move(ex)),
extract(std::move(ex)), callback(std::move(cb))
callback(std::move(cb)) {
{ task->user_data = this;
// 将用户数据指针设置为当前对象,以便在回调中可以访问它 this->task = task;
task->user_data = this; }
this->task = task;
} virtual ~WFRedisSubscribeTask()
{
// 析构函数,释放与任务关联的资源 if (this->task)
virtual ~WFRedisSubscribeTask() this->task->dismiss();
{ }
if (this->task)
this->task->dismiss(); // 可能是取消任务或释放相关资源 friend class WFRedisSubscriber;
}
// 允许WFRedisSubscriber类访问受保护的成员
friend class WFRedisSubscriber;
}; };
// 定义一个名为 WFRedisSubscriber 的类
class WFRedisSubscriber class WFRedisSubscriber
{ {
public: public:
// 重载的 init 方法,只接受一个 URL 参数,内部调用另一个 init 方法并传入 NULL 作为 SSL 上下文
int init(const std::string& url) int init(const std::string& url)
{ {
return this->init(url, NULL); // 调用重载的 init 方法,传入 URL 和 NULL 作为 SSL 上下文 return this->init(url, NULL);
} }
// init 方法,接受一个 URL 和一个可选的 SSL 上下文参数 int init(const std::string& url, SSL_CTX *ssl_ctx);
int init(const std::string& url, SSL_CTX *ssl_ctx); // 此方法的具体实现在代码段之外
// 释放资源的 deinit 方法,当前为空实现
void deinit() { } void deinit() { }
public: public:
// 定义提取任务数据的函数类型
using extract_t = std::function<void (WFRedisSubscribeTask *)>; using extract_t = std::function<void (WFRedisSubscribeTask *)>;
// 定义任务回调的函数类型
using callback_t = std::function<void (WFRedisSubscribeTask *)>; using callback_t = std::function<void (WFRedisSubscribeTask *)>;
public: public:
// 创建一个订阅任务,接受频道列表、提取函数和回调函数 WFRedisSubscribeTask *
WFRedisSubscribeTask *create_subscribe_task(const std::vector<std::string>& channels, create_subscribe_task(const std::vector<std::string>& channels,
extract_t extract, callback_t callback); // 方法的具体实现在代码段之外 extract_t extract, callback_t callback);
// 创建一个模式订阅任务,接受模式列表、提取函数和回调函数 WFRedisSubscribeTask *
WFRedisSubscribeTask *create_psubscribe_task(const std::vector<std::string>& patterns, create_psubscribe_task(const std::vector<std::string>& patterns,
extract_t extract, callback_t callback); // 方法的具体实现在代码段之外 extract_t extract, callback_t callback);
protected: protected:
// 为任务设置 SSL 上下文
void set_ssl_ctx(WFRedisTask *task) const void set_ssl_ctx(WFRedisTask *task) const
{ {
// 使用别名简化类型名称
using RedisRequest = protocol::RedisRequest; using RedisRequest = protocol::RedisRequest;
using RedisResponse = protocol::RedisResponse; using RedisResponse = protocol::RedisResponse;
// 将传入的 WFRedisTask 强制转换为 WFComplexClientTask<RedisRequest, RedisResponse> 类型
auto *t = (WFComplexClientTask<RedisRequest, RedisResponse> *)task; auto *t = (WFComplexClientTask<RedisRequest, RedisResponse> *)task;
// 为任务设置 SSL 上下文,如果 ssl_ctx 为 NULL则使用默认设置 /* 'ssl_ctx' can be NULL and will use default. */
t->set_ssl_ctx(this->ssl_ctx); t->set_ssl_ctx(this->ssl_ctx);
} }
protected: protected:
// 创建一个 Redis 任务,接受命令和参数列表
WFRedisTask *create_redis_task(const std::string& command, WFRedisTask *create_redis_task(const std::string& command,
const std::vector<std::string>& params); // 方法的具体实现在代码段之外 const std::vector<std::string>& params);
protected: protected:
// 存储解析后的 URI用于连接 Redis 服务器
ParsedURI uri; ParsedURI uri;
// 存储 SSL 上下文,用于加密通信(如果需要)
SSL_CTX *ssl_ctx; SSL_CTX *ssl_ctx;
public: public:
virtual ~WFRedisSubscriber() { } virtual ~WFRedisSubscriber() { }
}; };
#endif #endif

@ -1,18 +1,15 @@
cmake_minimum_required(VERSION 3.6) # CMake 3.6 cmake_minimum_required(VERSION 3.6)
project(server) # "server" project(server)
#
set(SRC set(SRC
WFServer.cc # WFServer.cc WFServer.cc
) )
# MySQL "n"
if (NOT MYSQL STREQUAL "n") if (NOT MYSQL STREQUAL "n")
set(SRC set(SRC
${SRC} # ${SRC}
WFMySQLServer.cc # WFMySQLServer.cc WFMySQLServer.cc
) )
endif () endif ()
# add_library(${PROJECT_NAME} OBJECT ${SRC})
add_library(${PROJECT_NAME} OBJECT ${SRC}) # "server"

@ -19,52 +19,44 @@
#ifndef _WFDNSSERVER_H_ #ifndef _WFDNSSERVER_H_
#define _WFDNSSERVER_H_ #define _WFDNSSERVER_H_
#include "DnsMessage.h" // 引入 DNS 消息相关的头文件 #include "DnsMessage.h"
#include "WFServer.h" // 引入服务器功能的头文件 #include "WFServer.h"
#include "WFTaskFactory.h" // 引入任务工厂以创建 DNS 任务 #include "WFTaskFactory.h"
// 定义一个 dns_process_t 类型,它是一个接受 WFDnsTask 指针并返回 void 的函数对象
using dns_process_t = std::function<void (WFDnsTask *)>; using dns_process_t = std::function<void (WFDnsTask *)>;
using WFDnsServer = WFServer<protocol::DnsRequest,
protocol::DnsResponse>;
// 定义 WFDnsServer 类型,基于 WFServer 模板类,使用 DnsRequest 和 DnsResponse 作为协议
using WFDnsServer = WFServer<protocol::DnsRequest, protocol::DnsResponse>;
// 设置 DNS 服务器的默认参数
static constexpr struct WFServerParams DNS_SERVER_PARAMS_DEFAULT = static constexpr struct WFServerParams DNS_SERVER_PARAMS_DEFAULT =
{ {
.transport_type = TT_UDP, // 传输类型为 UDP .transport_type = TT_UDP,
.max_connections = 2000, // 最大连接数 .max_connections = 2000,
.peer_response_timeout = 10 * 1000, // 对端响应超时时间 .peer_response_timeout = 10 * 1000,
.receive_timeout = -1, // 接收超时时间,-1 表示不超时 .receive_timeout = -1,
.keep_alive_timeout = 300 * 1000, // 保持连接的超时时间 .keep_alive_timeout = 300 * 1000,
.request_size_limit = (size_t)-1, // 请求大小限制,-1 表示无限制 .request_size_limit = (size_t)-1,
.ssl_accept_timeout = 5000, // SSL 接受超时时间 .ssl_accept_timeout = 5000,
}; };
// WFDnsServer 构造函数的特化模板
template<> inline template<> inline
WFDnsServer::WFServer(dns_process_t proc) : WFDnsServer::WFServer(dns_process_t proc) :
WFServerBase(&DNS_SERVER_PARAMS_DEFAULT), // 调用基类构造函数,传入默认参数 WFServerBase(&DNS_SERVER_PARAMS_DEFAULT),
process(std::move(proc)) // 移动构造函数以提高性能 process(std::move(proc))
{ {
} }
// 创建新的会话,重写基类方法
template<> inline template<> inline
CommSession *WFDnsServer::new_session(long long seq, CommConnection *conn) CommSession *WFDnsServer::new_session(long long seq, CommConnection *conn)
{ {
WFDnsTask *task; WFDnsTask *task;
// 使用任务工厂创建一个新的 DNS 任务
task = WFServerTaskFactory::create_dns_task(this, this->process); task = WFServerTaskFactory::create_dns_task(this, this->process);
// 设置保持连接的超时时间
task->set_keep_alive(this->params.keep_alive_timeout); task->set_keep_alive(this->params.keep_alive_timeout);
// 设置接收超时时间
task->set_receive_timeout(this->params.receive_timeout); task->set_receive_timeout(this->params.receive_timeout);
// 设置请求大小限制
task->get_req()->set_size_limit(this->params.request_size_limit); task->get_req()->set_size_limit(this->params.request_size_limit);
return task; // 返回创建的任务 return task;
} }
#endif // _WFDNSSERVER_H_ #endif

@ -19,53 +19,45 @@
#ifndef _WFHTTPSERVER_H_ #ifndef _WFHTTPSERVER_H_
#define _WFHTTPSERVER_H_ #define _WFHTTPSERVER_H_
#include <utility> // 引入标准库中的工具函数 #include <utility>
#include "HttpMessage.h" // 引入 HTTP 消息相关的头文件 #include "HttpMessage.h"
#include "WFServer.h" // 引入服务器功能的头文件 #include "WFServer.h"
#include "WFTaskFactory.h" // 引入任务工厂以创建 HTTP 任务 #include "WFTaskFactory.h"
// 定义一个 http_process_t 类型,它是一个接受 WFHttpTask 指针并返回 void 的函数对象
using http_process_t = std::function<void (WFHttpTask *)>; using http_process_t = std::function<void (WFHttpTask *)>;
using WFHttpServer = WFServer<protocol::HttpRequest,
protocol::HttpResponse>;
// 定义 WFHttpServer 类型,基于 WFServer 模板类,使用 HttpRequest 和 HttpResponse 作为协议
using WFHttpServer = WFServer<protocol::HttpRequest, protocol::HttpResponse>;
// 设置 HTTP 服务器的默认参数
static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT = static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT =
{ {
.transport_type = TT_TCP, // 传输类型为 TCP .transport_type = TT_TCP,
.max_connections = 2000, // 最大连接数 .max_connections = 2000,
.peer_response_timeout = 10 * 1000, // 对端响应超时时间 .peer_response_timeout = 10 * 1000,
.receive_timeout = -1, // 接收超时时间,-1 表示不超时 .receive_timeout = -1,
.keep_alive_timeout = 60 * 1000, // 保持连接的超时时间 .keep_alive_timeout = 60 * 1000,
.request_size_limit = (size_t)-1, // 请求大小限制,-1 表示无限制 .request_size_limit = (size_t)-1,
.ssl_accept_timeout = 10 * 1000, // SSL 接受超时时间 .ssl_accept_timeout = 10 * 1000,
}; };
// WFHttpServer 构造函数的特化模板
template<> inline template<> inline
WFHttpServer::WFServer(http_process_t proc) : WFHttpServer::WFServer(http_process_t proc) :
WFServerBase(&HTTP_SERVER_PARAMS_DEFAULT), // 调用基类构造函数,传入默认参数 WFServerBase(&HTTP_SERVER_PARAMS_DEFAULT),
process(std::move(proc)) // 移动构造函数以提高性能 process(std::move(proc))
{ {
} }
// 创建新的会话,重写基类的方法
template<> inline template<> inline
CommSession *WFHttpServer::new_session(long long seq, CommConnection *conn) CommSession *WFHttpServer::new_session(long long seq, CommConnection *conn)
{ {
WFHttpTask *task; // 声明一个 HTTP 任务指针 WFHttpTask *task;
// 使用任务工厂创建一个新的 HTTP 任务
task = WFServerTaskFactory::create_http_task(this, this->process); task = WFServerTaskFactory::create_http_task(this, this->process);
// 设置保持连接的超时时间
task->set_keep_alive(this->params.keep_alive_timeout); task->set_keep_alive(this->params.keep_alive_timeout);
// 设置接收超时时间
task->set_receive_timeout(this->params.receive_timeout); task->set_receive_timeout(this->params.receive_timeout);
// 设置请求大小限制
task->get_req()->set_size_limit(this->params.request_size_limit); task->get_req()->set_size_limit(this->params.request_size_limit);
return task; // 返回创建的任务 return task;
} }
#endif // _WFHTTPSERVER_H_ #endif

@ -16,54 +16,45 @@
Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Authors: Wu Jiaxu (wujiaxu@sogou-inc.com)
*/ */
#include <sys/uio.h> // 引入系统I/O头文件用于向客户端发送数据 #include <sys/uio.h>
#include "WFMySQLServer.h" // 引入 MySQL 服务器的头文件 #include "WFMySQLServer.h"
// 创建新的连接
WFConnection *WFMySQLServer::new_connection(int accept_fd) WFConnection *WFMySQLServer::new_connection(int accept_fd)
{ {
// 调用基类的 new_connection 方法创建连接
WFConnection *conn = this->WFServer::new_connection(accept_fd); WFConnection *conn = this->WFServer::new_connection(accept_fd);
if (conn) // 如果连接成功 if (conn)
{ {
protocol::MySQLHandshakeResponse resp; // 创建 MySQL 握手响应对象 protocol::MySQLHandshakeResponse resp;
struct iovec vec[8]; // 定义 I/O 向量以便于发送数据 struct iovec vec[8];
int count; // 记录向量中有效数据的数量 int count;
// 设置服务器的握手响应参数 resp.server_set(0x0a, "5.5", 1, (const uint8_t *)"12345678901234567890",
resp.server_set(0x0a, "5.5", 1, (const uint8_t *)"12345678901234567890",
0, 33, 0); 0, 33, 0);
// 编码握手响应信息到 I/O 向量中
count = resp.encode(vec, 8); count = resp.encode(vec, 8);
if (count >= 0) // 如果编码成功 if (count >= 0)
{ {
// 使用 writev 函数发送握手响应数据到客户端
if (writev(accept_fd, vec, count) >= 0) if (writev(accept_fd, vec, count) >= 0)
return conn; // 返回创建的连接 return conn;
} }
// 如果发送失败,删除该连接
this->delete_connection(conn); this->delete_connection(conn);
} }
return NULL; // 如果连接失败,则返回 NULL return NULL;
} }
// 创建新的会话,重写基类的方法
CommSession *WFMySQLServer::new_session(long long seq, CommConnection *conn) CommSession *WFMySQLServer::new_session(long long seq, CommConnection *conn)
{ {
static mysql_process_t empty = [](WFMySQLTask *){ }; // 定义一个空的 MySQL 处理函数 static mysql_process_t empty = [](WFMySQLTask *){ };
WFMySQLTask *task; // 声明一个 MySQL 任务指针 WFMySQLTask *task;
// 使用任务工厂创建一个新的 MySQL 任务 task = WFServerTaskFactory::create_mysql_task(this, seq ? this->process :
task = WFServerTaskFactory::create_mysql_task(this, seq ? this->process : empty); empty);
// 设置保持连接的超时时间
task->set_keep_alive(this->params.keep_alive_timeout); task->set_keep_alive(this->params.keep_alive_timeout);
// 设置接收超时时间
task->set_receive_timeout(this->params.receive_timeout); task->set_receive_timeout(this->params.receive_timeout);
// 设置请求大小限制
task->get_req()->set_size_limit(this->params.request_size_limit); task->get_req()->set_size_limit(this->params.request_size_limit);
return task; // 返回创建的任务 return task;
} }

@ -19,43 +19,39 @@
#ifndef _WFMYSQLSERVER_H_ #ifndef _WFMYSQLSERVER_H_
#define _WFMYSQLSERVER_H_ #define _WFMYSQLSERVER_H_
#include <utility> // 引入标准库的实用工具 #include <utility>
#include "MySQLMessage.h" // 引入 MySQL 消息相关的头文件 #include "MySQLMessage.h"
#include "WFServer.h" // 引入服务器功能的头文件 #include "WFServer.h"
#include "WFTaskFactory.h" // 引入任务工厂以创建 MySQL 任务 #include "WFTaskFactory.h"
#include "WFConnection.h" // 引入连接类的头文件 #include "WFConnection.h"
// 定义一个 mysql_process_t 类型,这是一个接受 WFMySQLTask 指针并返回 void 的函数对象类型
using mysql_process_t = std::function<void (WFMySQLTask *)>; using mysql_process_t = std::function<void (WFMySQLTask *)>;
class MySQLServer;
// 定义 WFMySQLServer 类,继承自 WFServer 模板类 static constexpr struct WFServerParams MYSQL_SERVER_PARAMS_DEFAULT =
class WFMySQLServer : public WFServer<protocol::MySQLRequest, protocol::MySQLResponse> {
.transport_type = TT_TCP,
.max_connections = 2000,
.peer_response_timeout = 10 * 1000,
.receive_timeout = -1,
.keep_alive_timeout = 28800 * 1000,
.request_size_limit = (size_t)-1,
.ssl_accept_timeout = 10 * 1000,
};
class WFMySQLServer : public WFServer<protocol::MySQLRequest,
protocol::MySQLResponse>
{ {
public: public:
// 构造函数,接收一个 MySQL 处理函数 WFMySQLServer(mysql_process_t proc):
WFMySQLServer(mysql_process_t proc) : WFServer(&MYSQL_SERVER_PARAMS_DEFAULT, std::move(proc))
WFServer(&MYSQL_SERVER_PARAMS_DEFAULT, std::move(proc)) // 调用基类构造函数,使用默认参数 {
{ }
}
protected: protected:
// 重写 new_connection 方法,创建新的连接 virtual WFConnection *new_connection(int accept_fd);
virtual WFConnection *new_connection(int accept_fd); virtual CommSession *new_session(long long seq, CommConnection *conn);
// 重写 new_session 方法,创建新的会话
virtual CommSession *new_session(long long seq, CommConnection *conn);
}; };
// 设置 MySQL 服务器的默认参数 #endif
static constexpr struct WFServerParams MYSQL_SERVER_PARAMS_DEFAULT =
{
.transport_type = TT_TCP, // 设置传输类型为 TCP
.max_connections = 2000, // 设置最大连接数
.peer_response_timeout = 10 * 1000, // 设置对端响应超时时间
.receive_timeout = -1, // 设置接收超时时间,-1 表示不超时
.keep_alive_timeout = 28800 * 1000, // 设置保持连接的超时时间
.request_size_limit = (size_t)-1, // 设置请求大小限制,-1 表示无限制
.ssl_accept_timeout = 10 * 1000, // 设置 SSL 接受超时时间
};
#endif // _WFMYSQLSERVER_H_

@ -19,35 +19,31 @@
#ifndef _WFREDISSERVER_H_ #ifndef _WFREDISSERVER_H_
#define _WFREDISSERVER_H_ #define _WFREDISSERVER_H_
#include "RedisMessage.h" // 引入 Redis 消息相关的头文件 #include "RedisMessage.h"
#include "WFServer.h" // 引入服务器功能的头文件 #include "WFServer.h"
#include "WFTaskFactory.h" // 引入任务工厂以创建 Redis 任务 #include "WFTaskFactory.h"
// 定义一个 redis_process_t 类型,这是一个接受 WFRedisTask 指针并返回 void 的函数对象类型
using redis_process_t = std::function<void (WFRedisTask *)>; using redis_process_t = std::function<void (WFRedisTask *)>;
using WFRedisServer = WFServer<protocol::RedisRequest,
protocol::RedisResponse>;
// 定义 WFRedisServer 类型,基于 WFServer 模板类,使用 RedisRequest 和 RedisResponse 作为协议
using WFRedisServer = WFServer<protocol::RedisRequest, protocol::RedisResponse>;
// 设置 Redis 服务器的默认参数
static constexpr struct WFServerParams REDIS_SERVER_PARAMS_DEFAULT = static constexpr struct WFServerParams REDIS_SERVER_PARAMS_DEFAULT =
{ {
.transport_type = TT_TCP, // 传输类型设置为 TCP .transport_type = TT_TCP,
.max_connections = 2000, // 设置最大连接数为 2000 .max_connections = 2000,
.peer_response_timeout = 10 * 1000, // 对端响应超时时间为 10 秒 .peer_response_timeout = 10 * 1000,
.receive_timeout = -1, // 接收超时时间,-1 表示不超时 .receive_timeout = -1,
.keep_alive_timeout = 300 * 1000, // 保持连接的超时时间为 300 秒 .keep_alive_timeout = 300 * 1000,
.request_size_limit = (size_t)-1, // 请求大小限制,-1 表示无限制 .request_size_limit = (size_t)-1,
.ssl_accept_timeout = 5000, // SSL 接受超时时间为 5 秒 .ssl_accept_timeout = 5000,
}; };
// WFRedisServer 构造函数的特化模板
template<> inline template<> inline
WFRedisServer::WFServer(redis_process_t proc) : WFRedisServer::WFServer(redis_process_t proc) :
WFServerBase(&REDIS_SERVER_PARAMS_DEFAULT), // 调用基类构造函数,传入默认参数 WFServerBase(&REDIS_SERVER_PARAMS_DEFAULT),
process(std::move(proc)) // 采用移动语义以避免不必要的拷贝 process(std::move(proc))
{ {
} }
// 生成新会话的方法,重写基类的方法 #endif
#endif // _WFREDISSERVER_H_

@ -17,300 +17,272 @@
Wu Jiaxu (wujiaxu@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com)
*/ */
#include <sys/types.h> // 引入数据类型定义 #include <sys/types.h>
#include <sys/socket.h> // 引入套接字相关的函数 #include <sys/socket.h>
#include <errno.h> // 引入错误码定义 #include <errno.h>
#include <unistd.h> // 引入 UNIX 标准函数 #include <unistd.h>
#include <stdio.h> // 引入标准输入输出库 #include <stdio.h>
#include <atomic> // 引入原子操作库 #include <atomic>
#include <mutex> // 引入互斥锁库 #include <mutex>
#include <condition_variable> // 引入条件变量库 #include <condition_variable>
#include <openssl/ssl.h> // 引入 OpenSSL 库以支持 SSL #include <openssl/ssl.h>
#include "CommScheduler.h" // 引入调度器相关的头文件 #include "CommScheduler.h"
#include "EndpointParams.h" // 引入端点参数定义 #include "EndpointParams.h"
#include "WFConnection.h" // 引入连接类的定义 #include "WFConnection.h"
#include "WFGlobal.h" // 引入全局配置相关的定义 #include "WFGlobal.h"
#include "WFServer.h" // 引入服务器功能的头文件 #include "WFServer.h"
#define PORT_STR_MAX 5 // 定义最大端口字符串长度 #define PORT_STR_MAX 5
// 定义 WFServerConnection 类,继承自 WFConnection
class WFServerConnection : public WFConnection class WFServerConnection : public WFConnection
{ {
public: public:
// 构造函数,接受连接计数的原子指针
WFServerConnection(std::atomic<size_t> *conn_count) WFServerConnection(std::atomic<size_t> *conn_count)
{ {
this->conn_count = conn_count; // 初始化连接计数指针 this->conn_count = conn_count;
} }
// 析构函数
virtual ~WFServerConnection() virtual ~WFServerConnection()
{ {
(*this->conn_count)--; // 每当一个连接被销毁,减少连接计数 (*this->conn_count)--;
} }
private: private:
std::atomic<size_t> *conn_count; // 存储连接计数的原子变量指针 std::atomic<size_t> *conn_count;
}; };
// SSL 上下文回调函数,用于处理 SSL 连接
int WFServerBase::ssl_ctx_callback(SSL *ssl, int *al, void *arg) int WFServerBase::ssl_ctx_callback(SSL *ssl, int *al, void *arg)
{ {
WFServerBase *server = (WFServerBase *)arg; // 将参数转换为 WFServerBase 类型 WFServerBase *server = (WFServerBase *)arg;
const char *servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); // 获取服务器名称 const char *servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
SSL_CTX *ssl_ctx = server->get_server_ssl_ctx(servername); // 根据服务器名称获取 SSL 上下文 SSL_CTX *ssl_ctx = server->get_server_ssl_ctx(servername);
if (!ssl_ctx) // 如果未找到 SSL 上下文 if (!ssl_ctx)
return SSL_TLSEXT_ERR_NOACK; // 返回错误 return SSL_TLSEXT_ERR_NOACK;
// 如果找到的 SSL 上下文与当前不一致,则设置为找到的上下文
if (ssl_ctx != server->get_ssl_ctx()) if (ssl_ctx != server->get_ssl_ctx())
SSL_set_SSL_CTX(ssl, ssl_ctx); SSL_set_SSL_CTX(ssl, ssl_ctx);
return SSL_TLSEXT_ERR_OK; // 返回成功 return SSL_TLSEXT_ERR_OK;
} }
// 创建新的 SSL 上下文
SSL_CTX *WFServerBase::new_ssl_ctx(const char *cert_file, const char *key_file) SSL_CTX *WFServerBase::new_ssl_ctx(const char *cert_file, const char *key_file)
{ {
SSL_CTX *ssl_ctx = WFGlobal::new_ssl_server_ctx(); // 创建新的 SSL 服务器上下文 SSL_CTX *ssl_ctx = WFGlobal::new_ssl_server_ctx();
if (!ssl_ctx) // 如果创建失败 if (!ssl_ctx)
return NULL; // 返回 NULL return NULL;
// 加载证书链、私钥,并检查私钥有效性
if (SSL_CTX_use_certificate_chain_file(ssl_ctx, cert_file) > 0 && if (SSL_CTX_use_certificate_chain_file(ssl_ctx, cert_file) > 0 &&
SSL_CTX_use_PrivateKey_file(ssl_ctx, key_file, SSL_FILETYPE_PEM) > 0 && SSL_CTX_use_PrivateKey_file(ssl_ctx, key_file, SSL_FILETYPE_PEM) > 0 &&
SSL_CTX_check_private_key(ssl_ctx) > 0 && SSL_CTX_check_private_key(ssl_ctx) > 0 &&
SSL_CTX_set_tlsext_servername_callback(ssl_ctx, ssl_ctx_callback) > 0 && SSL_CTX_set_tlsext_servername_callback(ssl_ctx, ssl_ctx_callback) > 0 &&
SSL_CTX_set_tlsext_servername_arg(ssl_ctx, this) > 0) SSL_CTX_set_tlsext_servername_arg(ssl_ctx, this) > 0)
{ {
return ssl_ctx; // 返回创建的 SSL 上下文 return ssl_ctx;
} }
SSL_CTX_free(ssl_ctx); // 释放 SSL 上下文 SSL_CTX_free(ssl_ctx);
return NULL; // 返回 NULL return NULL;
} }
// 初始化服务器
int WFServerBase::init(const struct sockaddr *bind_addr, socklen_t addrlen, int WFServerBase::init(const struct sockaddr *bind_addr, socklen_t addrlen,
const char *cert_file, const char *key_file) const char *cert_file, const char *key_file)
{ {
int timeout = this->params.peer_response_timeout; // 将超时时间初始化为对端响应超时时间 int timeout = this->params.peer_response_timeout;
if (this->params.receive_timeout >= 0) // 如果接收超时时间有效 if (this->params.receive_timeout >= 0)
{ {
if ((unsigned int)timeout > (unsigned int)this->params.receive_timeout) if ((unsigned int)timeout > (unsigned int)this->params.receive_timeout)
timeout = this->params.receive_timeout; // 将超时时间更新为接收超时时间 timeout = this->params.receive_timeout;
} }
// 检查是否需要 SSL/TLS 连接
if (this->params.transport_type == TT_TCP_SSL || if (this->params.transport_type == TT_TCP_SSL ||
this->params.transport_type == TT_SCTP_SSL) this->params.transport_type == TT_SCTP_SSL)
{ {
if (!cert_file || !key_file) // 如果没有提供证书或密钥文件 if (!cert_file || !key_file)
{ {
errno = EINVAL; // 设置错误号为无效参数 errno = EINVAL;
return -1; // 返回错误 return -1;
} }
} }
// 初始化基本通信服务
if (this->CommService::init(bind_addr, addrlen, -1, timeout) < 0) if (this->CommService::init(bind_addr, addrlen, -1, timeout) < 0)
return -1; // 初始化失败,返回错误 return -1;
// 如果提供了证书和密钥,设置 SSL 上下文
if (cert_file && key_file && this->params.transport_type != TT_UDP) if (cert_file && key_file && this->params.transport_type != TT_UDP)
{ {
SSL_CTX *ssl_ctx = this->new_ssl_ctx(cert_file, key_file); // 创建新的 SSL 上下文 SSL_CTX *ssl_ctx = this->new_ssl_ctx(cert_file, key_file);
if (!ssl_ctx) // 如果创建失败 if (!ssl_ctx)
{ {
this->deinit(); // 反初始化 this->deinit();
return -1; // 返回错误 return -1;
} }
this->set_ssl(ssl_ctx, this->params.ssl_accept_timeout); // 设置 SSL 上下文 this->set_ssl(ssl_ctx, this->params.ssl_accept_timeout);
} }
this->scheduler = WFGlobal::get_scheduler(); // 获取全局调度器 this->scheduler = WFGlobal::get_scheduler();
return 0; // 返回成功 return 0;
} }
// 创建监听文件描述符
int WFServerBase::create_listen_fd() int WFServerBase::create_listen_fd()
{ {
if (this->listen_fd < 0) // 如果尚未创建监听文件描述符 if (this->listen_fd < 0)
{ {
const struct sockaddr *bind_addr; // 定义绑定地址 const struct sockaddr *bind_addr;
socklen_t addrlen; // 定义地址长度 socklen_t addrlen;
int type, protocol; // 定义套接字类型和协议 int type, protocol;
int reuse = 1; // 允许地址重用 int reuse = 1;
// 根据传输类型设置套接字类型和协议
switch (this->params.transport_type) switch (this->params.transport_type)
{ {
case TT_TCP: case TT_TCP:
case TT_TCP_SSL: case TT_TCP_SSL:
type = SOCK_STREAM; // TCP 使用流式套接字 type = SOCK_STREAM;
protocol = 0; // 默认协议 protocol = 0;
break; break;
case TT_UDP: case TT_UDP:
type = SOCK_DGRAM; // UDP 使用数据报套接字 type = SOCK_DGRAM;
protocol = 0; // 默认协议 protocol = 0;
break; break;
#ifdef IPPROTO_SCTP #ifdef IPPROTO_SCTP
case TT_SCTP: case TT_SCTP:
case TT_SCTP_SSL: case TT_SCTP_SSL:
type = SOCK_STREAM; // SCTP 使用流式套接字 type = SOCK_STREAM;
protocol = IPPROTO_SCTP; // SCTP 协议 protocol = IPPROTO_SCTP;
break; break;
#endif #endif
default: default:
errno = EPROTONOSUPPORT; // 不支持的协议类型 errno = EPROTONOSUPPORT;
return -1; // 返回错误 return -1;
} }
this->get_addr(&bind_addr, &addrlen); // 获取绑定地址和长度 this->get_addr(&bind_addr, &addrlen);
this->listen_fd = socket(bind_addr->sa_family, type, protocol); // 创建套接字 this->listen_fd = socket(bind_addr->sa_family, type, protocol);
if (this->listen_fd >= 0) // 如果套接字创建成功 if (this->listen_fd >= 0)
{ {
setsockopt(this->listen_fd, SOL_SOCKET, SO_REUSEADDR, // 设置选项以重用地址 setsockopt(this->listen_fd, SOL_SOCKET, SO_REUSEADDR,
&reuse, sizeof (int)); &reuse, sizeof (int));
} }
} }
else else
this->listen_fd = dup(this->listen_fd); // 复制已存在的监听文件描述符 this->listen_fd = dup(this->listen_fd);
return this->listen_fd; // 返回监听文件描述符 return this->listen_fd;
} }
// 创建新的连接
WFConnection *WFServerBase::new_connection(int accept_fd) WFConnection *WFServerBase::new_connection(int accept_fd)
{ {
// 检查当前连接数是否在最大连接数限制内
if (++this->conn_count <= this->params.max_connections || if (++this->conn_count <= this->params.max_connections ||
this->drain(1) == 1) // 如果连接数未超过最大值 this->drain(1) == 1)
{ {
int reuse = 1; // 允许地址重用 int reuse = 1;
setsockopt(accept_fd, SOL_SOCKET, SO_REUSEADDR, // 设置选项以重用地址 setsockopt(accept_fd, SOL_SOCKET, SO_REUSEADDR,
&reuse, sizeof (int)); &reuse, sizeof (int));
return new WFServerConnection(&this->conn_count); // 创建新的连接对象 return new WFServerConnection(&this->conn_count);
} }
this->conn_count--; // 如果连接数超过最大值,减少连接计数 this->conn_count--;
errno = EMFILE; // 设置错误号为文件描述符过多 errno = EMFILE;
return NULL; // 返回 NULL return NULL;
} }
// 删除连接
void WFServerBase::delete_connection(WFConnection *conn) void WFServerBase::delete_connection(WFConnection *conn)
{ {
delete (WFServerConnection *)conn; // 删除连接对象 delete (WFServerConnection *)conn;
} }
// 处理未绑定的情况
void WFServerBase::handle_unbound() void WFServerBase::handle_unbound()
{ {
this->mutex.lock(); // 加锁 this->mutex.lock();
this->unbind_finish = true; // 设置解除绑定状态 this->unbind_finish = true;
this->cond.notify_one(); // 通知等待的线程 this->cond.notify_one();
this->mutex.unlock(); // 解锁 this->mutex.unlock();
} }
// 启动服务器
int WFServerBase::start(const struct sockaddr *bind_addr, socklen_t addrlen, int WFServerBase::start(const struct sockaddr *bind_addr, socklen_t addrlen,
const char *cert_file, const char *key_file) const char *cert_file, const char *key_file)
{ {
SSL_CTX *ssl_ctx; // 声明 SSL 上下文指针 SSL_CTX *ssl_ctx;
// 初始化服务器
if (this->init(bind_addr, addrlen, cert_file, key_file) >= 0) if (this->init(bind_addr, addrlen, cert_file, key_file) >= 0)
{ {
// 如果调度器绑定成功,返回 0
if (this->scheduler->bind(this) >= 0) if (this->scheduler->bind(this) >= 0)
return 0; return 0;
ssl_ctx = this->get_ssl_ctx(); // 获取当前 SSL 上下文 ssl_ctx = this->get_ssl_ctx();
this->deinit(); // 反初始化 this->deinit();
if (ssl_ctx) // 如果 SSL 上下文存在 if (ssl_ctx)
SSL_CTX_free(ssl_ctx); // 释放 SSL 上下文 SSL_CTX_free(ssl_ctx);
} }
this->listen_fd = -1; // 设置监听文件描述符为无效 this->listen_fd = -1;
return -1; // 返回错误 return -1;
} }
// 启动服务器,使用主机和端口号
int WFServerBase::start(int family, const char *host, unsigned short port, int WFServerBase::start(int family, const char *host, unsigned short port,
const char *cert_file, const char *key_file) const char *cert_file, const char *key_file)
{ {
struct addrinfo hints = { // 设置地址信息结构 struct addrinfo hints = {
.ai_flags = AI_PASSIVE, // 使套接字可用于绑定 .ai_flags = AI_PASSIVE,
.ai_family = family, // 设置地址族 .ai_family = family,
.ai_socktype = SOCK_STREAM, // 设置为流套接字 .ai_socktype = SOCK_STREAM,
}; };
struct addrinfo *addrinfo; // 用于存储返回的地址信息 struct addrinfo *addrinfo;
char port_str[PORT_STR_MAX + 1]; // 存储端口字符串 char port_str[PORT_STR_MAX + 1];
int ret; // 存储返回值 int ret;
// 将端口号转换为字符串
snprintf(port_str, PORT_STR_MAX + 1, "%d", port); snprintf(port_str, PORT_STR_MAX + 1, "%d", port);
ret = getaddrinfo(host, port_str, &hints, &addrinfo); // 获取地址信息 ret = getaddrinfo(host, port_str, &hints, &addrinfo);
if (ret == 0) // 如果成功获取地址信息 if (ret == 0)
{ {
// 启动服务器
ret = start(addrinfo->ai_addr, (socklen_t)addrinfo->ai_addrlen, ret = start(addrinfo->ai_addr, (socklen_t)addrinfo->ai_addrlen,
cert_file, key_file); cert_file, key_file);
freeaddrinfo(addrinfo); // 释放地址信息 freeaddrinfo(addrinfo);
} }
else // 获取地址信息失败 else
{ {
if (ret != EAI_SYSTEM) // 如果不是系统错误 if (ret != EAI_SYSTEM)
errno = EINVAL; // 设置错误号为无效参数 errno = EINVAL;
ret = -1; // 返回错误 ret = -1;
} }
return ret; // 返回结果 return ret;
} }
// 启动服务器,使用已存在的监听文件描述符
int WFServerBase::serve(int listen_fd, int WFServerBase::serve(int listen_fd,
const char *cert_file, const char *key_file) const char *cert_file, const char *key_file)
{ {
struct sockaddr_storage ss; // 存储 socket 地址 struct sockaddr_storage ss;
socklen_t len = sizeof ss; // 定义地址长度 socklen_t len = sizeof ss;
// 获取当前 socket 名称
if (getsockname(listen_fd, (struct sockaddr *)&ss, &len) < 0) if (getsockname(listen_fd, (struct sockaddr *)&ss, &len) < 0)
return -1; // 返回错误 return -1;
this->listen_fd = listen_fd; // 设置监听文件描述符 this->listen_fd = listen_fd;
return start((struct sockaddr *)&ss, len, cert_file, key_file); // 启动服务器 return start((struct sockaddr *)&ss, len, cert_file, key_file);
} }
// 关闭服务器
void WFServerBase::shutdown() void WFServerBase::shutdown()
{ {
this->listen_fd = -1; // 设置监听文件描述符为无效 this->listen_fd = -1;
this->scheduler->unbind(this); // 解除调度器的绑定 this->scheduler->unbind(this);
} }
// 等待服务器完成
void WFServerBase::wait_finish() void WFServerBase::wait_finish()
{ {
SSL_CTX *ssl_ctx = this->get_ssl_ctx(); // 获取当前 SSL 上下文 SSL_CTX *ssl_ctx = this->get_ssl_ctx();
std::unique_lock<std::mutex> lock(this->mutex); // 使用唯一锁管理 mutex std::unique_lock<std::mutex> lock(this->mutex);
// 等待解除绑定完成
while (!this->unbind_finish) while (!this->unbind_finish)
this->cond.wait(lock); // 条件等待 this->cond.wait(lock);
this->deinit(); // 反初始化 this->deinit();
this->unbind_finish = false; // 重置解除绑定状态 this->unbind_finish = false;
lock.unlock(); // 解锁 lock.unlock();
if (ssl_ctx)
if (ssl_ctx) // 如果 SSL 上下文存在 SSL_CTX_free(ssl_ctx);
SSL_CTX_free(ssl_ctx); // 释放 SSL 上下文
} }

@ -17,117 +17,113 @@
Wu Jiaxu (wujiaxu@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com)
*/ */
#ifndef _WFSERVER_H_ // 防止头文件被多次包含 #ifndef _WFSERVER_H_
#define _WFSERVER_H_ #define _WFSERVER_H_
#include <sys/types.h> // 引入基本数据类型定义 #include <sys/types.h>
#include <sys/socket.h> // 引入 socket 相关的函数 #include <sys/socket.h>
#include <errno.h> // 引入错误码定义 #include <errno.h>
#include <functional> // 引入函数对象 #include <functional>
#include <atomic> // 引入原子操作的支持 #include <atomic>
#include <mutex> // 引入互斥量支持 #include <mutex>
#include <condition_variable> // 引入条件变量支持 #include <condition_variable>
#include <openssl/ssl.h> // 引入 OpenSSL 以支持 SSL/TLS #include <openssl/ssl.h>
#include "EndpointParams.h" // 引入端点参数定义 #include "EndpointParams.h"
#include "WFTaskFactory.h" // 引入任务工厂以创建任务 #include "WFTaskFactory.h"
// 定义服务器参数结构体
struct WFServerParams struct WFServerParams
{ {
enum TransportType transport_type; // 传输类型TCP, UDP, SSL等 enum TransportType transport_type;
size_t max_connections; // 最大并发连接数 size_t max_connections;
int peer_response_timeout; // 每次读取或写入操作的超时 int peer_response_timeout; /* timeout of each read or write operation */
int receive_timeout; // 接收整个消息的超时 int receive_timeout; /* timeout of receiving the whole message */
int keep_alive_timeout; // 连接保持存活的超时 int keep_alive_timeout;
size_t request_size_limit; // 请求大小限制 size_t request_size_limit;
int ssl_accept_timeout; // SSL 接受超时时间 int ssl_accept_timeout; /* if not ssl, this will be ignored */
}; };
// 设置服务器参数的默认值
static constexpr struct WFServerParams SERVER_PARAMS_DEFAULT = static constexpr struct WFServerParams SERVER_PARAMS_DEFAULT =
{ {
.transport_type = TT_TCP, // 默认使用 TCP .transport_type = TT_TCP,
.max_connections = 2000, // 默认最多 2000 个连接 .max_connections = 2000,
.peer_response_timeout = 10 * 1000, // 默认对端响应超时 10 秒 .peer_response_timeout = 10 * 1000,
.receive_timeout = -1, // 默认不超时 .receive_timeout = -1,
.keep_alive_timeout = 60 * 1000, // 默认连接保持存活 60 秒 .keep_alive_timeout = 60 * 1000,
.request_size_limit = (size_t)-1, // 请求大小限制为无限制 .request_size_limit = (size_t)-1,
.ssl_accept_timeout = 10 * 1000, // SSL 接受超时设置为 10 秒 .ssl_accept_timeout = 10 * 1000,
}; };
// 定义 WFServerBase 基类,用于建立服务器
class WFServerBase : protected CommService class WFServerBase : protected CommService
{ {
public: public:
// 构造函数,接收服务器参数并初始化
WFServerBase(const struct WFServerParams *params) : WFServerBase(const struct WFServerParams *params) :
conn_count(0) // 初始化连接计数为 0 conn_count(0)
{ {
this->params = *params; // 保存服务器参数 this->params = *params;
this->unbind_finish = false; // 初始化解除绑定状态 this->unbind_finish = false;
this->listen_fd = -1; // 初始化监听文件描述符为无效值 this->listen_fd = -1;
} }
public: public:
// 启动一个 TCP 服务器函数 /* To start a TCP server */
/* 在指定端口上启动服务器IPv4 */ /* Start on port with IPv4. */
int start(unsigned short port) int start(unsigned short port)
{ {
return start(AF_INET, NULL, port, NULL, NULL); return start(AF_INET, NULL, port, NULL, NULL);
} }
/* 根据家庭类型启动服务器IPv4 或 IPv6 */ /* Start with family. AF_INET or AF_INET6. */
int start(int family, unsigned short port) int start(int family, unsigned short port)
{ {
return start(family, NULL, port, NULL, NULL); return start(family, NULL, port, NULL, NULL);
} }
/* 使用主机名和端口启动服务器 */ /* Start with hostname and port. */
int start(const char *host, unsigned short port) int start(const char *host, unsigned short port)
{ {
return start(AF_INET, host, port, NULL, NULL); return start(AF_INET, host, port, NULL, NULL);
} }
/* 使用家庭类型、主机名和端口启动服务器 */ /* Start with family, hostname and port. */
int start(int family, const char *host, unsigned short port) int start(int family, const char *host, unsigned short port)
{ {
return start(family, host, port, NULL, NULL); return start(family, host, port, NULL, NULL);
} }
/* 使用指定的地址绑定启动服务器 */ /* Start with binding address. */
int start(const struct sockaddr *bind_addr, socklen_t addrlen) int start(const struct sockaddr *bind_addr, socklen_t addrlen)
{ {
return start(bind_addr, addrlen, NULL, NULL); return start(bind_addr, addrlen, NULL, NULL);
} }
// 启动一个 SSL 服务器 /* To start an SSL server. */
/* 在指定端口启动 SSL 服务器 */
int start(unsigned short port, const char *cert_file, const char *key_file) int start(unsigned short port, const char *cert_file, const char *key_file)
{ {
return start(AF_INET, NULL, port, cert_file, key_file); return start(AF_INET, NULL, port, cert_file, key_file);
} }
/* 使用家庭类型和端口启动 SSL 服务器 */
int start(int family, unsigned short port, int start(int family, unsigned short port,
const char *cert_file, const char *key_file) const char *cert_file, const char *key_file)
{ {
return start(family, NULL, port, cert_file, key_file); return start(family, NULL, port, cert_file, key_file);
} }
/* 使用主机名和端口启动 SSL 服务器 */
int start(const char *host, unsigned short port, int start(const char *host, unsigned short port,
const char *cert_file, const char *key_file) const char *cert_file, const char *key_file)
{ {
return start(AF_INET, host, port, cert_file, key_file); return start(AF_INET, host, port, cert_file, key_file);
} }
/* 使用家庭类型、主机名和端口启动 SSL 服务器 */
int start(int family, const char *host, unsigned short port, int start(int family, const char *host, unsigned short port,
const char *cert_file, const char *key_file); const char *cert_file, const char *key_file);
/* 使用指定文件描述符启动服务器,用于优雅重启或 SCTP 服务器 */ /* This is the only necessary start function. */
int start(const struct sockaddr *bind_addr, socklen_t addrlen,
const char *cert_file, const char *key_file);
/* To start with a specified fd. For graceful restart or SCTP server. */
int serve(int listen_fd) int serve(int listen_fd)
{ {
return serve(listen_fd, NULL, NULL); return serve(listen_fd, NULL, NULL);
@ -135,119 +131,114 @@ public:
int serve(int listen_fd, const char *cert_file, const char *key_file); int serve(int listen_fd, const char *cert_file, const char *key_file);
/* 停止服务器是一个阻塞操作 */ /* stop() is a blocking operation. */
void stop() void stop()
{ {
this->shutdown(); // 关闭服务器 this->shutdown();
this->wait_finish(); // 等待完成 this->wait_finish();
} }
/* 非阻塞地终止服务器。用于停止多个服务器。 /* Nonblocking terminating the server. For stopping multiple servers.
* shutdown() wait_finish() * Typically, call shutdown() and then wait_finish().
* 线 shutdown() start() wait_finish() */ * But indeed wait_finish() can be called before shutdown(), even before
void shutdown(); // 停止服务器 * start() in another thread. */
void wait_finish(); // 等待所有连接和任务完成 void shutdown();
void wait_finish();
public: public:
// 获取当前连接数
size_t get_conn_count() const { return this->conn_count; } size_t get_conn_count() const { return this->conn_count; }
/* 获取监听地址。这通常在服务器在随机端口上启动后使用start() 端口为 0。 */ /* Get the listening address. This is often used after starting
* server on a random port (start() with port == 0). */
int get_listen_addr(struct sockaddr *addr, socklen_t *addrlen) const int get_listen_addr(struct sockaddr *addr, socklen_t *addrlen) const
{ {
if (this->listen_fd >= 0) if (this->listen_fd >= 0)
return getsockname(this->listen_fd, addr, addrlen); // 获取当前的 socket 地址 return getsockname(this->listen_fd, addr, addrlen);
errno = ENOTCONN; // 如果没有连接,设置错误号为没有连接 errno = ENOTCONN;
return -1; // 返回错误 return -1;
} }
// 获取服务器参数
const struct WFServerParams *get_params() const { return &this->params; } const struct WFServerParams *get_params() const { return &this->params; }
protected: protected:
/* 重写此函数以创建服务器的初始 SSL CTX */ /* Override this function to create the initial SSL CTX of the server */
virtual SSL_CTX *new_ssl_ctx(const char *cert_file, const char *key_file); virtual SSL_CTX *new_ssl_ctx(const char *cert_file, const char *key_file);
/* 重写此函数以实现支持 TLS SNI 的服务器。 /* Override this function to implement server that supports TLS SNI.
* "servername" NULL * "servername" will be NULL if client does not set a host name.
* NULL */ * Returning NULL to indicate that servername is not supported. */
virtual SSL_CTX *get_server_ssl_ctx(const char *servername) virtual SSL_CTX *get_server_ssl_ctx(const char *servername)
{ {
return this->get_ssl_ctx(); // 返回当前的 SSL 上下文 return this->get_ssl_ctx();
} }
/* 这个函数可以在 'new_ssl_ctx' 的实现中使用。 */ /* This can be used by the implementation of 'new_ssl_ctx'. */
static int ssl_ctx_callback(SSL *ssl, int *al, void *arg); static int ssl_ctx_callback(SSL *ssl, int *al, void *arg);
protected: protected:
WFServerParams params; // 服务器参数结构体 WFServerParams params;
protected: protected:
// 创建新连接的方法
virtual int create_listen_fd(); virtual int create_listen_fd();
virtual WFConnection *new_connection(int accept_fd); // 创建新连接 virtual WFConnection *new_connection(int accept_fd);
void delete_connection(WFConnection *conn); // 删除连接 void delete_connection(WFConnection *conn);
private: private:
// 初始化服务器的方法
int init(const struct sockaddr *bind_addr, socklen_t addrlen, int init(const struct sockaddr *bind_addr, socklen_t addrlen,
const char *cert_file, const char *key_file); const char *cert_file, const char *key_file);
virtual void handle_unbound(); // 处理解除绑定 virtual void handle_unbound();
protected: protected:
std::atomic<size_t> conn_count; // 当前连接计数 std::atomic<size_t> conn_count;
private: private:
int listen_fd; // 监听文件描述符 int listen_fd;
bool unbind_finish; // 标记解除绑定是否完成 bool unbind_finish;
std::mutex mutex; // 互斥锁,用于线程安全 std::mutex mutex;
std::condition_variable cond; // 条件变量,用于线程间通知 std::condition_variable cond;
class CommScheduler *scheduler; // 调度器 class CommScheduler *scheduler;
}; };
// 模板类,用于定义特定协议的服务器
template<class REQ, class RESP> template<class REQ, class RESP>
class WFServer : public WFServerBase class WFServer : public WFServerBase
{ {
public: public:
// 构造函数,接收参数和处理函数
WFServer(const struct WFServerParams *params, WFServer(const struct WFServerParams *params,
std::function<void (WFNetworkTask<REQ, RESP> *)> proc) : std::function<void (WFNetworkTask<REQ, RESP> *)> proc) :
WFServerBase(params), // 调用基类构造函数 WFServerBase(params),
process(std::move(proc)) // 移动处理函数,避免不必要的拷贝 process(std::move(proc))
{ {
} }
// 构造函数,使用默认参数
WFServer(std::function<void (WFNetworkTask<REQ, RESP> *)> proc) : WFServer(std::function<void (WFNetworkTask<REQ, RESP> *)> proc) :
WFServerBase(&SERVER_PARAMS_DEFAULT), // 调用基类构造函数,传入默认参数 WFServerBase(&SERVER_PARAMS_DEFAULT),
process(std::move(proc)) // 移动处理函数 process(std::move(proc))
{ {
} }
protected: protected:
virtual CommSession *new_session(long long seq, CommConnection *conn); // 创建新会话 virtual CommSession *new_session(long long seq, CommConnection *conn);
protected: protected:
std::function<void (WFNetworkTask<REQ, RESP> *)> process; // 处理函数 std::function<void (WFNetworkTask<REQ, RESP> *)> process;
}; };
// 创建新会话的方法,实现具体的会话处理
template<class REQ, class RESP> template<class REQ, class RESP>
CommSession *WFServer<REQ, RESP>::new_session(long long seq, CommConnection *conn) CommSession *WFServer<REQ, RESP>::new_session(long long seq, CommConnection *conn)
{ {
using factory = WFNetworkTaskFactory<REQ, RESP>; // 使用工厂来创建任务 using factory = WFNetworkTaskFactory<REQ, RESP>;
WFNetworkTask<REQ, RESP> *task; // 声明任务指针 WFNetworkTask<REQ, RESP> *task;
task = factory::create_server_task(this, this->process); // 使用工厂创建任务 task = factory::create_server_task(this, this->process);
task->set_keep_alive(this->params.keep_alive_timeout); // 设置保持连接的超时 task->set_keep_alive(this->params.keep_alive_timeout);
task->set_receive_timeout(this->params.receive_timeout); // 设置接收超时 task->set_receive_timeout(this->params.receive_timeout);
task->get_req()->set_size_limit(this->params.request_size_limit); // 设置请求大小限制 task->get_req()->set_size_limit(this->params.request_size_limit);
return task; // 返回创建的会话任务 return task;
} }
#endif // _WFSERVER_H_ #endif

@ -1,10 +1,6 @@
-- 定义一个 xmake 目标,名为 server
target("server") target("server")
set_kind("object") -- 设置目标类型为对象文件 set_kind("object")
add_files("*.cc")
add_files("*.cc") -- 添加当前目录下所有的 .cc 文件到目标中
-- 检查是否未定义 MySQL 配置
if not has_config("mysql") then if not has_config("mysql") then
remove_files("WFMySQLServer.cc") -- 如果未定义 MySQL 配置,则移除 MySQL 服务器相关的源文件 remove_files("WFMySQLServer.cc")
end end

Loading…
Cancel
Save