pull/2/head
xiangguo 7 months ago
parent 6201dc959e
commit ea602b4c9f

@ -452,11 +452,14 @@ class _GeneratorWorkerMp(multiprocessing.Process):
""" """
def __init__(self, dataset, eof, max_rowsize, queue_size, ppid): def __init__(self, dataset, eof, max_rowsize, queue_size, ppid):
# 初始化一个多进程队列,用于存储索引
self.idx_queue = multiprocessing.Queue(queue_size) self.idx_queue = multiprocessing.Queue(queue_size)
# 如果启用了共享内存,则初始化一个共享队列,否则初始化一个多进程队列
if get_enable_shared_mem(): if get_enable_shared_mem():
self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize) self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize)
else: else:
self.res_queue = multiprocessing.Queue(queue_size) self.res_queue = multiprocessing.Queue(queue_size)
# 设置队列的_joincancelled属性为True表示在进程退出时队列不会阻塞
self.idx_queue._joincancelled = True # pylint: disable=W0212 self.idx_queue._joincancelled = True # pylint: disable=W0212
self.res_queue._joincancelled = True # pylint: disable=W0212 self.res_queue._joincancelled = True # pylint: disable=W0212
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid)) super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid))
@ -465,6 +468,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
""" """
Put function for worker index queue. Never block. Raise queue.Full on failure. Put function for worker index queue. Never block. Raise queue.Full on failure.
""" """
# 将item放入idx_queue队列中不阻塞如果失败则抛出queue.Full异常
self.idx_queue.put_nowait(item) self.idx_queue.put_nowait(item)
def get(self): def get(self):
@ -476,12 +480,19 @@ class _GeneratorWorkerMp(multiprocessing.Process):
return self.res_queue.get(timeout=30) return self.res_queue.get(timeout=30)
def queue_empty(self): def queue_empty(self):
# 检查idx_queue是否为空
if not self.idx_queue.empty(): if not self.idx_queue.empty():
# 如果不为空,记录警告日志
logger.warning("idx_queue is not empty.") logger.warning("idx_queue is not empty.")
# 返回False
return False return False
# 检查res_queue是否为空
if not self.res_queue.empty(): if not self.res_queue.empty():
# 如果不为空,记录警告日志
logger.warning("res_queue is not empty.") logger.warning("res_queue is not empty.")
# 返回False
return False return False
# 如果两个队列都为空返回True
return True return True
def __del__(self): def __del__(self):
@ -632,14 +643,17 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
python_multiprocessing=True, max_rowsize=6): python_multiprocessing=True, max_rowsize=6):
# 调用父类的初始化方法
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id) shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
# 如果source是zip类型则将其转换为列表
if isinstance(source, builtins.zip): if isinstance(source, builtins.zip):
# Although zip is iteratable, it does not have the feature of repeated iteration, so pass it to the array. # Although zip is iteratable, it does not have the feature of repeated iteration, so pass it to the array.
self.source = [item for item in source] self.source = [item for item in source]
else: else:
self.source = source self.source = source
self.prepared_source = None # source to be sent to C++ self.prepared_source = None # source to be sent to C++
# 如果self.operator_mixed属性为True则将num_parallel_workers设置为1
if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True: if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True:
self.num_parallel_workers = 1 self.num_parallel_workers = 1
logger.warning( logger.warning(
@ -650,56 +664,78 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
self.python_multiprocessing = python_multiprocessing self.python_multiprocessing = python_multiprocessing
# 将column_names转换为列表
self.column_names = to_list(column_names) self.column_names = to_list(column_names)
# 如果column_types不为空则将其转换为detypelist类型
if column_types is not None: if column_types is not None:
self.column_types = mstypelist_to_detypelist(column_types) self.column_types = mstypelist_to_detypelist(column_types)
else: else:
self.column_types = [] self.column_types = []
self.schema = schema self.schema = schema
# 如果schema不为空则将其转换为Schema类型
if schema is not None: if schema is not None:
# 如果schema不为空则将其赋值给self.schema
self.schema = schema self.schema = schema
# 如果schema不是Schema类型则将其转换为Schema类型
if not isinstance(schema, Schema): if not isinstance(schema, Schema):
self.schema = Schema(schema) self.schema = Schema(schema)
# Move get dataset_size by len from parse to here, because self.source will # Move get dataset_size by len from parse to here, because self.source will
# lose attribution of '__len__' after deepcopy. # lose attribution of '__len__' after deepcopy.
self.source_len = -1 # unknown self.source_len = -1 # unknown
# 如果self.source有__len__属性则获取self.source的长度
if hasattr(self.source, "__len__"): if hasattr(self.source, "__len__"):
self.source_len = len(self.source) self.source_len = len(self.source)
# 设置最大行大小
self.max_rowsize = max_rowsize self.max_rowsize = max_rowsize
# 设置采样函数为None
self.sample_fn = None self.sample_fn = None
def __deepcopy__(self, memodict): def __deepcopy__(self, memodict):
# 深度复制当前对象并传入一个字典memodict用于存储已经复制的对象
if id(self) in memodict: if id(self) in memodict:
# 如果当前对象的id已经在memodict中则直接返回该对象
return memodict[id(self)] return memodict[id(self)]
# 否则调用__safe_deepcopy__方法进行深度复制并传入memodict和exclude参数
new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__")) new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__"))
sample_fn = None sample_fn = None
# 如果新对象的sampler属性不为空并且self.source对象具有__getitem__方法
if new_op.sampler is not None and hasattr(self.source, "__getitem__"): if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
# The reason why there is a try catch here is because when the new op is being constructed with shared # The reason why there is a try catch here is because when the new op is being constructed with shared
# memory enabled, there will be an exception thrown if there is not enough shared memory available # memory enabled, there will be an exception thrown if there is not enough shared memory available
# 如果self.source_len为-1则抛出RuntimeError异常因为尝试构造一个随机访问的数据集需要__len__方法
if self.source_len == -1: if self.source_len == -1:
raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!") raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
try: try:
# 如果新对象的num_parallel_workers大于1则调用__validate_memory_usage方法进行内存使用验证
if new_op.num_parallel_workers > 1: if new_op.num_parallel_workers > 1:
self.__validate_memory_usage() self.__validate_memory_usage()
# 创建一个SamplerFn对象用于并行采样
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing, sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,
self.max_rowsize) self.max_rowsize)
# 将新对象的prepared_source属性设置为_cpp_sampler_fn_mp函数用于并行采样
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
else: else:
# 否则将新对象的prepared_source属性设置为_cpp_sampler_fn函数用于单线程采样
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
# 将新对象的sample_fn属性设置为sample_fn
new_op.sample_fn = sample_fn new_op.sample_fn = sample_fn
except RuntimeError as e: except RuntimeError as e:
# 如果抛出RuntimeError异常则抛出Exception异常并传入异常信息
raise Exception(str(e)) raise Exception(str(e))
else: else:
try: try:
# 否则将新对象的sampler属性设置为Nonesample_fn属性设置为sample_fn
new_op.sampler = None new_op.sampler = None
new_op.sample_fn = sample_fn new_op.sample_fn = sample_fn
# 将新对象的source_len属性设置为min(new_op.source_len, new_op.num_samples)如果new_op.num_samples不为0否则设置为new_op.source_len
new_op.source_len = min(new_op.source_len, new_op.source_len = min(new_op.source_len,
new_op.num_samples) if new_op.num_samples != 0 else new_op.source_len new_op.num_samples) if new_op.num_samples != 0 else new_op.source_len
# 遍历self.source对象
iter(self.source) iter(self.source)
except TypeError: except TypeError:
# Use generator function if input callable # Use generator function if input callable
@ -711,19 +747,26 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
return new_op return new_op
# 判断是否被洗牌
def is_shuffled(self): def is_shuffled(self):
return self.sampler.is_shuffled() return self.sampler.is_shuffled()
# 判断是否被分片
def is_sharded(self): def is_sharded(self):
return self.sampler.is_sharded() return self.sampler.is_sharded()
# 解析
def parse(self, children=None): def parse(self, children=None):
# 如果schema为空则返回GeneratorNode对象
if self.schema is None: if self.schema is None:
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len, return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
self.sampler, self.num_parallel_workers) self.sampler, self.num_parallel_workers)
# 获取schema
schema = self.schema schema = self.schema
# 如果schema是Schema类型则获取cpp_schema
if isinstance(schema, Schema): if isinstance(schema, Schema):
schema = self.schema.cpp_schema schema = self.schema.cpp_schema
# 返回GeneratorNode对象
return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler, return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler,
self.num_parallel_workers) self.num_parallel_workers)
@ -735,24 +778,37 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
# if use num_parallel_workers is to large when python_multiprocessing=True which would cause # if use num_parallel_workers is to large when python_multiprocessing=True which would cause
# OOM error get the num_shards # OOM error get the num_shards
valid_num_shards = 1 valid_num_shards = 1
# 判断self.sampler是否为samplers.DistributedSampler类型
if isinstance(self.sampler, samplers.DistributedSampler): if isinstance(self.sampler, samplers.DistributedSampler):
# 如果是则将self.sampler的num_shards赋值给valid_num_shards
valid_num_shards = self.sampler.num_shards valid_num_shards = self.sampler.num_shards
# 否则判断self.num_shards是否为None
elif self.num_shards is not None: elif self.num_shards is not None:
# 如果不是则将self.num_shards赋值给valid_num_shards
valid_num_shards = self.num_shards valid_num_shards = self.num_shards
# get process memory usage # get process memory usage
# 获取当前进程
process = psutil.Process(os.getpid()) process = psutil.Process(os.getpid())
# 获取当前进程的内存信息
process_memory = process.memory_info().rss process_memory = process.memory_info().rss
# 获取系统内存的空闲量
sys_memory_free = psutil.virtual_memory().free sys_memory_free = psutil.virtual_memory().free
# 计算可能使用的总内存量
total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards
# 如果总内存可能使用的内存量除以系统可用内存大于0.85
if total_memory_maybe_used / sys_memory_free > 0.85: if total_memory_maybe_used / sys_memory_free > 0.85:
# 计算有效的worker数量即系统可用内存乘以0.85除以有效的shards数量再除以每个进程的内存
valid_num_worker = math.floor(sys_memory_free * 0.85 / valid_num_shards / process_memory) valid_num_worker = math.floor(sys_memory_free * 0.85 / valid_num_shards / process_memory)
# 如果有效的worker数量小于等于0则将其设置为1
valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker
# 构造警告信息提示用户num_parallel_workers设置过大可能会导致内存占用过高或OOM建议将其减小到valid_num_worker或更小
info = "GeneratorDataset's num_parallel_workers: {} is too large which may cause a lot of memory " \ info = "GeneratorDataset's num_parallel_workers: {} is too large which may cause a lot of memory " \
"occupation (>85%) or out of memory(OOM) during multiprocessing. Therefore, it is recommended " \ "occupation (>85%) or out of memory(OOM) during multiprocessing. Therefore, it is recommended " \
"to reduce num_parallel_workers to {} or smaller.".format(self.num_parallel_workers, "to reduce num_parallel_workers to {} or smaller.".format(self.num_parallel_workers,
valid_num_worker) valid_num_worker)
# 打印警告信息
logger.warning(info) logger.warning(info)
@ -764,37 +820,55 @@ class _NumpySlicesDataset:
def __init__(self, data, column_list=None): def __init__(self, data, column_list=None):
self.column_list = None self.column_list = None
# Convert dict data into tuple # Convert dict data into tuple
# 判断data是否为字典类型
if isinstance(data, dict): if isinstance(data, dict):
# 如果是字典类型则调用process_dict方法处理
data = self.process_dict(data) data = self.process_dict(data)
# 判断data是否为元组类型
if isinstance(data, tuple): if isinstance(data, tuple):
# 如果是元组类型则将self.data初始化为空元组
self.data = () self.data = ()
# 获取data的长度
data_len = len(data) data_len = len(data)
# 遍历data中的每个元素
for i in range(data_len): for i in range(data_len):
# 将data中的每个元素转换为numpy数组并添加到self.data中
self.data = self.data + (np.array(data[i]),) self.data = self.data + (np.array(data[i]),)
else: else:
# 如果data不是元组类型则将data转换为numpy数组并添加到self.data中
self.data = (np.array(data),) self.data = (np.array(data),)
# check whether the data length in each column is equal # check whether the data length in each column is equal
# 获取每个data_item的长度
data_len = [len(data_item) for data_item in self.data] data_len = [len(data_item) for data_item in self.data]
# 如果每个data_item的长度不相等则抛出ValueError异常
if data_len[1:] != data_len[:-1]: if data_len[1:] != data_len[:-1]:
raise ValueError("Data length in each column is not equal.") raise ValueError("Data length in each column is not equal.")
# Init column_name # Init column_name
# 如果column_list不为空则将self.column_list赋值为column_list
if column_list is not None: if column_list is not None:
self.column_list = column_list self.column_list = column_list
# 如果self.column_list为空则将self.column_list赋值为空列表
elif self.column_list is None: elif self.column_list is None:
self.column_list = [] self.column_list = []
# 获取data的列数
column_num = len(self.data) column_num = len(self.data)
# 遍历列数,将"column_" + str(i)添加到self.column_list中
for i in range(column_num): for i in range(column_num):
self.column_list.append("column_" + str(i)) self.column_list.append("column_" + str(i))
def __getitem__(self, index): def __getitem__(self, index):
# 获取指定索引的数据行
data_row = [d[index, ...] for d in self.data] data_row = [d[index, ...] for d in self.data]
# 将数据行转换为元组
data_res = tuple(data_row) data_res = tuple(data_row)
# 返回数据行
return data_res return data_res
def __len__(self): def __len__(self):
# 返回data的第一个元素的长度
return len(self.data[0]) return len(self.data[0])
def process_dict(self, input_data): def process_dict(self, input_data):
@ -802,24 +876,29 @@ class _NumpySlicesDataset:
Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first. Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
""" """
# Convert pandas like dict(has "values" column) into General dict # Convert pandas like dict(has "values" column) into General dict
# 将pandas样式的字典有"values"列)转换为通用字典
data_keys = list(input_data.keys()) data_keys = list(input_data.keys())
# 获取字典的第一个键对应的值
data_col = input_data[data_keys[0]] data_col = input_data[data_keys[0]]
# 如果值有values属性则将其转换为通用字典
if hasattr(data_col, "values"): if hasattr(data_col, "values"):
new_dict = {} new_dict = {}
for key in data_keys: for key in data_keys:
# 将字典中的键对应的值转换为列表
item1 = input_data.pop(key) item1 = input_data.pop(key)
new_dict[key] = item1.values new_dict[key] = item1.values
# 将转换后的字典赋值给input_data
input_data = new_dict input_data = new_dict
# Convert the data in dict into tuple # Convert the data in dict into tuple
data = () data = () # 初始化一个空元组
keys = list(input_data.keys()) keys = list(input_data.keys()) # 将输入数据的键转换为列表
self.column_list = keys self.column_list = keys # 将键列表赋值给实例变量column_list
for key in keys: for key in keys: # 遍历键列表
value = input_data[key] value = input_data[key] # 获取键对应的值
data = data + (list(value),) data = data + (list(value),) # 将值转换为列表,并添加到元组中
return data return data # 返回元组
class NumpySlicesDataset(GeneratorDataset): class NumpySlicesDataset(GeneratorDataset):
@ -909,7 +988,9 @@ class NumpySlicesDataset(GeneratorDataset):
@check_numpyslicesdataset @check_numpyslicesdataset
def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None,
num_shards=None, shard_id=None): num_shards=None, shard_id=None):
# 创建一个_NumpySlicesDataset对象传入data和column_names参数
dataset = _NumpySlicesDataset(data, column_names) dataset = _NumpySlicesDataset(data, column_names)
# 调用父类的__init__方法传入dataset、column_names、num_samples、num_parallel_workers、shuffle、sampler、num_shards和shard_id参数
super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples, super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
num_shards=num_shards, shard_id=shard_id) num_shards=num_shards, shard_id=shard_id)

