branch-Zhangruiqin
liuwenhao 7 months ago
parent bce0ac63ed
commit 65ca9afacc

@ -1,3 +1,6 @@
这段代码是一个Python类的实现名为`MindData`它是一个用于模拟MindSpore框架中数据集处理的桩Stub下面是对这段代码的逐行注释
```python
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -12,77 +15,94 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
'''Remove after MindData merge to MindSpore ''' '''Remove after MindData merge to MindSpore '''
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor
class MindData: class MindData:
""" Stub for MindData """ """ Stub for MindData """
# 构造函数初始化MindData类的实例
def __init__(self, size=1, batch_size=None, repeat_count=1, def __init__(self, size=1, batch_size=None, repeat_count=1,
np_types=None, output_shapes=None, input_indexs=()): np_types=None, output_shapes=None, input_indexs=()):
self._size = size self._size = size # 数据集的大小
self._batch_size = batch_size self._batch_size = batch_size # 批处理大小
self._repeat_count = repeat_count self._repeat_count = repeat_count # 重复次数
self._np_types = np_types self._np_types = np_types # NumPy数据类型
self._output_shapes = output_shapes self._output_shapes = output_shapes # 输出形状
self._input_indexs = input_indexs self._input_indexs = input_indexs # 输入索引
self._iter_num = 0 self._iter_num = 0 # 迭代次数计数器
self.dynamic_setting = [False, None] self.dynamic_setting = [False, None] # 动态设置标志和值
# 获取数据集大小
def get_dataset_size(self): def get_dataset_size(self):
return self._size return self._size
# 获取重复次数
def get_repeat_count(self): def get_repeat_count(self):
return self._repeat_count return self._repeat_count
# 获取批处理大小
def get_batch_size(self): def get_batch_size(self):
return self._batch_size return self._batch_size
# 获取输出数据类型
def output_types(self): def output_types(self):
return self._np_types return self._np_types
# 获取输出形状
def output_shapes(self): def output_shapes(self):
return self._output_shapes return self._output_shapes
# 输入索引属性
@property @property
def input_indexs(self): def input_indexs(self):
return self._input_indexs return self._input_indexs
# 设备队列设置
def device_que(self, send_epoch_end=True, create_data_info_queue=False): def device_que(self, send_epoch_end=True, create_data_info_queue=False):
self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736' self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736'
self.send_epoch_end = send_epoch_end self.send_epoch_end = send_epoch_end
return self return self
# 创建元组迭代器
def create_tuple_iterator(self, num_epochs=-1, do_copy=True): def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self.__iter__() return self.__iter__()
# 发送数据
def send(self, num_epochs=-1): def send(self, num_epochs=-1):
pass pass
# 停止发送数据
def stop_send(self): def stop_send(self):
pass pass
# 释放资源
def release(self): def release(self):
pass pass
# 继续发送数据
def continue_send(self): def continue_send(self):
pass pass
# 获取数据信息
def get_data_info(self): def get_data_info(self):
pass pass
# 动态最小最大形状
def dynamic_min_max_shapes(self): def dynamic_min_max_shapes(self):
pass pass
# 获取长度
def __len__(self): def __len__(self):
return self._size return self._size
# 迭代器
def __iter__(self): def __iter__(self):
return self return self
# 获取下一个元素
def __next__(self): def __next__(self):
if self._size < self._iter_num: if self._size < self._iter_num:
raise StopIteration raise StopIteration
@ -90,11 +110,13 @@ class MindData:
next_value = [] next_value = []
for shape, typ in zip(self._output_shapes, self._np_types): for shape, typ in zip(self._output_shapes, self._np_types):
next_value.append(Tensor(np.ndarray(shape, typ))) next_value.append(Tensor(np.ndarray(shape, typ)))
return tuple(next_value) return tuple(next_value)
# 下一个元素
def next(self): def next(self):
return self.__next__() return self.__next__()
# 重置迭代器
def reset(self): def reset(self):
self._iter_num = 0 self._iter_num = 0

Loading…
Cancel
Save