liangtingyu
yyy 3 days ago
parent 6fd305c6e4
commit ea46511c5e

@ -0,0 +1,148 @@
cmake_minimum_required(VERSION 3.6)
if(ANDROID)
include_directories(${OPENSSL_INCLUDE_DIR})
link_directories(${OPENSSL_LINK_DIR})
else()
find_package(OpenSSL REQUIRED)
endif ()
include_directories(${OPENSSL_INCLUDE_DIR} ${INC_DIR}/workflow)
if (KAFKA STREQUAL "y")
find_path(SNAPPY_INCLUDE_PATH NAMES snappy.h)
find_library(SNAPPY_LIB NAMES snappy)
if ((NOT SNAPPY_INCLUDE_PATH) OR (NOT SNAPPY_LIB))
message(FATAL_ERROR "Fail to find snappy with KAFKA=y")
endif ()
include_directories(${SNAPPY_INCLUDE_PATH})
endif ()
add_subdirectory(kernel)
add_subdirectory(util)
add_subdirectory(manager)
add_subdirectory(protocol)
add_subdirectory(factory)
add_subdirectory(nameservice)
add_subdirectory(server)
add_subdirectory(client)
add_dependencies(kernel LINK_HEADERS)
add_dependencies(util LINK_HEADERS)
add_dependencies(manager LINK_HEADERS)
add_dependencies(protocol LINK_HEADERS)
add_dependencies(factory LINK_HEADERS)
add_dependencies(nameservice LINK_HEADERS)
add_dependencies(server LINK_HEADERS)
add_dependencies(client LINK_HEADERS)
set(STATIC_LIB_NAME ${PROJECT_NAME}-static)
set(SHARED_LIB_NAME ${PROJECT_NAME}-shared)
add_library(
${STATIC_LIB_NAME} STATIC
$<TARGET_OBJECTS:kernel>
$<TARGET_OBJECTS:util>
$<TARGET_OBJECTS:manager>
$<TARGET_OBJECTS:protocol>
$<TARGET_OBJECTS:factory>
$<TARGET_OBJECTS:nameservice>
$<TARGET_OBJECTS:server>
$<TARGET_OBJECTS:client>
)
add_library(
${SHARED_LIB_NAME} SHARED
$<TARGET_OBJECTS:kernel>
$<TARGET_OBJECTS:util>
$<TARGET_OBJECTS:manager>
$<TARGET_OBJECTS:protocol>
$<TARGET_OBJECTS:factory>
$<TARGET_OBJECTS:nameservice>
$<TARGET_OBJECTS:server>
$<TARGET_OBJECTS:client>
)
if(ANDROID)
target_link_libraries(${SHARED_LIB_NAME} ssl crypto c)
target_link_libraries(${STATIC_LIB_NAME} ssl crypto c)
else()
target_link_libraries(${SHARED_LIB_NAME} OpenSSL::SSL OpenSSL::Crypto pthread)
target_link_libraries(${STATIC_LIB_NAME} OpenSSL::SSL OpenSSL::Crypto pthread)
endif ()
set_target_properties(${STATIC_LIB_NAME} PROPERTIES OUTPUT_NAME ${PROJECT_NAME})
set_target_properties(${SHARED_LIB_NAME} PROPERTIES OUTPUT_NAME ${PROJECT_NAME} VERSION ${PROJECT_VERSION} SOVERSION ${PROJECT_VERSION_MAJOR})
if (KAFKA STREQUAL "y")
add_dependencies(client_kafka LINK_HEADERS)
add_dependencies(util_kafka LINK_HEADERS)
add_dependencies(protocol_kafka LINK_HEADERS)
add_dependencies(factory_kafka LINK_HEADERS)
set(KAFKA_STATIC_LIB_NAME "wfkafka-static")
add_library(
${KAFKA_STATIC_LIB_NAME} STATIC
$<TARGET_OBJECTS:client_kafka>
$<TARGET_OBJECTS:util_kafka>
$<TARGET_OBJECTS:protocol_kafka>
$<TARGET_OBJECTS:factory_kafka>
)
set_target_properties(${KAFKA_STATIC_LIB_NAME} PROPERTIES OUTPUT_NAME "wfkafka")
set(KAFKA_SHARED_LIB_NAME "wfkafka-shared")
add_library(
${KAFKA_SHARED_LIB_NAME} SHARED
$<TARGET_OBJECTS:client_kafka>
$<TARGET_OBJECTS:util_kafka>
$<TARGET_OBJECTS:protocol_kafka>
$<TARGET_OBJECTS:factory_kafka>
)
if (APPLE)
target_link_libraries(${KAFKA_SHARED_LIB_NAME}
${SHARED_LIB_NAME} z lz4 zstd ${SNAPPY_LIB})
else ()
target_link_libraries(${KAFKA_SHARED_LIB_NAME} ${SHARED_LIB_NAME})
endif ()
set_target_properties(${KAFKA_SHARED_LIB_NAME} PROPERTIES OUTPUT_NAME "wfkafka" VERSION ${PROJECT_VERSION} SOVERSION ${PROJECT_VERSION_MAJOR})
endif ()
install(
TARGETS ${STATIC_LIB_NAME}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
COMPONENT devel
)
install(
TARGETS ${SHARED_LIB_NAME}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
COMPONENT devel
)
if (KAFKA STREQUAL "y")
install(
TARGETS ${KAFKA_STATIC_LIB_NAME}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
COMPONENT devel
)
install(
TARGETS ${KAFKA_SHARED_LIB_NAME}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
COMPONENT devel
)
endif ()
target_include_directories(${STATIC_LIB_NAME} BEFORE PUBLIC
"$<BUILD_INTERFACE:${INC_DIR}>"
"$<INSTALL_INTERFACE:${INC_DIR}>")
target_include_directories(${SHARED_LIB_NAME} BEFORE PUBLIC
"$<BUILD_INTERFACE:${INC_DIR}>"
"$<INSTALL_INTERFACE:${INC_DIR}>")
if (KAFKA STREQUAL "y")
target_include_directories(${KAFKA_STATIC_LIB_NAME} BEFORE PUBLIC
"$<BUILD_INTERFACE:${INC_DIR}>"
"$<INSTALL_INTERFACE:${INC_DIR}>")
target_include_directories(${KAFKA_SHARED_LIB_NAME} BEFORE PUBLIC
"$<BUILD_INTERFACE:${INC_DIR}>"
"$<INSTALL_INTERFACE:${INC_DIR}>")
endif ()

@ -0,0 +1,37 @@
cmake_minimum_required(VERSION 3.6)
project(client)
set(SRC
WFDnsClient.cc
)
if (NOT REDIS STREQUAL "n")
set(SRC
${SRC}
WFRedisSubscriber.cc
)
endif ()
if (NOT MYSQL STREQUAL "n")
set(SRC
${SRC}
WFMySQLConnection.cc
)
endif ()
if (NOT CONSUL STREQUAL "n")
set(SRC
${SRC}
WFConsulClient.cc
)
endif ()
add_library(${PROJECT_NAME} OBJECT ${SRC})
if (KAFKA STREQUAL "y")
set(SRC
WFKafkaClient.cc
)
add_library("client_kafka" OBJECT ${SRC})
endif ()

File diff suppressed because it is too large Load Diff

@ -0,0 +1,161 @@
/*
Copyright (c) 2022 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Wang Zhenpeng (wangzhenpeng@sogou-inc.com)
*/
#ifndef _WFCONSULCLIENT_H_
#define _WFCONSULCLIENT_H_
#include <string>
#include <vector>
#include <utility>
#include <functional>
#include "HttpMessage.h"
#include "WFTaskFactory.h"
#include "ConsulDataTypes.h"
class WFConsulTask;
using consul_callback_t = std::function<void (WFConsulTask *)>;
enum
{
CONSUL_API_TYPE_UNKNOWN = 0,
CONSUL_API_TYPE_DISCOVER,
CONSUL_API_TYPE_LIST_SERVICE,
CONSUL_API_TYPE_REGISTER,
CONSUL_API_TYPE_DEREGISTER,
};
class WFConsulTask : public WFGenericTask
{
public:
bool get_discover_result(
std::vector<struct protocol::ConsulServiceInstance>& result);
bool get_list_service_result(
std::vector<struct protocol::ConsulServiceTags>& result);
public:
void set_service(const struct protocol::ConsulService *service);
void set_api_type(int api_type)
{
this->api_type = api_type;
}
int get_api_type() const
{
return this->api_type;
}
void set_callback(consul_callback_t cb)
{
this->callback = std::move(cb);
}
void set_consul_index(long long consul_index)
{
this->consul_index = consul_index;
}
long long get_consul_index() const { return this->consul_index; }
const protocol::HttpResponse *get_http_resp() const
{
return &this->http_resp;
}
protected:
void set_config(protocol::ConsulConfig conf)
{
this->config = std::move(conf);
}
protected:
virtual void dispatch();
virtual SubTask *done();
WFHttpTask *create_discover_task();
WFHttpTask *create_list_service_task();
WFHttpTask *create_register_task();
WFHttpTask *create_deregister_task();
std::string generate_discover_request();
long long get_consul_index(protocol::HttpResponse *resp);
static bool check_task_result(WFHttpTask *task, WFConsulTask *consul_task);
static void discover_callback(WFHttpTask *task);
static void list_service_callback(WFHttpTask *task);
static void register_callback(WFHttpTask *task);
protected:
protocol::ConsulConfig config;
struct protocol::ConsulService service;
std::string proxy_url;
int retry_max;
int api_type;
bool finish;
long long consul_index;
protocol::HttpResponse http_resp;
consul_callback_t callback;
protected:
WFConsulTask(const std::string& proxy_url,
const std::string& service_namespace,
const std::string& service_name,
const std::string& service_id,
int retry_max, consul_callback_t&& cb);
virtual ~WFConsulTask() { }
friend class WFConsulClient;
};
class WFConsulClient
{
public:
// example: http://127.0.0.1:8500
int init(const std::string& proxy_url);
int init(const std::string& proxy_url, protocol::ConsulConfig config);
void deinit() { }
WFConsulTask *create_discover_task(const std::string& service_namespace,
const std::string& service_name,
int retry_max,
consul_callback_t cb);
WFConsulTask *create_list_service_task(const std::string& service_namespace,
int retry_max,
consul_callback_t cb);
WFConsulTask *create_register_task(const std::string& service_namespace,
const std::string& service_name,
const std::string& service_id,
int retry_max,
consul_callback_t cb);
WFConsulTask *create_deregister_task(const std::string& service_namespace,
const std::string& service_id,
int retry_max,
consul_callback_t cb);
private:
std::string proxy_url;
protocol::ConsulConfig config;
public:
virtual ~WFConsulClient() { }
};
#endif

@ -0,0 +1,277 @@
/*
Copyright (c) 2021 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Liu Kai (liukaidx@sogou-inc.com)
*/
#include <string>
#include <vector>
#include <atomic>
#include "URIParser.h"
#include "StringUtil.h"
#include "WFDnsClient.h"
using namespace protocol;
using DnsCtx = std::function<void (WFDnsTask *)>;
using ComplexTask = WFComplexClientTask<DnsRequest, DnsResponse, DnsCtx>;
class DnsParams
{
public:
struct dns_params
{
std::vector<ParsedURI> uris;
std::vector<std::string> search_list;
int ndots;
int attempts;
bool rotate;
};
public:
DnsParams()
{
this->ref = new std::atomic<size_t>(1);
this->params = new dns_params();
}
DnsParams(const DnsParams& p)
{
this->ref = p.ref;
this->params = p.params;
this->incref();
}
DnsParams& operator=(const DnsParams& p)
{
if (this != &p)
{
this->decref();
this->ref = p.ref;
this->params = p.params;
this->incref();
}
return *this;
}
~DnsParams() { this->decref(); }
const dns_params *get_params() const { return this->params; }
dns_params *get_params() { return this->params; }
private:
void incref() { (*this->ref)++; }
void decref()
{
if (--*this->ref == 0)
{
delete this->params;
delete this->ref;
}
}
private:
dns_params *params;
std::atomic<size_t> *ref;
};
enum
{
DNS_STATUS_TRY_ORIGIN_DONE = 0,
DNS_STATUS_TRY_ORIGIN_FIRST = 1,
DNS_STATUS_TRY_ORIGIN_LAST = 2
};
struct DnsStatus
{
std::string origin_name;
std::string current_name;
size_t next_server; // next server to try
size_t last_server; // last server to try
size_t next_domain; // next search domain to try
int attempts_left;
int try_origin_state;
};
static int __get_ndots(const std::string& s)
{
int ndots = 0;
for (size_t i = 0; i < s.size(); i++)
ndots += s[i] == '.';
return ndots;
}
static bool __has_next_name(const DnsParams::dns_params *p,
struct DnsStatus *s)
{
if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_FIRST)
{
s->current_name = s->origin_name;
s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE;
return true;
}
if (s->next_domain < p->search_list.size())
{
s->current_name = s->origin_name;
s->current_name.push_back('.');
s->current_name.append(p->search_list[s->next_domain]);
s->next_domain++;
return true;
}
if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_LAST)
{
s->current_name = s->origin_name;
s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE;
return true;
}
return false;
}
static void __callback_internal(WFDnsTask *task, const DnsParams& params,
struct DnsStatus& s)
{
ComplexTask *ctask = static_cast<ComplexTask *>(task);
int state = task->get_state();
DnsRequest *req = task->get_req();
DnsResponse *resp = task->get_resp();
const auto *p = params.get_params();
int rcode = resp->get_rcode();
bool try_next_server = state != WFT_STATE_SUCCESS ||
rcode == DNS_RCODE_SERVER_FAILURE ||
rcode == DNS_RCODE_NOT_IMPLEMENTED ||
rcode == DNS_RCODE_REFUSED;
bool try_next_name = rcode == DNS_RCODE_FORMAT_ERROR ||
rcode == DNS_RCODE_NAME_ERROR ||
resp->get_ancount() == 0;
if (try_next_server)
{
if (s.last_server == s.next_server)
s.attempts_left--;
if (s.attempts_left <= 0)
return;
s.next_server = (s.next_server + 1) % p->uris.size();
ctask->set_redirect(p->uris[s.next_server]);
return;
}
if (try_next_name && __has_next_name(p, &s))
{
req->set_question_name(s.current_name.c_str());
ctask->set_redirect(p->uris[s.next_server]);
return;
}
}
int WFDnsClient::init(const std::string& url)
{
return this->init(url, "", 1, 2, false);
}
int WFDnsClient::init(const std::string& url, const std::string& search_list,
int ndots, int attempts, bool rotate)
{
std::vector<std::string> hosts;
std::vector<ParsedURI> uris;
std::string host;
ParsedURI uri;
this->id = 0;
hosts = StringUtil::split_filter_empty(url, ',');
for (size_t i = 0; i < hosts.size(); i++)
{
host = hosts[i];
if (strncasecmp(host.c_str(), "dns://", 6) != 0 &&
strncasecmp(host.c_str(), "dnss://", 7) != 0)
{
host = "dns://" + host;
}
if (URIParser::parse(host, uri) != 0)
return -1;
uris.emplace_back(std::move(uri));
}
if (uris.empty() || ndots < 0 || attempts < 1)
{
errno = EINVAL;
return -1;
}
this->params = new DnsParams;
DnsParams::dns_params *q = ((DnsParams *)this->params)->get_params();
q->uris = std::move(uris);
q->search_list = StringUtil::split_filter_empty(search_list, ',');
q->ndots = ndots > 15 ? 15 : ndots;
q->attempts = attempts > 5 ? 5 : attempts;
q->rotate = rotate;
return 0;
}
void WFDnsClient::deinit()
{
delete (DnsParams *)this->params;
this->params = NULL;
}
WFDnsTask *WFDnsClient::create_dns_task(const std::string& name,
dns_callback_t callback)
{
DnsParams::dns_params *p = ((DnsParams *)this->params)->get_params();
struct DnsStatus status;
size_t next_server;
WFDnsTask *task;
DnsRequest *req;
next_server = p->rotate ? this->id++ % p->uris.size() : 0;
status.origin_name = name;
status.next_domain = 0;
status.attempts_left = p->attempts;
status.try_origin_state = DNS_STATUS_TRY_ORIGIN_FIRST;
if (!name.empty() && name.back() == '.')
status.next_domain = p->search_list.size();
else if (__get_ndots(name) < p->ndots)
status.try_origin_state = DNS_STATUS_TRY_ORIGIN_LAST;
__has_next_name(p, &status);
task = WFTaskFactory::create_dns_task(p->uris[next_server], 0,
std::move(callback));
status.next_server = next_server;
status.last_server = (next_server + p->uris.size() - 1) % p->uris.size();
req = task->get_req();
req->set_question(status.current_name.c_str(), DNS_TYPE_A, DNS_CLASS_IN);
req->set_rd(1);
ComplexTask *ctask = static_cast<ComplexTask *>(task);
*ctask->get_mutable_ctx() = std::bind(__callback_internal,
std::placeholders::_1,
*(DnsParams *)params, status);
return task;
}

@ -0,0 +1,47 @@
/*
Copyright (c) 2021 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Liu Kai (liukaidx@sogou-inc.com)
*/
#ifndef _WFDNSCLIENT_H_
#define _WFDNSCLIENT_H_
#include <string>
#include <atomic>
#include "WFTaskFactory.h"
#include "DnsMessage.h"
class WFDnsClient
{
public:
int init(const std::string& url);
int init(const std::string& url, const std::string& search_list,
int ndots, int attempts, bool rotate);
void deinit();
WFDnsTask *create_dns_task(const std::string& name,
dns_callback_t callback);
private:
void *params;
std::atomic<size_t> id;
public:
virtual ~WFDnsClient() { }
};
#endif

File diff suppressed because it is too large Load Diff

@ -0,0 +1,194 @@
/*
Copyright (c) 2020 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Wang Zhulei (wangzhulei@sogou-inc.com)
*/
#ifndef _WFKAFKACLIENT_H_
#define _WFKAFKACLIENT_H_
#include <string>
#include <vector>
#include <functional>
#include <openssl/ssl.h>
#include "WFTask.h"
#include "KafkaMessage.h"
#include "KafkaResult.h"
class WFKafkaTask;
class WFKafkaClient;
using kafka_callback_t = std::function<void (WFKafkaTask *)>;
using kafka_partitioner_t = std::function<int (const char *topic_name,
const void *key,
size_t key_len,
int partition_num)>;
class WFKafkaTask : public WFGenericTask
{
public:
virtual bool add_topic(const std::string& topic) = 0;
virtual bool add_toppar(const protocol::KafkaToppar& toppar) = 0;
virtual bool add_produce_record(const std::string& topic, int partition,
protocol::KafkaRecord record) = 0;
virtual bool add_offset_toppar(const protocol::KafkaToppar& toppar) = 0;
void add_commit_record(const protocol::KafkaRecord& record)
{
protocol::KafkaToppar toppar;
toppar.set_topic_partition(record.get_topic(), record.get_partition());
toppar.set_offset(record.get_offset());
toppar.set_error(0);
this->toppar_list.add_item(std::move(toppar));
}
void add_commit_toppar(const protocol::KafkaToppar& toppar)
{
protocol::KafkaToppar toppar_t;
toppar_t.set_topic_partition(toppar.get_topic(), toppar.get_partition());
toppar_t.set_offset(toppar.get_offset());
toppar_t.set_error(0);
this->toppar_list.add_item(std::move(toppar_t));
}
void add_commit_item(const std::string& topic, int partition,
long long offset)
{
protocol::KafkaToppar toppar;
toppar.set_topic_partition(topic, partition);
toppar.set_offset(offset);
toppar.set_error(0);
this->toppar_list.add_item(std::move(toppar));
}
void set_api_type(int api_type)
{
this->api_type = api_type;
}
int get_api_type() const
{
return this->api_type;
}
void set_config(protocol::KafkaConfig conf)
{
this->config = std::move(conf);
}
void set_partitioner(kafka_partitioner_t partitioner)
{
this->partitioner = std::move(partitioner);
}
protocol::KafkaResult *get_result()
{
return &this->result;
}
int get_kafka_error() const
{
return this->kafka_error;
}
void set_callback(kafka_callback_t cb)
{
this->callback = std::move(cb);
}
protected:
WFKafkaTask(int retry_max, kafka_callback_t&& cb)
{
this->callback = std::move(cb);
this->retry_max = retry_max;
this->finish = false;
}
virtual ~WFKafkaTask() { }
virtual SubTask *done();
protected:
protocol::KafkaConfig config;
protocol::KafkaTopparList toppar_list;
protocol::KafkaMetaList meta_list;
protocol::KafkaResult result;
kafka_callback_t callback;
kafka_partitioner_t partitioner;
int api_type;
int kafka_error;
int retry_max;
bool finish;
private:
friend class WFKafkaClient;
};
class WFKafkaClient
{
public:
// example: kafka://10.160.23.23:9000
// example: kafka://kafka.sogou
// example: kafka.sogou:9090
// example: kafka://10.160.23.23:9000,10.123.23.23,kafka://kafka.sogou
// example: kafkas://kafka.sogou -> kafka over TLS
int init(const std::string& broker_url)
{
return this->init(broker_url, NULL);
}
int init(const std::string& broker_url, const std::string& group)
{
return this->init(broker_url, group, NULL);
}
// With a specific SSL_CTX. Effective only on brokers over TLS.
int init(const std::string& broker_url, SSL_CTX *ssl_ctx);
int init(const std::string& broker_url, const std::string& group,
SSL_CTX *ssl_ctx);
int deinit();
// example: topic=xxx&topic=yyy&api=fetch
// example: api=commit
WFKafkaTask *create_kafka_task(const std::string& query,
int retry_max,
kafka_callback_t cb);
WFKafkaTask *create_kafka_task(int retry_max, kafka_callback_t cb);
void set_config(protocol::KafkaConfig conf);
public:
/* If you don't leavegroup manually, rebalance would be triggered */
WFKafkaTask *create_leavegroup_task(int retry_max,
kafka_callback_t callback);
public:
protocol::KafkaMetaList *get_meta_list();
protocol::KafkaBrokerList *get_broker_list();
private:
class KafkaMember *member;
friend class KafkaClientTask;
};
#endif

@ -0,0 +1,55 @@
/*
Copyright (c) 2020 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#include <errno.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#include <utility>
#include "URIParser.h"
#include "WFMySQLConnection.h"
int WFMySQLConnection::init(const std::string& url, SSL_CTX *ssl_ctx)
{
std::string query;
ParsedURI uri;
if (URIParser::parse(url, uri) >= 0)
{
if (uri.query)
{
query = uri.query;
query += '&';
}
query += "transaction=INTERNAL_CONN_ID_" + std::to_string(this->id);
free(uri.query);
uri.query = strdup(query.c_str());
if (uri.query)
{
this->uri = std::move(uri);
this->ssl_ctx = ssl_ctx;
return 0;
}
}
else if (uri.state == URI_STATE_INVALID)
errno = EINVAL;
return -1;
}

@ -0,0 +1,89 @@
/*
Copyright (c) 2020 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#ifndef _WFMYSQLCONNECTION_H_
#define _WFMYSQLCONNECTION_H_
#include <string>
#include <utility>
#include <functional>
#include <openssl/ssl.h>
#include "URIParser.h"
#include "WFTaskFactory.h"
class WFMySQLConnection
{
public:
/* example: mysql://username:passwd@127.0.0.1/dbname?character_set=utf8
* IP string is recommmended in url. When using a domain name, the first
* address resovled will be used. Don't use upstream name as a host. */
int init(const std::string& url)
{
return this->init(url, NULL);
}
int init(const std::string& url, SSL_CTX *ssl_ctx);
void deinit() { }
public:
WFMySQLTask *create_query_task(const std::string& query,
mysql_callback_t callback)
{
WFMySQLTask *task = WFTaskFactory::create_mysql_task(this->uri, 0,
std::move(callback));
this->set_ssl_ctx(task);
task->get_req()->set_query(query);
return task;
}
/* If you don't disconnect manually, the TCP connection will be
* kept alive after this object is deleted, and maybe reused by
* another WFMySQLConnection object with same id and url. */
WFMySQLTask *create_disconnect_task(mysql_callback_t callback)
{
WFMySQLTask *task = this->create_query_task("", std::move(callback));
this->set_ssl_ctx(task);
task->set_keep_alive(0);
return task;
}
protected:
void set_ssl_ctx(WFMySQLTask *task) const
{
using MySQLRequest = protocol::MySQLRequest;
using MySQLResponse = protocol::MySQLResponse;
auto *t = (WFComplexClientTask<MySQLRequest, MySQLResponse> *)task;
/* 'ssl_ctx' can be NULL and will use default. */
t->set_ssl_ctx(this->ssl_ctx);
}
protected:
ParsedURI uri;
SSL_CTX *ssl_ctx;
int id;
public:
/* Make sure that concurrent connections have different id.
* When a connection object is deleted, id can be reused. */
WFMySQLConnection(int id) { this->id = id; }
virtual ~WFMySQLConnection() { }
};
#endif