@ -256,36 +256,53 @@ class DatasetHelper:
""" """
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1): def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
# 检查dataset_sink_mode是否为布尔值
dataset_sink_mode = Validator.check_bool(dataset_sink_mode) dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
# 检查sink_size是否为整数
Validator.check_is_int(sink_size) Validator.check_is_int(sink_size)
# 如果sink_size小于-1或者等于0抛出异常
if sink_size < -1 or sink_size == 0: if sink_size < -1 or sink_size == 0:
raise ValueError("The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size)) raise ValueError("The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size))
# 如果sink_size等于-1则将其设置为dataset的dataset_size
if sink_size == -1: if sink_size == -1:
sink_size = dataset.get_dataset_size() sink_size = dataset.get_dataset_size()
# 如果dataset_sink_mode为True则根据不同的设备类型选择不同的迭代器
if dataset_sink_mode: if dataset_sink_mode:
# 如果启用了GE则使用GE的迭代器
if context.get_context("enable_ge"): if context.get_context("enable_ge"):
iterclass = _DatasetIterGE iterclass = _DatasetIterGE
else: else:
# 如果当前模式为GRAPH_MODE则根据角色选择不同的迭代器
if context.get_context("mode") == context.GRAPH_MODE: if context.get_context("mode") == context.GRAPH_MODE:
# 如果当前角色为调度器或者参数服务器,则使用参数服务器的迭代器
if _is_role_sched() or _is_role_pserver(): if _is_role_sched() or _is_role_pserver():
iterclass = _DatasetIterPSServer iterclass = _DatasetIterPSServer
# 如果当前角色为工作节点并且是参数服务器模式,则使用参数服务器工作节点的迭代器
elif _is_role_worker() and _is_ps_mode(): elif _is_role_worker() and _is_ps_mode():
iterclass = _DatasetIterPSWork iterclass = _DatasetIterPSWork
# 如果当前设备类型为Ascend或者GPU则使用多线程循环的迭代器
elif (context.get_context("device_target") == "Ascend") or \ elif (context.get_context("device_target") == "Ascend") or \
(context.get_context("device_target") == "GPU"): (context.get_context("device_target") == "GPU"):
iterclass = _DatasetIterMSLoopSink iterclass = _DatasetIterMSLoopSink
# 如果当前设备类型为CPU则抛出异常因为CPU不支持数据集下沉模式
elif context.get_context("device_target") == "CPU": elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device " raise RuntimeError("Currently dataset sink mode is not supported when the device "
"target is CPU, please set dataset sink mode to False.") "target is CPU, please set dataset sink mode to False.")
# 如果当前模式不是GRAPH_MODE则使用PyNative的迭代器
else: else:
iterclass = _DatasetIterPyNative iterclass = _DatasetIterPyNative
# 创建迭代器
self.iter = iterclass(dataset, sink_size, epoch_num) self.iter = iterclass(dataset, sink_size, epoch_num)
# 如果dataset_sink_mode为False则使用普通的迭代器
else: else:
# 如果不是分布式训练则使用_DatasetIterNormal类
iterclass = _DatasetIterNormal iterclass = _DatasetIterNormal
# 初始化迭代器
self.iter = iterclass(dataset, epoch_num=epoch_num) self.iter = iterclass(dataset, epoch_num=epoch_num)
def __iter__(self): def __iter__(self):
# 返回self.iter的迭代器
return self.iter.__iter__() return self.iter.__iter__()
# A temp solution for loop sink. Delete later # A temp solution for loop sink. Delete later
@ -301,6 +318,7 @@ class DatasetHelper:
>>> >>>
>>> types, shapes = dataset_helper.types_shapes() >>> types, shapes = dataset_helper.types_shapes()
""" """
# 从当前配置的dataset中获取类型和形状
return self.iter.types_shapes() return self.iter.types_shapes()
def sink_size(self): def sink_size(self):
@ -316,18 +334,22 @@ class DatasetHelper:
>>> # if sink_size==-1, then will return the full size of source dataset. >>> # if sink_size==-1, then will return the full size of source dataset.
>>> sink_size = dataset_helper.sink_size() >>> sink_size = dataset_helper.sink_size()
""" """
# 返回迭代器的接收缓冲区大小
return self.iter.get_sink_size() return self.iter.get_sink_size()
def stop_send(self): def stop_send(self):
"""Stop send data about data sink.""" """Stop send data about data sink."""
# 停止发送关于数据接收器的数据
self.iter.stop_send() self.iter.stop_send()
def release(self): def release(self):
"""Free up resources about data sink.""" """Free up resources about data sink."""
# 释放数据接收器的资源
self.iter.release() self.iter.release()
def continue_send(self): def continue_send(self):
"""Continue to send data to device at the beginning of epoch.""" """Continue to send data to device at the beginning of epoch."""
# 在每个epoch的开始处继续向设备发送数据
self.iter.continue_send() self.iter.continue_send()
def _reset(self, step): def _reset(self, step):
@ -339,6 +361,7 @@ class DatasetHelper:
In sink mode, it returns the types and shapes of the current data. In sink mode, it returns the types and shapes of the current data.
Generally, it works in dynamic shape scenarios. Generally, it works in dynamic shape scenarios.
""" """
# 返回迭代器的数据信息
return self.iter.get_data_info() return self.iter.get_data_info()
def dynamic_min_max_shapes(self): def dynamic_min_max_shapes(self):
@ -355,6 +378,7 @@ class DatasetHelper:
>>> >>>
>>> min_shapes, max_shapes = dataset_helper.dynamic_min_max_shapes() >>> min_shapes, max_shapes = dataset_helper.dynamic_min_max_shapes()
""" """
# 返回self.iter的dynamic_min_max_shapes方法
return self.iter.dynamic_min_max_shapes() return self.iter.dynamic_min_max_shapes()
@ -362,20 +386,27 @@ class _DatasetIter:
"""Base iter for dataset helper""" """Base iter for dataset helper"""
def __init__(self, dataset, sink_size, epoch_num): def __init__(self, dataset, sink_size, epoch_num):
# 初始化函数传入数据集、sink大小和epoch数量
self.dataset = dataset self.dataset = dataset
self.sink_size = sink_size self.sink_size = sink_size
self.sink_count = self.get_sink_count(dataset) self.sink_count = self.get_sink_count(dataset)
# 如果数据集没有__transfer_dataset__属性
if not hasattr(dataset, '__transfer_dataset__'): if not hasattr(dataset, '__transfer_dataset__'):
# 如果数据集有__loop_size__属性
if hasattr(dataset, '__loop_size__'): if hasattr(dataset, '__loop_size__'):
# PS mode does not support loop sink and need get the real sink size. # PS mode does not support loop sink and need get the real sink size.
# 如果不是worker角色或者不是ps模式则设置sink_size为dataset的循环大小
if not (_is_role_worker() and _is_ps_mode()): if not (_is_role_worker() and _is_ps_mode()):
self.sink_size = dataset.__loop_size__ self.sink_size = dataset.__loop_size__
# 如果sink_size为1sink_count为1dataset的大小不为1并且设备目标为Ascend则创建数据信息队列
create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and dataset.get_dataset_size() != 1 create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and dataset.get_dataset_size() != 1
and context.get_context("device_target") == "Ascend") and context.get_context("device_target") == "Ascend")
# 执行数据图并将sink_size和create_data_info_queue作为参数传入
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size, dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size,
create_data_info_queue=create_data_info_queue) create_data_info_queue=create_data_info_queue)
# 如果dataset没有__no_send__属性则发送数据
if not hasattr(dataset, '__no_send__'): if not hasattr(dataset, '__no_send__'):
_send_data(dataset, epoch_num) _send_data(dataset, epoch_num)
else: else:
@ -384,33 +415,48 @@ class _DatasetIter:
_cell_graph_executor.set_queue_name(dataset.__transfer_dataset__.queue_name) _cell_graph_executor.set_queue_name(dataset.__transfer_dataset__.queue_name)
_send_data_no_flag(dataset, epoch_num) _send_data_no_flag(dataset, epoch_num)
# 获取dataset的stop_send方法
self.stop_send = dataset.__transfer_dataset__.stop_send self.stop_send = dataset.__transfer_dataset__.stop_send
# 获取dataset的release方法
self.release = dataset.__transfer_dataset__.release self.release = dataset.__transfer_dataset__.release
# 获取dataset的continue_send方法
self.continue_send = dataset.__transfer_dataset__.continue_send self.continue_send = dataset.__transfer_dataset__.continue_send
# 获取dataset的get_data_info方法
self.get_data_info = dataset.__transfer_dataset__.get_data_info self.get_data_info = dataset.__transfer_dataset__.get_data_info
# 获取dataset的dynamic_min_max_shapes属性
self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes
# 获取dataset的数据类型和数据形状
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
# 如果dataset的__transfer_dataset__属性中有_reset方法则获取该_reset方法
if hasattr(dataset.__transfer_dataset__, "_reset"): if hasattr(dataset.__transfer_dataset__, "_reset"):
self._reset = dataset.__transfer_dataset__._reset # pylint: disable=W0212 self._reset = dataset.__transfer_dataset__._reset # pylint: disable=W0212
def __iter__(self): def __iter__(self):
# 初始化索引为0
self.index = 0 self.index = 0
# 返回self
return self return self
# 迭代器的下一项
def __next__(self): def __next__(self):
# 如果索引大于等于sink_count抛出StopIteration异常
if self.index >= self.sink_count: if self.index >= self.sink_count:
raise StopIteration() raise StopIteration()
# 索引加1
self.index += 1 self.index += 1
# 返回op()的返回值
return self.op() return self.op()
def types_shapes(self): def types_shapes(self):
""" """
Return the types and shapes of the dataset. The type and shape of each data in the dataset 返回数据集的类型和形状数据集中每个数据的类型和形状应该是一致的
should be consistent.
""" """
return self.dataset_types, self.dataset_shapes return self.dataset_types, self.dataset_shapes
def get_sink_count(self, dataset): def get_sink_count(self, dataset):
"""
获取数据集的sink次数
:param dataset: 数据集对象
:return: sink次数
"""
sink_count = 1 sink_count = 1
if hasattr(dataset, '__loop_size__'): if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__ loop_size = dataset.__loop_size__
@ -421,7 +467,10 @@ class _DatasetIter:
return sink_count return sink_count
def get_sink_size(self): def get_sink_size(self):
"""get sink_size to device""" """
获取设备的sink大小
:return: sink大小
"""
sink_size = 1 sink_size = 1
if hasattr(self.dataset, '__loop_size__'): if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__ sink_size = self.dataset.__loop_size__

