|
|
|
@ -22,22 +22,47 @@ class Gather(Expander):
|
|
|
|
|
"""Expand Gather"""
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
对输入张量进行扩展操作。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 扩展后的张量。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# 获取输入和索引
|
|
|
|
|
inputs, indices = self.inputs
|
|
|
|
|
# 获取轴
|
|
|
|
|
axis = self.attrs['axis']
|
|
|
|
|
# 如果轴小于0,则将其转换为正数
|
|
|
|
|
if axis < 0:
|
|
|
|
|
axis += len(inputs.shape)
|
|
|
|
|
# 如果索引的维度为1,则直接进行Gather操作
|
|
|
|
|
if len(indices.shape) == 1:
|
|
|
|
|
result = graph_builder.emit('Gather', [inputs, indices], attrs={'axis': axis})
|
|
|
|
|
# 否则,对索引进行Reshape操作,然后进行Gather操作,最后再进行Reshape操作
|
|
|
|
|
else:
|
|
|
|
|
# 获取原始索引的形状
|
|
|
|
|
ori_indices_shape = indices.shape
|
|
|
|
|
# 计算索引的形状的乘积
|
|
|
|
|
indices_shape_one_dim = 1
|
|
|
|
|
for dim in ori_indices_shape:
|
|
|
|
|
indices_shape_one_dim *= dim
|
|
|
|
|
# 构造新的索引形状
|
|
|
|
|
new_indices_shape = [indices_shape_one_dim]
|
|
|
|
|
# 对索引进行Reshape操作
|
|
|
|
|
reshape_indices = graph_builder.emit('Reshape', [indices], attrs={'shape': new_indices_shape})
|
|
|
|
|
# 对输入和Reshape后的索引进行Gather操作
|
|
|
|
|
tmp_result = graph_builder.emit('Gather', [inputs, reshape_indices], attrs={'axis': axis})
|
|
|
|
|
# 获取输出的形状
|
|
|
|
|
output_shape = list(inputs.shape)
|
|
|
|
|
# 将索引的形状插入到输出的形状中
|
|
|
|
|
output_shape[axis:axis] = ori_indices_shape
|
|
|
|
|
# 删除输出的形状中多余的维度
|
|
|
|
|
del output_shape[axis + len(ori_indices_shape)]
|
|
|
|
|
# 对Gather操作的结果进行Reshape操作
|
|
|
|
|
result = graph_builder.emit('Reshape', [tmp_result], attrs={'shape': output_shape})
|
|
|
|
|
# 返回结果
|
|
|
|
|
return result
|
|
|
|
|