|
|
@ -452,11 +452,14 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, eof, max_rowsize, queue_size, ppid):
|
|
|
|
def __init__(self, dataset, eof, max_rowsize, queue_size, ppid):
|
|
|
|
|
|
|
|
# 初始化一个多进程队列,用于存储索引
|
|
|
|
self.idx_queue = multiprocessing.Queue(queue_size)
|
|
|
|
self.idx_queue = multiprocessing.Queue(queue_size)
|
|
|
|
|
|
|
|
# 如果启用了共享内存,则初始化一个共享队列,否则初始化一个多进程队列
|
|
|
|
if get_enable_shared_mem():
|
|
|
|
if get_enable_shared_mem():
|
|
|
|
self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize)
|
|
|
|
self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.res_queue = multiprocessing.Queue(queue_size)
|
|
|
|
self.res_queue = multiprocessing.Queue(queue_size)
|
|
|
|
|
|
|
|
# 设置队列的_joincancelled属性为True,表示在进程退出时,队列不会阻塞
|
|
|
|
self.idx_queue._joincancelled = True # pylint: disable=W0212
|
|
|
|
self.idx_queue._joincancelled = True # pylint: disable=W0212
|
|
|
|
self.res_queue._joincancelled = True # pylint: disable=W0212
|
|
|
|
self.res_queue._joincancelled = True # pylint: disable=W0212
|
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid))
|
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid))
|
|
|
@ -465,6 +468,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Put function for worker index queue. Never block. Raise queue.Full on failure.
|
|
|
|
Put function for worker index queue. Never block. Raise queue.Full on failure.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
# 将item放入idx_queue队列中,不阻塞,如果失败则抛出queue.Full异常
|
|
|
|
self.idx_queue.put_nowait(item)
|
|
|
|
self.idx_queue.put_nowait(item)
|
|
|
|
|
|
|
|
|
|
|
|
def get(self):
|
|
|
|
def get(self):
|
|
|
@ -476,12 +480,19 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|
|
|
return self.res_queue.get(timeout=30)
|
|
|
|
return self.res_queue.get(timeout=30)
|
|
|
|
|
|
|
|
|
|
|
|
def queue_empty(self):
|
|
|
|
def queue_empty(self):
|
|
|
|
|
|
|
|
# 检查idx_queue是否为空
|
|
|
|
if not self.idx_queue.empty():
|
|
|
|
if not self.idx_queue.empty():
|
|
|
|
|
|
|
|
# 如果不为空,记录警告日志
|
|
|
|
logger.warning("idx_queue is not empty.")
|
|
|
|
logger.warning("idx_queue is not empty.")
|
|
|
|
|
|
|
|
# 返回False
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
# 检查res_queue是否为空
|
|
|
|
if not self.res_queue.empty():
|
|
|
|
if not self.res_queue.empty():
|
|
|
|
|
|
|
|
# 如果不为空,记录警告日志
|
|
|
|
logger.warning("res_queue is not empty.")
|
|
|
|
logger.warning("res_queue is not empty.")
|
|
|
|
|
|
|
|
# 返回False
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
# 如果两个队列都为空,返回True
|
|
|
|
return True
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
def __del__(self):
|
|
|
@ -632,14 +643,17 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
|
|
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
|
|
|
|
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
|
|
|
|
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
|
|
|
|
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
|
|
|
|
python_multiprocessing=True, max_rowsize=6):
|
|
|
|
python_multiprocessing=True, max_rowsize=6):
|
|
|
|
|
|
|
|
# 调用父类的初始化方法
|
|
|
|
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
|
|
|
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
|
|
|
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
|
|
|
|
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
|
|
|
|
|
|
|
|
# 如果source是zip类型,则将其转换为列表
|
|
|
|
if isinstance(source, builtins.zip):
|
|
|
|
if isinstance(source, builtins.zip):
|
|
|
|
# Although zip is iteratable, it does not have the feature of repeated iteration, so pass it to the array.
|
|
|
|
# Although zip is iteratable, it does not have the feature of repeated iteration, so pass it to the array.
|
|
|
|
self.source = [item for item in source]
|
|
|
|
self.source = [item for item in source]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.source = source
|
|
|
|
self.source = source
|
|
|
|
self.prepared_source = None # source to be sent to C++
|
|
|
|
self.prepared_source = None # source to be sent to C++
|
|
|
|
|
|
|
|
# 如果self.operator_mixed属性为True,则将num_parallel_workers设置为1
|
|
|
|
if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True:
|
|
|
|
if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True:
|
|
|
|
self.num_parallel_workers = 1
|
|
|
|
self.num_parallel_workers = 1
|
|
|
|
logger.warning(
|
|
|
|
logger.warning(
|
|
|
@ -650,56 +664,78 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
|
|
|
|
|
|
|
|
|
|
self.python_multiprocessing = python_multiprocessing
|
|
|
|
self.python_multiprocessing = python_multiprocessing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 将column_names转换为列表
|
|
|
|
self.column_names = to_list(column_names)
|
|
|
|
self.column_names = to_list(column_names)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 如果column_types不为空,则将其转换为detypelist类型
|
|
|
|
if column_types is not None:
|
|
|
|
if column_types is not None:
|
|
|
|
self.column_types = mstypelist_to_detypelist(column_types)
|
|
|
|
self.column_types = mstypelist_to_detypelist(column_types)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.column_types = []
|
|
|
|
self.column_types = []
|
|
|
|
|
|
|
|
|
|
|
|
self.schema = schema
|
|
|
|
self.schema = schema
|
|
|
|
|
|
|
|
# 如果schema不为空,则将其转换为Schema类型
|
|
|
|
if schema is not None:
|
|
|
|
if schema is not None:
|
|
|
|
|
|
|
|
# 如果schema不为空,则将其赋值给self.schema
|
|
|
|
self.schema = schema
|
|
|
|
self.schema = schema
|
|
|
|
|
|
|
|
# 如果schema不是Schema类型,则将其转换为Schema类型
|
|
|
|
if not isinstance(schema, Schema):
|
|
|
|
if not isinstance(schema, Schema):
|
|
|
|
self.schema = Schema(schema)
|
|
|
|
self.schema = Schema(schema)
|
|
|
|
# Move get dataset_size by len from parse to here, because self.source will
|
|
|
|
# Move get dataset_size by len from parse to here, because self.source will
|
|
|
|
# lose attribution of '__len__' after deepcopy.
|
|
|
|
# lose attribution of '__len__' after deepcopy.
|
|
|
|
self.source_len = -1 # unknown
|
|
|
|
self.source_len = -1 # unknown
|
|
|
|
|
|
|
|
# 如果self.source有__len__属性,则获取self.source的长度
|
|
|
|
if hasattr(self.source, "__len__"):
|
|
|
|
if hasattr(self.source, "__len__"):
|
|
|
|
self.source_len = len(self.source)
|
|
|
|
self.source_len = len(self.source)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 设置最大行大小
|
|
|
|
self.max_rowsize = max_rowsize
|
|
|
|
self.max_rowsize = max_rowsize
|
|
|
|
|
|
|
|
# 设置采样函数为None
|
|
|
|
self.sample_fn = None
|
|
|
|
self.sample_fn = None
|
|
|
|
|
|
|
|
|
|
|
|
def __deepcopy__(self, memodict):
|
|
|
|
def __deepcopy__(self, memodict):
|
|
|
|
|
|
|
|
# 深度复制当前对象,并传入一个字典memodict,用于存储已经复制的对象
|
|
|
|
if id(self) in memodict:
|
|
|
|
if id(self) in memodict:
|
|
|
|
|
|
|
|
# 如果当前对象的id已经在memodict中,则直接返回该对象
|
|
|
|
return memodict[id(self)]
|
|
|
|
return memodict[id(self)]
|
|
|
|
|
|
|
|
# 否则,调用__safe_deepcopy__方法进行深度复制,并传入memodict和exclude参数
|
|
|
|
new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__"))
|
|
|
|
new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__"))
|
|
|
|
|
|
|
|
|
|
|
|
sample_fn = None
|
|
|
|
sample_fn = None
|
|
|
|
|
|
|
|
# 如果新对象的sampler属性不为空,并且self.source对象具有__getitem__方法
|
|
|
|
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
|
|
|
|
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
|
|
|
|
# The reason why there is a try catch here is because when the new op is being constructed with shared
|
|
|
|
# The reason why there is a try catch here is because when the new op is being constructed with shared
|
|
|
|
# memory enabled, there will be an exception thrown if there is not enough shared memory available
|
|
|
|
# memory enabled, there will be an exception thrown if there is not enough shared memory available
|
|
|
|
|
|
|
|
# 如果self.source_len为-1,则抛出RuntimeError异常,因为尝试构造一个随机访问的数据集,需要__len__方法
|
|
|
|
if self.source_len == -1:
|
|
|
|
if self.source_len == -1:
|
|
|
|
raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
|
|
|
|
raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
|
|
|
|
# 如果新对象的num_parallel_workers大于1,则调用__validate_memory_usage方法进行内存使用验证
|
|
|
|
if new_op.num_parallel_workers > 1:
|
|
|
|
if new_op.num_parallel_workers > 1:
|
|
|
|
self.__validate_memory_usage()
|
|
|
|
self.__validate_memory_usage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 创建一个SamplerFn对象,用于并行采样
|
|
|
|
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,
|
|
|
|
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,
|
|
|
|
self.max_rowsize)
|
|
|
|
self.max_rowsize)
|
|
|
|
|
|
|
|
# 将新对象的prepared_source属性设置为_cpp_sampler_fn_mp函数,用于并行采样
|
|
|
|
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
|
|
|
|
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
# 否则,将新对象的prepared_source属性设置为_cpp_sampler_fn函数,用于单线程采样
|
|
|
|
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
|
|
|
|
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
|
|
|
|
|
|
|
|
# 将新对象的sample_fn属性设置为sample_fn
|
|
|
|
new_op.sample_fn = sample_fn
|
|
|
|
new_op.sample_fn = sample_fn
|
|
|
|
except RuntimeError as e:
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
|
|
|
# 如果抛出RuntimeError异常,则抛出Exception异常,并传入异常信息
|
|
|
|
raise Exception(str(e))
|
|
|
|
raise Exception(str(e))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
|
|
|
|
# 否则,将新对象的sampler属性设置为None,sample_fn属性设置为sample_fn
|
|
|
|
new_op.sampler = None
|
|
|
|
new_op.sampler = None
|
|
|
|
new_op.sample_fn = sample_fn
|
|
|
|
new_op.sample_fn = sample_fn
|
|
|
|
|
|
|
|
# 将新对象的source_len属性设置为min(new_op.source_len, new_op.num_samples),如果new_op.num_samples不为0,否则设置为new_op.source_len
|
|
|
|
new_op.source_len = min(new_op.source_len,
|
|
|
|
new_op.source_len = min(new_op.source_len,
|
|
|
|
new_op.num_samples) if new_op.num_samples != 0 else new_op.source_len
|
|
|
|
new_op.num_samples) if new_op.num_samples != 0 else new_op.source_len
|
|
|
|
|
|
|
|
# 遍历self.source对象
|
|
|
|
iter(self.source)
|
|
|
|
iter(self.source)
|
|
|
|
except TypeError:
|
|
|
|
except TypeError:
|
|
|
|
# Use generator function if input callable
|
|
|
|
# Use generator function if input callable
|
|
|
@ -711,19 +747,26 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
|
|
|
|
|
|
|
|
|
|
return new_op
|
|
|
|
return new_op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 判断是否被洗牌
|
|
|
|
def is_shuffled(self):
|
|
|
|
def is_shuffled(self):
|
|
|
|
return self.sampler.is_shuffled()
|
|
|
|
return self.sampler.is_shuffled()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 判断是否被分片
|
|
|
|
def is_sharded(self):
|
|
|
|
def is_sharded(self):
|
|
|
|
return self.sampler.is_sharded()
|
|
|
|
return self.sampler.is_sharded()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 解析
|
|
|
|
def parse(self, children=None):
|
|
|
|
def parse(self, children=None):
|
|
|
|
|
|
|
|
# 如果schema为空,则返回GeneratorNode对象
|
|
|
|
if self.schema is None:
|
|
|
|
if self.schema is None:
|
|
|
|
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
|
|
|
|
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
|
|
|
|
self.sampler, self.num_parallel_workers)
|
|
|
|
self.sampler, self.num_parallel_workers)
|
|
|
|
|
|
|
|
# 获取schema
|
|
|
|
schema = self.schema
|
|
|
|
schema = self.schema
|
|
|
|
|
|
|
|
# 如果schema是Schema类型,则获取cpp_schema
|
|
|
|
if isinstance(schema, Schema):
|
|
|
|
if isinstance(schema, Schema):
|
|
|
|
schema = self.schema.cpp_schema
|
|
|
|
schema = self.schema.cpp_schema
|
|
|
|
|
|
|
|
# 返回GeneratorNode对象
|
|
|
|
return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler,
|
|
|
|
return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler,
|
|
|
|
self.num_parallel_workers)
|
|
|
|
self.num_parallel_workers)
|
|
|
|
|
|
|
|
|
|
|
@ -735,24 +778,37 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|
|
|
# if use num_parallel_workers is to large when python_multiprocessing=True which would cause
|
|
|
|
# if use num_parallel_workers is to large when python_multiprocessing=True which would cause
|
|
|
|
# OOM error get the num_shards
|
|
|
|
# OOM error get the num_shards
|
|
|
|
valid_num_shards = 1
|
|
|
|
valid_num_shards = 1
|
|
|
|
|
|
|
|
# 判断self.sampler是否为samplers.DistributedSampler类型
|
|
|
|
if isinstance(self.sampler, samplers.DistributedSampler):
|
|
|
|
if isinstance(self.sampler, samplers.DistributedSampler):
|
|
|
|
|
|
|
|
# 如果是,则将self.sampler的num_shards赋值给valid_num_shards
|
|
|
|
valid_num_shards = self.sampler.num_shards
|
|
|
|
valid_num_shards = self.sampler.num_shards
|
|
|
|
|
|
|
|
# 否则,判断self.num_shards是否为None
|
|
|
|
elif self.num_shards is not None:
|
|
|
|
elif self.num_shards is not None:
|
|
|
|
|
|
|
|
# 如果不是,则将self.num_shards赋值给valid_num_shards
|
|
|
|
valid_num_shards = self.num_shards
|
|
|
|
valid_num_shards = self.num_shards
|
|
|
|
|
|
|
|
|
|
|
|
# get process memory usage
|
|
|
|
# get process memory usage
|
|
|
|
|
|
|
|
# 获取当前进程
|
|
|
|
process = psutil.Process(os.getpid())
|
|
|
|
process = psutil.Process(os.getpid())
|
|
|
|
|
|
|
|
# 获取当前进程的内存信息
|
|
|
|
process_memory = process.memory_info().rss
|
|
|
|
process_memory = process.memory_info().rss
|
|
|
|
|
|
|
|
# 获取系统内存的空闲量
|
|
|
|
sys_memory_free = psutil.virtual_memory().free
|
|
|
|
sys_memory_free = psutil.virtual_memory().free
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 计算可能使用的总内存量
|
|
|
|
total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards
|
|
|
|
total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards
|
|
|
|
|
|
|
|
# 如果总内存可能使用的内存量除以系统可用内存大于0.85
|
|
|
|
if total_memory_maybe_used / sys_memory_free > 0.85:
|
|
|
|
if total_memory_maybe_used / sys_memory_free > 0.85:
|
|
|
|
|
|
|
|
# 计算有效的worker数量,即系统可用内存乘以0.85除以有效的shards数量再除以每个进程的内存
|
|
|
|
valid_num_worker = math.floor(sys_memory_free * 0.85 / valid_num_shards / process_memory)
|
|
|
|
valid_num_worker = math.floor(sys_memory_free * 0.85 / valid_num_shards / process_memory)
|
|
|
|
|
|
|
|
# 如果有效的worker数量小于等于0,则将其设置为1
|
|
|
|
valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker
|
|
|
|
valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker
|
|
|
|
|
|
|
|
# 构造警告信息,提示用户num_parallel_workers设置过大,可能会导致内存占用过高或OOM,建议将其减小到valid_num_worker或更小
|
|
|
|
info = "GeneratorDataset's num_parallel_workers: {} is too large which may cause a lot of memory " \
|
|
|
|
info = "GeneratorDataset's num_parallel_workers: {} is too large which may cause a lot of memory " \
|
|
|
|
"occupation (>85%) or out of memory(OOM) during multiprocessing. Therefore, it is recommended " \
|
|
|
|
"occupation (>85%) or out of memory(OOM) during multiprocessing. Therefore, it is recommended " \
|
|
|
|
"to reduce num_parallel_workers to {} or smaller.".format(self.num_parallel_workers,
|
|
|
|
"to reduce num_parallel_workers to {} or smaller.".format(self.num_parallel_workers,
|
|
|
|
valid_num_worker)
|
|
|
|
valid_num_worker)
|
|
|
|
|
|
|
|
# 打印警告信息
|
|
|
|
logger.warning(info)
|
|
|
|
logger.warning(info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -764,37 +820,55 @@ class _NumpySlicesDataset:
|
|
|
|
def __init__(self, data, column_list=None):
|
|
|
|
def __init__(self, data, column_list=None):
|
|
|
|
self.column_list = None
|
|
|
|
self.column_list = None
|
|
|
|
# Convert dict data into tuple
|
|
|
|
# Convert dict data into tuple
|
|
|
|
|
|
|
|
# 判断data是否为字典类型
|
|
|
|
if isinstance(data, dict):
|
|
|
|
if isinstance(data, dict):
|
|
|
|
|
|
|
|
# 如果是字典类型,则调用process_dict方法处理
|
|
|
|
data = self.process_dict(data)
|
|
|
|
data = self.process_dict(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 判断data是否为元组类型
|
|
|
|
if isinstance(data, tuple):
|
|
|
|
if isinstance(data, tuple):
|
|
|
|
|
|
|
|
# 如果是元组类型,则将self.data初始化为空元组
|
|
|
|
self.data = ()
|
|
|
|
self.data = ()
|
|
|
|
|
|
|
|
# 获取data的长度
|
|
|
|
data_len = len(data)
|
|
|
|
data_len = len(data)
|
|
|
|
|
|
|
|
# 遍历data中的每个元素
|
|
|
|
for i in range(data_len):
|
|
|
|
for i in range(data_len):
|
|
|
|
|
|
|
|
# 将data中的每个元素转换为numpy数组,并添加到self.data中
|
|
|
|
self.data = self.data + (np.array(data[i]),)
|
|
|
|
self.data = self.data + (np.array(data[i]),)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
# 如果data不是元组类型,则将data转换为numpy数组,并添加到self.data中
|
|
|
|
self.data = (np.array(data),)
|
|
|
|
self.data = (np.array(data),)
|
|
|
|
|
|
|
|
|
|
|
|
# check whether the data length in each column is equal
|
|
|
|
# check whether the data length in each column is equal
|
|
|
|
|
|
|
|
# 获取每个data_item的长度
|
|
|
|
data_len = [len(data_item) for data_item in self.data]
|
|
|
|
data_len = [len(data_item) for data_item in self.data]
|
|
|
|
|
|
|
|
# 如果每个data_item的长度不相等,则抛出ValueError异常
|
|
|
|
if data_len[1:] != data_len[:-1]:
|
|
|
|
if data_len[1:] != data_len[:-1]:
|
|
|
|
raise ValueError("Data length in each column is not equal.")
|
|
|
|
raise ValueError("Data length in each column is not equal.")
|
|
|
|
|
|
|
|
|
|
|
|
# Init column_name
|
|
|
|
# Init column_name
|
|
|
|
|
|
|
|
# 如果column_list不为空,则将self.column_list赋值为column_list
|
|
|
|
if column_list is not None:
|
|
|
|
if column_list is not None:
|
|
|
|
self.column_list = column_list
|
|
|
|
self.column_list = column_list
|
|
|
|
|
|
|
|
# 如果self.column_list为空,则将self.column_list赋值为空列表
|
|
|
|
elif self.column_list is None:
|
|
|
|
elif self.column_list is None:
|
|
|
|
self.column_list = []
|
|
|
|
self.column_list = []
|
|
|
|
|
|
|
|
# 获取data的列数
|
|
|
|
column_num = len(self.data)
|
|
|
|
column_num = len(self.data)
|
|
|
|
|
|
|
|
# 遍历列数,将"column_" + str(i)添加到self.column_list中
|
|
|
|
for i in range(column_num):
|
|
|
|
for i in range(column_num):
|
|
|
|
self.column_list.append("column_" + str(i))
|
|
|
|
self.column_list.append("column_" + str(i))
|
|
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
|
|
|
# 获取指定索引的数据行
|
|
|
|
data_row = [d[index, ...] for d in self.data]
|
|
|
|
data_row = [d[index, ...] for d in self.data]
|
|
|
|
|
|
|
|
# 将数据行转换为元组
|
|
|
|
data_res = tuple(data_row)
|
|
|
|
data_res = tuple(data_row)
|
|
|
|
|
|
|
|
# 返回数据行
|
|
|
|
return data_res
|
|
|
|
return data_res
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
def __len__(self):
|
|
|
|
|
|
|
|
# 返回data的第一个元素的长度
|
|
|
|
return len(self.data[0])
|
|
|
|
return len(self.data[0])
|
|
|
|
|
|
|
|
|
|
|
|
def process_dict(self, input_data):
|
|
|
|
def process_dict(self, input_data):
|
|
|
@ -802,24 +876,29 @@ class _NumpySlicesDataset:
|
|
|
|
Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
|
|
|
|
Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# Convert pandas like dict(has "values" column) into General dict
|
|
|
|
# Convert pandas like dict(has "values" column) into General dict
|
|
|
|
|
|
|
|
# 将pandas样式的字典(有"values"列)转换为通用字典
|
|
|
|
data_keys = list(input_data.keys())
|
|
|
|
data_keys = list(input_data.keys())
|
|
|
|
|
|
|
|
# 获取字典的第一个键对应的值
|
|
|
|
data_col = input_data[data_keys[0]]
|
|
|
|
data_col = input_data[data_keys[0]]
|
|
|
|
|
|
|
|
# 如果值有values属性,则将其转换为通用字典
|
|
|
|
if hasattr(data_col, "values"):
|
|
|
|
if hasattr(data_col, "values"):
|
|
|
|
new_dict = {}
|
|
|
|
new_dict = {}
|
|
|
|
for key in data_keys:
|
|
|
|
for key in data_keys:
|
|
|
|
|
|
|
|
# 将字典中的键对应的值转换为列表
|
|
|
|
item1 = input_data.pop(key)
|
|
|
|
item1 = input_data.pop(key)
|
|
|
|
new_dict[key] = item1.values
|
|
|
|
new_dict[key] = item1.values
|
|
|
|
|
|
|
|
# 将转换后的字典赋值给input_data
|
|
|
|
input_data = new_dict
|
|
|
|
input_data = new_dict
|
|
|
|
|
|
|
|
|
|
|
|
# Convert the data in dict into tuple
|
|
|
|
# Convert the data in dict into tuple
|
|
|
|
data = ()
|
|
|
|
data = () # 初始化一个空元组
|
|
|
|
keys = list(input_data.keys())
|
|
|
|
keys = list(input_data.keys()) # 将输入数据的键转换为列表
|
|
|
|
self.column_list = keys
|
|
|
|
self.column_list = keys # 将键列表赋值给实例变量column_list
|
|
|
|
for key in keys:
|
|
|
|
for key in keys: # 遍历键列表
|
|
|
|
value = input_data[key]
|
|
|
|
value = input_data[key] # 获取键对应的值
|
|
|
|
data = data + (list(value),)
|
|
|
|
data = data + (list(value),) # 将值转换为列表,并添加到元组中
|
|
|
|
|
|
|
|
|
|
|
|
return data
|
|
|
|
return data # 返回元组
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NumpySlicesDataset(GeneratorDataset):
|
|
|
|
class NumpySlicesDataset(GeneratorDataset):
|
|
|
@ -909,7 +988,9 @@ class NumpySlicesDataset(GeneratorDataset):
|
|
|
|
@check_numpyslicesdataset
|
|
|
|
@check_numpyslicesdataset
|
|
|
|
def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None,
|
|
|
|
def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None,
|
|
|
|
num_shards=None, shard_id=None):
|
|
|
|
num_shards=None, shard_id=None):
|
|
|
|
|
|
|
|
# 创建一个_NumpySlicesDataset对象,传入data和column_names参数
|
|
|
|
dataset = _NumpySlicesDataset(data, column_names)
|
|
|
|
dataset = _NumpySlicesDataset(data, column_names)
|
|
|
|
|
|
|
|
# 调用父类的__init__方法,传入dataset、column_names、num_samples、num_parallel_workers、shuffle、sampler、num_shards和shard_id参数
|
|
|
|
super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
|
|
|
|
super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
|
|
|
|
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
|
|
|
|
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
|
|
|
|
num_shards=num_shards, shard_id=shard_id)
|
|
|
|
num_shards=num_shards, shard_id=shard_id)
|
|
|
|