@ -75,6 +75,7 @@ def _write_device_target(file):
def build_dependencies(): def build_dependencies():
"""generate python file""" """generate python file"""
# 生成version.py文件
version_file = os.path.join(pkg_dir, 'mindspore', 'version.py') version_file = os.path.join(pkg_dir, 'mindspore', 'version.py')
with open(version_file, 'w') as f: with open(version_file, 'w') as f:
_write_version(f) _write_version(f)
@ -83,6 +84,7 @@ def build_dependencies():
with open(version_file, 'w') as f: with open(version_file, 'w') as f:
_write_version(f) _write_version(f)
# 生成default_config.py文件
config_file = os.path.join(pkg_dir, 'mindspore', 'default_config.py') config_file = os.path.join(pkg_dir, 'mindspore', 'default_config.py')
with open(config_file, 'w') as f: with open(config_file, 'w') as f:
_write_config(f) _write_config(f)
@ -91,6 +93,7 @@ def build_dependencies():
with open(config_file, 'w') as f: with open(config_file, 'w') as f:
_write_config(f) _write_config(f)
# 向default_config.py文件中追加device_target
target = os.path.join(pkg_dir, 'mindspore', 'default_config.py') target = os.path.join(pkg_dir, 'mindspore', 'default_config.py')
with open(target, 'a') as f: with open(target, 'a') as f:
_write_device_target(f) _write_device_target(f)
@ -99,6 +102,7 @@ def build_dependencies():
with open(target, 'a') as f: with open(target, 'a') as f:
_write_device_target(f) _write_device_target(f)
# 向default_config.py文件中追加package_name
package_info = os.path.join(pkg_dir, 'mindspore', 'default_config.py') package_info = os.path.join(pkg_dir, 'mindspore', 'default_config.py')
with open(package_info, 'a') as f: with open(package_info, 'a') as f:
_write_package_name(f) _write_package_name(f)
@ -107,6 +111,7 @@ def build_dependencies():
with open(package_info, 'a') as f: with open(package_info, 'a') as f:
_write_package_name(f) _write_package_name(f)
# 生成.commit_id文件
commit_file = os.path.join(pkg_dir, 'mindspore', '.commit_id') commit_file = os.path.join(pkg_dir, 'mindspore', '.commit_id')
with open(commit_file, 'w') as f: with open(commit_file, 'w') as f:
_write_commit_file(f) _write_commit_file(f)
@ -156,16 +161,24 @@ def update_permissions(path):
Args: Args:
path (str): Target directory path. path (str): Target directory path.
""" """
# 判断操作系统是否为Windows
if platform.system() == "Windows": if platform.system() == "Windows":
return return
# 遍历目标目录下的所有文件和文件夹
for dirpath, dirnames, filenames in os.walk(path): for dirpath, dirnames, filenames in os.walk(path):
# 遍历文件夹
for dirname in dirnames: for dirname in dirnames:
# 获取文件夹的完整路径
dir_fullpath = os.path.join(dirpath, dirname) dir_fullpath = os.path.join(dirpath, dirname)
# 更新文件夹的权限
os.chmod(dir_fullpath, stat.S_IREAD | stat.S_IWRITE | os.chmod(dir_fullpath, stat.S_IREAD | stat.S_IWRITE |
stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP) stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP)
# 遍历文件
for filename in filenames: for filename in filenames:
# 获取文件的完整路径
file_fullpath = os.path.join(dirpath, filename) file_fullpath = os.path.join(dirpath, filename)
# 更新文件的权限
os.chmod(file_fullpath, stat.S_IREAD) os.chmod(file_fullpath, stat.S_IREAD)
@ -174,7 +187,9 @@ class EggInfo(egg_info):
def run(self): def run(self):
super().run() super().run()
# 获取egg-info目录的路径
egg_info_dir = os.path.join(pkg_dir, 'mindspore.egg-info') egg_info_dir = os.path.join(pkg_dir, 'mindspore.egg-info')
# 更新egg-info目录的权限
update_permissions(egg_info_dir) update_permissions(egg_info_dir)
@ -183,41 +198,64 @@ class BuildPy(build_py):
def run(self): def run(self):
super().run() super().run()
# 获取build目录下的lib/mindspore目录的路径
mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore') mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore')
# 更新lib/mindspore目录的权限
update_permissions(mindspore_dir) update_permissions(mindspore_dir)
# 获取build目录下的lib/mindspore/_akg目录的路径
mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore', '_akg') mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore', '_akg')
# 更新lib/mindspore/_akg目录的权限
update_permissions(mindspore_dir) update_permissions(mindspore_dir)
# 设置包的名称
setup( setup(
name=package_name, name=package_name,
# 设置包的版本
version=version, version=version,
# 设置包的作者
author='The MindSpore Authors', author='The MindSpore Authors',
# 设置包的作者邮箱
author_email='contact@mindspore.cn', author_email='contact@mindspore.cn',
# 设置包的网址
url='https://www.mindspore.cn', url='https://www.mindspore.cn',
# 设置包的下载网址
download_url='https://github.com/mindspore-ai/mindspore/tags', download_url='https://github.com/mindspore-ai/mindspore/tags',
# 设置包的源代码网址
project_urls={ project_urls={
'Sources': 'https://github.com/mindspore-ai/mindspore', 'Sources': 'https://github.com/mindspore-ai/mindspore',
# 设置包的问题追踪网址
'Issue Tracker': 'https://github.com/mindspore-ai/mindspore/issues', 'Issue Tracker': 'https://github.com/mindspore-ai/mindspore/issues',
}, },
# 设置包的描述
description='MindSpore is a new open source deep learning training/inference ' description='MindSpore is a new open source deep learning training/inference '
'framework that could be used for mobile, edge and cloud scenarios.', 'framework that could be used for mobile, edge and cloud scenarios.',
# 读取readme文件作为包的详细描述
long_description=readme, long_description=readme,
# 设置详细描述的格式
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
# 查找包中的所有模块
packages=find_packages(), packages=find_packages(),
# 设置包的数据
package_data=package_data, package_data=package_data,
# 包含包中的所有数据
include_package_data=True, include_package_data=True,
# 设置自定义的命令类
cmdclass={ cmdclass={
'egg_info': EggInfo, 'egg_info': EggInfo,
'build_py': BuildPy, 'build_py': BuildPy,
}, },
# 设置包的入口点
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
'cache_admin=mindspore.dataset.engine.cache_admin:main', 'cache_admin=mindspore.dataset.engine.cache_admin:main',
], ],
}, },
# 设置包的Python版本要求
python_requires='>=3.7', python_requires='>=3.7',
# 设置包的依赖
install_requires=required_package, install_requires=required_package,
# 设置包的分类器
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',
'Environment :: Console', 'Environment :: Console',
@ -234,6 +272,8 @@ setup(
'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules', 'Topic :: Software Development :: Libraries :: Python Modules',
], ],
# 设置包的许可证
license='Apache 2.0', license='Apache 2.0',
# 设置包的关键词
keywords='mindspore machine learning', keywords='mindspore machine learning',
) )

Loading…
Cancel
Save