You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
71 lines
1.7 KiB
71 lines
1.7 KiB
"""This file exports ONNX ops for opset 18.
|
|
|
|
Note [ONNX Operators that are added/updated in opset 18]
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
|
|
New operators:
|
|
CenterCropPad
|
|
Col2Im
|
|
Mish
|
|
OptionalGetElement
|
|
OptionalHasElement
|
|
Pad
|
|
Resize
|
|
ScatterElements
|
|
ScatterND
|
|
"""
|
|
|
|
import functools
|
|
from typing import Sequence
|
|
|
|
from torch import _C
|
|
from torch.onnx import symbolic_helper
|
|
from torch.onnx._internal import _beartype, registration
|
|
|
|
# EDITING THIS FILE? READ THIS FIRST!
|
|
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
|
|
|
__all__ = ["col2im"]
|
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
|
|
|
|
|
|
@_onnx_symbolic("aten::col2im")
|
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
|
|
@_beartype.beartype
|
|
def col2im(
|
|
g,
|
|
input: _C.Value,
|
|
output_size: _C.Value,
|
|
kernel_size: _C.Value,
|
|
dilation: Sequence[int],
|
|
padding: Sequence[int],
|
|
stride: Sequence[int],
|
|
):
|
|
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
|
|
adjusted_padding = []
|
|
for pad in padding:
|
|
for _ in range(2):
|
|
adjusted_padding.append(pad)
|
|
|
|
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
|
|
if not adjusted_padding:
|
|
adjusted_padding = [0, 0] * num_dimensional_axis
|
|
|
|
if not dilation:
|
|
dilation = [1] * num_dimensional_axis
|
|
|
|
if not stride:
|
|
stride = [1] * num_dimensional_axis
|
|
|
|
return g.op(
|
|
"Col2Im",
|
|
input,
|
|
output_size,
|
|
kernel_size,
|
|
dilations_i=dilation,
|
|
pads_i=adjusted_padding,
|
|
strides_i=stride,
|
|
)
|