@ -0,0 +1,129 @@
/*
Copyright (c) 2024 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#include <errno.h>
#include <string>
#include <vector>
#include <mutex>
#include "URIParser.h"
#include "RedisTaskImpl.inl"
#include "WFRedisSubscriber.h"
int WFRedisSubscribeTask::sync_send(const std::string& command,
const std::vector<std::string>& params)
{
std::string str("*" + std::to_string(1 + params.size()) + "\r\n");
int ret;
str += "$" + std::to_string(command.size()) + "\r\n" + command + "\r\n";
for (const std::string& p : params)
str += "$" + std::to_string(p.size()) + "\r\n" + p + "\r\n";
this->mutex.lock();
if (this->task)
{
ret = this->task->push(str.c_str(), str.size());
if (ret == (int)str.size())
ret = 0;
else
{
if (ret >= 0)
errno = ENOBUFS;
ret = -1;
}
}
else
{
errno = ENOENT;
ret = -1;
}
this->mutex.unlock();
return ret;
}
void WFRedisSubscribeTask::task_extract(WFRedisTask *task)
{
auto *t = (WFRedisSubscribeTask *)task->user_data;
if (t->extract)
t->extract(t);
}
void WFRedisSubscribeTask::task_callback(WFRedisTask *task)
{
auto *t = (WFRedisSubscribeTask *)task->user_data;
t->mutex.lock();
t->task = NULL;
t->mutex.unlock();
t->state = task->get_state();
t->error = task->get_error();
if (t->callback)
t->callback(t);
t->release();
}
int WFRedisSubscriber::init(const std::string& url, SSL_CTX *ssl_ctx)
{
if (URIParser::parse(url, this->uri) >= 0)
{
this->ssl_ctx = ssl_ctx;
return 0;
}
if (this->uri.state == URI_STATE_INVALID)
errno = EINVAL;
return -1;
}
WFRedisTask *
WFRedisSubscriber::create_redis_task(const std::string& command,
const std::vector<std::string>& params)
{
WFRedisTask *task = __WFRedisTaskFactory::create_subscribe_task(this->uri,
WFRedisSubscribeTask::task_extract,
WFRedisSubscribeTask::task_callback);
this->set_ssl_ctx(task);
task->get_req()->set_request(command, params);
return task;
}
WFRedisSubscribeTask *
WFRedisSubscriber::create_subscribe_task(
const std::vector<std::string>& channels,
extract_t extract, callback_t callback)
{
WFRedisTask *task = this->create_redis_task("SUBSCRIBE", channels);
return new WFRedisSubscribeTask(task, std::move(extract),
std::move(callback));
}
WFRedisSubscribeTask *
WFRedisSubscriber::create_psubscribe_task(
const std::vector<std::string>& patterns,
extract_t extract, callback_t callback)
{
WFRedisTask *task = this->create_redis_task("PSUBSCRIBE", patterns);
return new WFRedisSubscribeTask(task, std::move(extract),
std::move(callback));
}

@ -0,0 +1,240 @@
/*
Copyright (c) 2024 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#ifndef _WFREDISSUBSCRIBER_H_
#define _WFREDISSUBSCRIBER_H_
#include <errno.h>
#include <string>
#include <vector>
#include <utility>
#include <functional>
#include <atomic>
#include <mutex>
#include <openssl/ssl.h>
#include "WFTask.h"
#include "WFTaskFactory.h"
class WFRedisSubscribeTask : public WFGenericTask
{
public:
/* Note: Call 'get_resp()' only in the 'extract' function or
before the task is started to set response size limit. */
protocol::RedisResponse *get_resp()
{
return this->task->get_resp();
}
public:
/* User needs to call 'release()' exactly once, anywhere. */
void release()
{
if (this->flag.exchange(true))
delete this;
}
public:
/* Note: After 'release()' is called, all the requesting functions
should not be called except in 'extract', because the task
point may have been deleted because 'callback' finished. */
int subscribe(const std::vector<std::string>& channels)
{
return this->sync_send("SUBSCRIBE", channels);
}
int unsubscribe(const std::vector<std::string>& channels)
{
return this->sync_send("UNSUBSCRIBE", channels);
}
int unsubscribe()
{
return this->sync_send("UNSUBSCRIBE", { });
}
int psubscribe(const std::vector<std::string>& patterns)
{
return this->sync_send("PSUBSCRIBE", patterns);
}
int punsubscribe(const std::vector<std::string>& patterns)
{
return this->sync_send("PUNSUBSCRIBE", patterns);
}
int punsubscribe()
{
return this->sync_send("PUNSUBSCRIBE", { });
}
int ping(const std::string& message)
{
return this->sync_send("PING", { message });
}
int ping()
{
return this->sync_send("PING", { });
}
int quit()
{
return this->sync_send("QUIT", { });
}
public:
/* All 'timeout' proxy functions can only be called only before
the task is started or in 'extract'. */
/* Timeout of waiting for each message. Very useful. If not set,
the max waiting time will be the global 'response_timeout'*/
void set_watch_timeout(int timeout)
{
this->task->set_watch_timeout(timeout);
}
/* Timeout of receiving a complete message. */
void set_recv_timeout(int timeout)
{
this->task->set_receive_timeout(timeout);
}
/* Timeout of sending the first subscribe request. */
void set_send_timeout(int timeout)
{
this->task->set_send_timeout(timeout);
}
/* The default keep alive timeout is 0. If you want to keep
the connection alive, make sure not to send any request
after all channels/patterns were unsubscribed. */
void set_keep_alive(int timeout)
{
this->task->set_keep_alive(timeout);
}
public:
/* Call 'set_extract' or 'set_callback' only before the task
is started, or in 'extract'. */
void set_extract(std::function<void (WFRedisSubscribeTask *)> ex)
{
this->extract = std::move(ex);
}
void set_callback(std::function<void (WFRedisSubscribeTask *)> cb)
{
this->callback = std::move(cb);
}
protected:
virtual void dispatch()
{
series_of(this)->push_front(this->task);
this->subtask_done();
}
virtual SubTask *done()
{
return series_of(this)->pop();
}
protected:
int sync_send(const std::string& command,
const std::vector<std::string>& params);
static void task_extract(WFRedisTask *task);
static void task_callback(WFRedisTask *task);
protected:
WFRedisTask *task;
std::mutex mutex;
std::atomic<bool> flag;
std::function<void (WFRedisSubscribeTask *)> extract;
std::function<void (WFRedisSubscribeTask *)> callback;
protected:
WFRedisSubscribeTask(WFRedisTask *task,
std::function<void (WFRedisSubscribeTask *)>&& ex,
std::function<void (WFRedisSubscribeTask *)>&& cb) :
flag(false),
extract(std::move(ex)),
callback(std::move(cb))
{
task->user_data = this;
this->task = task;
}
virtual ~WFRedisSubscribeTask()
{
if (this->task)
this->task->dismiss();
}
friend class WFRedisSubscriber;
};
class WFRedisSubscriber
{
public:
int init(const std::string& url)
{
return this->init(url, NULL);
}
int init(const std::string& url, SSL_CTX *ssl_ctx);
void deinit() { }
public:
using extract_t = std::function<void (WFRedisSubscribeTask *)>;
using callback_t = std::function<void (WFRedisSubscribeTask *)>;
public:
WFRedisSubscribeTask *
create_subscribe_task(const std::vector<std::string>& channels,
extract_t extract, callback_t callback);
WFRedisSubscribeTask *
create_psubscribe_task(const std::vector<std::string>& patterns,
extract_t extract, callback_t callback);
protected:
void set_ssl_ctx(WFRedisTask *task) const
{
using RedisRequest = protocol::RedisRequest;
using RedisResponse = protocol::RedisResponse;
auto *t = (WFComplexClientTask<RedisRequest, RedisResponse> *)task;
/* 'ssl_ctx' can be NULL and will use default. */
t->set_ssl_ctx(this->ssl_ctx);
}
protected:
WFRedisTask *create_redis_task(const std::string& command,
const std::vector<std::string>& params);
protected:
ParsedURI uri;
SSL_CTX *ssl_ctx;
public:
virtual ~WFRedisSubscriber() { }
};
#endif

@ -0,0 +1,23 @@
target("client")
set_kind("object")
add_files("*.cc")
remove_files("WFKafkaClient.cc")
if not has_config("redis") then
remove_files("WFRedisSubscriber.cc")
end
if not has_config("mysql") then
remove_files("WFMySQLConnection.cc")
end
if not has_config("consul") then
remove_files("WFConsulClient.cc")
end
target("kafka_client")
if has_config("kafka") then
add_files("WFKafkaClient.cc")
set_kind("object")
add_deps("client")
add_packages("zlib", "snappy", "zstd", "lz4")
else
set_kind("phony")
end

@ -0,0 +1,37 @@
cmake_minimum_required(VERSION 3.6)
project(factory)
set(SRC
WFGraphTask.cc
DnsTaskImpl.cc
WFTaskFactory.cc
Workflow.cc
HttpTaskImpl.cc
WFResourcePool.cc
WFMessageQueue.cc
FileTaskImpl.cc
)
if (NOT MYSQL STREQUAL "n")
set(SRC
${SRC}
MySQLTaskImpl.cc
)
endif ()
if (NOT REDIS STREQUAL "n")
set(SRC
${SRC}
RedisTaskImpl.cc
)
endif ()
add_library(${PROJECT_NAME} OBJECT ${SRC})
if (KAFKA STREQUAL "y")
set(SRC
KafkaTaskImpl.cc
)
add_library("factory_kafka" OBJECT ${SRC})
set_property(SOURCE KafkaTaskImpl.cc APPEND PROPERTY COMPILE_OPTIONS "-fno-rtti")
endif ()

@ -0,0 +1,226 @@
/*
Copyright (c) 2021 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Liu Kai (liukaidx@sogou-inc.com)
*/
#include <string>
#include <atomic>
#include "DnsMessage.h"
#include "WFTaskError.h"
#include "WFTaskFactory.h"
#include "WFServer.h"
using namespace protocol;
#define DNS_KEEPALIVE_DEFAULT (60 * 1000)
/**********Client**********/
class ComplexDnsTask : public WFComplexClientTask<DnsRequest, DnsResponse,
std::function<void (WFDnsTask *)>>
{
static struct addrinfo hints;
static std::atomic<size_t> seq;
public:
ComplexDnsTask(int retry_max, dns_callback_t&& cb):
WFComplexClientTask(retry_max, std::move(cb))
{
this->set_transport_type(TT_UDP);
}
protected:
virtual CommMessageOut *message_out();
virtual bool init_success();
virtual bool finish_once();
private:
bool need_redirect();
};
struct addrinfo ComplexDnsTask::hints =
{
.ai_flags = AI_NUMERICSERV | AI_NUMERICHOST,
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_STREAM
};
std::atomic<size_t> ComplexDnsTask::seq(0);
CommMessageOut *ComplexDnsTask::message_out()
{
DnsRequest *req = this->get_req();
DnsResponse *resp = this->get_resp();
enum TransportType type = this->get_transport_type();
if (req->get_id() == 0)
req->set_id(++ComplexDnsTask::seq * 99991 % 65535 + 1);
resp->set_request_id(req->get_id());
resp->set_request_name(req->get_question_name());
req->set_single_packet(type == TT_UDP);
resp->set_single_packet(type == TT_UDP);
return this->WFClientTask::message_out();
}
bool ComplexDnsTask::init_success()
{
if (uri_.scheme && strcasecmp(uri_.scheme, "dnss") == 0)
this->WFComplexClientTask::set_transport_type(TT_TCP_SSL);
else if (!uri_.scheme || strcasecmp(uri_.scheme, "dns") != 0)
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_SCHEME_INVALID;
return false;
}
if (!this->route_result_.request_object)
{
enum TransportType type = this->get_transport_type();
struct addrinfo *addr;
int ret;
ret = getaddrinfo(uri_.host, uri_.port, &hints, &addr);
if (ret != 0)
{
this->state = WFT_STATE_DNS_ERROR;
this->error = ret;
return false;
}
auto *ep = &WFGlobal::get_global_settings()->dns_server_params;
ret = WFGlobal::get_route_manager()->get(type, addr, info_, ep,
uri_.host, ssl_ctx_,
route_result_);
freeaddrinfo(addr);
if (ret < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
return false;
}
}
return true;
}
bool ComplexDnsTask::finish_once()
{
if (this->state == WFT_STATE_SUCCESS)
{
if (need_redirect())
this->set_redirect(uri_);
else if (this->state != WFT_STATE_SUCCESS)
this->disable_retry();
}
/* If retry times meet retry max and there is no redirect,
* we ask the client for a retry or redirect.
*/
if (retry_times_ == retry_max_ && !redirect_ && *this->get_mutable_ctx())
{
/* Reset type to UDP before a client redirect. */
this->set_transport_type(TT_UDP);
(*this->get_mutable_ctx())(this);
}
return true;
}
bool ComplexDnsTask::need_redirect()
{
DnsResponse *client_resp = this->get_resp();
enum TransportType type = this->get_transport_type();
if (type == TT_UDP && client_resp->get_tc() == 1)
{
this->set_transport_type(TT_TCP);
return true;
}
return false;
}
/**********Client Factory**********/
WFDnsTask *WFTaskFactory::create_dns_task(const std::string& url,
int retry_max,
dns_callback_t callback)
{
ParsedURI uri;
URIParser::parse(url, uri);
return WFTaskFactory::create_dns_task(uri, retry_max, std::move(callback));
}
WFDnsTask *WFTaskFactory::create_dns_task(const ParsedURI& uri,
int retry_max,
dns_callback_t callback)
{
ComplexDnsTask *task = new ComplexDnsTask(retry_max, std::move(callback));
const char *name;
if (uri.path && uri.path[0] && uri.path[1])
name = uri.path + 1;
else
name = ".";
DnsRequest *req = task->get_req();
req->set_question(name, DNS_TYPE_A, DNS_CLASS_IN);
task->init(uri);
task->set_keep_alive(DNS_KEEPALIVE_DEFAULT);
return task;
}
/**********Server**********/
class WFDnsServerTask : public WFServerTask<DnsRequest, DnsResponse>
{
public:
WFDnsServerTask(CommService *service,
std::function<void (WFDnsTask *)>& proc) :
WFServerTask(service, WFGlobal::get_scheduler(), proc)
{
this->type = ((WFServerBase *)service)->get_params()->transport_type;
}
protected:
virtual CommMessageIn *message_in()
{
this->get_req()->set_single_packet(this->type == TT_UDP);
return this->WFServerTask::message_in();
}
virtual CommMessageOut *message_out()
{
this->get_resp()->set_single_packet(this->type == TT_UDP);
return this->WFServerTask::message_out();
}
protected:
enum TransportType type;
};
/**********Server Factory**********/
WFDnsTask *WFServerTaskFactory::create_dns_task(CommService *service,
std::function<void (WFDnsTask *)>& proc)
{
return new WFDnsServerTask(service, proc);
}

@ -0,0 +1,400 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Xie Han (xiehan@sogou-inc.com)
Li Jinghao (lijinghao@sogou-inc.com)
*/
#include <fcntl.h>
#include <unistd.h>
#include <string>
#include "WFGlobal.h"
#include "WFTaskFactory.h"
class WFFilepreadTask : public WFFileIOTask
{
public:
WFFilepreadTask(int fd, void *buf, size_t count, off_t offset,
IOService *service, fio_callback_t&& cb) :
WFFileIOTask(service, std::move(cb))
{
this->args.fd = fd;
this->args.buf = buf;
this->args.count = count;
this->args.offset = offset;
}
protected:
virtual int prepare()
{
this->prep_pread(this->args.fd, this->args.buf, this->args.count,
this->args.offset);
return 0;
}
};
class WFFilepwriteTask : public WFFileIOTask
{
public:
WFFilepwriteTask(int fd, const void *buf, size_t count, off_t offset,
IOService *service, fio_callback_t&& cb) :
WFFileIOTask(service, std::move(cb))
{
this->args.fd = fd;
this->args.buf = (void *)buf;
this->args.count = count;
this->args.offset = offset;
}
protected:
virtual int prepare()
{
this->prep_pwrite(this->args.fd, this->args.buf, this->args.count,
this->args.offset);
return 0;
}
};
class WFFilepreadvTask : public WFFileVIOTask
{
public:
WFFilepreadvTask(int fd, const struct iovec *iov, int iovcnt, off_t offset,
IOService *service, fvio_callback_t&& cb) :
WFFileVIOTask(service, std::move(cb))
{
this->args.fd = fd;
this->args.iov = iov;
this->args.iovcnt = iovcnt;
this->args.offset = offset;
}
protected:
virtual int prepare()
{
this->prep_preadv(this->args.fd, this->args.iov, this->args.iovcnt,
this->args.offset);
return 0;
}
};
class WFFilepwritevTask : public WFFileVIOTask
{
public:
WFFilepwritevTask(int fd, const struct iovec *iov, int iovcnt, off_t offset,
IOService *service, fvio_callback_t&& cb) :
WFFileVIOTask(service, std::move(cb))
{
this->args.fd = fd;
this->args.iov = iov;
this->args.iovcnt = iovcnt;
this->args.offset = offset;
}
protected:
virtual int prepare()
{
this->prep_pwritev(this->args.fd, this->args.iov, this->args.iovcnt,
this->args.offset);
return 0;
}
};
class WFFilefsyncTask : public WFFileSyncTask
{
public:
WFFilefsyncTask(int fd, IOService *service, fsync_callback_t&& cb) :
WFFileSyncTask(service, std::move(cb))
{
this->args.fd = fd;
}
protected:
virtual int prepare()
{
this->prep_fsync(this->args.fd);
return 0;
}
};
class WFFilefdsyncTask : public WFFileSyncTask
{
public:
WFFilefdsyncTask(int fd, IOService *service, fsync_callback_t&& cb) :
WFFileSyncTask(service, std::move(cb))
{
this->args.fd = fd;
}
protected:
virtual int prepare()
{
this->prep_fdsync(this->args.fd);
return 0;
}
};
/* File tasks created with path name. */
class __WFFilepreadTask : public WFFilepreadTask
{
public:
__WFFilepreadTask(const std::string& path, void *buf, size_t count,
off_t offset, IOService *service, fio_callback_t&& cb):
WFFilepreadTask(-1, buf, count, offset, service, std::move(cb)),
pathname(path)
{
}
protected:
virtual int prepare()
{
this->args.fd = open(this->pathname.c_str(), O_RDONLY);
if (this->args.fd < 0)
return -1;
return WFFilepreadTask::prepare();
}
virtual SubTask *done()
{
if (this->args.fd >= 0)
{
close(this->args.fd);
this->args.fd = -1;
}
return WFFilepreadTask::done();
}
protected:
std::string pathname;
};
class __WFFilepwriteTask : public WFFilepwriteTask
{
public:
__WFFilepwriteTask(const std::string& path, const void *buf, size_t count,
off_t offset, IOService *service, fio_callback_t&& cb):
WFFilepwriteTask(-1, buf, count, offset, service, std::move(cb)),
pathname(path)
{
}
protected:
virtual int prepare()
{
this->args.fd = open(this->pathname.c_str(), O_WRONLY | O_CREAT, 0644);
if (this->args.fd < 0)
return -1;
return WFFilepwriteTask::prepare();
}
virtual SubTask *done()
{
if (this->args.fd >= 0)
{
close(this->args.fd);
this->args.fd = -1;
}
return WFFilepwriteTask::done();
}
protected:
std::string pathname;
};
class __WFFilepreadvTask : public WFFilepreadvTask
{
public:
__WFFilepreadvTask(const std::string& path, const struct iovec *iov,
int iovcnt, off_t offset, IOService *service,
fvio_callback_t&& cb) :
WFFilepreadvTask(-1, iov, iovcnt, offset, service, std::move(cb)),
pathname(path)
{
}
protected:
virtual int prepare()
{
this->args.fd = open(this->pathname.c_str(), O_RDONLY);
if (this->args.fd < 0)
return -1;
return WFFilepreadvTask::prepare();
}
virtual SubTask *done()
{
if (this->args.fd >= 0)
{
close(this->args.fd);
this->args.fd = -1;
}
return WFFilepreadvTask::done();
}
protected:
std::string pathname;
};
class __WFFilepwritevTask : public WFFilepwritevTask
{
public:
__WFFilepwritevTask(const std::string& path, const struct iovec *iov,
int iovcnt, off_t offset, IOService *service,
fvio_callback_t&& cb) :
WFFilepwritevTask(-1, iov, iovcnt, offset, service, std::move(cb)),
pathname(path)
{
}
protected:
virtual int prepare()
{
this->args.fd = open(this->pathname.c_str(), O_WRONLY | O_CREAT, 0644);
if (this->args.fd < 0)
return -1;
return WFFilepwritevTask::prepare();
}
protected:
virtual SubTask *done()
{
if (this->args.fd >= 0)
{
close(this->args.fd);
this->args.fd = -1;
}
return WFFilepwritevTask::done();
}
protected:
std::string pathname;
};
/* Factory functions with fd. */
WFFileIOTask *WFTaskFactory::create_pread_task(int fd,
void *buf,
size_t count,
off_t offset,
fio_callback_t callback)
{
return new WFFilepreadTask(fd, buf, count, offset,
WFGlobal::get_io_service(),
std::move(callback));
}
WFFileIOTask *WFTaskFactory::create_pwrite_task(int fd,
const void *buf,
size_t count,
off_t offset,
fio_callback_t callback)
{
return new WFFilepwriteTask(fd, buf, count, offset,
WFGlobal::get_io_service(),
std::move(callback));
}
WFFileVIOTask *WFTaskFactory::create_preadv_task(int fd,
const struct iovec *iovec,
int iovcnt,
off_t offset,
fvio_callback_t callback)
{
return new WFFilepreadvTask(fd, iovec, iovcnt, offset,
WFGlobal::get_io_service(),
std::move(callback));
}
WFFileVIOTask *WFTaskFactory::create_pwritev_task(int fd,
const struct iovec *iovec,
int iovcnt,
off_t offset,
fvio_callback_t callback)
{
return new WFFilepwritevTask(fd, iovec, iovcnt, offset,
WFGlobal::get_io_service(),
std::move(callback));
}
WFFileSyncTask *WFTaskFactory::create_fsync_task(int fd,
fsync_callback_t callback)
{
return new WFFilefsyncTask(fd,
WFGlobal::get_io_service(),
std::move(callback));
}
WFFileSyncTask *WFTaskFactory::create_fdsync_task(int fd,
fsync_callback_t callback)
{
return new WFFilefdsyncTask(fd,
WFGlobal::get_io_service(),
std::move(callback));
}
/* Factory functions with path name. */
WFFileIOTask *WFTaskFactory::create_pread_task(const std::string& path,
void *buf,
size_t count,
off_t offset,
fio_callback_t callback)
{
return new __WFFilepreadTask(path, buf, count, offset,
WFGlobal::get_io_service(),
std::move(callback));
}
WFFileIOTask *WFTaskFactory::create_pwrite_task(const std::string& path,
const void *buf,
size_t count,
off_t offset,
fio_callback_t callback)
{
return new __WFFilepwriteTask(path, buf, count, offset,
WFGlobal::get_io_service(),
std::move(callback));
}
WFFileVIOTask *WFTaskFactory::create_preadv_task(const std::string& path,
const struct iovec *iovec,
int iovcnt,
off_t offset,
fvio_callback_t callback)
{
return new __WFFilepreadvTask(path, iovec, iovcnt, offset,
WFGlobal::get_io_service(),
std::move(callback));
}
WFFileVIOTask *WFTaskFactory::create_pwritev_task(const std::string& path,
const struct iovec *iovec,
int iovcnt,
off_t offset,
fvio_callback_t callback)
{
return new __WFFilepwritevTask(path, iovec, iovcnt, offset,
WFGlobal::get_io_service(),
std::move(callback));
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,784 @@
/*
Copyright (c) 2020 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Wang Zhulei (wangzhulei@sogou-inc.com)
Xie Han (xiehan@sogou-inc.com)
*/
#include <assert.h>
#include <stdio.h>
#include <string.h>
#include <string>
#include <set>
#include <openssl/sha.h>
#include <openssl/evp.h>
#include "StringUtil.h"
#include "KafkaTaskImpl.inl"
using namespace protocol;
#define KAFKA_KEEPALIVE_DEFAULT (60 * 1000)
#define KAFKA_ROUNDTRIP_TIMEOUT (5 * 1000)
static KafkaCgroup __create_cgroup(const KafkaCgroup *c)
{
KafkaCgroup g;
const char *member_id = c->get_member_id();
if (member_id)
g.set_member_id(member_id);
g.set_group(c->get_group());
return g;
}
/**********Client**********/
class __ComplexKafkaTask : public WFComplexClientTask<KafkaRequest, KafkaResponse, int>
{
public:
__ComplexKafkaTask(int retry_max, __kafka_callback_t&& callback) :
WFComplexClientTask(retry_max, std::move(callback))
{
is_user_request_ = true;
is_redirect_ = false;
ctx_ = 0;
}
protected:
virtual CommMessageOut *message_out();
virtual CommMessageIn *message_in();
virtual bool init_success();
virtual bool finish_once();
private:
struct KafkaConnectionInfo
{
kafka_api_t api;
kafka_sasl_t sasl;
std::string mechanisms;
KafkaConnectionInfo()
{
kafka_api_init(&this->api);
kafka_sasl_init(&this->sasl);
}
~KafkaConnectionInfo()
{
kafka_api_deinit(&this->api);
kafka_sasl_deinit(&this->sasl);
}
bool init(const char *mechanisms)
{
this->mechanisms = mechanisms;
if (strncasecmp(mechanisms, "SCRAM", 5) == 0)
{
if (strcasecmp(mechanisms, "SCRAM-SHA-1") == 0)
{
this->sasl.scram.evp = EVP_sha1();
this->sasl.scram.scram_h = SHA1;
this->sasl.scram.scram_h_size = SHA_DIGEST_LENGTH;
}
else if (strcasecmp(mechanisms, "SCRAM-SHA-256") == 0)
{
this->sasl.scram.evp = EVP_sha256();
this->sasl.scram.scram_h = SHA256;
this->sasl.scram.scram_h_size = SHA256_DIGEST_LENGTH;
}
else if (strcasecmp(mechanisms, "SCRAM-SHA-512") == 0)
{
this->sasl.scram.evp = EVP_sha512();
this->sasl.scram.scram_h = SHA512;
this->sasl.scram.scram_h_size = SHA512_DIGEST_LENGTH;
}
else
return false;
}
return true;
}
};
virtual int keep_alive_timeout();
virtual int first_timeout();
bool has_next();
bool process_produce();
bool process_fetch();
bool process_metadata();
bool process_list_offsets();
bool process_find_coordinator();
bool process_join_group();
bool process_sync_group();
bool process_sasl_authenticate();
bool process_sasl_handshake();
bool is_user_request_;
bool is_redirect_;
std::string user_info_;
};
CommMessageOut *__ComplexKafkaTask::message_out()
{
long long seqid = this->get_seq();
if (seqid == 0)
{
KafkaConnectionInfo *conn_info = new KafkaConnectionInfo;
this->get_req()->set_api(&conn_info->api);
this->get_connection()->set_context(conn_info, [](void *ctx) {
delete (KafkaConnectionInfo *)ctx;
});
if (!this->get_req()->get_config()->get_broker_version())
{
KafkaRequest *req = new KafkaRequest;
req->duplicate(*this->get_req());
req->set_api_type(Kafka_ApiVersions);
is_user_request_ = false;
return req;
}
else
{
kafka_api_version_t *api;
size_t api_cnt;
const char *v = this->get_req()->get_config()->get_broker_version();
int ret = kafka_api_version_is_queryable(v, &api, &api_cnt);
kafka_api_version_t *p = NULL;
if (ret == 0)
{
p = (kafka_api_version_t *)malloc(api_cnt * sizeof(*p));
if (p)
{
memcpy(p, api, api_cnt * sizeof(kafka_api_version_t));
conn_info->api.api = p;
conn_info->api.elements = api_cnt;
conn_info->api.features = kafka_get_features(p, api_cnt);
}
}
if (!p)
return NULL;
seqid++;
}
}
if (seqid == 1)
{
const char *sasl_mech = this->get_req()->get_config()->get_sasl_mech();
KafkaConnectionInfo *conn_info =
(KafkaConnectionInfo *)this->get_connection()->get_context();
if (sasl_mech && conn_info->sasl.status == 0)
{
if (!conn_info->init(sasl_mech))
return NULL;
this->get_req()->set_api(&conn_info->api);
this->get_req()->set_sasl(&conn_info->sasl);
KafkaRequest *req = new KafkaRequest;
req->duplicate(*this->get_req());
if (conn_info->api.features & KAFKA_FEATURE_SASL_HANDSHAKE)
req->set_api_type(Kafka_SaslHandshake);
else
req->set_api_type(Kafka_SaslAuthenticate);
req->set_correlation_id(1);
is_user_request_ = false;
return req;
}
}
KafkaConnectionInfo *conn_info =
(KafkaConnectionInfo *)this->get_connection()->get_context();
KafkaRequest *req = this->get_req();
req->set_api(&conn_info->api);
if (req->get_api_type() == Kafka_Fetch ||
req->get_api_type() == Kafka_ListOffsets)
{
KafkaTopparList *req_toppar_lst = req->get_toppar_list();
KafkaToppar *toppar;
KafkaTopparList toppar_list;
bool flag = false;
long long cfg_ts = req->get_config()->get_offset_timestamp();
long long tp_ts;
req_toppar_lst->rewind();
while ((toppar = req_toppar_lst->get_next()) != NULL)
{
tp_ts = toppar->get_offset_timestamp();
if (tp_ts == KAFKA_TIMESTAMP_UNINIT)
tp_ts = cfg_ts;
if (toppar->get_offset() == KAFKA_OFFSET_UNINIT)
{
if (tp_ts == KAFKA_TIMESTAMP_EARLIEST)
toppar->set_offset(toppar->get_low_watermark());
else if (tp_ts < 0)
{
toppar->set_offset(toppar->get_high_watermark());
tp_ts = KAFKA_TIMESTAMP_LATEST;
}
}
else if (toppar->get_offset() == KAFKA_OFFSET_OVERFLOW)
{
if (tp_ts == KAFKA_TIMESTAMP_EARLIEST)
toppar->set_offset(toppar->get_low_watermark());
else
{
toppar->set_offset(toppar->get_high_watermark());
tp_ts = KAFKA_TIMESTAMP_LATEST;
}
}
if (toppar->get_offset() < 0)
{
toppar->set_offset_timestamp(tp_ts);
toppar_list.add_item(*toppar);
flag = true;
}
}
if (flag)
{
KafkaRequest *new_req = new KafkaRequest;
new_req->set_api(&conn_info->api);
new_req->set_broker(*req->get_broker());
new_req->set_toppar_list(toppar_list);
new_req->set_config(*req->get_config());
new_req->set_api_type(Kafka_ListOffsets);
new_req->set_correlation_id(seqid);
is_user_request_ = false;
return new_req;
}
}
this->get_req()->set_correlation_id(seqid);
return this->WFComplexClientTask::message_out();
}
CommMessageIn *__ComplexKafkaTask::message_in()
{
KafkaRequest *req = static_cast<KafkaRequest *>(this->get_message_out());
KafkaResponse *resp = this->get_resp();
KafkaCgroup *cgroup;
resp->set_api_type(req->get_api_type());
resp->set_api_version(req->get_api_version());
resp->duplicate(*req);
switch (req->get_api_type())
{
case Kafka_FindCoordinator:
case Kafka_Heartbeat:
cgroup = req->get_cgroup();
if (cgroup->get_group())
resp->set_cgroup(__create_cgroup(cgroup));
break;
default:
break;
}
return this->WFComplexClientTask::message_in();
}
bool __ComplexKafkaTask::init_success()
{
enum TransportType type;
if (uri_.scheme && strcasecmp(uri_.scheme, "kafka") == 0)
type = TT_TCP;
else if (uri_.scheme && strcasecmp(uri_.scheme, "kafkas") == 0)
type = TT_TCP_SSL;
else
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_SCHEME_INVALID;
return false;
}
std::string username, password, sasl, client;
if (uri_.userinfo)
{
const char *pos = strchr(uri_.userinfo, ':');
if (pos)
{
username = std::string(uri_.userinfo, pos - uri_.userinfo);
StringUtil::url_decode(username);
const char *pos1 = strchr(pos + 1, ':');
if (pos1)
{
password = std::string(pos + 1, pos1 - pos - 1);
StringUtil::url_decode(password);
const char *pos2 = strchr(pos1 + 1, ':');
if (pos2)
{
sasl = std::string(pos1 + 1, pos2 - pos1 - 1);
client = std::string(pos1 + 1);
}
}
}
if (username.empty() || password.empty() || sasl.empty() || client.empty())
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_SCHEME_INVALID;
return false;
}
user_info_ = uri_.userinfo;
size_t info_len = username.size() + password.size() + sasl.size() +
client.size() + 50;
char *info = new char[info_len];
snprintf(info, info_len, "%s|user:%s|pass:%s|sasl:%s|client:%s|", "kafka",
username.c_str(), password.c_str(), sasl.c_str(), client.c_str());
this->WFComplexClientTask::set_info(info);
delete []info;
}
this->WFComplexClientTask::set_transport_type(type);
return true;
}
int __ComplexKafkaTask::keep_alive_timeout()
{
if (this->get_resp()->get_broker()->get_error())
return 0;
return this->WFComplexClientTask::keep_alive_timeout();
}
int __ComplexKafkaTask::first_timeout()
{
KafkaRequest *client_req = this->get_req();
int ret = 0;
switch(client_req->get_api_type())
{
case Kafka_Fetch:
ret = client_req->get_config()->get_fetch_timeout();
break;
case Kafka_JoinGroup:
ret = client_req->get_config()->get_session_timeout();
break;
case Kafka_SyncGroup:
ret = client_req->get_config()->get_rebalance_timeout();
break;
case Kafka_Produce:
ret = client_req->get_config()->get_produce_timeout();
break;
default:
return 0;
}
return ret + KAFKA_ROUNDTRIP_TIMEOUT;
}
bool __ComplexKafkaTask::process_find_coordinator()
{
KafkaCgroup *cgroup = this->get_resp()->get_cgroup();
ctx_ = cgroup->get_error();
if (ctx_)
{
this->error = WFT_ERR_KAFKA_CGROUP_FAILED;
this->state = WFT_STATE_TASK_ERROR;
return false;
}
else
{
this->get_req()->set_cgroup(*cgroup);
KafkaBroker *coordinator = cgroup->get_coordinator();
std::string url(uri_.scheme);
url += "://";
url += user_info_ + "@";
url += coordinator->get_host();
url += ":" + std::to_string(coordinator->get_port());
ParsedURI uri;
URIParser::parse(url, uri);
set_redirect(std::move(uri));
this->get_req()->set_api_type(Kafka_JoinGroup);
is_redirect_ = true;
return true;
}
}
bool __ComplexKafkaTask::process_join_group()
{
KafkaResponse *msg = this->get_resp();
switch(msg->get_cgroup()->get_error())
{
case KAFKA_MEMBER_ID_REQUIRED:
this->get_req()->set_api_type(Kafka_JoinGroup);
break;
case KAFKA_UNKNOWN_MEMBER_ID:
msg->get_cgroup()->set_member_id("");
this->get_req()->set_api_type(Kafka_JoinGroup);
break;
case 0:
this->get_req()->set_api_type(Kafka_Metadata);
break;
default:
ctx_ = msg->get_cgroup()->get_error();
this->error = WFT_ERR_KAFKA_CGROUP_FAILED;
this->state = WFT_STATE_TASK_ERROR;
return false;
}
return true;
}
bool __ComplexKafkaTask::process_sync_group()
{
ctx_ = this->get_resp()->get_cgroup()->get_error();
if (ctx_)
{
this->error = WFT_ERR_KAFKA_CGROUP_FAILED;
this->state = WFT_STATE_TASK_ERROR;
return false;
}
else
{
this->get_req()->set_api_type(Kafka_OffsetFetch);
return true;
}
}
bool __ComplexKafkaTask::process_metadata()
{
KafkaResponse *msg = this->get_resp();
msg->get_meta_list()->rewind();
KafkaMeta *meta;
while ((meta = msg->get_meta_list()->get_next()) != NULL)
{
switch (meta->get_error())
{
case KAFKA_LEADER_NOT_AVAILABLE:
this->get_req()->set_api_type(Kafka_Metadata);
return true;
case 0:
break;
default:
ctx_ = meta->get_error();
this->error = WFT_ERR_KAFKA_META_FAILED;
this->state = WFT_STATE_TASK_ERROR;
return false;
}
}
this->get_req()->set_meta_list(*msg->get_meta_list());
if (msg->get_cgroup()->get_group())
{
if (msg->get_cgroup()->is_leader())
{
KafkaCgroup *cgroup = msg->get_cgroup();
if (cgroup->run_assignor(msg->get_meta_list(),
cgroup->get_protocol_name()) < 0)
{
this->error = WFT_ERR_KAFKA_CGROUP_ASSIGN_FAILED;
this->state = WFT_STATE_TASK_ERROR;
return false;
}
}
this->get_req()->set_api_type(Kafka_SyncGroup);
return true;
}
return false;
}
bool __ComplexKafkaTask::process_fetch()
{
bool ret = false;
KafkaToppar *toppar;
this->get_resp()->get_toppar_list()->rewind();
while ((toppar = this->get_resp()->get_toppar_list()->get_next()) != NULL)
{
if (toppar->get_error() == KAFKA_OFFSET_OUT_OF_RANGE)
{
toppar->set_offset(KAFKA_OFFSET_OVERFLOW);
toppar->set_low_watermark(KAFKA_OFFSET_UNINIT);
toppar->set_high_watermark(KAFKA_OFFSET_UNINIT);
ret = true;
}
switch (toppar->get_error())
{
case KAFKA_UNKNOWN_TOPIC_OR_PARTITION:
case KAFKA_LEADER_NOT_AVAILABLE:
case KAFKA_NOT_LEADER_FOR_PARTITION:
case KAFKA_BROKER_NOT_AVAILABLE:
case KAFKA_REPLICA_NOT_AVAILABLE:
case KAFKA_KAFKA_STORAGE_ERROR:
case KAFKA_FENCED_LEADER_EPOCH:
this->get_req()->set_api_type(Kafka_Metadata);
return true;
case 0:
case KAFKA_OFFSET_OUT_OF_RANGE:
break;
default:
ctx_ = toppar->get_error();
this->error = WFT_ERR_KAFKA_FETCH_FAILED;
this->state = WFT_STATE_TASK_ERROR;
return false;
}
}
return ret;
}
bool __ComplexKafkaTask::process_list_offsets()
{
KafkaToppar *toppar;
this->get_resp()->get_toppar_list()->rewind();
while ((toppar = this->get_resp()->get_toppar_list()->get_next()) != NULL)
{
if (toppar->get_error())
{
this->error = toppar->get_error();
this->state = WFT_STATE_TASK_ERROR;
}
}
return false;
}
bool __ComplexKafkaTask::process_produce()
{
KafkaToppar *toppar;
this->get_resp()->get_toppar_list()->rewind();
while ((toppar = this->get_resp()->get_toppar_list()->get_next()) != NULL)
{
if (!toppar->record_reach_end())
{
this->get_req()->set_api_type(Kafka_Produce);
return true;
}
switch (toppar->get_error())
{
case KAFKA_UNKNOWN_TOPIC_OR_PARTITION:
case KAFKA_LEADER_NOT_AVAILABLE:
case KAFKA_NOT_LEADER_FOR_PARTITION:
case KAFKA_BROKER_NOT_AVAILABLE:
case KAFKA_REPLICA_NOT_AVAILABLE:
case KAFKA_KAFKA_STORAGE_ERROR:
case KAFKA_FENCED_LEADER_EPOCH:
this->get_req()->set_api_type(Kafka_Metadata);
return true;
case 0:
break;
default:
this->error = toppar->get_error();
this->state = WFT_STATE_TASK_ERROR;
return false;
}
}
return false;
}
bool __ComplexKafkaTask::process_sasl_handshake()
{
ctx_ = this->get_resp()->get_broker()->get_error();
if (ctx_)
{
this->error = WFT_ERR_KAFKA_SASL_DISALLOWED;
this->state = WFT_STATE_TASK_ERROR;
return false;
}
return true;
}
bool __ComplexKafkaTask::process_sasl_authenticate()
{
ctx_ = this->get_resp()->get_broker()->get_error();
if (ctx_)
{
this->error = WFT_ERR_KAFKA_SASL_DISALLOWED;
this->state = WFT_STATE_TASK_ERROR;
}
return false;
}
bool __ComplexKafkaTask::has_next()
{
switch (this->get_resp()->get_api_type())
{
case Kafka_Produce:
return this->process_produce();
case Kafka_Fetch:
return this->process_fetch();
case Kafka_Metadata:
return this->process_metadata();
case Kafka_FindCoordinator:
return this->process_find_coordinator();
case Kafka_JoinGroup:
return this->process_join_group();
case Kafka_SyncGroup:
return this->process_sync_group();
case Kafka_SaslHandshake:
return this->process_sasl_handshake();
case Kafka_SaslAuthenticate:
return this->process_sasl_authenticate();
case Kafka_ListOffsets:
return this->process_list_offsets();
case Kafka_OffsetCommit:
case Kafka_OffsetFetch:
case Kafka_LeaveGroup:
case Kafka_DescribeGroups:
case Kafka_Heartbeat:
ctx_ = this->get_resp()->get_cgroup()->get_error();
if (ctx_)
{
this->error = WFT_ERR_KAFKA_CGROUP_FAILED;
this->state = WFT_STATE_TASK_ERROR;
}
break;
case Kafka_ApiVersions:
break;
default:
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_KAFKA_API_UNKNOWN;
break;
}
return false;
}
bool __ComplexKafkaTask::finish_once()
{
bool finish = true;
if (this->state == WFT_STATE_SUCCESS)
finish = !has_next();
if (!is_user_request_)
{
delete this->get_message_out();
this->get_resp()->clear_buf();
}
if (is_redirect_ && this->state == WFT_STATE_UNDEFINED)
{
this->get_req()->clear_buf();
is_redirect_ = false;
}
else if (this->state == WFT_STATE_SUCCESS)
{
if (!is_user_request_)
{
is_user_request_ = true;
return false;
}
if (!finish)
{
this->get_req()->clear_buf();
this->get_resp()->clear_buf();
return false;
}
}
else
{
this->get_resp()->set_api_type(this->get_req()->get_api_type());
this->get_resp()->set_api_version(this->get_req()->get_api_version());
}
is_user_request_ = true;
return true;
}
/**********Factory**********/
// kafka://user:password:sasl@host:port/api=type&topic=name
__WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const std::string& url,
SSL_CTX *ssl_ctx,
int retry_max,
__kafka_callback_t callback)
{
auto *task = new __ComplexKafkaTask(retry_max, std::move(callback));
task->set_ssl_ctx(ssl_ctx);
ParsedURI uri;
URIParser::parse(url, uri);
task->init(std::move(uri));
task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT);
return task;
}
__WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const ParsedURI& uri,
SSL_CTX *ssl_ctx,
int retry_max,
__kafka_callback_t callback)
{
auto *task = new __ComplexKafkaTask(retry_max, std::move(callback));
task->set_ssl_ctx(ssl_ctx);
task->init(uri);
task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT);
return task;
}
__WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(enum TransportType type,
const char *host,
unsigned short port,
SSL_CTX *ssl_ctx,
const std::string& info,
int retry_max,
__kafka_callback_t callback)
{
auto *task = new __ComplexKafkaTask(retry_max, std::move(callback));
task->set_ssl_ctx(ssl_ctx);
ParsedURI uri;
char buf[32];
if (type == TT_TCP_SSL)
uri.scheme = strdup("kafkas");
else
uri.scheme = strdup("kafka");
if (!info.empty())
uri.userinfo = strdup(info.c_str());
uri.host = strdup(host);
sprintf(buf, "%u", port);
uri.port = strdup(buf);
if (!uri.scheme || !uri.host || !uri.port ||
(!info.empty() && !uri.userinfo))
{
uri.state = URI_STATE_ERROR;
uri.error = errno;
}
else
uri.state = URI_STATE_SUCCESS;
task->init(std::move(uri));
task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT);
return task;
}

