You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1234 lines
37 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from __future__ import annotations
import functools
import sys
import warnings
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch._C._onnx as _C_onnx
import torch.onnx
from torch import _C
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx import (
_constants,
_type_utils,
errors,
symbolic_helper,
symbolic_opset9 as opset9,
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
# This file exports ONNX ops for opset 10
# Opset 10 is supported by ONNX release 1.5.0
# release on 04/24/19
__all__ = [
"dequantize",
"div",
"embedding_bag",
"fake_quantize_per_tensor_affine",
"flip",
"fmod",
"isfinite",
"isinf",
"nan_to_num",
"quantize_per_tensor",
"quantized_add_relu",
"quantized_add",
"quantized_cat",
"quantized_conv1d_relu",
"quantized_conv2d_relu",
"quantized_conv3d_relu",
"quantized_conv1d",
"quantized_conv2d",
"quantized_conv3d",
"quantized_conv_transpose1d",
"quantized_conv_transpose2d",
"quantized_conv_transpose3d",
"quantized_group_norm",
"quantized_hardswish",
"quantized_instance_norm",
"quantized_layer_norm",
"quantized_leaky_relu",
"quantized_linear",
"quantized_linear_relu",
"quantized_mul",
"quantized_sigmoid",
"slice",
"sort",
"topk",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10)
def _apply_params(*args, **kwargs):
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
def _apply(fn):
return fn(*args, **kwargs)
return _apply
@_onnx_symbolic("aten::div")
@_beartype.beartype
def div(g: jit_utils.GraphContext, self, other, *args):
if len(args) == 0:
return opset9.true_divide(g, self, other)
else:
return _div_rounding_mode(g, self, other, *args)
@symbolic_helper.parse_args("v", "v", "s")
@_beartype.beartype
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
if rounding_mode == "floor":
return _floor_divide(g, self, other)
else:
return opset9._div_rounding_mode(g, self, other, rounding_mode)
@_onnx_symbolic("aten::_floor_divide")
@_beartype.beartype
def _floor_divide(g: jit_utils.GraphContext, self, other):
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
out = opset9.true_divide(g, self, other)
return g.op("Floor", out)
else:
# Integer division does trunction rounding
div = g.op("Div", self, other)
# Division is negative if: self < 0 != other < 0
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero))
# For negative numbers with self % other != 0, subtract 1 to round down instead of up
mod = g.op("Mod", self, other, fmod_i=0)
fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
fixup = g.op("Sub", div, one)
return g.op("Where", fixup_mask, fixup, div)
@_onnx_symbolic("aten::sort")
@symbolic_helper.parse_args("v", "i", "i", "none")
@_beartype.beartype
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
@_onnx_symbolic("aten::topk")
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
@_beartype.beartype
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
return symbolic_helper._topk_helper(
g, self, k, dim, largest=largest, sorted=sorted, out=out
)
def _aten_max_pool_onnx(
g: jit_utils.GraphContext,
self: _C.Value,
kernel_shape: Sequence[int],
strides: Sequence[int],
pads: Sequence[int],
dilations: Sequence[int],
ceil_mode: bool,
unbatched_rank: int,
) -> _C.Value:
self_rank = g.op("Size", g.op("Shape", self))
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
self = g.op(
"Unsqueeze",
self,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
pool_result, _ = g.op(
"MaxPool",
self,
outputs=2,
ceil_mode_i=ceil_mode,
dilations_i=dilations,
kernel_shape_i=kernel_shape,
pads_i=pads,
strides_i=strides,
)
if self_rank == unbatched_rank:
pool_result = g.op(
"Squeeze",
pool_result,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
return pool_result
# For MaxPool
def _adjust_attributes_of_max_pool(
expand_size: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
padding: Union[Sequence[int], int],
dilation: Union[Sequence[int], int],
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:
"""Adjust attributes of avg_pool to match ONNX specification."""
if isinstance(dilation, int):
dilation = [dilation] * expand_size
if isinstance(kernel_size, int):
kernel_shape = [kernel_size] * expand_size
else:
kernel_shape = kernel_size # type: ignore[assignment]
if isinstance(padding, int):
pads = [padding] * expand_size * 2 # type: ignore[operator, assignment]
elif len(padding) == 1:
pads = padding * expand_size * 2 # type: ignore[operator, assignment]
elif len(padding) == 2:
# 2D padding
pads = padding * 2 # type: ignore[operator, assignment]
elif len(padding) == 3:
# 3D padding
pads = padding * 2 # type: ignore[operator, assignment]
else:
# When padding is already done for all dimensions,
# we don't need to double it
# eg: (1, 1, 1, 1, 1, 1)
pads = padding # type: ignore[assignment]
if isinstance(stride, int):
strides = [stride] * expand_size
elif not stride:
strides = kernel_shape
else:
strides = stride # type: ignore[assignment]
return (kernel_shape, strides, pads, dilation)
def _aten_max_pool_with_indices_onnx(
g: jit_utils.GraphContext,
self: _C.Value,
kernel_shape: Sequence[int],
strides: Sequence[int],
pads: Sequence[int],
dilations: Sequence[int],
ceil_mode: bool,
unbatched_rank: int,
n_dims_one: Sequence[int],
n_dims_zero: Sequence[int],
n_dims_axes: Sequence[int],
) -> Tuple[_C.Value, Sequence[int]]:
self_rank = g.op("Size", g.op("Shape", self))
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
self = g.op(
"Unsqueeze",
self,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
pool_result, indices = g.op(
"MaxPool",
self,
outputs=2,
ceil_mode_i=ceil_mode,
dilations_i=dilations,
kernel_shape_i=kernel_shape,
pads_i=pads,
strides_i=strides,
)
_, flatten_indices = g.op(
"MaxPool",
self,
outputs=2,
dilations_i=dilations,
kernel_shape_i=n_dims_one,
strides_i=n_dims_one,
)
ends = g.op("Constant", value_t=torch.tensor(n_dims_one))
starts = g.op("Constant", value_t=torch.tensor(n_dims_zero))
axes = g.op("Constant", value_t=torch.tensor(n_dims_axes))
delta = g.op("Slice", flatten_indices, starts, ends, axes)
indices = g.op("Sub", indices, delta)
if self_rank == unbatched_rank:
pool_result = g.op(
"Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64)
)
indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64))
return (pool_result, indices)
@_onnx_symbolic(
"aten::max_pool1d",
decorate=[_apply_params("max_pool1d", 1, return_indices=False)],
)
@_onnx_symbolic(
"aten::max_pool2d",
decorate=[_apply_params("max_pool2d", 2, return_indices=False)],
)
@_onnx_symbolic(
"aten::max_pool3d",
decorate=[_apply_params("max_pool3d", 3, return_indices=False)],
)
@_onnx_symbolic(
"aten::max_pool1d_with_indices",
decorate=[
_apply_params(
"max_pool1d_with_indices",
1,
return_indices=True,
)
],
)
@_onnx_symbolic(
"aten::max_pool2d_with_indices",
decorate=[
_apply_params(
"max_pool2d_with_indices",
2,
return_indices=True,
)
],
)
@_onnx_symbolic(
"aten::max_pool3d_with_indices",
decorate=[
_apply_params(
"max_pool3d_with_indices",
3,
return_indices=True,
)
],
)
@_beartype.beartype
def _max_pool(name: str, expand_size: int, return_indices: bool):
@symbolic_helper.quantized_args(True, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
def symbolic_fn(
g: jit_utils.GraphContext,
input: _C.Value,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Union[int, Sequence[int]],
dilation: Sequence[int],
ceil_mode: bool,
):
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)
if return_indices:
return _aten_max_pool_with_indices_onnx(
g,
input,
kernel_shape,
strides,
pads,
dilations,
ceil_mode,
expand_size + 1,
([1] * expand_size),
([0] * expand_size),
([2 + i for i in range(expand_size)]),
)
else:
return _aten_max_pool_onnx(
g,
input,
kernel_shape,
strides,
pads,
dilations,
ceil_mode,
expand_size + 1,
)
return symbolic_fn
# For AvgPool
def _adjust_attributes_of_avg_pool(
expand_size: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
padding: Union[Sequence[int], int],
) -> Tuple[Sequence[int], Sequence[int], Sequence[int]]:
"""Adjust attributes of avg_pool to match ONNX specification."""
if isinstance(kernel_size, int):
kernel_shape = [kernel_size] * expand_size
else:
kernel_shape = kernel_size # type: ignore[assignment]
if isinstance(padding, int):
pads = [padding] * expand_size * 2
elif len(padding) == 1:
pads = padding * expand_size * 2 # type: ignore[operator, assignment]
elif len(padding) == 2:
pads = padding * expand_size # type: ignore[operator, assignment]
else:
pads = padding * 2 # type: ignore[operator, assignment]
if isinstance(stride, int):
strides = [stride] * expand_size
elif not stride:
strides = kernel_shape
else:
strides = stride # type: ignore[assignment]
return (kernel_shape, strides, pads)
@_onnx_symbolic(
"aten::avg_pool1d",
decorate=[_apply_params("avg_pool1d", 1)],
)
@_onnx_symbolic(
"aten::avg_pool2d",
decorate=[_apply_params("avg_pool2d", 2)],
)
@_onnx_symbolic(
"aten::avg_pool3d",
decorate=[_apply_params("avg_pool3d", 3)],
)
@_beartype.beartype
def _avg_pool(name, expand_size):
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
@_beartype.beartype
def symbolic_fn(
g,
input: _C.Value,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Union[int, Sequence[int]],
ceil_mode: int,
count_include_pad: int,
divisor_override=None,
):
kernel_shape, strides, pads = _adjust_attributes_of_avg_pool(
expand_size, kernel_size, stride, padding
)
result = g.op(
"AveragePool",
input,
ceil_mode_i=ceil_mode,
count_include_pad_i=count_include_pad,
kernel_shape_i=kernel_shape,
pads_i=pads,
strides_i=strides,
)
return result
return symbolic_fn
@_onnx_symbolic(
"aten::upsample_nearest1d",
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest2d",
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest3d",
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_linear1d",
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
)
@_onnx_symbolic(
"aten::upsample_bilinear2d",
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
)
@_onnx_symbolic(
"aten::upsample_trilinear3d",
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
)
@_beartype.beartype
def _interpolate(name, dim, interpolate_mode):
@symbolic_helper.quantized_args(True, False, False)
@_beartype.beartype
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = symbolic_helper._get_interpolate_attributes(
g, interpolate_mode, args
)
symbolic_helper._interpolate_warning(interpolate_mode)
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
if align_corners:
return symbolic_helper._unimplemented(name, "align_corners == True", input)
if scales is None:
scales = symbolic_helper._interpolate_size_to_scales(
g, input, output_size, dim
)
return g.op("Resize", input, scales, mode_s=interpolate_mode)
return symbolic_fn
@_onnx_symbolic("aten::__interpolate")
@_beartype.beartype
def __interpolate(
g: jit_utils.GraphContext,
input,
size,
scale_factor,
mode,
align_corners,
recompute_scale_factor,
antialias,
):
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
g, input, size, scale_factor, mode, align_corners
)
return g.op("Resize", input, scales, mode_s=mode)
@_beartype.beartype
def _slice(
g: jit_utils.GraphContext,
input: torch._C.Value,
axes: Union[List, torch.Tensor, torch._C.Value],
starts: Union[List, torch.Tensor, torch._C.Value],
ends: Union[List, torch.Tensor, torch._C.Value],
steps: Optional[Union[List, torch.Tensor, torch._C.Value]] = None,
):
def is_none_value(value):
if value is None:
return True
return (
isinstance(value, torch._C.Value)
and value.node().kind() == "prim::Constant"
and isinstance(value.type(), _C.NoneType)
)
def to_slice_input(list_or_value, default_value=None):
# Convert input param into a 1D torch.Value.
if is_none_value(list_or_value) and default_value is not None:
list_or_value = [default_value]
if isinstance(list_or_value, (list, torch.Tensor)):
return g.op("Constant", value_t=torch.tensor(list_or_value))
rank = symbolic_helper._get_tensor_rank(list_or_value)
if rank == 0:
return symbolic_helper._unsqueeze_helper(g, list_or_value, [0])
if rank == 1:
return list_or_value
raise errors.SymbolicValueError(
f"Rank must be 0 or 1, not {rank}", list_or_value
)
def get_const_value(list_or_value):
if isinstance(list_or_value, (list, torch.Tensor)):
if len(list_or_value) == 1:
return list_or_value[0]
return None
return symbolic_helper._maybe_get_const(list_or_value, "i")
# Check if slice is a no-op
if (
get_const_value(starts) == 0
and get_const_value(ends) == _constants.INT64_MAX
and (steps is None or get_const_value(steps) == 1)
):
return input
axes = to_slice_input(axes)
starts = to_slice_input(starts, default_value=0)
ends = to_slice_input(ends, default_value=_constants.INT64_MAX)
if steps is None:
return g.op("Slice", input, starts, ends, axes)
steps = to_slice_input(steps, default_value=1)
return g.op("Slice", input, starts, ends, axes, steps)
@_onnx_symbolic("aten::slice")
@_beartype.beartype
def slice(g: jit_utils.GraphContext, self, *args):
if len(args) == 4:
# aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
dims, start, end, step = args
elif len(args) == 3:
# aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[]
start, end, step = args
dims = [0]
else:
raise errors.SymbolicValueError("Unknown aten::slice signature", self)
return symbolic_helper._slice_helper(
g,
self,
axes=dims,
starts=start,
ends=end,
steps=step,
)
@_onnx_symbolic("aten::flip")
@symbolic_helper.parse_args("v", "is")
@_beartype.beartype
def flip(g: jit_utils.GraphContext, input, dims):
return symbolic_helper._slice_helper(
g,
input,
axes=dims,
starts=[-1] * len(dims),
ends=[-_constants.INT64_MAX] * len(dims),
steps=[-1] * len(dims),
)
@_onnx_symbolic("aten::fmod")
@_beartype.beartype
def fmod(g: jit_utils.GraphContext, input, other):
return g.op("Mod", input, other, fmod_i=1)
@_onnx_symbolic("aten::embedding_bag")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
@_beartype.beartype
def embedding_bag(
g: jit_utils.GraphContext,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
):
if scale_grad_by_freq and GLOBALS.export_training:
return symbolic_helper._onnx_unsupported(
"embedding_bag with scale_grad_by_freq for training mode"
)
if padding_idx is not None and padding_idx >= 0:
raise RuntimeError("embedding_bag with padding_idx")
warnings.warn(
"Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. "
"Please use opset 11 or higher to export model for dynamic input shape.'"
)
offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0)
if offsets_dim_0 is not None:
if include_last_offset:
offset_len = offsets_dim_0 - 1
offsets_extended = offsets
else:
offset_len = offsets_dim_0
offsets_extended = [
offsets,
g.op("Constant", value_t=torch.tensor([sys.maxsize])),
]
offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
list_ = []
for i in range(offset_len):
start_ = symbolic_helper._unsqueeze_helper(
g,
opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)),
[0],
)
end_ = symbolic_helper._unsqueeze_helper(
g,
opset9.select(
g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)
),
[0],
)
axes_ = g.op("Constant", value_t=torch.tensor([0]))
indices_row = g.op("Slice", indices, start_, end_, axes_)
embeddings = g.op("Gather", embedding_matrix, indices_row)
if not symbolic_helper._is_none(per_sample_weights):
per_sample_weights_row = g.op(
"Slice", per_sample_weights, start_, end_, axes_
)
per_sample_weights_row = symbolic_helper._unsqueeze_helper(
g, per_sample_weights_row, [1]
)
embeddings = g.op("Mul", embeddings, per_sample_weights_row)
if mode == 0:
embeddings = symbolic_helper._reducesum_helper(
g, embeddings, axes_i=[0], keepdims_i=0
)
elif mode == 1:
embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
else:
embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0])
list_.append(embeddings)
output = g.op("Concat", *list_, axis_i=0)
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
return output, None, None, None
else:
return symbolic_helper._onnx_unsupported(
"embedding_bag with unknown shape of offsets for opset 10 is not supported. "
"please use opset 11 or higher."
)
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
@_beartype.beartype
def fake_quantize_per_tensor_affine(
g: jit_utils.GraphContext,
inputs,
scale,
zero_point,
quant_min=-128,
quant_max=127,
):
# NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
if (quant_min, quant_max) == (0, 127):
symbolic_helper._onnx_opset_unsupported_detailed(
"fake_quantize_per_tensor_affine",
10,
13,
"Quantize range (0, 127) not supported, requires opset 13 Clip",
inputs,
)
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
raise errors.SymbolicValueError(
f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
f"Got ({quant_min}, {quant_max})",
inputs,
)
scale = symbolic_helper._maybe_get_scalar(scale)
if scale is None:
symbolic_helper._onnx_opset_unsupported_detailed(
"fake_quantize_per_tensor_affine",
10,
13,
"Non-constant scale not supported",
inputs,
)
scale = scale.float().data # Avoid exporter generating double type
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
else:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
return g.op(
"DequantizeLinear",
g.op("QuantizeLinear", inputs, scale, zero_point),
scale,
zero_point,
)
@_onnx_symbolic("aten::isinf")
@_beartype.beartype
def isinf(g: jit_utils.GraphContext, input):
return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE))
@_onnx_symbolic("aten::isfinite")
@_beartype.beartype
def isfinite(g: jit_utils.GraphContext, input):
inf_node = isinf(g, input)
nan_node = opset9.isnan(g, input)
return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node))
@_onnx_symbolic("aten::quantize_per_tensor")
@_beartype.beartype
def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
# TODO(justinchuby): Extract all the cast ops into a helper function.
zero_point = g.op(
"Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type()
)
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
return symbolic_helper.quantize_helper(g, input, scale, zero_point)
@_onnx_symbolic("aten::dequantize")
@_beartype.beartype
def dequantize(g: jit_utils.GraphContext, input):
return symbolic_helper.dequantize_helper(g, input)[0]
@_onnx_symbolic("aten::nan_to_num")
@symbolic_helper.parse_args("v", "f", "f", "f")
@_beartype.beartype
def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf):
# Cannot create a int type tensor with inf/nan values, so we simply
# return the original tensor
if not symbolic_helper._is_fp(input):
return input
input_dtype = _type_utils.JitScalarType.from_value(input).dtype()
if nan is None:
nan = 0.0
nan_cond = opset9.isnan(g, input)
nan_result = g.op(
"Where",
nan_cond,
g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)),
input,
)
# For None values of posinf, neginf we use the greatest/lowest finite
# value representable by inputs dtype.
finfo = torch.finfo(input_dtype)
if posinf is None:
posinf = finfo.max
posinf_cond = opset9.logical_and(
g,
isinf(g, nan_result),
opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))),
)
nan_posinf_result = g.op(
"Where",
posinf_cond,
g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)),
nan_result,
)
if neginf is None:
neginf = finfo.min
neginf_cond = opset9.logical_and(
g,
isinf(g, nan_posinf_result),
opset9.lt(
g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0]))
),
)
return g.op(
"Where",
neginf_cond,
g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)),
nan_posinf_result,
)
# Quantized symbolics ---------------------------------------------------------
# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were
# introduced in opset version 10.
@_onnx_symbolic("quantized::linear")
@_beartype.beartype
def quantized_linear(
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.linear(g, input, weight, bias)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::linear_relu")
@_beartype.beartype
def quantized_linear_relu(
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.linear(g, input, weight, bias)
output = opset9.relu(g, output)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::add")
@_beartype.beartype
def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
output = opset9.add(g, x, y)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::add_relu")
@_beartype.beartype
def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
output = opset9.add(g, x, y)
output = opset9.relu(g, output)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::mul")
@_beartype.beartype
def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
output = opset9.mul(g, x, y)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::hardswish")
@_beartype.beartype
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.hardswish(g, x)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::sigmoid")
@_beartype.beartype
def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.sigmoid(g, x)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::leaky_relu")
@_beartype.beartype
def quantized_leaky_relu(
g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point
):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.leaky_relu(g, x, negative_slope, inplace)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::layer_norm")
@_beartype.beartype
def quantized_layer_norm(
g: jit_utils.GraphContext,
x,
normalized_shape,
weight,
bias,
eps,
op_scale,
op_zero_point,
):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::group_norm")
@_beartype.beartype
def quantized_group_norm(
g: jit_utils.GraphContext,
x,
num_groups,
weight,
bias,
eps,
op_scale,
op_zero_point,
):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::instance_norm")
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
@_beartype.beartype
def quantized_instance_norm(
g: jit_utils.GraphContext,
q_input,
weight,
bias,
eps,
op_scale,
op_zero_point,
):
input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input)
output = opset9.instance_norm(
g, input, weight, bias, None, None, False, 0.0, eps, False
)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv1d_relu")
@_beartype.beartype
def quantized_conv1d_relu(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
output = opset9.relu(g, output)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv2d_relu")
@_beartype.beartype
def quantized_conv2d_relu(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
output = opset9.relu(g, output)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv3d_relu")
@_beartype.beartype
def quantized_conv3d_relu(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups)
output = opset9.relu(g, output)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv1d")
@_beartype.beartype
def quantized_conv1d(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv2d")
@_beartype.beartype
def quantized_conv2d(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv3d")
@_beartype.beartype
def quantized_conv3d(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv_transpose1d")
@_beartype.beartype
def quantized_conv_transpose1d(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
output_padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv_transpose2d(
g, input, weight, bias, stride, padding, output_padding, groups, dilation
)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv_transpose2d")
@_beartype.beartype
def quantized_conv_transpose2d(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
output_padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv_transpose2d(
g, input, weight, bias, stride, padding, output_padding, groups, dilation
)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv_transpose3d")
@_beartype.beartype
def quantized_conv_transpose3d(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
output_padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv_transpose3d(
g, input, weight, bias, stride, padding, output_padding, groups, dilation
)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::cat")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def quantized_cat(
g: jit_utils.GraphContext,
q_inputs: _C.Value,
dim: int,
op_scale: _C.Value,
op_zero_point: _C.Value,
) -> _C.Value:
unpacked_inputs = symbolic_helper._unpack_list(q_inputs)
dequantized = [
symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs
]
concatenated = g.op("Concat", *dequantized, axis_i=dim)
return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)