_extends\graph_kernel\expanders\gather.py

branch-yixin
yixin 2 months ago
parent efae72cc05
commit 07b3545276

@ -22,22 +22,47 @@ class Gather(Expander):
"""Expand Gather"""
def _expand(self, graph_builder):
"""
对输入张量进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展后的张量
"""
# 获取输入和索引
inputs, indices = self.inputs
# 获取轴
axis = self.attrs['axis']
# 如果轴小于0则将其转换为正数
if axis < 0:
axis += len(inputs.shape)
# 如果索引的维度为1则直接进行Gather操作
if len(indices.shape) == 1:
result = graph_builder.emit('Gather', [inputs, indices], attrs={'axis': axis})
# 否则对索引进行Reshape操作然后进行Gather操作最后再进行Reshape操作
else:
# 获取原始索引的形状
ori_indices_shape = indices.shape
# 计算索引的形状的乘积
indices_shape_one_dim = 1
for dim in ori_indices_shape:
indices_shape_one_dim *= dim
# 构造新的索引形状
new_indices_shape = [indices_shape_one_dim]
# 对索引进行Reshape操作
reshape_indices = graph_builder.emit('Reshape', [indices], attrs={'shape': new_indices_shape})
# 对输入和Reshape后的索引进行Gather操作
tmp_result = graph_builder.emit('Gather', [inputs, reshape_indices], attrs={'axis': axis})
# 获取输出的形状
output_shape = list(inputs.shape)
# 将索引的形状插入到输出的形状中
output_shape[axis:axis] = ori_indices_shape
# 删除输出的形状中多余的维度
del output_shape[axis + len(ori_indices_shape)]
# 对Gather操作的结果进行Reshape操作
result = graph_builder.emit('Reshape', [tmp_result], attrs={'shape': output_shape})
# 返回结果
return result

Loading…
Cancel
Save