@ -0,0 +1,53 @@
/*
Copyright (c) 2020 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Wang Zhulei (wangzhulei@sogou-inc.com)
*/
#include <openssl/ssl.h>
#include "WFTaskFactory.h"
#include "KafkaMessage.h"
// Kafka internal task. For __ComplexKafkaTask usage only
using __WFKafkaTask = WFNetworkTask<protocol::KafkaRequest,
protocol::KafkaResponse>;
using __kafka_callback_t = std::function<void (__WFKafkaTask *)>;
class __WFKafkaTaskFactory
{
public:
/* __WFKafkaTask is create by __ComplexKafkaTask. This is an internal
* interface for create internal task. It should not be created directly by common
* user task.
*/
static __WFKafkaTask *create_kafka_task(const ParsedURI& uri,
SSL_CTX *ssl_ctx,
int retry_max,
__kafka_callback_t callback);
static __WFKafkaTask *create_kafka_task(const std::string& url,
SSL_CTX *ssl_ctx,
int retry_max,
__kafka_callback_t callback);
static __WFKafkaTask *create_kafka_task(enum TransportType type,
const char *host,
unsigned short port,
SSL_CTX *ssl_ctx,
const std::string& info,
int retry_max,
__kafka_callback_t callback);
};

@ -0,0 +1,855 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Xie Han (xiehan@sogou-inc.com)
Wu Jiaxu (wujiaxu@sogou-inc.com)
Li Yingxin (liyingxin@sogou-inc.com)
*/
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <string>
#include <unordered_map>
#include <openssl/ssl.h>
#include <openssl/bio.h>
#include "WFTaskError.h"
#include "WFTaskFactory.h"
#include "StringUtil.h"
#include "WFGlobal.h"
#include "mysql_types.h"
using namespace protocol;
#define MYSQL_KEEPALIVE_DEFAULT (60 * 1000)
#define MYSQL_KEEPALIVE_TRANSACTION (3600 * 1000)
/**********Client**********/
class ComplexMySQLTask : public WFComplexClientTask<MySQLRequest, MySQLResponse>
{
protected:
virtual bool check_request();
virtual CommMessageOut *message_out();
virtual CommMessageIn *message_in();
virtual int keep_alive_timeout();
virtual int first_timeout();
virtual bool init_success();
virtual bool finish_once();
protected:
virtual WFConnection *get_connection() const
{
WFConnection *conn = this->WFComplexClientTask::get_connection();
if (conn)
{
void *ctx = conn->get_context();
if (ctx)
conn = (WFConnection *)ctx;
}
return conn;
}
private:
enum ConnState
{
ST_SSL_REQUEST,
ST_AUTH_REQUEST,
ST_AUTH_SWITCH_REQUEST,
ST_CLEAR_PASSWORD_REQUEST,
ST_SHA256_PUBLIC_KEY_REQUEST,
ST_CSHA2_PUBLIC_KEY_REQUEST,
ST_RSA_AUTH_REQUEST,
ST_CHARSET_REQUEST,
ST_FIRST_USER_REQUEST,
ST_USER_REQUEST
};
struct MyConnection : public WFConnection
{
std::string str; // shared by auth, auth_swich and rsa_auth requests
unsigned char seed[20];
enum ConnState state;
unsigned char mysql_seqid;
SSL *ssl;
SSLWrapper wrapper;
MyConnection(SSL *ssl) : wrapper(&wrapper, ssl)
{
this->ssl = ssl;
}
};
int check_handshake(MySQLHandshakeResponse *resp);
int auth_switch(MySQLAuthResponse *resp, MyConnection *conn);
struct MySSLWrapper : public SSLWrapper
{
MySSLWrapper(ProtocolMessage *msg, SSL *ssl) :
SSLWrapper(msg, ssl)
{ }
ProtocolMessage *get_msg() const { return this->message; }
virtual ~MySSLWrapper() { delete this->message; }
};
private:
std::string username_;
std::string password_;
std::string db_;
std::string res_charset_;
short character_set_;
short state_;
int error_;
bool is_ssl_;
bool is_user_request_;
public:
ComplexMySQLTask(int retry_max, mysql_callback_t&& callback):
WFComplexClientTask(retry_max, std::move(callback)),
character_set_(33),
is_user_request_(true)
{}
};
bool ComplexMySQLTask::check_request()
{
if (this->req.query_is_unset() == false)
{
if (this->req.get_command() == MYSQL_COM_QUERY)
{
std::string query = this->req.get_query();
if (strncasecmp(query.c_str(), "USE ", 4) &&
strncasecmp(query.c_str(), "SET NAMES ", 10) &&
strncasecmp(query.c_str(), "SET CHARSET ", 12) &&
strncasecmp(query.c_str(), "SET CHARACTER SET ", 18))
{
return true;
}
}
this->error = WFT_ERR_MYSQL_COMMAND_DISALLOWED;
}
else
this->error = WFT_ERR_MYSQL_QUERY_NOT_SET;
this->state = WFT_STATE_TASK_ERROR;
return false;
}
static SSL *__create_ssl(SSL_CTX *ssl_ctx)
{
BIO *wbio;
BIO *rbio;
SSL *ssl;
rbio = BIO_new(BIO_s_mem());
if (rbio)
{
wbio = BIO_new(BIO_s_mem());
if (wbio)
{
ssl = SSL_new(ssl_ctx);
if (ssl)
{
SSL_set_bio(ssl, rbio, wbio);
return ssl;
}
BIO_free(wbio);
}
BIO_free(rbio);
}
return NULL;
}
CommMessageOut *ComplexMySQLTask::message_out()
{
MySQLAuthSwitchRequest *auth_switch_req;
MySQLRSAAuthRequest *rsa_auth_req;
MySQLAuthRequest *auth_req;
MySQLRequest *req;
is_user_request_ = false;
if (this->get_seq() == 0)
return new MySQLHandshakeRequest;
auto *conn = (MyConnection *)this->get_connection();
switch (conn->state)
{
case ST_SSL_REQUEST:
req = new MySQLSSLRequest(character_set_, conn->ssl);
req->set_seqid(conn->mysql_seqid);
return req;
case ST_AUTH_REQUEST:
req = new MySQLAuthRequest;
auth_req = (MySQLAuthRequest *)req;
auth_req->set_auth(username_, password_, db_, character_set_);
auth_req->set_auth_plugin_name(std::move(conn->str));
auth_req->set_seed(conn->seed);
break;
case ST_CLEAR_PASSWORD_REQUEST:
conn->str = "mysql_clear_password";
case ST_AUTH_SWITCH_REQUEST:
req = new MySQLAuthSwitchRequest;
auth_switch_req = (MySQLAuthSwitchRequest *)req;
auth_switch_req->set_password(password_);
auth_switch_req->set_auth_plugin_name(std::move(conn->str));
auth_switch_req->set_seed(conn->seed);
#if OPENSSL_VERSION_NUMBER < 0x10100000L
WFGlobal::get_ssl_client_ctx();
#endif
break;
case ST_SHA256_PUBLIC_KEY_REQUEST:
req = new MySQLPublicKeyRequest;
((MySQLPublicKeyRequest *)req)->set_sha256();
break;
case ST_CSHA2_PUBLIC_KEY_REQUEST:
req = new MySQLPublicKeyRequest;
((MySQLPublicKeyRequest *)req)->set_caching_sha2();
break;
case ST_RSA_AUTH_REQUEST:
req = new MySQLRSAAuthRequest;
rsa_auth_req = (MySQLRSAAuthRequest *)req;
rsa_auth_req->set_password(password_);
rsa_auth_req->set_public_key(std::move(conn->str));
rsa_auth_req->set_seed(conn->seed);
break;
case ST_CHARSET_REQUEST:
req = new MySQLRequest;
req->set_query("SET NAMES " + res_charset_);
break;
case ST_FIRST_USER_REQUEST:
if (this->is_fixed_conn())
{
auto *target = (RouteManager::RouteTarget *)this->target;
/* If it's a transaction task, generate a ECONNRESET error when
* the target was reconnected. */
if (target->state)
{
is_user_request_ = true;
errno = ECONNRESET;
return NULL;
}
target->state = 1;
}
case ST_USER_REQUEST:
is_user_request_ = true;
req = (MySQLRequest *)this->WFComplexClientTask::message_out();
break;
default:
assert(0);
return NULL;
}
if (!is_user_request_ && conn->state != ST_CHARSET_REQUEST)
req->set_seqid(conn->mysql_seqid);
if (!is_ssl_)
return req;
if (is_user_request_)
{
conn->wrapper = SSLWrapper(req, conn->ssl);
return &conn->wrapper;
}
else
return new MySSLWrapper(req, conn->ssl);
}
CommMessageIn *ComplexMySQLTask::message_in()
{
MySQLResponse *resp;
if (this->get_seq() == 0)
return new MySQLHandshakeResponse;
auto *conn = (MyConnection *)this->get_connection();
switch (conn->state)
{
case ST_SSL_REQUEST:
return new SSLHandshaker(conn->ssl);
case ST_AUTH_REQUEST:
case ST_AUTH_SWITCH_REQUEST:
resp = new MySQLAuthResponse;
break;
case ST_CLEAR_PASSWORD_REQUEST:
case ST_RSA_AUTH_REQUEST:
resp = new MySQLResponse;
break;
case ST_SHA256_PUBLIC_KEY_REQUEST:
case ST_CSHA2_PUBLIC_KEY_REQUEST:
resp = new MySQLPublicKeyResponse;
break;
case ST_CHARSET_REQUEST:
resp = new MySQLResponse;
break;
case ST_FIRST_USER_REQUEST:
case ST_USER_REQUEST:
resp = (MySQLResponse *)this->WFComplexClientTask::message_in();
break;
default:
assert(0);
return NULL;
}
if (!is_ssl_)
return resp;
if (is_user_request_)
{
conn->wrapper = SSLWrapper(resp, conn->ssl);
return &conn->wrapper;
}
else
return new MySSLWrapper(resp, conn->ssl);
}
int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp)
{
SSL *ssl = NULL;
if (resp->host_disallowed())
{
this->resp = std::move(*(MySQLResponse *)resp);
state_ = WFT_STATE_TASK_ERROR;
error_ = WFT_ERR_MYSQL_HOST_NOT_ALLOWED;
return 0;
}
if (is_ssl_)
{
if (resp->get_capability_flags() & 0x800)
{
static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx();
ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx);
if (!ssl)
{
state_ = WFT_STATE_SYS_ERROR;
error_ = errno;
return 0;
}
SSL_set_connect_state(ssl);
}
else
{
this->resp = std::move(*(MySQLResponse *)resp);
state_ = WFT_STATE_TASK_ERROR;
error_ = WFT_ERR_MYSQL_SSL_NOT_SUPPORTED;
return 0;
}
}
auto *conn = this->get_connection();
auto *my_conn = new MyConnection(ssl);
my_conn->str = resp->get_auth_plugin_name();
if (!password_.empty() && my_conn->str == "sha256_password")
my_conn->str = "caching_sha2_password";
resp->get_seed(my_conn->seed);
my_conn->state = is_ssl_ ? ST_SSL_REQUEST : ST_AUTH_REQUEST;
my_conn->mysql_seqid = resp->get_seqid() + 1;
conn->set_context(my_conn, [](void *ctx) {
auto *my_conn = (MyConnection *)ctx;
if (my_conn->ssl)
SSL_free(my_conn->ssl);
delete my_conn;
});
return MYSQL_KEEPALIVE_DEFAULT;
}
int ComplexMySQLTask::auth_switch(MySQLAuthResponse *resp, MyConnection *conn)
{
std::string name = resp->get_auth_plugin_name();
if (conn->state != ST_AUTH_REQUEST ||
(name == "mysql_clear_password" && !is_ssl_))
{
state_ = WFT_STATE_SYS_ERROR;
error_ = EBADMSG;
return 0;
}
if (password_.empty())
{
conn->state = ST_CLEAR_PASSWORD_REQUEST;
}
else if (name == "sha256_password")
{
if (is_ssl_)
conn->state = ST_CLEAR_PASSWORD_REQUEST;
else
conn->state = ST_SHA256_PUBLIC_KEY_REQUEST;
}
else
{
conn->str = std::move(name);
conn->state = ST_AUTH_SWITCH_REQUEST;
}
resp->get_seed(conn->seed);
conn->mysql_seqid = resp->get_seqid() + 1;
return MYSQL_KEEPALIVE_DEFAULT;
}
int ComplexMySQLTask::keep_alive_timeout()
{
auto *msg = (ProtocolMessage *)this->get_message_in();
MySQLAuthResponse *auth_resp;
MySQLResponse *resp;
state_ = WFT_STATE_SUCCESS;
error_ = 0;
if (this->get_seq() == 0)
return check_handshake((MySQLHandshakeResponse *)msg);
auto *conn = (MyConnection *)this->get_connection();
if (conn->state == ST_SSL_REQUEST)
{
conn->state = ST_AUTH_REQUEST;
conn->mysql_seqid++;
return MYSQL_KEEPALIVE_DEFAULT;
}
if (is_ssl_)
resp = (MySQLResponse *)((MySSLWrapper *)msg)->get_msg();
else
resp = (MySQLResponse *)msg;
switch (conn->state)
{
case ST_AUTH_REQUEST:
case ST_AUTH_SWITCH_REQUEST:
case ST_CLEAR_PASSWORD_REQUEST:
case ST_RSA_AUTH_REQUEST:
if (resp->is_ok_packet())
{
if (!res_charset_.empty())
conn->state = ST_CHARSET_REQUEST;
else
conn->state = ST_FIRST_USER_REQUEST;
break;
}
if (resp->is_error_packet() ||
conn->state == ST_CLEAR_PASSWORD_REQUEST ||
conn->state == ST_RSA_AUTH_REQUEST)
{
this->resp = std::move(*resp);
state_ = WFT_STATE_TASK_ERROR;
error_ = WFT_ERR_MYSQL_ACCESS_DENIED;
return 0;
}
auth_resp = (MySQLAuthResponse *)resp;
if (auth_resp->is_continue())
{
if (is_ssl_)
conn->state = ST_CLEAR_PASSWORD_REQUEST;
else
conn->state = ST_CSHA2_PUBLIC_KEY_REQUEST;
break;
}
return auth_switch(auth_resp, conn);
case ST_SHA256_PUBLIC_KEY_REQUEST:
case ST_CSHA2_PUBLIC_KEY_REQUEST:
conn->str = ((MySQLPublicKeyResponse *)resp)->get_public_key();
conn->state = ST_RSA_AUTH_REQUEST;
break;
case ST_CHARSET_REQUEST:
if (!resp->is_ok_packet())
{
this->resp = std::move(*resp);
state_ = WFT_STATE_TASK_ERROR;
error_ = WFT_ERR_MYSQL_INVALID_CHARACTER_SET;
return 0;
}
conn->state = ST_FIRST_USER_REQUEST;
return MYSQL_KEEPALIVE_DEFAULT;
case ST_FIRST_USER_REQUEST:
conn->state = ST_USER_REQUEST;
case ST_USER_REQUEST:
return this->keep_alive_timeo;
default:
assert(0);
return 0;
}
conn->mysql_seqid = resp->get_seqid() + 1;
return MYSQL_KEEPALIVE_DEFAULT;
}
int ComplexMySQLTask::first_timeout()
{
return is_user_request_ ? this->watch_timeo : 0;
}
/*
+--------------------+---------------------+-----+
| CHARACTER_SET_NAME | COLLATION_NAME | ID |
+--------------------+---------------------+-----+
| big5 | big5_chinese_ci | 1 |
| dec8 | dec8_swedish_ci | 3 |
| cp850 | cp850_general_ci | 4 |
| hp8 | hp8_english_ci | 6 |
| koi8r | koi8r_general_ci | 7 |
| latin1 | latin1_swedish_ci | 8 |
| latin2 | latin2_general_ci | 9 |
| swe7 | swe7_swedish_ci | 10 |
| ascii | ascii_general_ci | 11 |
| ujis | ujis_japanese_ci | 12 |
| sjis | sjis_japanese_ci | 13 |
| hebrew | hebrew_general_ci | 16 |
| tis620 | tis620_thai_ci | 18 |
| euckr | euckr_korean_ci | 19 |
| koi8u | koi8u_general_ci | 22 |
| gb2312 | gb2312_chinese_ci | 24 |
| greek | greek_general_ci | 25 |
| cp1250 | cp1250_general_ci | 26 |
| gbk | gbk_chinese_ci | 28 |
| latin5 | latin5_turkish_ci | 30 |
| armscii8 | armscii8_general_ci | 32 |
| utf8 | utf8_general_ci | 33 |
| ucs2 | ucs2_general_ci | 35 |
| cp866 | cp866_general_ci | 36 |
| keybcs2 | keybcs2_general_ci | 37 |
| macce | macce_general_ci | 38 |
| macroman | macroman_general_ci | 39 |
| cp852 | cp852_general_ci | 40 |
| latin7 | latin7_general_ci | 41 |
| cp1251 | cp1251_general_ci | 51 |
| utf16 | utf16_general_ci | 54 |
| utf16le | utf16le_general_ci | 56 |
| cp1256 | cp1256_general_ci | 57 |
| cp1257 | cp1257_general_ci | 59 |
| utf32 | utf32_general_ci | 60 |
| binary | binary | 63 |
| geostd8 | geostd8_general_ci | 92 |
| cp932 | cp932_japanese_ci | 95 |
| eucjpms | eucjpms_japanese_ci | 97 |
| gb18030 | gb18030_chinese_ci | 248 |
| utf8mb4 | utf8mb4_0900_ai_ci | 255 |
+--------------------+---------------------+-----+
*/
static int __mysql_get_character_set(const std::string& charset)
{
static std::unordered_map<std::string, int> charset_map = {
{"big5", 1},
{"dec8", 3},
{"cp850", 4},
{"hp8", 5},
{"koi8r", 6},
{"latin1", 7},
{"latin2", 8},
{"swe7", 10},
{"ascii", 11},
{"ujis", 12},
{"sjis", 13},
{"hebrew", 16},
{"tis620", 18},
{"euckr", 19},
{"koi8u", 22},
{"gb2312", 24},
{"greek", 25},
{"cp1250", 26},
{"gbk", 28},
{"latin5", 30},
{"armscii8",32},
{"utf8", 33},
{"ucs2", 35},
{"cp866", 36},
{"keybcs2", 37},
{"macce", 38},
{"macroman",39},
{"cp852", 40},
{"latin7", 41},
{"cp1251", 51},
{"utf16", 54},
{"utf16le", 56},
{"cp1256", 57},
{"cp1257", 59},
{"utf32", 60},
{"binary", 63},
{"geostd8", 92},
{"cp932", 95},
{"eucjpms", 97},
{"gb18030", 248},
{"utf8mb4", 255},
};
const auto it = charset_map.find(charset);
if (it != charset_map.cend())
return it->second;
return -1;
}
bool ComplexMySQLTask::init_success()
{
if (uri_.scheme && strcasecmp(uri_.scheme, "mysql") == 0)
is_ssl_ = false;
else if (uri_.scheme && strcasecmp(uri_.scheme, "mysqls") == 0)
is_ssl_ = true;
else
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_SCHEME_INVALID;
return false;
}
//todo mysql+unix
username_.clear();
password_.clear();
db_.clear();
if (uri_.userinfo)
{
const char *colon = NULL;
const char *pos = uri_.userinfo;
while (*pos && *pos != ':')
pos++;
if (*pos == ':')
colon = pos++;
if (colon)
{
if (colon > uri_.userinfo)
{
username_.assign(uri_.userinfo, colon - uri_.userinfo);
StringUtil::url_decode(username_);
}
if (*pos)
{
password_.assign(pos);
StringUtil::url_decode(password_);
}
}
else
{
username_.assign(uri_.userinfo);
StringUtil::url_decode(username_);
}
}
if (uri_.path && uri_.path[0] == '/' && uri_.path[1])
{
db_.assign(uri_.path + 1);
StringUtil::url_decode(db_);
}
std::string transaction;
if (uri_.query)
{
auto query_kv = URIParser::split_query(uri_.query);
for (auto& kv : query_kv)
{
if (strcasecmp(kv.first.c_str(), "transaction") == 0)
transaction = std::move(kv.second);
else if (strcasecmp(kv.first.c_str(), "character_set") == 0)
{
character_set_ = __mysql_get_character_set(kv.second);
if (character_set_ < 0)
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_MYSQL_INVALID_CHARACTER_SET;
return false;
}
}
else if (strcasecmp(kv.first.c_str(), "character_set_results") == 0)
res_charset_ = std::move(kv.second);
}
}
size_t info_len = username_.size() + password_.size() + db_.size() +
res_charset_.size() + 50;
char *info = new char[info_len];
snprintf(info, info_len, "%s|user:%s|pass:%s|db:%s|"
"charset:%d|rcharset:%s",
is_ssl_ ? "mysqls" : "mysql", username_.c_str(), password_.c_str(),
db_.c_str(), character_set_, res_charset_.c_str());
this->WFComplexClientTask::set_transport_type(TT_TCP);
if (!transaction.empty())
{
this->set_fixed_addr(true);
this->set_fixed_conn(true);
this->WFComplexClientTask::set_info(info + ("|txn:" + transaction));
}
else
this->WFComplexClientTask::set_info(info);
delete []info;
return true;
}
bool ComplexMySQLTask::finish_once()
{
if (!is_user_request_)
{
delete this->get_message_out();
delete this->get_message_in();
if (this->state == WFT_STATE_SUCCESS && state_ != WFT_STATE_SUCCESS)
{
this->state = state_;
this->error = error_;
this->disable_retry();
}
is_user_request_ = true;
return false;
}
if (this->is_fixed_conn())
{
if (this->state != WFT_STATE_SUCCESS || this->keep_alive_timeo == 0)
{
if (this->target)
((RouteManager::RouteTarget *)this->target)->state = 0;
}
}
return true;
}
/**********Client Factory**********/
// mysql://user:password@host:port/db_name
// url = "mysql://admin:123456@192.168.1.101:3301/test"
// url = "mysql://127.0.0.1:3306"
WFMySQLTask *WFTaskFactory::create_mysql_task(const std::string& url,
int retry_max,
mysql_callback_t callback)
{
auto *task = new ComplexMySQLTask(retry_max, std::move(callback));
ParsedURI uri;
URIParser::parse(url, uri);
task->init(std::move(uri));
if (task->is_fixed_conn())
task->set_keep_alive(MYSQL_KEEPALIVE_TRANSACTION);
else
task->set_keep_alive(MYSQL_KEEPALIVE_DEFAULT);
return task;
}
WFMySQLTask *WFTaskFactory::create_mysql_task(const ParsedURI& uri,
int retry_max,
mysql_callback_t callback)
{
auto *task = new ComplexMySQLTask(retry_max, std::move(callback));
task->init(uri);
if (task->is_fixed_conn())
task->set_keep_alive(MYSQL_KEEPALIVE_TRANSACTION);
else
task->set_keep_alive(MYSQL_KEEPALIVE_DEFAULT);
return task;
}
/**********Server**********/
class WFMySQLServerTask : public WFServerTask<MySQLRequest, MySQLResponse>
{
public:
WFMySQLServerTask(CommService *service,
std::function<void (WFMySQLTask *)>& proc):
WFServerTask(service, WFGlobal::get_scheduler(), proc)
{}
protected:
virtual SubTask *done();
virtual CommMessageOut *message_out();
virtual CommMessageIn *message_in();
};
SubTask *WFMySQLServerTask::done()
{
if (this->get_seq() == 0)
delete this->get_message_in();
return this->WFServerTask::done();
}
CommMessageOut *WFMySQLServerTask::message_out()
{
long long seqid = this->get_seq();
if (seqid == 0)
this->resp.set_ok_packet(); // always success
return this->WFServerTask::message_out();
}
CommMessageIn *WFMySQLServerTask::message_in()
{
long long seqid = this->get_seq();
if (seqid == 0)
return new MySQLAuthRequest;
return this->WFServerTask::message_in();
}
/**********Server Factory**********/
WFMySQLTask *WFServerTaskFactory::create_mysql_task(CommService *service,
std::function<void (WFMySQLTask *)>& process)
{
return new WFMySQLServerTask(service, process);
}

@ -0,0 +1,446 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Wu Jiaxu (wujiaxu@sogou-inc.com)
Li Yingxin (liyingxin@sogou-inc.com)
Liu Kai (liukaidx@sogou-inc.com)
Xie Han (xiehan@sogou-inc.com)
*/
#include <stdio.h>
#include <string.h>
#include <string>
#include "PackageWrapper.h"
#include "WFTaskError.h"
#include "WFTaskFactory.h"
#include "StringUtil.h"
#include "RedisTaskImpl.inl"
using namespace protocol;
#define REDIS_KEEPALIVE_DEFAULT (60 * 1000)
#define REDIS_REDIRECT_MAX 3
/**********Client**********/
class ComplexRedisTask : public WFComplexClientTask<RedisRequest, RedisResponse>
{
public:
ComplexRedisTask(int retry_max, redis_callback_t&& callback):
WFComplexClientTask(retry_max, std::move(callback)),
db_num_(0),
is_user_request_(true),
redirect_count_(0)
{}
protected:
virtual bool check_request();
virtual CommMessageOut *message_out();
virtual CommMessageIn *message_in();
virtual int keep_alive_timeout();
virtual int first_timeout();
virtual bool init_success();
virtual bool finish_once();
protected:
bool need_redirect();
std::string username_;
std::string password_;
int db_num_;
bool succ_;
bool is_user_request_;
int redirect_count_;
};
bool ComplexRedisTask::check_request()
{
std::string command;
if (this->req.get_command(command) &&
(strcasecmp(command.c_str(), "AUTH") == 0 ||
strcasecmp(command.c_str(), "SELECT") == 0 ||
strcasecmp(command.c_str(), "RESET") == 0 ||
strcasecmp(command.c_str(), "ASKING") == 0))
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_REDIS_COMMAND_DISALLOWED;
return false;
}
return true;
}
CommMessageOut *ComplexRedisTask::message_out()
{
long long seqid = this->get_seq();
if (seqid <= 1)
{
if (seqid == 0 && (!password_.empty() || !username_.empty()))
{
auto *auth_req = new RedisRequest;
if (!username_.empty())
auth_req->set_request("AUTH", {username_, password_});
else
auth_req->set_request("AUTH", {password_});
succ_ = false;
is_user_request_ = false;
return auth_req;
}
if (db_num_ > 0 &&
(seqid == 0 || !password_.empty() || !username_.empty()))
{
auto *select_req = new RedisRequest;
char buf[32];
sprintf(buf, "%d", db_num_);
select_req->set_request("SELECT", {buf});
succ_ = false;
is_user_request_ = false;
return select_req;
}
}
return this->WFComplexClientTask::message_out();
}
CommMessageIn *ComplexRedisTask::message_in()
{
RedisRequest *req = this->get_req();
RedisResponse *resp = this->get_resp();
if (is_user_request_)
resp->set_asking(req->is_asking());
else
resp->set_asking(false);
return this->WFComplexClientTask::message_in();
}
int ComplexRedisTask::keep_alive_timeout()
{
if (this->is_user_request_)
return this->keep_alive_timeo;
RedisResponse *resp = this->get_resp();
succ_ = (resp->parse_success() &&
resp->result_ptr()->type != REDIS_REPLY_TYPE_ERROR);
return succ_ ? REDIS_KEEPALIVE_DEFAULT : 0;
}
int ComplexRedisTask::first_timeout()
{
return is_user_request_ ? this->watch_timeo : 0;
}
bool ComplexRedisTask::init_success()
{
enum TransportType type;
if (uri_.scheme && strcasecmp(uri_.scheme, "redis") == 0)
type = TT_TCP;
else if (uri_.scheme && strcasecmp(uri_.scheme, "rediss") == 0)
type = TT_TCP_SSL;
else
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_SCHEME_INVALID;
return false;
}
//todo redis+unix
//https://stackoverflow.com/questions/26964595/whats-the-correct-way-to-use-a-unix-domain-socket-in-requests-framework
//https://stackoverflow.com/questions/27037990/connecting-to-postgres-via-database-url-and-unix-socket-in-rails
if (uri_.userinfo)
{
char *p = strchr(uri_.userinfo, ':');
if (p)
{
username_.assign(uri_.userinfo, p);
password_.assign(p + 1);
StringUtil::url_decode(username_);
StringUtil::url_decode(password_);
}
else
{
username_.assign(uri_.userinfo);
StringUtil::url_decode(username_);
}
}
if (uri_.path && uri_.path[0] == '/' && uri_.path[1])
db_num_ = atoi(uri_.path + 1);
size_t info_len = username_.size() + password_.size() + 32 + 32;
char *info = new char[info_len];
sprintf(info, "redis|user:%s|pass:%s|db:%d", username_.c_str(),
password_.c_str(), db_num_);
this->WFComplexClientTask::set_transport_type(type);
this->WFComplexClientTask::set_info(info);
delete []info;
return true;
}
bool ComplexRedisTask::need_redirect()
{
RedisRequest *client_req = this->get_req();
RedisResponse *client_resp = this->get_resp();
redis_reply_t *reply = client_resp->result_ptr();
if (reply->type == REDIS_REPLY_TYPE_ERROR)
{
if (reply->str == NULL)
return false;
if (strncasecmp(reply->str, "NOAUTH ", 7) == 0)
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_REDIS_ACCESS_DENIED;
return false;
}
bool asking = false;
if (strncasecmp(reply->str, "ASK ", 4) == 0)
asking = true;
else if (strncasecmp(reply->str, "MOVED ", 6) != 0)
return false;
if (redirect_count_ >= REDIS_REDIRECT_MAX)
return false;
std::string err_str(reply->str, reply->len);
auto split_result = StringUtil::split_filter_empty(err_str, ' ');
if (split_result.size() == 3)
{
client_req->set_asking(asking);
// format: COMMAND SLOT HOSTPORT
// example: MOVED/ASK 123 127.0.0.1:6379
std::string& hostport = split_result[2];
redirect_count_++;
ParsedURI uri;
std::string url;
url.append(uri_.scheme);
url.append("://");
url.append(hostport);
URIParser::parse(url, uri);
std::swap(uri.host, uri_.host);
std::swap(uri.port, uri_.port);
std::swap(uri.state, uri_.state);
std::swap(uri.error, uri_.error);
return true;
}
}
return false;
}
bool ComplexRedisTask::finish_once()
{
if (!is_user_request_)
{
is_user_request_ = true;
delete this->get_message_out();
if (this->state == WFT_STATE_SUCCESS)
{
if (succ_)
this->clear_resp();
else
{
this->disable_retry();
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_REDIS_ACCESS_DENIED;
}
}
return false;
}
if (this->state == WFT_STATE_SUCCESS)
{
if (need_redirect())
this->set_redirect(uri_);
else if (this->state != WFT_STATE_SUCCESS)
this->disable_retry();
}
return true;
}
/****** Redis Subscribe ******/
class ComplexRedisSubscribeTask : public ComplexRedisTask
{
public:
virtual int push(const void *buf, size_t size)
{
if (finished_)
{
errno = ENOENT;
return -1;
}
if (!watching_)
{
errno = EAGAIN;
return -1;
}
return this->scheduler->push(buf, size, this);
}
protected:
virtual CommMessageIn *message_in()
{
if (!is_user_request_)
return this->ComplexRedisTask::message_in();
return &wrapper_;
}
virtual int first_timeout()
{
return watching_ ? this->watch_timeo : 0;
}
protected:
class SubscribeWrapper : public PackageWrapper
{
protected:
virtual ProtocolMessage *next_in(ProtocolMessage *message);
protected:
ComplexRedisSubscribeTask *task_;
public:
SubscribeWrapper(ComplexRedisSubscribeTask *task) :
PackageWrapper(task->get_resp())
{
task_ = task;
}
};
protected:
SubscribeWrapper wrapper_;
bool watching_;
bool finished_;
std::function<void (WFRedisTask *)> extract_;
public:
ComplexRedisSubscribeTask(std::function<void (WFRedisTask *)>&& extract,
redis_callback_t&& callback) :
ComplexRedisTask(0, std::move(callback)),
wrapper_(this),
extract_(std::move(extract))
{
watching_ = false;
finished_ = false;
}
};
ProtocolMessage *
ComplexRedisSubscribeTask::SubscribeWrapper::next_in(ProtocolMessage *message)
{
redis_reply_t *reply = ((RedisResponse *)message)->result_ptr();
if (reply->type != REDIS_REPLY_TYPE_ARRAY)
{
task_->finished_ = true;
return NULL;
}
if (reply->elements == 3 &&
reply->element[2]->type == REDIS_REPLY_TYPE_INTEGER &&
reply->element[2]->integer == 0)
{
task_->finished_ = true;
}
task_->watching_ = true;
task_->extract_(task_);
task_->clear_resp();
return task_->finished_ ? NULL : &task_->resp;
}
/**********Factory**********/
// redis://:password@host:port/db_num
// url = "redis://:admin@192.168.1.101:6001/3"
// url = "redis://127.0.0.1:6379"
WFRedisTask *WFTaskFactory::create_redis_task(const std::string& url,
int retry_max,
redis_callback_t callback)
{
auto *task = new ComplexRedisTask(retry_max, std::move(callback));
ParsedURI uri;
URIParser::parse(url, uri);
task->init(std::move(uri));
task->set_keep_alive(REDIS_KEEPALIVE_DEFAULT);
return task;
}
WFRedisTask *WFTaskFactory::create_redis_task(const ParsedURI& uri,
int retry_max,
redis_callback_t callback)
{
auto *task = new ComplexRedisTask(retry_max, std::move(callback));
task->init(uri);
task->set_keep_alive(REDIS_KEEPALIVE_DEFAULT);
return task;
}
WFRedisTask *
__WFRedisTaskFactory::create_subscribe_task(const std::string& url,
extract_t extract,
redis_callback_t callback)
{
auto *task = new ComplexRedisSubscribeTask(std::move(extract),
std::move(callback));
ParsedURI uri;
URIParser::parse(url, uri);
task->init(std::move(uri));
return task;
}
WFRedisTask *
__WFRedisTaskFactory::create_subscribe_task(const ParsedURI& uri,
extract_t extract,
redis_callback_t callback)
{
auto *task = new ComplexRedisSubscribeTask(std::move(extract),
std::move(callback));
task->init(uri);
return task;
}

@ -0,0 +1,37 @@
/*
Copyright (c) 2024 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Xie Han (xiehan@sogou-inc.com)
*/
#include "WFTaskFactory.h"
// Internal, for WFRedisSubscribeTask only.
class __WFRedisTaskFactory
{
private:
using extract_t = std::function<void (WFRedisTask *)>;
public:
static WFRedisTask *create_subscribe_task(const std::string& url,
extract_t extract,
redis_callback_t callback);
static WFRedisTask *create_subscribe_task(const ParsedURI& uri,
extract_t extract,
redis_callback_t callback);
};

@ -0,0 +1,259 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#ifndef _WFALGOTASKFACTORY_H_
#define _WFALGOTASKFACTORY_H_
#include <utility>
#include <string>
#include <vector>
#include "WFTask.h"
namespace algorithm
{
template<typename T>
struct SortInput
{
T *first;
T *last;
};
template<typename T>
struct SortOutput
{
T *first;
T *last;
};
template<typename T>
struct MergeInput
{
T *first1;
T *last1;
T *first2;
T *last2;
T *d_first;
};
template<typename T>
struct MergeOutput
{
T *first;
T *last;
};
template<typename T>
struct ShuffleInput
{
T *first;
T *last;
};
template<typename T>
struct ShuffleOutput
{
T *first;
T *last;
};
template<typename T>
struct RemoveInput
{
T *first;
T *last;
T value;
};
template<typename T>
struct RemoveOutput
{
T *first;
T *last;
};
template<typename T>
struct UniqueInput
{
T *first;
T *last;
};
template<typename T>
struct UniqueOutput
{
T *first;
T *last;
};
template<typename T>
struct ReverseInput
{
T *first;
T *last;
};
template<typename T>
struct ReverseOutput
{
T *first;
T *last;
};
template<typename T>
struct RotateInput
{
T *first;
T *middle;
T *last;
};
template<typename T>
struct RotateOutput
{
T *first;
T *last;
};
template<typename KEY = std::string, typename VAL = std::string>
using ReduceInput = std::vector<std::pair<KEY, VAL>>;
template<typename KEY = std::string, typename VAL = std::string>
using ReduceOutput = std::vector<std::pair<KEY, VAL>>;
} /* namespace algorithm */
template<typename T>
using WFSortTask = WFThreadTask<algorithm::SortInput<T>,
algorithm::SortOutput<T>>;
template<typename T>
using sort_callback_t = std::function<void (WFSortTask<T> *)>;
template<typename T>
using WFMergeTask = WFThreadTask<algorithm::MergeInput<T>,
algorithm::MergeOutput<T>>;
template<typename T>
using merge_callback_t = std::function<void (WFMergeTask<T> *)>;
template<typename T>
using WFShuffleTask = WFThreadTask<algorithm::ShuffleInput<T>,
algorithm::ShuffleOutput<T>>;
template<typename T>
using shuffle_callback_t = std::function<void (WFShuffleTask<T> *)>;
template<typename T>
using WFRemoveTask = WFThreadTask<algorithm::RemoveInput<T>,
algorithm::RemoveOutput<T>>;
template<typename T>
using remove_callback_t = std::function<void (WFRemoveTask<T> *)>;
template<typename T>
using WFUniqueTask = WFThreadTask<algorithm::UniqueInput<T>,
algorithm::UniqueOutput<T>>;
template<typename T>
using unique_callback_t = std::function<void (WFUniqueTask<T> *)>;
template<typename T>
using WFReverseTask = WFThreadTask<algorithm::ReverseInput<T>,
algorithm::ReverseOutput<T>>;
template<typename T>
using reverse_callback_t = std::function<void (WFReverseTask<T> *)>;
template<typename T>
using WFRotateTask = WFThreadTask<algorithm::RotateInput<T>,
algorithm::RotateOutput<T>>;
template<typename T>
using rotate_callback_t = std::function<void (WFRotateTask<T> *)>;
class WFAlgoTaskFactory
{
public:
template<typename T, class CB = sort_callback_t<T>>
static WFSortTask<T> *create_sort_task(const std::string& queue_name,
T *first, T *last,
CB callback);
template<typename T, class CMP, class CB = sort_callback_t<T>>
static WFSortTask<T> *create_sort_task(const std::string& queue_name,
T *first, T *last,
CMP compare,
CB callback);
template<typename T, class CB = sort_callback_t<T>>
static WFSortTask<T> *create_psort_task(const std::string& queue_name,
T *first, T *last,
CB callback);
template<typename T, class CMP, class CB = sort_callback_t<T>>
static WFSortTask<T> *create_psort_task(const std::string& queue_name,
T *first, T *last,
CMP compare,
CB callback);
template<typename T, class CB = merge_callback_t<T>>
static WFMergeTask<T> *create_merge_task(const std::string& queue_name,
T *first1, T *last1,
T *first2, T *last2,
T *d_first,
CB callback);
template<typename T, class CMP, class CB = merge_callback_t<T>>
static WFMergeTask<T> *create_merge_task(const std::string& queue_name,
T *first1, T *last1,
T *first2, T *last2,
T *d_first,
CMP compare,
CB callback);
template<typename T, class CB = shuffle_callback_t<T>>
static WFShuffleTask<T> *create_shuffle_task(const std::string& queue_name,
T *first, T *last,
CB callback);
template<typename T, class URBG, class CB = shuffle_callback_t<T>>
static WFShuffleTask<T> *create_shuffle_task(const std::string& queue_name,
T *first, T *last,
URBG generator,
CB callback);
template<typename T, class CB = remove_callback_t<T>>
static WFRemoveTask<T> *create_remove_task(const std::string& queue_name,
T *first, T *last,
T value,
CB callback);
template<typename T, class CB = unique_callback_t<T>>
static WFUniqueTask<T> *create_unique_task(const std::string& queue_name,
T *first, T *last,
CB callback);
template<typename T, class CB = reverse_callback_t<T>>
static WFReverseTask<T> *create_reverse_task(const std::string& queue_name,
T *first, T *last,
CB callback);
template<typename T, class CB = rotate_callback_t<T>>
static WFRotateTask<T> *create_rotate_task(const std::string& queue_name,
T *first, T *middle, T *last,
CB callback);
};
#include "WFAlgoTaskFactory.inl"
#endif

@ -0,0 +1,651 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#include <stdlib.h>
#include <random>
#include <algorithm>
#include <vector>
#include <functional>
#include <utility>
#include "Workflow.h"
#include "WFGlobal.h"
/********** Classes without CMP **********/
template<typename T>
class __WFSortTask : public WFSortTask<T>
{
protected:
virtual void execute()
{
std::sort(this->input.first, this->input.last);
this->output.first = this->input.first;
this->output.last = this->input.last;
}
public:
__WFSortTask(ExecQueue *queue, Executor *executor,
T *first, T *last,
sort_callback_t<T>&& cb) :
WFSortTask<T>(queue, executor, std::move(cb))
{
this->input.first = first;
this->input.last = last;
this->output.first = NULL;
this->output.last = NULL;
}
};
template<typename T>
class __WFMergeTask : public WFMergeTask<T>
{
protected:
virtual void execute();
public:
__WFMergeTask(ExecQueue *queue, Executor *executor,
T *first1, T *last1, T *first2, T *last2, T *d_first,
merge_callback_t<T>&& cb) :
WFMergeTask<T>(queue, executor, std::move(cb))
{
this->input.first1 = first1;
this->input.last1 = last1;
this->input.first2 = first2;
this->input.last2 = last2;
this->input.d_first = d_first;
this->output.first = NULL;
this->output.last = NULL;
}
};
template<typename T>
void __WFMergeTask<T>::execute()
{
auto *input = &this->input;
auto *output = &this->output;
if (input->first1 == input->d_first && input->last1 == input->first2)
{
std::inplace_merge(input->first1, input->first2, input->last2);
output->last = input->last2;
}
else if (input->first2 == input->d_first && input->last2 == input->first1)
{
std::inplace_merge(input->first2, input->first1, input->last1);
output->last = input->last1;
}
else
{
output->last = std::merge(input->first1, input->last1,
input->first2, input->last2,
input->d_first);
}
output->first = input->d_first;
}
template<typename T>
class __WFParSortTask : public __WFSortTask<T>
{
public:
virtual void dispatch();
protected:
virtual SubTask *done()
{
if (this->flag)
return series_of(this)->pop();
return this->WFSortTask<T>::done();
}
virtual void execute();
protected:
int depth;
int flag;
public:
__WFParSortTask(ExecQueue *queue, Executor *executor,
T *first, T *last, int depth,
sort_callback_t<T>&& cb) :
__WFSortTask<T>(queue, executor, first, last, std::move(cb))
{
this->depth = depth;
this->flag = 0;
}
};
template<typename T>
void __WFParSortTask<T>::dispatch()
{
size_t n = this->input.last - this->input.first;
if (!this->flag && this->depth < 7 && n >= 32)
{
SeriesWork *series = series_of(this);
T *middle = this->input.first + n / 2;
auto *task1 =
new __WFParSortTask<T>(this->queue, this->executor,
this->input.first, middle,
this->depth + 1,
nullptr);
auto *task2 =
new __WFParSortTask<T>(this->queue, this->executor,
middle, this->input.last,
this->depth + 1,
nullptr);
SeriesWork *sub_series[2] = {
Workflow::create_series_work(task1, nullptr),
Workflow::create_series_work(task2, nullptr)
};
ParallelWork *parallel =
Workflow::create_parallel_work(sub_series, 2, nullptr);
series->push_front(this);
series->push_front(parallel);
this->flag = 1;
this->subtask_done();
}
else
this->__WFSortTask<T>::dispatch();
}
template<typename T>
void __WFParSortTask<T>::execute()
{
if (this->flag)
{
size_t n = this->input.last - this->input.first;
T *middle = this->input.first + n / 2;
std::inplace_merge(this->input.first, middle, this->input.last);
this->output.first = this->input.first;
this->output.last = this->input.last;
this->flag = 0;
}
else
this->__WFSortTask<T>::execute();
}
/********** Classes with CMP **********/
template<typename T, class CMP>
class __WFSortTaskCmp : public __WFSortTask<T>
{
protected:
virtual void execute()
{
std::sort(this->input.first, this->input.last,
std::move(this->compare));
this->output.first = this->input.first;
this->output.last = this->input.last;
}
protected:
CMP compare;
public:
__WFSortTaskCmp(ExecQueue *queue, Executor *executor,
T *first, T *last, CMP&& cmp,
sort_callback_t<T>&& cb) :
__WFSortTask<T>(queue, executor, first, last, std::move(cb)),
compare(std::move(cmp))
{
}
};
template<typename T, class CMP>
class __WFMergeTaskCmp : public __WFMergeTask<T>
{
protected:
virtual void execute();
protected:
CMP compare;
public:
__WFMergeTaskCmp(ExecQueue *queue, Executor *executor,
T *first1, T *last1, T *first2, T *last2,
T *d_first, CMP&& cmp,
merge_callback_t<T>&& cb) :
__WFMergeTask<T>(queue, executor, first1, last1, first2, last2, d_first,
std::move(cb)),
compare(std::move(cmp))
{
}
};
template<typename T, class CMP>
void __WFMergeTaskCmp<T, CMP>::execute()
{
auto *input = &this->input;
auto *output = &this->output;
if (input->first1 == input->d_first && input->last1 == input->first2)
{
std::inplace_merge(input->first1, input->first2, input->last2,
std::move(this->compare));
output->last = input->last2;
}
else if (input->first2 == input->d_first && input->last2 == input->first1)
{
std::inplace_merge(input->first2, input->first1, input->last1,
std::move(this->compare));
output->last = input->last1;
}
else
{
output->last = std::merge(input->first1, input->last1,
input->first2, input->last2,
input->d_first,
std::move(this->compare));
}
output->first = input->d_first;
}
template<typename T, class CMP>
class __WFParSortTaskCmp : public __WFSortTaskCmp<T, CMP>
{
public:
virtual void dispatch();
protected:
virtual SubTask *done()
{
if (this->flag)
return series_of(this)->pop();
return this->WFSortTask<T>::done();
}
virtual void execute();
protected:
int depth;
int flag;
public:
__WFParSortTaskCmp(ExecQueue *queue, Executor *executor,
T *first, T *last, CMP cmp, int depth,
sort_callback_t<T>&& cb) :
__WFSortTaskCmp<T, CMP>(queue, executor, first, last, std::move(cmp),
std::move(cb))
{
this->depth = depth;
this->flag = 0;
}
};
template<typename T, class CMP>
void __WFParSortTaskCmp<T, CMP>::dispatch()
{
size_t n = this->input.last - this->input.first;
if (!this->flag && this->depth < 7 && n >= 32)
{
SeriesWork *series = series_of(this);
T *middle = this->input.first + n / 2;
auto *task1 =
new __WFParSortTaskCmp<T, CMP>(this->queue, this->executor,
this->input.first, middle,
this->compare, this->depth + 1,
nullptr);
auto *task2 =
new __WFParSortTaskCmp<T, CMP>(this->queue, this->executor,
middle, this->input.last,
this->compare, this->depth + 1,
nullptr);
SeriesWork *sub_series[2] = {
Workflow::create_series_work(task1, nullptr),
Workflow::create_series_work(task2, nullptr)
};
ParallelWork *parallel =
Workflow::create_parallel_work(sub_series, 2, nullptr);
series->push_front(this);
series->push_front(parallel);
this->flag = 1;
this->subtask_done();
}
else
this->__WFSortTaskCmp<T, CMP>::dispatch();
}
template<typename T, class CMP>
void __WFParSortTaskCmp<T, CMP>::execute()
{
if (this->flag)
{
size_t n = this->input.last - this->input.first;
T *middle = this->input.first + n / 2;
std::inplace_merge(this->input.first, middle, this->input.last,
std::move(this->compare));
this->output.first = this->input.first;
this->output.last = this->input.last;
this->flag = 0;
}
else
this->__WFSortTaskCmp<T, CMP>::execute();
}
/********** Factory functions without CMP **********/
template<typename T, class CB>
WFSortTask<T> *WFAlgoTaskFactory::create_sort_task(const std::string& name,
T *first, T *last,
CB callback)
{
return new __WFSortTask<T>(WFGlobal::get_exec_queue(name),
WFGlobal::get_compute_executor(),
first, last,
std::move(callback));
}
template<typename T, class CB>
WFMergeTask<T> *WFAlgoTaskFactory::create_merge_task(const std::string& name,
T *first1, T *last1,
T *first2, T *last2,
T *d_first,
CB callback)
{
return new __WFMergeTask<T>(WFGlobal::get_exec_queue(name),
WFGlobal::get_compute_executor(),
first1, last1, first2, last2, d_first,
std::move(callback));
}
template<typename T, class CB>
WFSortTask<T> *WFAlgoTaskFactory::create_psort_task(const std::string& name,
T *first, T *last,
CB callback)
{
return new __WFParSortTask<T>(WFGlobal::get_exec_queue(name),
WFGlobal::get_compute_executor(),
first, last, 0,
std::move(callback));
}
/********** Factory functions with CMP **********/
template<typename T, class CMP, class CB>
WFSortTask<T> *WFAlgoTaskFactory::create_sort_task(const std::string& name,
T *first, T *last,
CMP compare,
CB callback)
{
return new __WFSortTaskCmp<T, CMP>(WFGlobal::get_exec_queue(name),
WFGlobal::get_compute_executor(),
first, last, std::move(compare),
std::move(callback));
}
template<typename T, class CMP, class CB>
WFMergeTask<T> *WFAlgoTaskFactory::create_merge_task(const std::string& name,
T *first1, T *last1,
T *first2, T *last2,
T *d_first,
CMP compare,
CB callback)
{
return new __WFMergeTaskCmp<T, CMP>(WFGlobal::get_exec_queue(name),
WFGlobal::get_compute_executor(),
first1, last1, first2, last2,
d_first, std::move(compare),
std::move(callback));
}
template<typename T, class CMP, class CB>
WFSortTask<T> *WFAlgoTaskFactory::create_psort_task(const std::string& name,
T *first, T *last,
CMP compare,
CB callback)
{
return new __WFParSortTaskCmp<T, CMP>(WFGlobal::get_exec_queue(name),
WFGlobal::get_compute_executor(),
first, last, std::move(compare), 0,
std::move(callback));
}
/****************** Shuffle ******************/
template<typename T>
class __WFShuffleTask : public WFShuffleTask<T>
{
protected:
virtual void execute()
{
std::shuffle(this->input.first, this->input.last,
std::mt19937_64(random()));
this->output.first = this->input.first;
this->output.last = this->input.last;
}
public:
__WFShuffleTask(ExecQueue *queue, Executor *executor,
T *first, T *last,
shuffle_callback_t<T>&& cb) :
WFShuffleTask<T>(queue, executor, std::move(cb))
{
this->input.first = first;
this->input.last = last;
this->output.first = NULL;
this->output.last = NULL;
}
};
template<typename T, class URBG>
class __WFShuffleTaskGen : public __WFShuffleTask<T>
{
protected:
virtual void execute()
{
std::shuffle(this->input.first, this->input.last,
std::move(this->generator));
this->output.first = this->input.first;
this->output.last = this->input.last;
}
protected:
URBG generator;
public:
__WFShuffleTaskGen(ExecQueue *queue, Executor *executor,
T *first, T *last, URBG&& gen,
shuffle_callback_t<T>&& cb) :
__WFShuffleTask<T>(queue, executor, std::move(cb)),
generator(std::move(gen))
{
}
};
template<typename T, class CB>
WFShuffleTask<T> *WFAlgoTaskFactory::create_shuffle_task(const std::string& name,
T *first, T *last,
CB callback)
{
return new __WFShuffleTask<T>(WFGlobal::get_exec_queue(name),
WFGlobal::get_compute_executor(),
first, last,
std::move(callback));
}
template<typename T, class URBG, class CB>
WFShuffleTask<T> *WFAlgoTaskFactory::create_shuffle_task(const std::string& name,
T *first, T *last,
URBG generator,
CB callback)
{
return new __WFShuffleTaskGen<T, URBG>(WFGlobal::get_exec_queue(name),
WFGlobal::get_compute_executor(),
first, last, std::move(generator),
std::move(callback));
}
/****************** Remove ******************/
template<typename T>
class __WFRemoveTask : public WFRemoveTask<T>
{
protected:
virtual void execute()
{
this->output.last = std::remove(this->input.first, this->input.last,
this->input.value);
this->output.first = this->input.first;
}
public:
__WFRemoveTask(ExecQueue *queue, Executor *executor,
T *first, T *last, T&& value,
remove_callback_t<T>&& cb) :
WFRemoveTask<T>(queue, executor, std::move(cb))
{
this->input.first = first;
this->input.last = last;
this->input.value = std::move(value);
this->output.first = NULL;
this->output.last = NULL;
}
};
template<typename T, class CB>
WFRemoveTask<T> *WFAlgoTaskFactory::create_remove_task(const std::string& name,
T *first, T *last,
T value,
CB callback)
{
return new __WFRemoveTask<T>(WFGlobal::get_exec_queue(name),
WFGlobal::get_compute_executor(),
first, last, std::move(value),
std::move(callback));
}
/****************** Unique ******************/
template<typename T>
class __WFUniqueTask : public WFUniqueTask<T>
{
protected:
virtual void execute()
{
this->output.last = std::unique(this->input.first, this->input.last);
this->output.first = this->input.first;
}
public:
__WFUniqueTask(ExecQueue *queue, Executor *executor,
T *first, T *last,
unique_callback_t<T>&& cb) :
WFUniqueTask<T>(queue, executor, std::move(cb))
{
this->input.first = first;
this->input.last = last;
this->output.first = NULL;
this->output.last = NULL;
}
};
template<typename T, class CB>
WFUniqueTask<T> *WFAlgoTaskFactory::create_unique_task(const std::string& name,
T *first, T *last,
CB callback)
{
return new __WFUniqueTask<T>(WFGlobal::get_exec_queue(name),
WFGlobal::get_compute_executor(),
first, last,
std::move(callback));
}
/****************** Reverse ******************/
template<typename T>
class __WFReverseTask : public WFReverseTask<T>
{
protected:
virtual void execute()
{
std::reverse(this->input.first, this->input.last);
this->output.first = this->input.first;
this->output.last = this->input.last;
}
public:
__WFReverseTask(ExecQueue *queue, Executor *executor,
T *first, T *last,
reverse_callback_t<T>&& cb) :
WFReverseTask<T>(queue, executor, std::move(cb))
{
this->input.first = first;
this->input.last = last;
this->output.first = NULL;
this->output.last = NULL;
}
};
template<typename T, class CB>
WFReverseTask<T> *WFAlgoTaskFactory::create_reverse_task(const std::string& name,
T *first, T *last,
CB callback)
{
return new __WFReverseTask<T>(WFGlobal::get_exec_queue(name),
WFGlobal::get_compute_executor(),
first, last,
std::move(callback));
}
/****************** Rotate ******************/
template<typename T>
class __WFRotateTask : public WFRotateTask<T>
{
protected:
virtual void execute()
{
std::rotate(this->input.first, this->input.middle, this->input.last);
this->output.first = this->input.first;
this->output.last = this->input.last;
}
public:
__WFRotateTask(ExecQueue *queue, Executor *executor,
T *first, T* middle, T *last,
rotate_callback_t<T>&& cb) :
WFRotateTask<T>(queue, executor, std::move(cb))
{
this->input.first = first;
this->input.middle = middle;
this->input.last = last;
this->output.first = NULL;
this->output.last = NULL;
}
};
template<typename T, class CB>
WFRotateTask<T> *WFAlgoTaskFactory::create_rotate_task(const std::string& name,
T *first, T *middle, T *last,
CB callback)
{
return new __WFRotateTask<T>(WFGlobal::get_exec_queue(name),
WFGlobal::get_compute_executor(),
first, middle, last,
std::move(callback));
}

@ -0,0 +1,69 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#ifndef _WFCONNECTION_H_
#define _WFCONNECTION_H_
#include <utility>
#include <atomic>
#include <functional>
#include "Communicator.h"
class WFConnection : public CommConnection
{
public:
void *get_context() const
{
return this->context;
}
void set_context(void *context, std::function<void (void *)> deleter)
{
this->context = context;
this->deleter = std::move(deleter);
}
void *test_set_context(void *test_context, void *new_context,
std::function<void (void *)> deleter)
{
if (this->context.compare_exchange_strong(test_context, new_context))
{
this->deleter = std::move(deleter);
return new_context;
}
return test_context;
}
private:
std::atomic<void *> context;
std::function<void (void *)> deleter;
public:
WFConnection() : context(NULL) { }
protected:
virtual ~WFConnection()
{
if (this->deleter)
this->deleter(this->context);
}
};
#endif

@ -0,0 +1,110 @@
/*
Copyright (c) 2020 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#include <vector>
#include "Workflow.h"
#include "WFGraphTask.h"
SubTask *WFGraphNode::done()
{
SeriesWork *series = series_of(this);
if (!this->user_data)
{
this->value = 1;
this->user_data = (void *)1;
}
else
delete this;
return series->pop();
}
WFGraphNode::~WFGraphNode()
{
if (this->user_data)
{
if (series_of(this)->is_canceled())
{
for (WFGraphNode *node : this->successors)
series_of(node)->SeriesWork::cancel();
}
for (WFGraphNode *node : this->successors)
node->WFCounterTask::count();
}
}
WFGraphNode& WFGraphTask::create_graph_node(SubTask *task)
{
WFGraphNode *node = new WFGraphNode;
SeriesWork *series = Workflow::create_series_work(node, node, nullptr);
series->push_back(task);
this->parallel->add_series(series);
return *node;
}
void WFGraphTask::dispatch()
{
SeriesWork *series = series_of(this);
if (this->parallel)
{
series->push_front(this);
series->push_front(this->parallel);
this->parallel = NULL;
}
else
this->state = WFT_STATE_SUCCESS;
this->subtask_done();
}
SubTask *WFGraphTask::done()
{
SeriesWork *series = series_of(this);
if (this->state == WFT_STATE_SUCCESS)
{
if (this->callback)
this->callback(this);
delete this;
}
return series->pop();
}
WFGraphTask::~WFGraphTask()
{
SeriesWork *series;
size_t i;
if (this->parallel)
{
for (i = 0; i < this->parallel->size(); i++)
{
series = this->parallel->series_at(i);
series->unset_last_task();
}
this->parallel->dismiss();
}
}

@ -0,0 +1,107 @@
/*
Copyright (c) 2020 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#ifndef _WFGRAPHTASK_H_
#define _WFGRAPHTASK_H_
#include <vector>
#include <utility>
#include <functional>
#include "Workflow.h"
#include "WFTask.h"
class WFGraphNode : protected WFCounterTask
{
public:
void precede(WFGraphNode& node)
{
node.value++;
this->successors.push_back(&node);
}
void succeed(WFGraphNode& node)
{
node.precede(*this);
}
protected:
virtual SubTask *done();
protected:
std::vector<WFGraphNode *> successors;
protected:
WFGraphNode() : WFCounterTask(0, nullptr) { }
virtual ~WFGraphNode();
friend class WFGraphTask;
};
static inline WFGraphNode& operator --(WFGraphNode& node, int)
{
return node;
}
static inline WFGraphNode& operator > (WFGraphNode& prec, WFGraphNode& succ)
{
prec.precede(succ);
return succ;
}
static inline WFGraphNode& operator < (WFGraphNode& succ, WFGraphNode& prec)
{
succ.succeed(prec);
return prec;
}
static inline WFGraphNode& operator --(WFGraphNode& node)
{
return node;
}
class WFGraphTask : public WFGenericTask
{
public:
WFGraphNode& create_graph_node(SubTask *task);
public:
void set_callback(std::function<void (WFGraphTask *)> cb)
{
this->callback = std::move(cb);
}
protected:
virtual void dispatch();
virtual SubTask *done();
protected:
ParallelWork *parallel;
std::function<void (WFGraphTask *)> callback;
public:
WFGraphTask(std::function<void (WFGraphTask *)>&& cb) :
callback(std::move(cb))
{
this->parallel = Workflow::create_parallel_work(nullptr);
}
protected:
virtual ~WFGraphTask();
};
#endif

@ -0,0 +1,94 @@
/*
Copyright (c) 2022 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Xie Han (xiehan@sogou-inc.com)
*/
#include "list.h"
#include "WFTask.h"
#include "WFMessageQueue.h"
class __MQConditional : public WFConditional
{
public:
struct list_head list;
struct WFMessageQueue::Data *data;
public:
virtual void dispatch();
virtual void signal(void *msg) { }
public:
__MQConditional(SubTask *task, void **msgbuf,
struct WFMessageQueue::Data *data) :
WFConditional(task, msgbuf)
{
this->data = data;
}
__MQConditional(SubTask *task,
struct WFMessageQueue::Data *data) :
WFConditional(task)
{
this->data = data;
}
};
void __MQConditional::dispatch()
{
struct WFMessageQueue::Data *data = this->data;
data->mutex.lock();
if (!list_empty(&data->msg_list))
this->WFConditional::signal(data->pop());
else
list_add_tail(&this->list, &data->wait_list);
data->mutex.unlock();
this->WFConditional::dispatch();
}
WFConditional *WFMessageQueue::get(SubTask *task, void **msgbuf)
{
return new __MQConditional(task, msgbuf, &this->data);
}
WFConditional *WFMessageQueue::get(SubTask *task)
{
return new __MQConditional(task, &this->data);
}
void WFMessageQueue::post(void *msg)
{
struct WFMessageQueue::Data *data = &this->data;
WFConditional *cond;
data->mutex.lock();
if (!list_empty(&data->wait_list))
{
cond = list_entry(data->wait_list.next, __MQConditional, list);
list_del(data->wait_list.next);
}
else
{
cond = NULL;
this->push(msg);
}
data->mutex.unlock();
if (cond)
cond->WFConditional::signal(msg);
}

@ -0,0 +1,88 @@
/*
Copyright (c) 2022 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Xie Han (xiehan@sogou-inc.com)
*/
#ifndef _WFMESSAGEQUEUE_H_
#define _WFMESSAGEQUEUE_H_
#include <mutex>
#include "list.h"
#include "WFTask.h"
class WFMessageQueue
{
public:
WFConditional *get(SubTask *task, void **msgbuf);
WFConditional *get(SubTask *task);
void post(void *msg);
public:
struct Data
{
void *pop() { return this->queue->pop(); }
void push(void *msg) { this->queue->push(msg); }
struct list_head msg_list;
struct list_head wait_list;
std::mutex mutex;
WFMessageQueue *queue;
};
protected:
struct MessageEntry
{
struct list_head list;
void *msg;
};
protected:
virtual void *pop()
{
struct MessageEntry *entry;
void *msg;
entry = list_entry(this->data.msg_list.next, struct MessageEntry, list);
list_del(&entry->list);
msg = entry->msg;
delete entry;
return msg;
}
virtual void push(void *msg)
{
struct MessageEntry *entry = new struct MessageEntry;
entry->msg = msg;
list_add_tail(&entry->list, &this->data.msg_list);
}
protected:
struct Data data;
public:
WFMessageQueue()
{
INIT_LIST_HEAD(&this->data.msg_list);
INIT_LIST_HEAD(&this->data.wait_list);
this->data.queue = this;
}
virtual ~WFMessageQueue() { }
};
#endif

@ -0,0 +1,341 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Wu Jiaxu (wujiaxu@sogou-inc.com)
*/
#ifndef _WFOPERATOR_H_
#define _WFOPERATOR_H_
#include "Workflow.h"
/**
* @file WFOperator.h
* @brief Workflow Series/Parallel/Task operator
*/
/**
* @brief S=S>P
* @note Equivalent to x.push_back(&y)
*/
static inline SeriesWork& operator>(SeriesWork& x, ParallelWork& y);
/**
* @brief S=P>S
* @note Equivalent to y.push_front(&x)
*/
static inline SeriesWork& operator>(ParallelWork& x, SeriesWork& y);
/**
* @brief S=S>t
* @note Equivalent to x.push_back(&y)
*/
static inline SeriesWork& operator>(SeriesWork& x, SubTask& y);
/**
* @brief S=S>t
* @note Equivalent to x.push_back(y)
*/
static inline SeriesWork& operator>(SeriesWork& x, SubTask *y);
/**
* @brief S=t>S
* @note Equivalent to y.push_front(&x)
*/
static inline SeriesWork& operator>(SubTask& x, SeriesWork& y);
/**
* @brief S=t>S
* @note Equivalent to y.push_front(x)
*/
static inline SeriesWork& operator>(SubTask *x, SeriesWork& y);
/**
* @brief S=P>P
* @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(&y)
*/
static inline SeriesWork& operator>(ParallelWork& x, ParallelWork& y);
/**
* @brief S=P>t
* @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(&y)
*/
static inline SeriesWork& operator>(ParallelWork& x, SubTask& y);
/**
* @brief S=P>t
* @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(y)
*/
static inline SeriesWork& operator>(ParallelWork& x, SubTask *y);
/**
* @brief S=t>P
* @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(&y)
*/
static inline SeriesWork& operator>(SubTask& x, ParallelWork& y);
/**
* @brief S=t>P
* @note Equivalent to Workflow::create_series_work(x, nullptr)->push_back(&y)
*/
static inline SeriesWork& operator>(SubTask *x, ParallelWork& y);
/**
* @brief S=t>t
* @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(&y)
*/
static inline SeriesWork& operator>(SubTask& x, SubTask& y);
/**
* @brief S=t>t
* @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(y)
*/
static inline SeriesWork& operator>(SubTask& x, SubTask *y);
/**
* @brief S=t>t
* @note Equivalent to Workflow::create_series_work(x, nullptr)->push_back(&y)
*/
static inline SeriesWork& operator>(SubTask *x, SubTask& y);
//static inline SeriesWork& operator>(SubTask *x, SubTask *y);//compile error!
/**
* @brief P=P*S
* @note Equivalent to x.add_series(&y)
*/
static inline ParallelWork& operator*(ParallelWork& x, SeriesWork& y);
/**
* @brief P=S*P
* @note Equivalent to y.push_back(&x)
*/
static inline ParallelWork& operator*(SeriesWork& x, ParallelWork& y);
/**
* @brief P=P*t
* @note Equivalent to x.add_series(Workflow::create_series_work(&y, nullptr))
*/
static inline ParallelWork& operator*(ParallelWork& x, SubTask& y);
/**
* @brief P=P*t
* @note Equivalent to x.add_series(Workflow::create_series_work(y, nullptr))
*/
static inline ParallelWork& operator*(ParallelWork& x, SubTask *y);
/**
* @brief P=t*P
* @note Equivalent to y.add_series(Workflow::create_series_work(&x, nullptr))
*/
static inline ParallelWork& operator*(SubTask& x, ParallelWork& y);
/**
* @brief P=t*P
* @note Equivalent to y.add_series(Workflow::create_series_work(x, nullptr))
*/
static inline ParallelWork& operator*(SubTask *x, ParallelWork& y);
/**
* @brief P=S*S
* @note Equivalent to Workflow::create_parallel_work({&x, &y}, 2, nullptr)
*/
static inline ParallelWork& operator*(SeriesWork& x, SeriesWork& y);
/**
* @brief P=S*t
* @note Equivalent to Workflow::create_parallel_work({&y}, 1, nullptr)->add_series(&x)
*/
static inline ParallelWork& operator*(SeriesWork& x, SubTask& y);
/**
* @brief P=S*t
* @note Equivalent to Workflow::create_parallel_work({y}, 1, nullptr)->add_series(&x)
*/
static inline ParallelWork& operator*(SeriesWork& x, SubTask *y);
/**
* @brief P=t*S
* @note Equivalent to Workflow::create_parallel_work({&x}, 1, nullptr)->add_series(&y)
*/
static inline ParallelWork& operator*(SubTask& x, SeriesWork& y);
/**
* @brief P=t*S
* @note Equivalent to Workflow::create_parallel_work({x}, 1, nullptr)->add_series(&y)
*/
static inline ParallelWork& operator*(SubTask *x, SeriesWork& y);
/**
* @brief P=t*t
* @note Equivalent to Workflow::create_parallel_work({Workflow::create_series_work(&x, nullptr), Workflow::create_series_work(&y, nullptr)}, 2, nullptr)
*/
static inline ParallelWork& operator*(SubTask& x, SubTask& y);
/**
* @brief P=t*t
* @note Equivalent to Workflow::create_parallel_work({Workflow::create_series_work(&x, nullptr), Workflow::create_series_work(y, nullptr)}, 2, nullptr)
*/
static inline ParallelWork& operator*(SubTask& x, SubTask *y);
/**
* @brief P=t*t
* @note Equivalent to Workflow::create_parallel_work({Workflow::create_series_work(x, nullptr), Workflow::create_series_work(&y, nullptr)}, 2, nullptr)
*/
static inline ParallelWork& operator*(SubTask *x, SubTask& y);
//static inline ParallelWork& operator*(SubTask *x, SubTask *y);//compile error!
//S=S>t
static inline SeriesWork& operator>(SeriesWork& x, SubTask& y)
{
x.push_back(&y);
return x;
}
static inline SeriesWork& operator>(SeriesWork& x, SubTask *y)
{
return x > *y;
}
//S=t>S
static inline SeriesWork& operator>(SubTask& x, SeriesWork& y)
{
y.push_front(&x);
return y;
}
static inline SeriesWork& operator>(SubTask *x, SeriesWork& y)
{
return *x > y;
}
//S=t>t
static inline SeriesWork& operator>(SubTask& x, SubTask& y)
{
SeriesWork *series = Workflow::create_series_work(&x, nullptr);
series->push_back(&y);
return *series;
}
static inline SeriesWork& operator>(SubTask& x, SubTask *y)
{
return x > *y;
}
static inline SeriesWork& operator>(SubTask *x, SubTask& y)
{
return *x > y;
}
//S=S>P
static inline SeriesWork& operator>(SeriesWork& x, ParallelWork& y)
{
return x > (SubTask&)y;
}
//S=P>S
static inline SeriesWork& operator>(ParallelWork& x, SeriesWork& y)
{
return y > (SubTask&)x;
}
//S=P>P
static inline SeriesWork& operator>(ParallelWork& x, ParallelWork& y)
{
return (SubTask&)x > (SubTask&)y;
}
//S=P>t
static inline SeriesWork& operator>(ParallelWork& x, SubTask& y)
{
return (SubTask&)x > y;
}
static inline SeriesWork& operator>(ParallelWork& x, SubTask *y)
{
return x > *y;
}
//S=t>P
static inline SeriesWork& operator>(SubTask& x, ParallelWork& y)
{
return x > (SubTask&)y;
}
static inline SeriesWork& operator>(SubTask *x, ParallelWork& y)
{
return *x > y;
}
//P=P*S
static inline ParallelWork& operator*(ParallelWork& x, SeriesWork& y)
{
x.add_series(&y);
return x;
}
//P=S*P
static inline ParallelWork& operator*(SeriesWork& x, ParallelWork& y)
{
return y * x;
}
//P=P*t
static inline ParallelWork& operator*(ParallelWork& x, SubTask& y)
{
x.add_series(Workflow::create_series_work(&y, nullptr));
return x;
}
static inline ParallelWork& operator*(ParallelWork& x, SubTask *y)
{
return x * (*y);
}
//P=t*P
static inline ParallelWork& operator*(SubTask& x, ParallelWork& y)
{
return y * x;
}
static inline ParallelWork& operator*(SubTask *x, ParallelWork& y)
{
return (*x) * y;
}
//P=S*S
static inline ParallelWork& operator*(SeriesWork& x, SeriesWork& y)
{
SeriesWork *arr[2] = {&x, &y};
return *Workflow::create_parallel_work(arr, 2, nullptr);
}
//P=S*t
static inline ParallelWork& operator*(SeriesWork& x, SubTask& y)
{
return x * (*Workflow::create_series_work(&y, nullptr));
}
static inline ParallelWork& operator*(SeriesWork& x, SubTask *y)
{
return x * (*y);
}
//P=t*S
static inline ParallelWork& operator*(SubTask& x, SeriesWork& y)
{
return y * x;
}
static inline ParallelWork& operator*(SubTask *x, SeriesWork& y)
{
return (*x) * y;
}
//P=t*t
static inline ParallelWork& operator*(SubTask& x, SubTask& y)
{
SeriesWork *arr[2] = {Workflow::create_series_work(&x, nullptr),
Workflow::create_series_work(&y, nullptr)};
return *Workflow::create_parallel_work(arr, 2, nullptr);
}
static inline ParallelWork& operator*(SubTask& x, SubTask *y)
{
return x * (*y);
}
static inline ParallelWork& operator*(SubTask *x, SubTask& y)
{
return (*x) * y;
}
#endif

@ -0,0 +1,117 @@
/*
Copyright (c) 2021 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Li Yingxin (liyingxin@sogou-inc.com)
Xie Han (xiehan@sogou-inc.com)
*/
#include <string.h>
#include "list.h"
#include "WFTask.h"
#include "WFResourcePool.h"
class __RPConditional : public WFConditional
{
public:
struct list_head list;
struct WFResourcePool::Data *data;
public:
virtual void dispatch();
virtual void signal(void *res) { }
public:
__RPConditional(SubTask *task, void **resbuf,
struct WFResourcePool::Data *data) :
WFConditional(task, resbuf)
{
this->data = data;
}
__RPConditional(SubTask *task,
struct WFResourcePool::Data *data) :
WFConditional(task)
{
this->data = data;
}
};
void __RPConditional::dispatch()
{
struct WFResourcePool::Data *data = this->data;
data->mutex.lock();
if (--data->value >= 0)
this->WFConditional::signal(data->pop());
else
list_add_tail(&this->list, &data->wait_list);
data->mutex.unlock();
this->WFConditional::dispatch();
}
WFConditional *WFResourcePool::get(SubTask *task, void **resbuf)
{
return new __RPConditional(task, resbuf, &this->data);
}
WFConditional *WFResourcePool::get(SubTask *task)
{
return new __RPConditional(task, &this->data);
}
void WFResourcePool::create(size_t n)
{
this->data.res = new void *[n];
this->data.value = n;
this->data.index = 0;
INIT_LIST_HEAD(&this->data.wait_list);
this->data.pool = this;
}
WFResourcePool::WFResourcePool(void *const *res, size_t n)
{
this->create(n);
memcpy(this->data.res, res, n * sizeof (void *));
}
WFResourcePool::WFResourcePool(size_t n)
{
this->create(n);
memset(this->data.res, 0, n * sizeof (void *));
}
void WFResourcePool::post(void *res)
{
struct WFResourcePool::Data *data = &this->data;
WFConditional *cond;
data->mutex.lock();
if (++data->value <= 0)
{
cond = list_entry(data->wait_list.next, __RPConditional, list);
list_del(data->wait_list.next);
}
else
{
cond = NULL;
this->push(res);
}
data->mutex.unlock();
if (cond)
cond->WFConditional::signal(res);
}

@ -0,0 +1,72 @@
/*
Copyright (c) 2021 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Li Yingxin (liyingxin@sogou-inc.com)
Xie Han (xiehan@sogou-inc.com)
*/
#ifndef _WFRESOURCEPOOL_H_
#define _WFRESOURCEPOOL_H_
#include <mutex>
#include "list.h"
#include "WFTask.h"
class WFResourcePool
{
public:
WFConditional *get(SubTask *task, void **resbuf);
WFConditional *get(SubTask *task);
void post(void *res);
public:
struct Data
{
void *pop() { return this->pool->pop(); }
void push(void *res) { this->pool->push(res); }
void **res;
long value;
size_t index;
struct list_head wait_list;
std::mutex mutex;
WFResourcePool *pool;
};
protected:
virtual void *pop()
{
return this->data.res[this->data.index++];
}
virtual void push(void *res)
{
this->data.res[--this->data.index] = res;
}
protected:
struct Data data;
private:
void create(size_t n);
public:
WFResourcePool(void *const *res, size_t n);
WFResourcePool(size_t n);
virtual ~WFResourcePool() { delete []this->data.res; }
};
#endif

@ -0,0 +1,872 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#ifndef _WFTASK_H_
#define _WFTASK_H_
#include <errno.h>
#include <string.h>
#include <assert.h>
#include <atomic>
#include <utility>
#include <functional>
#include "Executor.h"
#include "ExecRequest.h"
#include "Communicator.h"
#include "CommScheduler.h"
#include "CommRequest.h"
#include "SleepRequest.h"
#include "IORequest.h"
#include "Workflow.h"
#include "WFConnection.h"
enum
{
WFT_STATE_UNDEFINED = -1,
WFT_STATE_SUCCESS = CS_STATE_SUCCESS,
WFT_STATE_TOREPLY = CS_STATE_TOREPLY, /* for server task only */
WFT_STATE_NOREPLY = CS_STATE_TOREPLY + 1, /* for server task only */
WFT_STATE_SYS_ERROR = CS_STATE_ERROR,
WFT_STATE_SSL_ERROR = 65,
WFT_STATE_DNS_ERROR = 66, /* for client task only */
WFT_STATE_TASK_ERROR = 67,
WFT_STATE_ABORTED = CS_STATE_STOPPED
};
template<class INPUT, class OUTPUT>
class WFThreadTask : public ExecRequest
{
public:
void start()
{
assert(!series_of(this));
Workflow::start_series_work(this, nullptr);
}
void dismiss()
{
assert(!series_of(this));
delete this;
}
public:
INPUT *get_input() { return &this->input; }
OUTPUT *get_output() { return &this->output; }
public:
void *user_data;
public:
int get_state() const { return this->state; }
int get_error() const { return this->error; }
public:
void set_callback(std::function<void (WFThreadTask<INPUT, OUTPUT> *)> cb)
{
this->callback = std::move(cb);
}
protected:
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
if (this->callback)
this->callback(this);
delete this;
return series->pop();
}
protected:
INPUT input;
OUTPUT output;
std::function<void (WFThreadTask<INPUT, OUTPUT> *)> callback;
public:
WFThreadTask(ExecQueue *queue, Executor *executor,
std::function<void (WFThreadTask<INPUT, OUTPUT> *)>&& cb) :
ExecRequest(queue, executor),
callback(std::move(cb))
{
this->user_data = NULL;
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
}
protected:
virtual ~WFThreadTask() { }
};
template<class REQ, class RESP>
class WFNetworkTask : public CommRequest
{
public:
/* start(), dismiss() are for client tasks only. */
void start()
{
assert(!series_of(this));
Workflow::start_series_work(this, nullptr);
}
void dismiss()
{
assert(!series_of(this));
delete this;
}
public:
REQ *get_req() { return &this->req; }
RESP *get_resp() { return &this->resp; }
public:
void *user_data;
public:
int get_state() const { return this->state; }
int get_error() const { return this->error; }
/* Call when error is ETIMEDOUT, return values:
* TOR_NOT_TIMEOUT, TOR_WAIT_TIMEOUT, TOR_CONNECT_TIMEOUT,
* TOR_TRANSMIT_TIMEOUT (send or receive).
* SSL connect timeout also returns TOR_CONNECT_TIMEOUT. */
int get_timeout_reason() const { return this->timeout_reason; }
/* Call only in callback or server's process. */
long long get_task_seq() const
{
if (!this->target)
{
errno = ENOTCONN;
return -1;
}
return this->get_seq();
}
int get_peer_addr(struct sockaddr *addr, socklen_t *addrlen) const;
virtual WFConnection *get_connection() const = 0;
public:
/* All in milliseconds. timeout == -1 for unlimited. */
void set_send_timeout(int timeout) { this->send_timeo = timeout; }
void set_receive_timeout(int timeout) { this->receive_timeo = timeout; }
void set_keep_alive(int timeout) { this->keep_alive_timeo = timeout; }
void set_watch_timeout(int timeout) { this->watch_timeo = timeout; }
public:
/* Do not reply this request. */
void noreply()
{
if (this->state == WFT_STATE_TOREPLY)
this->state = WFT_STATE_NOREPLY;
}
/* Push reply data synchronously. */
virtual int push(const void *buf, size_t size)
{
if (this->state != WFT_STATE_TOREPLY &&
this->state != WFT_STATE_NOREPLY)
{
errno = ENOENT;
return -1;
}
return this->scheduler->push(buf, size, this);
}
/* To check if the connection was closed before replying.
Always returns 'true' in callback. */
bool closed() const
{
switch (this->state)
{
case WFT_STATE_UNDEFINED:
return false;
case WFT_STATE_TOREPLY:
case WFT_STATE_NOREPLY:
return !this->target->has_idle_conn();
default:
return true;
}
}
public:
void set_callback(std::function<void (WFNetworkTask<REQ, RESP> *)> cb)
{
this->callback = std::move(cb);
}
protected:
virtual int send_timeout() { return this->send_timeo; }
virtual int receive_timeout() { return this->receive_timeo; }
virtual int keep_alive_timeout() { return this->keep_alive_timeo; }
virtual int first_timeout() { return this->watch_timeo; }
protected:
int send_timeo;
int receive_timeo;
int keep_alive_timeo;
int watch_timeo;
REQ req;
RESP resp;
std::function<void (WFNetworkTask<REQ, RESP> *)> callback;
protected:
WFNetworkTask(CommSchedObject *object, CommScheduler *scheduler,
std::function<void (WFNetworkTask<REQ, RESP> *)>&& cb) :
CommRequest(object, scheduler),
callback(std::move(cb))
{
this->send_timeo = -1;
this->receive_timeo = -1;
this->keep_alive_timeo = 0;
this->watch_timeo = 0;
this->target = NULL;
this->timeout_reason = TOR_NOT_TIMEOUT;
this->user_data = NULL;
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
}
virtual ~WFNetworkTask() { }
};
class WFTimerTask : public SleepRequest
{
public:
void start()
{
assert(!series_of(this));
Workflow::start_series_work(this, nullptr);
}
void dismiss()
{
assert(!series_of(this));
delete this;
}
public:
void *user_data;
public:
int get_state() const { return this->state; }
int get_error() const { return this->error; }
public:
void set_callback(std::function<void (WFTimerTask *)> cb)
{
this->callback = std::move(cb);
}
protected:
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
if (this->callback)
this->callback(this);
delete this;
return series->pop();
}
protected:
std::function<void (WFTimerTask *)> callback;
public:
WFTimerTask(CommScheduler *scheduler,
std::function<void (WFTimerTask *)> cb) :
SleepRequest(scheduler),
callback(std::move(cb))
{
this->user_data = NULL;
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
}
protected:
virtual ~WFTimerTask() { }
};
template<class ARGS>
class WFFileTask : public IORequest
{
public:
void start()
{
assert(!series_of(this));
Workflow::start_series_work(this, nullptr);
}
void dismiss()
{
assert(!series_of(this));
delete this;
}
public:
ARGS *get_args() { return &this->args; }
long get_retval() const
{
if (this->state == WFT_STATE_SUCCESS)
return this->get_res();
else
return -1;
}
public:
void *user_data;
public:
int get_state() const { return this->state; }
int get_error() const { return this->error; }
public:
void set_callback(std::function<void (WFFileTask<ARGS> *)> cb)
{
this->callback = std::move(cb);
}
protected:
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
if (this->callback)
this->callback(this);
delete this;
return series->pop();
}
protected:
ARGS args;
std::function<void (WFFileTask<ARGS> *)> callback;
public:
WFFileTask(IOService *service,
std::function<void (WFFileTask<ARGS> *)>&& cb) :
IORequest(service),
callback(std::move(cb))
{
this->user_data = NULL;
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
}
protected:
virtual ~WFFileTask() { }
};
class WFGenericTask : public SubTask
{
public:
void start()
{
assert(!series_of(this));
Workflow::start_series_work(this, nullptr);
}
void dismiss()
{
assert(!series_of(this));
delete this;
}
public:
void *user_data;
public:
int get_state() const { return this->state; }
int get_error() const { return this->error; }
protected:
virtual void dispatch()
{
this->subtask_done();
}
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
delete this;
return series->pop();
}
protected:
int state;
int error;
public:
WFGenericTask()
{
this->user_data = NULL;
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
}
protected:
virtual ~WFGenericTask() { }
};
class WFCounterTask : public WFGenericTask
{
public:
virtual void count()
{
if (--this->value == 0)
{
this->state = WFT_STATE_SUCCESS;
this->subtask_done();
}
}
public:
void set_callback(std::function<void (WFCounterTask *)> cb)
{
this->callback = std::move(cb);
}
protected:
virtual void dispatch()
{
this->WFCounterTask::count();
}
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
if (this->callback)
this->callback(this);
delete this;
return series->pop();
}
protected:
std::atomic<unsigned int> value;
std::function<void (WFCounterTask *)> callback;
public:
WFCounterTask(unsigned int target_value,
std::function<void (WFCounterTask *)>&& cb) :
value(target_value + 1),
callback(std::move(cb))
{
}
protected:
virtual ~WFCounterTask() { }
};
class WFMailboxTask : public WFGenericTask
{
public:
virtual void send(void *msg)
{
*this->mailbox = msg;
if (this->flag.exchange(true))
{
this->state = WFT_STATE_SUCCESS;
this->subtask_done();
}
}
void **get_mailbox() const { return this->mailbox; }
public:
void set_callback(std::function<void (WFMailboxTask *)> cb)
{
this->callback = std::move(cb);
}
protected:
virtual void dispatch()
{
if (this->flag.exchange(true))
{
this->state = WFT_STATE_SUCCESS;
this->subtask_done();
}
}
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
if (this->callback)
this->callback(this);
delete this;
return series->pop();
}
protected:
void **mailbox;
std::atomic<bool> flag;
std::function<void (WFMailboxTask *)> callback;
public:
WFMailboxTask(void **mailbox,
std::function<void (WFMailboxTask *)>&& cb) :
flag(false),
callback(std::move(cb))
{
this->mailbox = mailbox;
}
WFMailboxTask(std::function<void (WFMailboxTask *)>&& cb) :
flag(false),
callback(std::move(cb))
{
this->mailbox = &this->user_data;
}
protected:
virtual ~WFMailboxTask() { }
};
class WFSelectorTask : public WFGenericTask
{
public:
virtual int submit(void *msg)
{
void *tmp = NULL;
int ret = 0;
if (this->message.compare_exchange_strong(tmp, msg) && msg)
{
ret = 1;
if (this->flag.exchange(true))
{
this->state = WFT_STATE_SUCCESS;
this->subtask_done();
}
}
if (--this->nleft == 0)
{
if (!this->message)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = ENOMSG;
this->subtask_done();
}
delete this;
}
return ret;
}
void *get_message() const { return this->message; }
public:
void set_callback(std::function<void (WFSelectorTask *)> cb)
{
this->callback = std::move(cb);
}
protected:
virtual void dispatch()
{
if (this->flag.exchange(true))
{
this->state = WFT_STATE_SUCCESS;
this->subtask_done();
}
if (--this->nleft == 0)
{
if (!this->message)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = ENOMSG;
this->subtask_done();
}
delete this;
}
}
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
if (this->callback)
this->callback(this);
return series->pop();
}
protected:
std::atomic<void *> message;
std::atomic<bool> flag;
std::atomic<size_t> nleft;
std::function<void (WFSelectorTask *)> callback;
public:
WFSelectorTask(size_t candidates,
std::function<void (WFSelectorTask *)>&& cb) :
message(NULL),
flag(false),
nleft(candidates + 1),
callback(std::move(cb))
{
}
protected:
virtual ~WFSelectorTask() { }
};
class WFConditional : public WFGenericTask
{
public:
virtual void signal(void *msg)
{
*this->msgbuf = msg;
if (this->flag.exchange(true))
this->subtask_done();
}
protected:
virtual void dispatch()
{
series_of(this)->push_front(this->task);
this->task = NULL;
if (this->flag.exchange(true))
this->subtask_done();
}
protected:
std::atomic<bool> flag;
SubTask *task;
void **msgbuf;
public:
WFConditional(SubTask *task, void **msgbuf) :
flag(false)
{
this->task = task;
this->msgbuf = msgbuf;
}
WFConditional(SubTask *task) :
flag(false)
{
this->task = task;
this->msgbuf = &this->user_data;
}
protected:
virtual ~WFConditional()
{
delete this->task;
}
};
class WFGoTask : public ExecRequest
{
public:
void start()
{
assert(!series_of(this));
Workflow::start_series_work(this, nullptr);
}
void dismiss()
{
assert(!series_of(this));
delete this;
}
public:
void *user_data;
public:
int get_state() const { return this->state; }
int get_error() const { return this->error; }
public:
void set_callback(std::function<void (WFGoTask *)> cb)
{
this->callback = std::move(cb);
}
protected:
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
if (this->callback)
this->callback(this);
delete this;
return series->pop();
}
protected:
std::function<void (WFGoTask *)> callback;
public:
WFGoTask(ExecQueue *queue, Executor *executor) :
ExecRequest(queue, executor)
{
this->user_data = NULL;
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
}
protected:
virtual ~WFGoTask() { }
};
class WFRepeaterTask : public WFGenericTask
{
public:
void set_create(std::function<SubTask *(WFRepeaterTask *)> create)
{
this->create = std::move(create);
}
void set_callback(std::function<void (WFRepeaterTask *)> cb)
{
this->callback = std::move(cb);
}
protected:
virtual void dispatch()
{
SubTask *task = this->create(this);
if (task)
{
series_of(this)->push_front(this);
series_of(this)->push_front(task);
}
else
this->state = WFT_STATE_SUCCESS;
this->subtask_done();
}
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
if (this->state != WFT_STATE_UNDEFINED)
{
if (this->callback)
this->callback(this);
delete this;
}
return series->pop();
}
protected:
std::function<SubTask *(WFRepeaterTask *)> create;
std::function<void (WFRepeaterTask *)> callback;
public:
WFRepeaterTask(std::function<SubTask *(WFRepeaterTask *)>&& create,
std::function<void (WFRepeaterTask *)>&& cb) :
create(std::move(create)),
callback(std::move(cb))
{
}
protected:
virtual ~WFRepeaterTask() { }
};
class WFModuleTask : public ParallelTask, protected SeriesWork
{
public:
void start()
{
assert(!series_of(this));
Workflow::start_series_work(this, nullptr);
}
void dismiss()
{
assert(!series_of(this));
delete this;
}
public:
SeriesWork *sub_series() { return this; }
const SeriesWork *sub_series() const { return this; }
public:
void *user_data;
public:
void set_callback(std::function<void (const WFModuleTask *)> cb)
{
this->callback = std::move(cb);
}
protected:
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
if (this->callback)
this->callback(this);
delete this;
return series->pop();
}
protected:
SubTask *first;
std::function<void (const WFModuleTask *)> callback;
public:
WFModuleTask(SubTask *first,
std::function<void (const WFModuleTask *)>&& cb) :
ParallelTask(&this->first, 1),
SeriesWork(first, nullptr),
callback(std::move(cb))
{
this->first = first;
this->set_in_parallel(this);
this->user_data = NULL;
}
protected:
virtual ~WFModuleTask()
{
if (!this->is_finished())
this->dismiss_recursive();
}
};
#include "WFTask.inl"
#endif

@ -0,0 +1,256 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
template<class REQ, class RESP>
int WFNetworkTask<REQ, RESP>::get_peer_addr(struct sockaddr *addr,
socklen_t *addrlen) const
{
const struct sockaddr *p;
socklen_t len;
if (this->target)
{
this->target->get_addr(&p, &len);
if (*addrlen >= len)
{
memcpy(addr, p, len);
*addrlen = len;
return 0;
}
errno = ENOBUFS;
}
else
errno = ENOTCONN;
return -1;
}
template<class REQ, class RESP>
class WFClientTask : public WFNetworkTask<REQ, RESP>
{
protected:
virtual CommMessageOut *message_out()
{
/* By using prepare function, users can modify request after
* the connection is established. */
if (this->prepare)
this->prepare(this);
return &this->req;
}
virtual CommMessageIn *message_in() { return &this->resp; }
protected:
virtual WFConnection *get_connection() const
{
CommConnection *conn;
if (this->target)
{
conn = this->CommSession::get_connection();
if (conn)
return (WFConnection *)conn;
}
errno = ENOTCONN;
return NULL;
}
protected:
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
if (this->state == WFT_STATE_SYS_ERROR && this->error < 0)
{
this->state = WFT_STATE_SSL_ERROR;
this->error = -this->error;
}
if (this->callback)
this->callback(this);
delete this;
return series->pop();
}
public:
void set_prepare(std::function<void (WFNetworkTask<REQ, RESP> *)> prep)
{
this->prepare = std::move(prep);
}
protected:
std::function<void (WFNetworkTask<REQ, RESP> *)> prepare;
public:
WFClientTask(CommSchedObject *object, CommScheduler *scheduler,
std::function<void (WFNetworkTask<REQ, RESP> *)>&& cb) :
WFNetworkTask<REQ, RESP>(object, scheduler, std::move(cb))
{
}
protected:
virtual ~WFClientTask() { }
};
template<class REQ, class RESP>
class WFServerTask : public WFNetworkTask<REQ, RESP>
{
protected:
virtual CommMessageOut *message_out() { return &this->resp; }
virtual CommMessageIn *message_in() { return &this->req; }
virtual void handle(int state, int error);
protected:
/* CommSession::get_connection() is supposed to be called only in the
* implementations of it's virtual functions. As a server task, to call
* this function after process() and before callback() is very dangerous
* and should be blocked. */
virtual WFConnection *get_connection() const
{
if (this->processor.task)
return (WFConnection *)this->CommSession::get_connection();
errno = EPERM;
return NULL;
}
protected:
virtual void dispatch()
{
if (this->state == WFT_STATE_TOREPLY)
{
/* Enable get_connection() again if the reply() call is success. */
this->processor.task = this;
if (this->scheduler->reply(this) >= 0)
return;
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
this->processor.task = NULL;
}
else
this->scheduler->shutdown(this);
this->subtask_done();
}
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
if (this->state == WFT_STATE_SYS_ERROR && this->error < 0)
{
this->state = WFT_STATE_SSL_ERROR;
this->error = -this->error;
}
if (this->callback)
this->callback(this);
/* Defer deleting the task. */
return series->pop();
}
protected:
class Processor : public SubTask
{
public:
Processor(WFServerTask<REQ, RESP> *task,
std::function<void (WFNetworkTask<REQ, RESP> *)>& proc) :
process(proc)
{
this->task = task;
}
virtual void dispatch()
{
this->process(this->task);
this->task = NULL; /* As a flag. get_conneciton() disabled. */
this->subtask_done();
}
virtual SubTask *done()
{
return series_of(this)->pop();
}
std::function<void (WFNetworkTask<REQ, RESP> *)>& process;
WFServerTask<REQ, RESP> *task;
} processor;
class Series : public SeriesWork
{
public:
Series(WFServerTask<REQ, RESP> *task) :
SeriesWork(&task->processor, nullptr)
{
this->set_last_task(task);
this->task = task;
}
virtual ~Series()
{
delete this->task;
}
WFServerTask<REQ, RESP> *task;
};
public:
WFServerTask(CommService *service, CommScheduler *scheduler,
std::function<void (WFNetworkTask<REQ, RESP> *)>& proc) :
WFNetworkTask<REQ, RESP>(NULL, scheduler, nullptr),
processor(this, proc)
{
}
protected:
virtual ~WFServerTask()
{
if (this->target)
((Series *)series_of(this))->task = NULL;
}
};
template<class REQ, class RESP>
void WFServerTask<REQ, RESP>::handle(int state, int error)
{
if (state == WFT_STATE_TOREPLY)
{
this->state = WFT_STATE_TOREPLY;
this->target = this->get_target();
new Series(this);
this->processor.dispatch();
}
else if (this->state == WFT_STATE_TOREPLY)
{
this->state = state;
this->error = error;
if (error == ETIMEDOUT)
this->timeout_reason = TOR_TRANSMIT_TIMEOUT;
this->subtask_done();
}
else
delete this;
}

@ -0,0 +1,75 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Wu Jiaxu (wujiaxu@sogou-inc.com)
*/
#ifndef _WFTASKERROR_H_
#define _WFTASKERROR_H_
/**
* @file WFTaskError.h
* @brief Workflow Task Error Code List
*/
/**
* @brief Defination of task state code
* @note Only for WFNetworkTask and only when get_state()==WFT_STATE_TASK_ERROR
*/
enum
{
//COMMON
WFT_ERR_URI_PARSE_FAILED = 1001, ///< URI, parse failed
WFT_ERR_URI_SCHEME_INVALID = 1002, ///< URI, invalid scheme
WFT_ERR_URI_PORT_INVALID = 1003, ///< URI, invalid port
WFT_ERR_UPSTREAM_UNAVAILABLE = 1004, ///< Upstream, all target server down
//HTTP
WFT_ERR_HTTP_BAD_REDIRECT_HEADER = 2001, ///< Http, 301/302/303/307/308 Location header value is NULL
WFT_ERR_HTTP_PROXY_CONNECT_FAILED = 2002, ///< Http, proxy CONNECT return non 200
//REDIS
WFT_ERR_REDIS_ACCESS_DENIED = 3001, ///< Redis, invalid password
WFT_ERR_REDIS_COMMAND_DISALLOWED = 3002, ///< Redis, command disabled, cannot be "AUTH"/"SELECT"
//MYSQL
WFT_ERR_MYSQL_HOST_NOT_ALLOWED = 4001, ///< MySQL
WFT_ERR_MYSQL_ACCESS_DENIED = 4002, ///< MySQL, authentication failed
WFT_ERR_MYSQL_INVALID_CHARACTER_SET = 4003, ///< MySQL, invalid charset, not found in MySQL-Documentation
WFT_ERR_MYSQL_COMMAND_DISALLOWED = 4004, ///< MySQL, sql command disabled, cannot be "USE"/"SET NAMES"/"SET CHARSET"/"SET CHARACTER SET"
WFT_ERR_MYSQL_QUERY_NOT_SET = 4005, ///< MySQL, query not set sql, maybe forget please check
WFT_ERR_MYSQL_SSL_NOT_SUPPORTED = 4006, ///< MySQL, SSL not supported by the server
//KAFKA
WFT_ERR_KAFKA_PARSE_RESPONSE_FAILED = 5001, ///< Kafka parse response failed
WFT_ERR_KAFKA_PRODUCE_FAILED = 5002,
WFT_ERR_KAFKA_FETCH_FAILED = 5003,
WFT_ERR_KAFKA_CGROUP_FAILED = 5004,
WFT_ERR_KAFKA_COMMIT_FAILED = 5005,
WFT_ERR_KAFKA_META_FAILED = 5006,
WFT_ERR_KAFKA_LEAVEGROUP_FAILED = 5007,
WFT_ERR_KAFKA_API_UNKNOWN = 5008, ///< api type not supported
WFT_ERR_KAFKA_VERSION_DISALLOWED = 5009, ///< broker version not supported
WFT_ERR_KAFKA_SASL_DISALLOWED = 5010, ///< sasl not supported
WFT_ERR_KAFKA_ARRANGE_FAILED = 5011, ///< arrange toppar failed
WFT_ERR_KAFKA_LIST_OFFSETS_FAILED = 5012,
WFT_ERR_KAFKA_CGROUP_ASSIGN_FAILED = 5013,
//CONSUL
WFT_ERR_CONSUL_API_UNKNOWN = 6001, ///< api type not supported
WFT_ERR_CONSUL_CHECK_RESPONSE_FAILED = 6002, ///< Consul http code failed
};
#endif

File diff suppressed because it is too large Load Diff

@ -0,0 +1,499 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Xie Han (xiehan@sogou-inc.com)
Wu Jiaxu (wujiaxu@sogou-inc.com)
*/
#ifndef _WFTASKFACTORY_H_
#define _WFTASKFACTORY_H_
#include <sys/types.h>
#include <sys/uio.h>
#include <time.h>
#include <utility>
#include <functional>
#include <openssl/ssl.h>
#include "URIParser.h"
#include "RedisMessage.h"
#include "HttpMessage.h"
#include "MySQLMessage.h"
#include "DnsMessage.h"
#include "Workflow.h"
#include "WFTask.h"
#include "WFGraphTask.h"
#include "EndpointParams.h"
// Network Client/Server tasks
using WFHttpTask = WFNetworkTask<protocol::HttpRequest,
protocol::HttpResponse>;
using http_callback_t = std::function<void (WFHttpTask *)>;
using WFRedisTask = WFNetworkTask<protocol::RedisRequest,
protocol::RedisResponse>;
using redis_callback_t = std::function<void (WFRedisTask *)>;
using WFMySQLTask = WFNetworkTask<protocol::MySQLRequest,
protocol::MySQLResponse>;
using mysql_callback_t = std::function<void (WFMySQLTask *)>;
using WFDnsTask = WFNetworkTask<protocol::DnsRequest,
protocol::DnsResponse>;
using dns_callback_t = std::function<void (WFDnsTask *)>;
// File IO tasks
struct FileIOArgs
{
int fd;
void *buf;
size_t count;
off_t offset;
};
struct FileVIOArgs
{
int fd;
const struct iovec *iov;
int iovcnt;
off_t offset;
};
struct FileSyncArgs
{
int fd;
};
using WFFileIOTask = WFFileTask<struct FileIOArgs>;
using fio_callback_t = std::function<void (WFFileIOTask *)>;
using WFFileVIOTask = WFFileTask<struct FileVIOArgs>;
using fvio_callback_t = std::function<void (WFFileVIOTask *)>;
using WFFileSyncTask = WFFileTask<struct FileSyncArgs>;
using fsync_callback_t = std::function<void (WFFileSyncTask *)>;
// Timer and counter
using timer_callback_t = std::function<void (WFTimerTask *)>;
using counter_callback_t = std::function<void (WFCounterTask *)>;
using mailbox_callback_t = std::function<void (WFMailboxTask *)>;
using selector_callback_t = std::function<void (WFSelectorTask *)>;
// Graph (DAG) task.
using graph_callback_t = std::function<void (WFGraphTask *)>;
using WFEmptyTask = WFGenericTask;
using WFDynamicTask = WFGenericTask;
using dynamic_create_t = std::function<SubTask *(WFDynamicTask *)>;
using repeated_create_t = std::function<SubTask *(WFRepeaterTask *)>;
using repeater_callback_t = std::function<void (WFRepeaterTask *)>;
using module_callback_t = std::function<void (const WFModuleTask *)>;
class WFTaskFactory
{
public:
static WFHttpTask *create_http_task(const std::string& url,
int redirect_max,
int retry_max,
http_callback_t callback);
static WFHttpTask *create_http_task(const ParsedURI& uri,
int redirect_max,
int retry_max,
http_callback_t callback);
static WFHttpTask *create_http_task(const std::string& url,
const std::string& proxy_url,
int redirect_max,
int retry_max,
http_callback_t callback);
static WFHttpTask *create_http_task(const ParsedURI& uri,
const ParsedURI& proxy_uri,
int redirect_max,
int retry_max,
http_callback_t callback);
static WFRedisTask *create_redis_task(const std::string& url,
int retry_max,
redis_callback_t callback);
static WFRedisTask *create_redis_task(const ParsedURI& uri,
int retry_max,
redis_callback_t callback);
static WFMySQLTask *create_mysql_task(const std::string& url,
int retry_max,
mysql_callback_t callback);
static WFMySQLTask *create_mysql_task(const ParsedURI& uri,
int retry_max,
mysql_callback_t callback);
static WFDnsTask *create_dns_task(const std::string& url,
int retry_max,
dns_callback_t callback);
static WFDnsTask *create_dns_task(const ParsedURI& uri,
int retry_max,
dns_callback_t callback);
public:
static WFFileIOTask *create_pread_task(int fd,
void *buf,
size_t count,
off_t offset,
fio_callback_t callback);
static WFFileIOTask *create_pwrite_task(int fd,
const void *buf,
size_t count,
off_t offset,
fio_callback_t callback);
/* preadv and pwritev tasks are supported by Linux aio only.
* On macOS or others, you will get an ENOSYS error in callback. */
static WFFileVIOTask *create_preadv_task(int fd,
const struct iovec *iov,
int iovcnt,
off_t offset,
fvio_callback_t callback);
static WFFileVIOTask *create_pwritev_task(int fd,
const struct iovec *iov,
int iovcnt,
off_t offset,
fvio_callback_t callback);
static WFFileSyncTask *create_fsync_task(int fd,
fsync_callback_t callback);
/* On systems that do not support fdatasync(), like macOS,
* fdsync task is equal to fsync task. */
static WFFileSyncTask *create_fdsync_task(int fd,
fsync_callback_t callback);
/* File tasks with path name. */
public:
static WFFileIOTask *create_pread_task(const std::string& path,
void *buf,
size_t count,
off_t offset,
fio_callback_t callback);
static WFFileIOTask *create_pwrite_task(const std::string& path,
const void *buf,
size_t count,
off_t offset,
fio_callback_t callback);
static WFFileVIOTask *create_preadv_task(const std::string& path,
const struct iovec *iov,
int iovcnt,
off_t offset,
fvio_callback_t callback);
static WFFileVIOTask *create_pwritev_task(const std::string& path,
const struct iovec *iov,
int iovcnt,
off_t offset,
fvio_callback_t callback);
public:
static WFTimerTask *create_timer_task(time_t seconds, long nanoseconds,
timer_callback_t callback);
/* create a named timer. */
static WFTimerTask *create_timer_task(const std::string& timer_name,
time_t seconds, long nanoseconds,
timer_callback_t callback);
/* cancel all timers under the name. */
static int cancel_by_name(const std::string& timer_name)
{
return WFTaskFactory::cancel_by_name(timer_name, (size_t)-1);
}
/* cancel at most 'max' timers under the name. */
static int cancel_by_name(const std::string& timer_name, size_t max);
/* timer in microseconds (deprecated) */
static WFTimerTask *create_timer_task(unsigned int microseconds,
timer_callback_t callback);
public:
/* Create an unnamed counter. Call counter->count() directly.
* NOTE: never call count() exceeding target_value. */
static WFCounterTask *create_counter_task(unsigned int target_value,
counter_callback_t callback)
{
return new WFCounterTask(target_value, std::move(callback));
}
/* Create a named counter. */
static WFCounterTask *create_counter_task(const std::string& counter_name,
unsigned int target_value,
counter_callback_t callback);
/* Count by a counter's name. When count_by_name(), it's safe to count
* exceeding target_value. When multiple counters share a same name,
* this operation will be performed on the first created. If no counter
* matches the name, nothing is performed. */
static int count_by_name(const std::string& counter_name)
{
return WFTaskFactory::count_by_name(counter_name, 1);
}
/* Count by name with a value n. When multiple counters share this name,
* the operation is performed on the counters in the sequence of its
* creation, and more than one counter may reach target value. */
static int count_by_name(const std::string& counter_name, unsigned int n);
public:
static WFMailboxTask *create_mailbox_task(void **mailbox,
mailbox_callback_t callback)
{
return new WFMailboxTask(mailbox, std::move(callback));
}
/* Use 'user_data' as mailbox. */
static WFMailboxTask *create_mailbox_task(mailbox_callback_t callback)
{
return new WFMailboxTask(std::move(callback));
}
static WFMailboxTask *create_mailbox_task(const std::string& mailbox_name,
void **mailbox,
mailbox_callback_t callback);
static WFMailboxTask *create_mailbox_task(const std::string& mailbox_name,
mailbox_callback_t callback);
/* The 'msg' will be sent to the all mailbox tasks under the name, and
* would be lost if no task matched. */
static int send_by_name(const std::string& mailbox_name, void *msg)
{
return WFTaskFactory::send_by_name(mailbox_name, msg, (size_t)-1);
}
static int send_by_name(const std::string& mailbox_name, void *msg,
size_t max);
static int send_by_name(const std::string& mailbox_name, void *const msg[],
size_t max);
public:
static WFSelectorTask *create_selector_task(size_t candidates,
selector_callback_t callback)
{
return new WFSelectorTask(candidates, std::move(callback));
}
public:
static WFConditional *create_conditional(SubTask *task, void **msgbuf)
{
return new WFConditional(task, msgbuf);
}
static WFConditional *create_conditional(SubTask *task)
{
return new WFConditional(task);
}
static WFConditional *create_conditional(const std::string& cond_name,
SubTask *task, void **msgbuf);
static WFConditional *create_conditional(const std::string& cond_name,
SubTask *task);
static int signal_by_name(const std::string& cond_name, void *msg)
{
return WFTaskFactory::signal_by_name(cond_name, msg, (size_t)-1);
}
static int signal_by_name(const std::string& cond_name, void *msg,
size_t max);
static int signal_by_name(const std::string& cond_name, void *const msg[],
size_t max);
public:
static WFConditional *create_guard(const std::string& resource_name,
SubTask *task);
static WFConditional *create_guard(const std::string& resource_name,
SubTask *task, void **msgbuf);
/* The 'guard' is acquired after started, so call 'release_guard' after
and only after the task is finished, typically in its callback.
The function returns 1 if another is signaled, otherwise returns 0. */
static int release_guard(const std::string& resource_name)
{
return WFTaskFactory::release_guard(resource_name, NULL);
}
static int release_guard(const std::string& resource_name, void *msg);
static int release_guard_safe(const std::string& resource_name)
{
return WFTaskFactory::release_guard_safe(resource_name, NULL);
}
static int release_guard_safe(const std::string& resource_name, void *msg);
public:
template<class FUNC, class... ARGS>
static WFGoTask *create_go_task(const std::string& queue_name,
FUNC&& func, ARGS&&... args);
/* Create 'Go' task with running time limit in seconds plus nanoseconds.
* If time exceeded, state WFT_STATE_SYS_ERROR and error ETIMEDOUT
* will be got in callback. */
template<class FUNC, class... ARGS>
static WFGoTask *create_timedgo_task(time_t seconds, long nanoseconds,
const std::string& queue_name,
FUNC&& func, ARGS&&... args);
/* Create 'Go' task on user's executor and execution queue. */
template<class FUNC, class... ARGS>
static WFGoTask *create_go_task(ExecQueue *queue, Executor *executor,
FUNC&& func, ARGS&&... args);
template<class FUNC, class... ARGS>
static WFGoTask *create_timedgo_task(time_t seconds, long nanoseconds,
ExecQueue *queue, Executor *executor,
FUNC&& func, ARGS&&... args);
/* For capturing 'task' itself in go task's running function. */
template<class FUNC, class... ARGS>
static void reset_go_task(WFGoTask *task, FUNC&& func, ARGS&&... args);
public:
static WFGraphTask *create_graph_task(graph_callback_t callback)
{
return new WFGraphTask(std::move(callback));
}
public:
static WFEmptyTask *create_empty_task()
{
return new WFEmptyTask;
}
static WFDynamicTask *create_dynamic_task(dynamic_create_t create);
static WFRepeaterTask *create_repeater_task(repeated_create_t create,
repeater_callback_t callback)
{
return new WFRepeaterTask(std::move(create), std::move(callback));
}
public:
static WFModuleTask *create_module_task(SubTask *first,
module_callback_t callback)
{
return new WFModuleTask(first, std::move(callback));
}
static WFModuleTask *create_module_task(SubTask *first, SubTask *last,
module_callback_t callback)
{
WFModuleTask *task = new WFModuleTask(first, std::move(callback));
task->sub_series()->set_last_task(last);
return task;
}
};
template<class REQ, class RESP>
class WFNetworkTaskFactory
{
private:
using T = WFNetworkTask<REQ, RESP>;
public:
static T *create_client_task(enum TransportType type,
const std::string& host,
unsigned short port,
int retry_max,
std::function<void (T *)> callback);
static T *create_client_task(enum TransportType type,
const std::string& url,
int retry_max,
std::function<void (T *)> callback);
static T *create_client_task(enum TransportType type,
const ParsedURI& uri,
int retry_max,
std::function<void (T *)> callback);
static T *create_client_task(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
int retry_max,
std::function<void (T *)> callback);
static T *create_client_task(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
SSL_CTX *ssl_ctx,
int retry_max,
std::function<void (T *)> callback);
public:
static T *create_server_task(CommService *service,
std::function<void (T *)>& process);
};
template<class INPUT, class OUTPUT>
class WFThreadTaskFactory
{
private:
using T = WFThreadTask<INPUT, OUTPUT>;
public:
static T *create_thread_task(const std::string& queue_name,
std::function<void (INPUT *, OUTPUT *)> routine,
std::function<void (T *)> callback);
/* Create thread task with running time limit. */
static T *create_thread_task(time_t seconds, long nanoseconds,
const std::string& queue_name,
std::function<void (INPUT *, OUTPUT *)> routine,
std::function<void (T *)> callback);
public:
/* Create thread task on user's executor and execution queue. */
static T *create_thread_task(ExecQueue *queue, Executor *executor,
std::function<void (INPUT *, OUTPUT *)> routine,
std::function<void (T *)> callback);
/* With running time limit. */
static T *create_thread_task(time_t seconds, long nanoseconds,
ExecQueue *queue, Executor *executor,
std::function<void (INPUT *, OUTPUT *)> routine,
std::function<void (T *)> callback);
};
#include "WFTaskFactory.inl"
#endif

@ -0,0 +1,916 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Xie Han (xiehan@sogou-inc.com)
Wu Jiaxu (wujiaxu@sogou-inc.com)
Li Yingxin (liyingxin@sogou-inc.com)
*/
#include <sys/types.h>
#include <sys/socket.h>
#include <errno.h>
#include <time.h>
#include <netdb.h>
#include <stdio.h>
#include <string>
#include <functional>
#include <utility>
#include <atomic>
#include <openssl/ssl.h>
#include "WFGlobal.h"
#include "Workflow.h"
#include "WFTask.h"
#include "RouteManager.h"
#include "URIParser.h"
#include "WFTaskError.h"
#include "EndpointParams.h"
#include "WFNameService.h"
class __WFDynamicTask : public WFDynamicTask
{
protected:
virtual void dispatch()
{
series_of(this)->push_front(this->create(this));
this->WFDynamicTask::dispatch();
}
protected:
std::function<SubTask *(WFDynamicTask *)> create;
public:
__WFDynamicTask(std::function<SubTask *(WFDynamicTask *)>&& create) :
create(std::move(create))
{
}
};
inline WFDynamicTask *
WFTaskFactory::create_dynamic_task(dynamic_create_t create)
{
return new __WFDynamicTask(std::move(create));
}
template<class REQ, class RESP, typename CTX = bool>
class WFComplexClientTask : public WFClientTask<REQ, RESP>
{
protected:
using task_callback_t = std::function<void (WFNetworkTask<REQ, RESP> *)>;
public:
WFComplexClientTask(int retry_max, task_callback_t&& cb):
WFClientTask<REQ, RESP>(NULL, WFGlobal::get_scheduler(), std::move(cb))
{
type_ = TT_TCP;
ssl_ctx_ = NULL;
fixed_addr_ = false;
fixed_conn_ = false;
retry_max_ = retry_max;
retry_times_ = 0;
redirect_ = false;
ns_policy_ = NULL;
router_task_ = NULL;
}
protected:
// new api for children
virtual bool init_success() { return true; }
virtual void init_failed() {}
virtual bool check_request() { return true; }
virtual WFRouterTask *route();
virtual bool finish_once() { return true; }
public:
void init(const ParsedURI& uri)
{
uri_ = uri;
init_with_uri();
}
void init(ParsedURI&& uri)
{
uri_ = std::move(uri);
init_with_uri();
}
void init(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
const std::string& info);
void set_transport_type(enum TransportType type)
{
type_ = type;
}
enum TransportType get_transport_type() const { return type_; }
void set_ssl_ctx(SSL_CTX *ssl_ctx) { ssl_ctx_ = ssl_ctx; }
virtual const ParsedURI *get_current_uri() const { return &uri_; }
void set_redirect(const ParsedURI& uri)
{
redirect_ = true;
init(uri);
}
void set_redirect(enum TransportType type, const struct sockaddr *addr,
socklen_t addrlen, const std::string& info)
{
redirect_ = true;
init(type, addr, addrlen, info);
}
bool is_fixed_addr() const { return this->fixed_addr_; }
bool is_fixed_conn() const { return this->fixed_conn_; }
protected:
void set_fixed_addr(int fixed) { this->fixed_addr_ = fixed; }
void set_fixed_conn(int fixed) { this->fixed_conn_ = fixed; }
void set_info(const std::string& info)
{
info_.assign(info);
}
void set_info(const char *info)
{
info_.assign(info);
}
protected:
virtual void dispatch();
virtual SubTask *done();
void clear_resp()
{
RESP resp;
*(protocol::ProtocolMessage *)&resp = std::move(this->resp);
this->resp = std::move(resp);
}
void disable_retry()
{
retry_times_ = retry_max_;
}
protected:
enum TransportType type_;
ParsedURI uri_;
std::string info_;
SSL_CTX *ssl_ctx_;
bool fixed_addr_;
bool fixed_conn_;
bool redirect_;
CTX ctx_;
int retry_max_;
int retry_times_;
WFNSPolicy *ns_policy_;
WFRouterTask *router_task_;
RouteManager::RouteResult route_result_;
WFNSTracing tracing_;
public:
CTX *get_mutable_ctx() { return &ctx_; }
private:
void clear_prev_state();
void init_with_uri();
bool set_port();
void router_callback(void *t);
void switch_callback(void *t);
};
template<class REQ, class RESP, typename CTX>
void WFComplexClientTask<REQ, RESP, CTX>::clear_prev_state()
{
ns_policy_ = NULL;
route_result_.clear();
if (tracing_.deleter)
{
tracing_.deleter(tracing_.data);
tracing_.deleter = NULL;
}
tracing_.data = NULL;
retry_times_ = 0;
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
this->timeout_reason = TOR_NOT_TIMEOUT;
}
template<class REQ, class RESP, typename CTX>
void WFComplexClientTask<REQ, RESP, CTX>::init(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
const std::string& info)
{
if (redirect_)
clear_prev_state();
auto params = WFGlobal::get_global_settings()->endpoint_params;
struct addrinfo addrinfo = { };
addrinfo.ai_family = addr->sa_family;
addrinfo.ai_addr = (struct sockaddr *)addr;
addrinfo.ai_addrlen = addrlen;
type_ = type;
info_.assign(info);
params.use_tls_sni = false;
if (WFGlobal::get_route_manager()->get(type, &addrinfo, info_, &params,
"", ssl_ctx_, route_result_) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
}
else if (this->init_success())
return;
this->init_failed();
}
template<class REQ, class RESP, typename CTX>
bool WFComplexClientTask<REQ, RESP, CTX>::set_port()
{
if (uri_.port)
{
int port = atoi(uri_.port);
if (port <= 0 || port > 65535)
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_PORT_INVALID;
return false;
}
return true;
}
if (uri_.scheme)
{
const char *port_str = WFGlobal::get_default_port(uri_.scheme);
if (port_str)
{
uri_.port = strdup(port_str);
if (uri_.port)
return true;
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
return false;
}
}
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_SCHEME_INVALID;
return false;
}
template<class REQ, class RESP, typename CTX>
void WFComplexClientTask<REQ, RESP, CTX>::init_with_uri()
{
if (redirect_)
{
clear_prev_state();
ns_policy_ = WFGlobal::get_dns_resolver();
}
if (uri_.state == URI_STATE_SUCCESS)
{
if (this->set_port())
{
if (this->init_success())
return;
}
}
else if (uri_.state == URI_STATE_ERROR)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = uri_.error;
}
else
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_PARSE_FAILED;
}
this->init_failed();
}
template<class REQ, class RESP, typename CTX>
WFRouterTask *WFComplexClientTask<REQ, RESP, CTX>::route()
{
auto&& cb = std::bind(&WFComplexClientTask::router_callback,
this,
std::placeholders::_1);
struct WFNSParams params = {
.type = type_,
.uri = uri_,
.info = info_.c_str(),
.ssl_ctx = ssl_ctx_,
.fixed_addr = fixed_addr_,
.fixed_conn = fixed_conn_,
.retry_times = retry_times_,
.tracing = &tracing_,
};
if (!ns_policy_)
{
WFNameService *ns = WFGlobal::get_name_service();
ns_policy_ = ns->get_policy(uri_.host ? uri_.host : "");
}
return ns_policy_->create_router_task(&params, std::move(cb));
}
template<class REQ, class RESP, typename CTX>
void WFComplexClientTask<REQ, RESP, CTX>::router_callback(void *t)
{
WFRouterTask *task = (WFRouterTask *)t;
this->state = task->get_state();
if (this->state == WFT_STATE_SUCCESS)
route_result_ = std::move(*task->get_result());
else if (this->state == WFT_STATE_UNDEFINED)
{
/* should not happend */
this->state = WFT_STATE_SYS_ERROR;
this->error = ENOSYS;
}
else
this->error = task->get_error();
}
template<class REQ, class RESP, typename CTX>
void WFComplexClientTask<REQ, RESP, CTX>::dispatch()
{
switch (this->state)
{
case WFT_STATE_UNDEFINED:
if (this->check_request())
{
if (this->route_result_.request_object)
{
case WFT_STATE_SUCCESS:
this->set_request_object(route_result_.request_object);
this->WFClientTask<REQ, RESP>::dispatch();
return;
}
router_task_ = this->route();
series_of(this)->push_front(this);
series_of(this)->push_front(router_task_);
}
default:
break;
}
this->subtask_done();
}
template<class REQ, class RESP, typename CTX>
void WFComplexClientTask<REQ, RESP, CTX>::switch_callback(void *t)
{
if (!redirect_)
{
if (this->state == WFT_STATE_SYS_ERROR && this->error < 0)
{
this->state = WFT_STATE_SSL_ERROR;
this->error = -this->error;
}
if (tracing_.deleter)
{
tracing_.deleter(tracing_.data);
tracing_.deleter = NULL;
}
if (this->callback)
this->callback(this);
}
if (redirect_)
{
redirect_ = false;
clear_resp();
this->target = NULL;
series_of(this)->push_front(this);
}
else
delete this;
}
template<class REQ, class RESP, typename CTX>
SubTask *WFComplexClientTask<REQ, RESP, CTX>::done()
{
SeriesWork *series = series_of(this);
if (router_task_)
{
router_task_ = NULL;
return series->pop();
}
bool is_user_request = this->finish_once();
if (ns_policy_)
{
if (this->state == WFT_STATE_SYS_ERROR ||
this->state == WFT_STATE_DNS_ERROR)
{
ns_policy_->failed(&route_result_, &tracing_, this->target);
}
else if (route_result_.request_object)
{
ns_policy_->success(&route_result_, &tracing_, this->target);
}
}
if (this->state == WFT_STATE_SUCCESS)
{
if (!is_user_request)
return this;
}
else if (this->state == WFT_STATE_SYS_ERROR)
{
if (retry_times_ < retry_max_)
{
redirect_ = true;
if (ns_policy_)
route_result_.clear();
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
this->timeout_reason = 0;
retry_times_++;
}
}
/* When the target or the connection is NULL, it's very likely that we are
* in the caller's thread. Running a timer will switch callback function to
* a handler thread, and this can prevent stack overflow. */
if (!this->target || !this->CommSession::get_connection())
{
auto&& cb = std::bind(&WFComplexClientTask::switch_callback,
this,
std::placeholders::_1);
WFTimerTask *timer;
timer = WFTaskFactory::create_timer_task(0, 0, std::move(cb));
series->push_front(timer);
}
else
this->switch_callback(NULL);
return series->pop();
}
/**********Template Network Factory**********/
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
const std::string& host,
unsigned short port,
int retry_max,
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
{
auto *task = new WFComplexClientTask<REQ, RESP>(retry_max, std::move(callback));
ParsedURI uri;
char buf[32];
sprintf(buf, "%u", port);
uri.scheme = strdup("scheme");
uri.host = strdup(host.c_str());
uri.port = strdup(buf);
if (!uri.scheme || !uri.host || !uri.port)
{
uri.state = URI_STATE_ERROR;
uri.error = errno;
}
else
uri.state = URI_STATE_SUCCESS;
task->init(std::move(uri));
task->set_transport_type(type);
return task;
}
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
const std::string& url,
int retry_max,
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
{
auto *task = new WFComplexClientTask<REQ, RESP>(retry_max, std::move(callback));
ParsedURI uri;
URIParser::parse(url, uri);
task->init(std::move(uri));
task->set_transport_type(type);
return task;
}
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
const ParsedURI& uri,
int retry_max,
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
{
auto *task = new WFComplexClientTask<REQ, RESP>(retry_max, std::move(callback));
task->init(uri);
task->set_transport_type(type);
return task;
}
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
int retry_max,
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
{
auto *task = new WFComplexClientTask<REQ, RESP>(retry_max, std::move(callback));
task->init(type, addr, addrlen, "");
return task;
}
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
SSL_CTX *ssl_ctx,
int retry_max,
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
{
auto *task = new WFComplexClientTask<REQ, RESP>(retry_max, std::move(callback));
task->set_ssl_ctx(ssl_ctx);
task->init(type, addr, addrlen, "");
return task;
}
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_server_task(CommService *service,
std::function<void (WFNetworkTask<REQ, RESP> *)>& process)
{
return new WFServerTask<REQ, RESP>(service, WFGlobal::get_scheduler(), process);
}
/**********Server Factory**********/
class WFServerTaskFactory
{
public:
static WFDnsTask *create_dns_task(CommService *service,
std::function<void (WFDnsTask *)>& process);
static WFHttpTask *create_http_task(CommService *service,
std::function<void (WFHttpTask *)>& process);
static WFMySQLTask *create_mysql_task(CommService *service,
std::function<void (WFMySQLTask *)>& process);
};
/************Go Task Factory************/
class __WFGoTask : public WFGoTask
{
public:
void set_go_func(std::function<void ()> func)
{
this->go = std::move(func);
}
protected:
virtual void execute()
{
this->go();
}
protected:
std::function<void ()> go;
public:
__WFGoTask(ExecQueue *queue, Executor *executor,
std::function<void ()>&& func) :
WFGoTask(queue, executor),
go(std::move(func))
{
}
};
class __WFTimedGoTask : public __WFGoTask
{
protected:
virtual void dispatch();
virtual SubTask *done();
protected:
virtual void handle(int state, int error);
protected:
static void timer_callback(WFTimerTask *timer);
protected:
time_t seconds;
long nanoseconds;
std::atomic<int> ref;
public:
__WFTimedGoTask(time_t seconds, long nanoseconds,
ExecQueue *queue, Executor *executor,
std::function<void ()>&& func) :
__WFGoTask(queue, executor, std::move(func)),
ref(4)
{
this->seconds = seconds;
this->nanoseconds = nanoseconds;
}
};
template<class FUNC, class... ARGS>
WFGoTask *WFTaskFactory::create_go_task(const std::string& queue_name,
FUNC&& func, ARGS&&... args)
{
auto&& tmp = std::bind(std::forward<FUNC>(func),
std::forward<ARGS>(args)...);
return new __WFGoTask(WFGlobal::get_exec_queue(queue_name),
WFGlobal::get_compute_executor(),
std::move(tmp));
}
template<class FUNC, class... ARGS>
WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds,
const std::string& queue_name,
FUNC&& func, ARGS&&... args)
{
auto&& tmp = std::bind(std::forward<FUNC>(func),
std::forward<ARGS>(args)...);
return new __WFTimedGoTask(seconds, nanoseconds,
WFGlobal::get_exec_queue(queue_name),
WFGlobal::get_compute_executor(),
std::move(tmp));
}
template<class FUNC, class... ARGS>
WFGoTask *WFTaskFactory::create_go_task(ExecQueue *queue, Executor *executor,
FUNC&& func, ARGS&&... args)
{
auto&& tmp = std::bind(std::forward<FUNC>(func),
std::forward<ARGS>(args)...);
return new __WFGoTask(queue, executor, std::move(tmp));
}
template<class FUNC, class... ARGS>
WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds,
ExecQueue *queue, Executor *executor,
FUNC&& func, ARGS&&... args)
{
auto&& tmp = std::bind(std::forward<FUNC>(func),
std::forward<ARGS>(args)...);
return new __WFTimedGoTask(seconds, nanoseconds,
queue, executor,
std::move(tmp));
}
template<class FUNC, class... ARGS>
void WFTaskFactory::reset_go_task(WFGoTask *task, FUNC&& func, ARGS&&... args)
{
auto&& tmp = std::bind(std::forward<FUNC>(func),
std::forward<ARGS>(args)...);
((__WFGoTask *)task)->set_go_func(std::move(tmp));
}
/**********Create go task with nullptr func**********/
template<> inline
WFGoTask *WFTaskFactory::create_go_task(const std::string& queue_name,
std::nullptr_t&& func)
{
return new __WFGoTask(WFGlobal::get_exec_queue(queue_name),
WFGlobal::get_compute_executor(),
nullptr);
}
template<> inline
WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds,
const std::string& queue_name,
std::nullptr_t&& func)
{
return new __WFTimedGoTask(seconds, nanoseconds,
WFGlobal::get_exec_queue(queue_name),
WFGlobal::get_compute_executor(),
nullptr);
}
template<> inline
WFGoTask *WFTaskFactory::create_go_task(ExecQueue *queue, Executor *executor,
std::nullptr_t&& func)
{
return new __WFGoTask(queue, executor, nullptr);
}
template<> inline
WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds,
ExecQueue *queue, Executor *executor,
std::nullptr_t&& func)
{
return new __WFTimedGoTask(seconds, nanoseconds, queue, executor, nullptr);
}
template<> inline
void WFTaskFactory::reset_go_task(WFGoTask *task, std::nullptr_t&& func)
{
((__WFGoTask *)task)->set_go_func(nullptr);
}
/**********Template Thread Task Factory**********/
template<class INPUT, class OUTPUT>
class __WFThreadTask : public WFThreadTask<INPUT, OUTPUT>
{
protected:
virtual void execute()
{
this->routine(&this->input, &this->output);
}
protected:
std::function<void (INPUT *, OUTPUT *)> routine;
public:
__WFThreadTask(ExecQueue *queue, Executor *executor,
std::function<void (INPUT *, OUTPUT *)>&& rt,
std::function<void (WFThreadTask<INPUT, OUTPUT> *)>&& cb) :
WFThreadTask<INPUT, OUTPUT>(queue, executor, std::move(cb)),
routine(std::move(rt))
{
}
};
template<class INPUT, class OUTPUT>
class __WFTimedThreadTask : public __WFThreadTask<INPUT, OUTPUT>
{
protected:
virtual void dispatch();
virtual SubTask *done();
protected:
virtual void handle(int state, int error);
protected:
static void timer_callback(WFTimerTask *timer);
protected:
time_t seconds;
long nanoseconds;
std::atomic<int> ref;
public:
__WFTimedThreadTask(time_t seconds, long nanoseconds,
ExecQueue *queue, Executor *executor,
std::function<void (INPUT *, OUTPUT *)>&& rt,
std::function<void (WFThreadTask<INPUT, OUTPUT> *)>&& cb) :
__WFThreadTask<INPUT, OUTPUT>(queue, executor, std::move(rt), std::move(cb)),
ref(4)
{
this->seconds = seconds;
this->nanoseconds = nanoseconds;
}
};
template<class INPUT, class OUTPUT>
void __WFTimedThreadTask<INPUT, OUTPUT>::dispatch()
{
WFTimerTask *timer;
timer = WFTaskFactory::create_timer_task(this->seconds, this->nanoseconds,
__WFTimedThreadTask::timer_callback);
timer->user_data = this;
this->__WFThreadTask<INPUT, OUTPUT>::dispatch();
timer->start();
}
template<class INPUT, class OUTPUT>
SubTask *__WFTimedThreadTask<INPUT, OUTPUT>::done()
{
if (this->callback)
this->callback(this);
return series_of(this)->pop();
}
template<class INPUT, class OUTPUT>
void __WFTimedThreadTask<INPUT, OUTPUT>::handle(int state, int error)
{
if (--this->ref == 3)
{
this->state = state;
this->error = error;
this->subtask_done();
}
if (--this->ref == 0)
delete this;
}
template<class INPUT, class OUTPUT>
void __WFTimedThreadTask<INPUT, OUTPUT>::timer_callback(WFTimerTask *timer)
{
auto *task = (__WFTimedThreadTask<INPUT, OUTPUT> *)timer->user_data;
if (--task->ref == 3)
{
if (timer->get_state() == WFT_STATE_SUCCESS)
{
task->state = WFT_STATE_SYS_ERROR;
task->error = ETIMEDOUT;
}
else
{
task->state = timer->get_state();
task->error = timer->get_error();
}
task->subtask_done();
}
if (--task->ref == 0)
delete task;
}
template<class INPUT, class OUTPUT>
WFThreadTask<INPUT, OUTPUT> *
WFThreadTaskFactory<INPUT, OUTPUT>::create_thread_task(const std::string& queue_name,
std::function<void (INPUT *, OUTPUT *)> routine,
std::function<void (WFThreadTask<INPUT, OUTPUT> *)> callback)
{
return new __WFThreadTask<INPUT, OUTPUT>(WFGlobal::get_exec_queue(queue_name),
WFGlobal::get_compute_executor(),
std::move(routine),
std::move(callback));
}
template<class INPUT, class OUTPUT>
WFThreadTask<INPUT, OUTPUT> *
WFThreadTaskFactory<INPUT, OUTPUT>::create_thread_task(time_t seconds, long nanoseconds,
const std::string& queue_name,
std::function<void (INPUT *, OUTPUT *)> routine,
std::function<void (WFThreadTask<INPUT, OUTPUT> *)> callback)
{
return new __WFTimedThreadTask<INPUT, OUTPUT>(seconds, nanoseconds,
WFGlobal::get_exec_queue(queue_name),
WFGlobal::get_compute_executor(),
std::move(routine),
std::move(callback));
}
template<class INPUT, class OUTPUT>
WFThreadTask<INPUT, OUTPUT> *
WFThreadTaskFactory<INPUT, OUTPUT>::create_thread_task(ExecQueue *queue, Executor *executor,
std::function<void (INPUT *, OUTPUT *)> routine,
std::function<void (WFThreadTask<INPUT, OUTPUT> *)> callback)
{
return new __WFThreadTask<INPUT, OUTPUT>(queue, executor,
std::move(routine),
std::move(callback));
}
template<class INPUT, class OUTPUT>
WFThreadTask<INPUT, OUTPUT> *
WFThreadTaskFactory<INPUT, OUTPUT>::create_thread_task(time_t seconds, long nanoseconds,
ExecQueue *queue, Executor *executor,
std::function<void (INPUT *, OUTPUT *)> routine,
std::function<void (WFThreadTask<INPUT, OUTPUT> *)> callback)
{
return new __WFTimedThreadTask<INPUT, OUTPUT>(seconds, nanoseconds,
queue, executor,
std::move(routine),
std::move(callback));
}

@ -0,0 +1,248 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#include <assert.h>
#include <stddef.h>
#include <string.h>
#include <utility>
#include <functional>
#include <mutex>
#include "Workflow.h"
SeriesWork::SeriesWork(SubTask *first, series_callback_t&& cb) :
callback(std::move(cb))
{
this->queue = this->buf;
this->queue_size = sizeof this->buf / sizeof *this->buf;
this->front = 0;
this->back = 0;
this->canceled = false;
this->finished = false;
assert(!series_of(first));
first->set_pointer(this);
this->first = first;
this->last = NULL;
this->context = NULL;
this->in_parallel = NULL;
}
SeriesWork::~SeriesWork()
{
if (this->queue != this->buf)
delete []this->queue;
}
void SeriesWork::dismiss_recursive()
{
SubTask *task = first;
this->callback = nullptr;
do
{
delete task;
task = this->pop_task();
} while (task);
}
void SeriesWork::expand_queue()
{
int size = 2 * this->queue_size;
SubTask **queue = new SubTask *[size];
int i, j;
i = 0;
j = this->front;
do
{
queue[i++] = this->queue[j++];
if (j == this->queue_size)
j = 0;
} while (j != this->back);
if (this->queue != this->buf)
delete []this->queue;
this->queue = queue;
this->queue_size = size;
this->front = 0;
this->back = i;
}
void SeriesWork::push_front(SubTask *task)
{
this->mutex.lock();
if (--this->front == -1)
this->front = this->queue_size - 1;
task->set_pointer(this);
this->queue[this->front] = task;
if (this->front == this->back)
this->expand_queue();
this->mutex.unlock();
}
void SeriesWork::push_back(SubTask *task)
{
this->mutex.lock();
task->set_pointer(this);
this->queue[this->back] = task;
if (++this->back == this->queue_size)
this->back = 0;
if (this->front == this->back)
this->expand_queue();
this->mutex.unlock();
}
SubTask *SeriesWork::pop()
{
bool canceled = this->canceled;
SubTask *task = this->pop_task();
if (!canceled)
return task;
while (task)
{
delete task;
task = this->pop_task();
}
return NULL;
}
SubTask *SeriesWork::pop_task()
{
SubTask *task;
this->mutex.lock();
if (this->front != this->back)
{
task = this->queue[this->front];
if (++this->front == this->queue_size)
this->front = 0;
}
else
{
task = this->last;
this->last = NULL;
}
this->mutex.unlock();
if (!task)
{
this->finished = true;
if (this->callback)
this->callback(this);
if (!this->in_parallel)
delete this;
}
return task;
}
ParallelWork::ParallelWork(parallel_callback_t&& cb) :
ParallelTask(new SubTask *[2 * 4], 0),
callback(std::move(cb))
{
this->buf_size = 4;
this->all_series = (SeriesWork **)&this->subtasks[this->buf_size];
this->context = NULL;
}
ParallelWork::ParallelWork(SeriesWork *const all_series[], size_t n,
parallel_callback_t&& cb) :
ParallelTask(new SubTask *[2 * (n > 4 ? n : 4)], n),
callback(std::move(cb))
{
size_t i;
this->buf_size = (n > 4 ? n : 4);
this->all_series = (SeriesWork **)&this->subtasks[this->buf_size];
for (i = 0; i < n; i++)
{
assert(!all_series[i]->in_parallel);
all_series[i]->in_parallel = this;
this->all_series[i] = all_series[i];
this->subtasks[i] = all_series[i]->first;
}
this->context = NULL;
}
void ParallelWork::expand_buf()
{
SubTask **buf;
size_t size;
this->buf_size *= 2;
buf = new SubTask *[2 * this->buf_size];
size = this->subtasks_nr * sizeof (void *);
memcpy(buf, this->subtasks, size);
memcpy(buf + this->buf_size, this->all_series, size);
delete []this->subtasks;
this->subtasks = buf;
this->all_series = (SeriesWork **)&buf[this->buf_size];
}
void ParallelWork::add_series(SeriesWork *series)
{
if (this->subtasks_nr == this->buf_size)
this->expand_buf();
assert(!series->in_parallel);
series->in_parallel = this;
this->all_series[this->subtasks_nr] = series;
this->subtasks[this->subtasks_nr] = series->first;
this->subtasks_nr++;
}
SubTask *ParallelWork::done()
{
SeriesWork *series = series_of(this);
size_t i;
if (this->callback)
this->callback(this);
for (i = 0; i < this->subtasks_nr; i++)
delete this->all_series[i];
this->subtasks_nr = 0;
delete this;
return series->pop();
}
ParallelWork::~ParallelWork()
{
size_t i;
for (i = 0; i < this->subtasks_nr; i++)
{
this->all_series[i]->in_parallel = NULL;
this->all_series[i]->dismiss_recursive();
}
delete []this->subtasks;
}

@ -0,0 +1,310 @@
/*
Copyright (c) 2019 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#ifndef _WORKFLOW_H_
#define _WORKFLOW_H_
#include <assert.h>
#include <stddef.h>
#include <utility>
#include <functional>
#include <mutex>
#include "SubTask.h"
class SeriesWork;
class ParallelWork;
using series_callback_t = std::function<void (const SeriesWork *)>;
using parallel_callback_t = std::function<void (const ParallelWork *)>;
class Workflow
{
public:
static SeriesWork *
create_series_work(SubTask *first, series_callback_t callback);
static void
start_series_work(SubTask *first, series_callback_t callback);
static ParallelWork *
create_parallel_work(parallel_callback_t callback);
static ParallelWork *
create_parallel_work(SeriesWork *const all_series[], size_t n,
parallel_callback_t callback);
static void
start_parallel_work(SeriesWork *const all_series[], size_t n,
parallel_callback_t callback);
public:
static SeriesWork *
create_series_work(SubTask *first, SubTask *last,
series_callback_t callback);
static void
start_series_work(SubTask *first, SubTask *last,
series_callback_t callback);
};
class SeriesWork
{
public:
void start()
{
assert(!this->in_parallel);
this->first->dispatch();
}
/* Call dismiss() only when you don't want to start a created series.
* This operation is recursive, so only call on the "root". */
void dismiss()
{
assert(!this->in_parallel);
this->dismiss_recursive();
}
public:
void push_back(SubTask *task);
void push_front(SubTask *task);
public:
void *get_context() const { return this->context; }
void set_context(void *context) { this->context = context; }
public:
/* Cancel a running series. Typically, called in the callback of a task
* that belongs to the series. All subsequent tasks in the series will be
* destroyed immediately and recursively (ParallelWork), without callback.
* But the callback of this canceled series will still be called. */
virtual void cancel() { this->canceled = true; }
/* Parallel work's callback may check the cancellation state of each
* sub-series, and cancel it's super-series recursively. */
bool is_canceled() const { return this->canceled; }
/* 'false' until the time of callback. Mainly for sub-class. */
bool is_finished() const { return this->finished; }
public:
void set_callback(series_callback_t callback)
{
this->callback = std::move(callback);
}
public:
virtual void *get_specific(const char *key) { return NULL; }
public:
/* The following functions are intended for task implementations only. */
SubTask *pop();
SubTask *get_last_task() const { return this->last; }
void set_last_task(SubTask *last)
{
last->set_pointer(this);
this->last = last;
}
void unset_last_task() { this->last = NULL; }
const ParallelTask *get_in_parallel() const { return this->in_parallel; }
protected:
void set_in_parallel(const ParallelTask *task) { this->in_parallel = task; }
void dismiss_recursive();
protected:
void *context;
series_callback_t callback;
private:
SubTask *pop_task();
void expand_queue();
private:
SubTask *buf[4];
SubTask *first;
SubTask *last;
SubTask **queue;
int queue_size;
int front;
int back;
bool canceled;
bool finished;
const ParallelTask *in_parallel;
std::mutex mutex;
protected:
SeriesWork(SubTask *first, series_callback_t&& callback);
virtual ~SeriesWork();
friend class ParallelWork;
friend class Workflow;
};
static inline SeriesWork *series_of(const SubTask *task)
{
return (SeriesWork *)task->get_pointer();
}
static inline SeriesWork& operator *(const SubTask& task)
{
return *series_of(&task);
}
static inline SeriesWork& operator << (SeriesWork& series, SubTask *task)
{
series.push_back(task);
return series;
}
inline SeriesWork *
Workflow::create_series_work(SubTask *first, series_callback_t callback)
{
return new SeriesWork(first, std::move(callback));
}
inline void
Workflow::start_series_work(SubTask *first, series_callback_t callback)
{
new SeriesWork(first, std::move(callback));
first->dispatch();
}
inline SeriesWork *
Workflow::create_series_work(SubTask *first, SubTask *last,
series_callback_t callback)
{
SeriesWork *series = new SeriesWork(first, std::move(callback));
series->set_last_task(last);
return series;
}
inline void
Workflow::start_series_work(SubTask *first, SubTask *last,
series_callback_t callback)
{
SeriesWork *series = new SeriesWork(first, std::move(callback));
series->set_last_task(last);
first->dispatch();
}
class ParallelWork : public ParallelTask
{
public:
void start()
{
assert(!series_of(this));
Workflow::start_series_work(this, nullptr);
}
void dismiss()
{
assert(!series_of(this));
delete this;
}
public:
void add_series(SeriesWork *series);
public:
void *get_context() const { return this->context; }
void set_context(void *context) { this->context = context; }
public:
SeriesWork *series_at(size_t index)
{
if (index < this->subtasks_nr)
return this->all_series[index];
else
return NULL;
}
const SeriesWork *series_at(size_t index) const
{
if (index < this->subtasks_nr)
return this->all_series[index];
else
return NULL;
}
SeriesWork& operator[] (size_t index)
{
return *this->series_at(index);
}
const SeriesWork& operator[] (size_t index) const
{
return *this->series_at(index);
}
size_t size() const { return this->subtasks_nr; }
public:
void set_callback(parallel_callback_t callback)
{
this->callback = std::move(callback);
}
protected:
virtual SubTask *done();
protected:
void *context;
parallel_callback_t callback;
private:
void expand_buf();
private:
size_t buf_size;
SeriesWork **all_series;
protected:
ParallelWork(parallel_callback_t&& callback);
ParallelWork(SeriesWork *const all_series[], size_t n,
parallel_callback_t&& callback);
virtual ~ParallelWork();
friend class Workflow;
};
inline ParallelWork *
Workflow::create_parallel_work(parallel_callback_t callback)
{
return new ParallelWork(std::move(callback));
}
inline ParallelWork *
Workflow::create_parallel_work(SeriesWork *const all_series[], size_t n,
parallel_callback_t callback)
{
return new ParallelWork(all_series, n, std::move(callback));
}
inline void
Workflow::start_parallel_work(SeriesWork *const all_series[], size_t n,
parallel_callback_t callback)
{
ParallelWork *p = new ParallelWork(all_series, n, std::move(callback));
Workflow::start_series_work(p, nullptr);
}
#endif

@ -0,0 +1,21 @@
target("factory")
add_files("*.cc")
set_kind("object")
if not has_config("mysql") then
remove_files("MySQLTaskImpl.cc")
end
if not has_config("redis") then
remove_files("RedisTaskImpl.cc")
end
remove_files("KafkaTaskImpl.cc")
target("kafka_factory")
if has_config("kafka") then
add_files("KafkaTaskImpl.cc")
set_kind("object")
add_cxxflags("-fno-rtti")
add_deps("factory")
add_packages("zlib", "snappy", "zstd", "lz4")
else
set_kind("phony")
end

@ -0,0 +1 @@
../../kernel/CommRequest.h

@ -0,0 +1 @@
../../kernel/CommScheduler.h

@ -0,0 +1 @@
../../kernel/Communicator.h

@ -0,0 +1 @@
../../protocol/ConsulDataTypes.h

@ -0,0 +1 @@
../../manager/DnsCache.h

@ -0,0 +1 @@
../../protocol/DnsMessage.h

@ -0,0 +1 @@
../../protocol/DnsUtil.h

@ -0,0 +1 @@
../../util/EncodeStream.h

@ -0,0 +1 @@
../../manager/EndpointParams.h

@ -0,0 +1 @@
../../kernel/ExecRequest.h

@ -0,0 +1 @@
../../kernel/Executor.h

@ -0,0 +1 @@
../../protocol/HttpMessage.h

@ -0,0 +1 @@
../../protocol/HttpUtil.h

@ -0,0 +1 @@
../../kernel/IORequest.h

@ -0,0 +1 @@
../../kernel/IOService_linux.h

@ -0,0 +1 @@
../../kernel/IOService_thread.h

@ -0,0 +1 @@
../../protocol/KafkaDataTypes.h

@ -0,0 +1 @@
../../protocol/KafkaMessage.h

@ -0,0 +1 @@
../../protocol/KafkaResult.h

@ -0,0 +1 @@
../../factory/KafkaTaskImpl.inl

@ -0,0 +1 @@
../../util/LRUCache.h

@ -0,0 +1 @@
../../protocol/MySQLMessage.h

@ -0,0 +1 @@
../../protocol/MySQLMessage.inl

@ -0,0 +1 @@
../../protocol/MySQLResult.h

@ -0,0 +1 @@
../../protocol/MySQLResult.inl

@ -0,0 +1 @@
../../protocol/MySQLUtil.h

@ -0,0 +1 @@
../../protocol/PackageWrapper.h

@ -0,0 +1 @@
../../protocol/ProtocolMessage.h

@ -0,0 +1 @@
../../protocol/RedisMessage.h

@ -0,0 +1 @@
../../factory/RedisTaskImpl.inl

@ -0,0 +1 @@
../../manager/RouteManager.h

@ -0,0 +1 @@
../../protocol/SSLWrapper.h

@ -0,0 +1 @@
../../kernel/SleepRequest.h

@ -0,0 +1 @@
../../util/StringUtil.h

@ -0,0 +1 @@
../../kernel/SubTask.h

@ -0,0 +1 @@
../../protocol/TLVMessage.h

@ -0,0 +1 @@
../../util/URIParser.h

@ -0,0 +1 @@
../../manager/UpstreamManager.h

@ -0,0 +1 @@
../../nameservice/UpstreamPolicies.h

@ -0,0 +1 @@
../../factory/WFAlgoTaskFactory.h

@ -0,0 +1 @@
../../factory/WFAlgoTaskFactory.inl

@ -0,0 +1 @@
../../factory/WFConnection.h

@ -0,0 +1 @@
../../client/WFConsulClient.h

@ -0,0 +1 @@
../../client/WFDnsClient.h

@ -0,0 +1 @@
../../nameservice/WFDnsResolver.h

@ -0,0 +1 @@
../../server/WFDnsServer.h

@ -0,0 +1 @@
../../manager/WFFacilities.h

@ -0,0 +1 @@
../../manager/WFFacilities.inl

@ -0,0 +1 @@
../../manager/WFFuture.h

@ -0,0 +1 @@
../../manager/WFGlobal.h

@ -0,0 +1 @@
../../factory/WFGraphTask.h

@ -0,0 +1 @@
../../server/WFHttpServer.h

@ -0,0 +1 @@
../../client/WFKafkaClient.h

@ -0,0 +1 @@
../../factory/WFMessageQueue.h

@ -0,0 +1 @@
../../client/WFMySQLConnection.h

@ -0,0 +1 @@
../../server/WFMySQLServer.h

@ -0,0 +1 @@
../../nameservice/WFNameService.h

@ -0,0 +1 @@
../../factory/WFOperator.h

@ -0,0 +1 @@
../../server/WFRedisServer.h

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save