# EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md # This file exports ONNX ops for opset 13 import functools import torch import torch._C._onnx as _C_onnx from torch.onnx import ( _constants, _type_utils, errors, symbolic_helper, symbolic_opset11 as opset11, symbolic_opset9 as opset9, utils, ) from torch.onnx._internal import _beartype, jit_utils, registration _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13) def _apply_params(*args, **kwargs): """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" def _apply(fn): return fn(*args, **kwargs) return _apply @_onnx_symbolic("aten::softmax") @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): softmax = g.op("Softmax", input, axis_i=dim) if dtype and dtype.node().kind() != "prim::Constant": parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") softmax = g.op( "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() ) return softmax @_onnx_symbolic("aten::log_softmax") @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): return_op = g.op("LogSoftmax", input, axis_i=dim) if dtype and dtype.node().kind() != "prim::Constant": parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") return_op = g.op( "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() ) return return_op @_onnx_symbolic("aten::frobenius_norm") @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): dim_val = symbolic_helper._maybe_get_const(dim, "is") if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0: return g.op("ReduceL2", self, keepdims_i=0) sqr = g.op("Mul", self, self) sumsqr = symbolic_helper._reducesum_helper(g, sqr, dim, keepdims_i=keepdim) return g.op("Sqrt", sumsqr) @_onnx_symbolic("aten::split") @symbolic_helper.parse_args("v", "v", "i", "i") @_beartype.beartype def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) if _outputs is None: return split_out # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. if ( symbolic_helper._is_packed_list(split_size_or_sizes) and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs ): split_sizes = [ symbolic_helper._unsqueeze_helper(g, v, [0]) for v in symbolic_helper._unpack_list(split_size_or_sizes) ] start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) res = [] for i in range(_outputs): end = g.op( "Add", start, split_sizes[i] ) # split_sizes is a list of same length as _outputs res.append(g.op("Slice", self, start, end, axis)) start = end return res return [ g.op( "SequenceAt", split_out, g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), ) for i in range(_outputs) ] split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") if split_val.dim() > 0: return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") size = symbolic_helper._get_tensor_dim_size(self, dim) if size is None: if _outputs is not None: size = split_size * _outputs else: raise errors.SymbolicValueError( "Unknown dimension size not supported", self ) splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: splits.append(leftover) splits = g.op("Constant", value_t=torch.tensor(splits)) return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) @_onnx_symbolic("aten::split_with_sizes") @_beartype.beartype def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): return split(g, self, split_sizes, dim, _outputs) @_onnx_symbolic("aten::unsafe_split") @_beartype.beartype def unsafe_split( g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None ): return split(g, self, split_size_or_sizes, dim, _outputs) @_onnx_symbolic("aten::unsafe_split_with_sizes") @_beartype.beartype def unsafe_split_with_sizes( g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None ): return split_with_sizes(g, self, split_sizes, dim, _outputs) @_onnx_symbolic("aten::tensor_split") @symbolic_helper.parse_args("v", "v", "i", "i") @_beartype.beartype def tensor_split( g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None ): axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) axis = opset11.unsqueeze(g, axis, 0) const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) if symbolic_helper._is_split_static(indices_or_sections, _outputs): split_val = symbolic_helper._node_get(indices_or_sections.node(), "value") if split_val.dim() > 0: start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) res = [] assert _outputs is not None for i in range(_outputs - 1): end = g.op( "Gather", indices_or_sections, g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), axis_i=0, ) res.append(g.op("Slice", self, start, end, axis)) start = end end = symbolic_helper._size_helper(g, self, axis) res.append(g.op("Slice", self, start, end, axis)) return res split_size = symbolic_helper._get_const( indices_or_sections, "i", "indices_or_sections" ) size = symbolic_helper._get_tensor_dim_size(self, dim) if size is None: if _outputs is not None: size = split_size * _outputs else: raise errors.SymbolicValueError( "Unknown dimension size not supported", self ) min_split_size = size // split_size num_splits_one_extra = size % split_size splits = num_splits_one_extra * [min_split_size + 1] leftover = (split_size - num_splits_one_extra) * [min_split_size] splits = g.op( "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long) ) return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) if ( symbolic_helper._is_tensor(indices_or_sections) and symbolic_helper._get_tensor_rank(indices_or_sections) == 1 ): loop_len = symbolic_helper._size_helper( g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0)) ) loop_len = opset11.unsqueeze(g, loop_len, 0) loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL) # To make the first slice in the below loop work, # we pad a zero to the first position so that it will be the initial start of slice. padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0) final_splits = g.op("SequenceEmpty") # Loop inputs loop, (loop_context,), _ = jit_utils.add_op_with_blocks( g, "Loop", loop_len, loop_condition, final_splits, outputs=1, n_blocks=1 ) loop_block = loop_context.block block_input_iter = utils._add_input_to_block(loop_block) cond = utils._add_input_to_block(loop_block) final_splits = utils._add_input_to_block(loop_block) start = loop_context.op( "Gather", indices_or_sections, block_input_iter, axis_i=0 ) end = loop_context.op( "Gather", indices_or_sections, loop_context.op("Add", block_input_iter, const_1), axis_i=0, ) slice = loop_context.op("Slice", self, start, end, axis) final_splits = loop_context.op("SequenceInsert", final_splits, slice) # Loop outputs cond_out = loop_context.op("Identity", loop_condition) utils._add_output_to_block(loop_block, cond_out) utils._add_output_to_block(loop_block, final_splits) loop_out = loop.node().output() start = g.op( "Gather", indices_or_sections, g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)), axis_i=0, ) start = opset11.unsqueeze(g, start, 0) end = symbolic_helper._size_helper(g, self, axis) last_slice = g.op("Slice", self, start, end, axis) return g.op("SequenceInsert", loop_out, last_slice) else: # scalar tensor dim_size = symbolic_helper._size_helper(g, self, axis) min_split_size = g.op("Div", dim_size, indices_or_sections) min_split_size_plus_1 = g.op( "Add", min_split_size, const_1, ) num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections) splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra) leftover = g.op( "Tile", min_split_size, g.op( "Sub", opset11.unsqueeze(g, indices_or_sections, 0), num_splits_one_extra, ), ) splits = g.op("Concat", splits, leftover, axis_i=0) if _outputs is None: return g.op("SplitToSequence", self, splits, axis_i=dim) return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) @_onnx_symbolic("aten::unbind") @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): if _outputs is None: return g.op( "SplitToSequence", self, g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), axis_i=dim, keepdims_i=0, ) splits = g.op("Constant", value_t=torch.tensor([1] * _outputs)) outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs) outputs = [outputs] if _outputs == 1 else outputs squeezed_outputs = [ g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim]))) for out in outputs ] return squeezed_outputs @_onnx_symbolic("aten::nonzero_numpy") # Emitted from `torch.nonzero(x, as_tuple=True)` @_beartype.beartype def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs) @_onnx_symbolic("aten::where") @symbolic_helper.parse_args("v", "v", "v", "i") @_beartype.beartype def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): # Assumes that torch.where's first argument takes only Bool and Byte tensors. if not symbolic_helper._is_bool(condition): condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) if self is None: condition = opset9.nonzero(g, condition) return symbolic_helper._unbind_helper( g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs ) return g.op("Where", condition, self, other) @_onnx_symbolic("aten::fake_quantize_per_channel_affine") @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i") @_beartype.beartype def fake_quantize_per_channel_affine( g: jit_utils.GraphContext, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127, ): # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: raise errors.SymbolicValueError( "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " f"Got ({quant_min}, {quant_max})", inputs, ) # ONNX defines zero_point to be int8 or uint8 if quant_min == 0: zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) else: zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis) if (quant_min, quant_max) == (0, 127): quantized = g.op( "Clip", quantized, opset9.unused(g), g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), ) return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis) @_onnx_symbolic("aten::fake_quantize_per_tensor_affine") @symbolic_helper.parse_args("v", "v", "v", "i", "i") @_beartype.beartype def fake_quantize_per_tensor_affine( g: jit_utils.GraphContext, inputs, scale, zero_point, quant_min=-128, quant_max=127, ): # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: raise errors.SymbolicValueError( "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " f"Got ({quant_min}, {quant_max})", inputs, ) if quant_min == 0: zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) else: zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) if ( _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) != _type_utils.JitScalarType.FLOAT ): scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) quantized = g.op("QuantizeLinear", inputs, scale, zero_point) if (quant_min, quant_max) == (0, 127): quantized = g.op( "Clip", quantized, opset9.unused(g), g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), ) return g.op("DequantizeLinear", quantized, scale, zero_point) @_beartype.beartype def _reduce_op_symbolic(onnx_op_name): @_beartype.beartype def symbolic(g, self, dim=None, keepdim=None): self = opset9._maybe_cast_reduce_op_input(g, self) if dim is None: # all-reduce path return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name) else: keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") return g.op(onnx_op_name, self, dim, keepdims_i=keepdim) return symbolic @_onnx_symbolic( "aten::sum", decorate=[_apply_params("ReduceSum", "sum")], ) @_beartype.beartype def _reduce_with_dtype(onnx_op, name): symbolic = _reduce_op_symbolic(onnx_op) @opset9.overload_by_arg_count @_beartype.beartype def reduce(g, *args, **kwargs): @symbolic_helper.parse_args("v", "none") @_beartype.beartype def reduce_nodim(g, self, dtype): dtype_onnx = None if dtype.node().kind() == "onnx::Constant": dtype = symbolic_helper._get_const(dtype, "i", "dtype") dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() self = g.op("Cast", self, to_i=dtype_onnx) elif dtype.node().kind() != "prim::Constant": return symbolic_helper._unimplemented(name, "dtype", dtype) result = symbolic(g, self) if dtype_onnx is not None: result_dtype_onnx = _type_utils.JitScalarType.from_value( result ).onnx_type() if result_dtype_onnx != dtype_onnx: result = g.op("Cast", result, to_i=dtype_onnx) return result @symbolic_helper.parse_args("v", "v", "i", "none") @_beartype.beartype def reduce_dim(g, self, dim, keepdim, dtype): dtype_onnx = None if dtype.node().kind() == "onnx::Constant": dtype = symbolic_helper._get_const(dtype, "i", "dtype") dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() self = g.op("Cast", self, to_i=dtype_onnx) elif dtype.node().kind() != "prim::Constant": return symbolic_helper._unimplemented(name, "dtype", dtype) result = symbolic(g, self, dim, keepdim) if dtype_onnx is not None: result_dtype_onnx = _type_utils.JitScalarType.from_value( result ).onnx_type() if result_dtype_onnx != dtype_onnx: result = g.op("Cast", result, to_i=dtype_onnx) return result return reduce_nodim, reduce_dim return reduce # Ported from # https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/core.py#L6097 # NOTE: Supporting aten::unflatten before opset13 needs helper function to adjust ONNX op changes in Concat, Slice, ... @_onnx_symbolic("aten::unflatten") @_beartype.beartype def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size): input_dim = symbolic_helper._get_tensor_rank(input) if input_dim is None: return symbolic_helper._unimplemented( "dim", "ONNX and PyTorch use different strategies to split the input. " "Input rank must be known at export time.", ) # dim could be negative input_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64)) dim = g.op("Add", input_dim, dim) dim = g.op("Mod", dim, input_dim) input_size = g.op("Shape", input) head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) head_end_idx = g.op( "Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) ) head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx) dim_plus_one = g.op( "Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) ) tail_start_idx = g.op( "Reshape", dim_plus_one, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)), ) tail_end_idx = g.op( "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) ) tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx) final_shape = g.op( "Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0 ) return symbolic_helper._reshape_helper(g, input, final_shape) @_onnx_symbolic("aten::unsafe_chunk") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): if _outputs is None: return g.op( "SplitToSequence", self, g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), axis_i=dim, keepdims_i=0, ) size = symbolic_helper._get_tensor_dim_size(self, dim) if size is None: return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size") split_size = (size + chunks - 1) // chunks splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: splits.append(leftover) # TODO: So far we don"t have a module using this method. We"ll keep # this as a constant unless we see a request of dynamics in any # user's modules. splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long)) return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) @_onnx_symbolic("aten::tile") @_beartype.beartype def tile(g: jit_utils.GraphContext, self, dims): self_shape = g.op("Shape", self) self_rank = g.op("Size", self_shape) dims_rank = g.op("Size", dims) diff = g.op("Sub", self_rank, dims_rank) const_zero = g.op("Constant", value_t=torch.tensor([0])) # 1. If dims is shorter than self.shape pad dims with 1 dims_shorter_than_self_shape = g.op("Greater", diff, const_zero) ( if_op_greater, (if_context_greater, else_context_greater), _, ) = jit_utils.add_op_with_blocks( g, "If", dims_shorter_than_self_shape, n_blocks=2, outputs=1 ) const_one = if_context_greater.op("Constant", value_t=torch.LongTensor([1])) diff_1d_greater = if_context_greater.op("Reshape", diff, const_one) exapnd_ones_greater = if_context_greater.op("Expand", const_one, diff_1d_greater) dims_ = if_context_greater.op("Concat", exapnd_ones_greater, dims, axis_i=0) utils._add_output_to_block(if_context_greater.block, dims_) identity_dim = else_context_greater.op("Identity", dims) utils._add_output_to_block(else_context_greater.block, identity_dim) dims_final = if_op_greater.node().output() # 2. If dims is longer than self.shape pad self.shape with 1 dims_longer_than_self_shape = g.op("Less", diff, const_zero) ( if_op_less, (if_context_less, else_context_less), _, ) = jit_utils.add_op_with_blocks( g, "If", dims_longer_than_self_shape, n_blocks=2, outputs=1 ) const_one = if_context_less.op("Constant", value_t=torch.LongTensor([1])) diff_1d_less = if_context_less.op( "Reshape", if_context_less.op("Abs", diff), const_one, ) exapnd_ones_less = if_context_less.op("Expand", const_one, diff_1d_less) self_final_shape = if_context_less.op( "Concat", exapnd_ones_less, self_shape, axis_i=0 ) self_ = if_context_less.op("Reshape", self, self_final_shape) utils._add_output_to_block(if_context_less.block, self_) identity_self = else_context_less.op("Identity", self) utils._add_output_to_block(else_context_less.block, identity_self) self_final = if_op_less.node().output() dims_final = g.op("Cast", dims_final, to_i=_C_onnx.TensorProtoDataType.INT64) return g.op("Tile", self_final, dims_final) @_onnx_symbolic("aten::repeat_interleave") @_beartype.beartype def repeat_interleave( g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None ): repeats_dim = symbolic_helper._get_tensor_rank(repeats) repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) input_sizes = symbolic_helper._get_tensor_sizes(self) if repeats_dim is None: raise errors.SymbolicValueError( "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", self, ) if repeats_sizes is None: raise errors.SymbolicValueError( "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", self, ) if input_sizes is None: raise errors.SymbolicValueError( "Unsupported: ONNX export of repeat_interleave for unknown input size.", self, ) final_dim = dim # if dim is None flatten # By default, use the flattened input array, and return a flat output array if symbolic_helper._is_none(dim): self = symbolic_helper._reshape_helper( g, self, g.op("Constant", value_t=torch.tensor([-1])) ) dim = torch.tensor(0, dtype=torch.int64) else: dim = symbolic_helper._maybe_get_scalar(dim) # Handle cases where dim is negative if dim < 0: dim += len(input_sizes) output_sizes = input_sizes.copy() for idx, input_size in enumerate(input_sizes): if input_size is None: output_sizes[idx], input_sizes[idx] = 0, -1 # Check if all indices should be repeated the same number of times. if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): return symbolic_helper._repeat_interleave_single_value_repeat_helper( g, self, repeats, dim ) cond_dynamic_repeats = repeats_dim == 1 and repeats_sizes[0] is None # If input size is dynamic or repeats vector is dynamic if output_sizes[dim] == 0 or cond_dynamic_repeats: reps = symbolic_helper._size_helper(g, self, dim) reps = opset11.unsqueeze(g, reps, 0) # Check if repeats is dynamic # As repeats is dynamic, we use a where node as a substitute for the if statement # If repests_dim = 1, expand repeats otherwise use original tensor if cond_dynamic_repeats: repeat_dim = symbolic_helper._size_helper( g, repeats, g.op("Constant", value_t=torch.LongTensor([0])) ) repeat_cond = g.op( "Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1])) ) repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats) # There are cases when the repeats are 1-d tensor with multiple repeats, but dim # provided along one of the dynamic axes provided. A simple example would be # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 # Now, repeat interleaving can be performed in pytorch when the value of * matches # with the number of elements in repeat, for example if * -> 2, number of repeats # should be 2 as well. else: return opset9.repeat_interleave(g, self, repeats, final_dim) reps_like = g.op( "ConstantOfShape", g.op("Shape", repeats), value_t=torch.tensor([1], dtype=torch.long), ) r_splits = split(g, repeats, reps_like, 0) i_splits = split(g, self, reps_like, dim) output_sizes[dim], input_sizes[dim] = -1, 1 # Create a loop to iterate over each value along the dimension # and perform individual interleaving using the repeats tensor # Loop is of the following pattern # input (trip_count, cond) # int trip_count = ...; # bool cond = ...; # for (int i=0; i < trip_count && cond; ++i) { # cond = ...; # } # Loop conditions loop_condition = g.op("Constant", value_t=torch.tensor(1)) loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) loop_len = reps # Create an empty sequence to store final expansions final_splits = g.op("SequenceEmpty") # Loop inputs loop, (loop_context,), _ = jit_utils.add_op_with_blocks( g, "Loop", loop_len, loop_condition, final_splits, n_blocks=1 ) loop_block = loop_context.block block_input_iter = utils._add_input_to_block(loop_block) cond = utils._add_input_to_block(loop_block) final_splits = utils._add_input_to_block(loop_block) r_split = loop_context.op("SequenceAt", r_splits, block_input_iter) i_split = loop_context.op("SequenceAt", i_splits, block_input_iter) i_split = opset11.unsqueeze(loop_context, i_split, dim + 1) r_concat = [ loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[: dim + 1])), r_split, loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1 :])), ] r_concat = loop_context.op("Concat", *r_concat, axis_i=0) i_split = opset9.expand(loop_context, i_split, r_concat, None) i_split = symbolic_helper._reshape_helper( loop_context, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes)) ) final_splits = loop_context.op("SequenceInsert", final_splits, i_split) # Loop outputs cond_out = loop_context.op( "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL ) utils._add_output_to_block(loop_block, cond_out) utils._add_output_to_block(loop_block, final_splits) loop_out = loop.node().output() loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim) return loop_out @_onnx_symbolic("aten::diagonal") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2): rank = symbolic_helper._get_tensor_rank(self) # Replace negative indexing when rank is known if rank is not None: dim1 = dim1 if dim1 >= 0 else dim1 + rank dim2 = dim2 if dim2 >= 0 else dim2 + rank dim1_size = opset9.size( g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1])) ) dim2_size = opset9.size( g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2])) ) # Create appropriate mask mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0) mask = opset9.zeros(g, mask_shape, None, None, None) mask = g.op("EyeLike", mask, k_i=offset) # dim1 and dim2 appended as a dimension at the end of the shape if rank is not None: axes = list(range(rank)) axes.remove(dim1) axes.remove(dim2) self = g.op("Transpose", self, perm_i=axes + [dim1, dim2]) else: return symbolic_helper._unimplemented("diagonal", "unknown input rank") # Multiply input and mask to calculate values along diagonal # The mask consists of one values where diagonal values are to be calculated # For example: # [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0], # [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0], # [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]] result = g.op("Mul", self, mask) result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0) # Calculate gather indices based on offset and dims # If offset is greater than zero, set offset to zero as this aids in # calculation of selection window offset_op = g.op("Constant", value_t=torch.LongTensor([offset])) if offset >= 0: diag_size = g.op( "Max", g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)), g.op("Constant", value_t=torch.LongTensor([0])), ) offset = 0 else: diag_size = g.op( "Max", g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size), g.op("Constant", value_t=torch.LongTensor([0])), ) diag_size = g.op("Concat", diag_size, axis_i=0) # Calculate which diagonal values to select # For example, in cases with offsets: # [[0, 1.1, 0] # [0, 0, 2.2]] # we need to select the last two columns, so we create a tensor # with all columns that are to be selected # So in this example, it is [1, 2] select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None) select_window = g.op( "CumSum", select_window_ones_fill, g.op("Constant", value_t=torch.LongTensor([0])), ) select_window = g.op( "Add", select_window, g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])), ) gather_shape = [ opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis]))) for axis in list(range(rank))[:-2] ] gather_shape.append(diag_size) gather_shape = g.op("Concat", *gather_shape, axis_i=0) gather_indices = opset9.zeros(g, gather_shape, 4, None, None) # There might be cases where offset value is greater than number of rows/columns # and might cause the diagonal to overrun and as a result of this, diag_size would be zero. # For example, if # offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows) # diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0 # In cases without diagonal overrun, we select the appropriate rows/columns along which we # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially # returning an empty tensor overrun_cond = g.op( "Not", g.op( "Equal", diag_size, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)), ), ) if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( g, "If", overrun_cond, n_blocks=2 ) gather_indices_if_block = if_context.op("Add", gather_indices, select_window) gather_indices_if_block = symbolic_helper._unsqueeze_helper( if_context, gather_indices_if_block, [rank - 1] ) final_non_overrun = if_context.op( "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2 ) final_overrun = opset9.zeros(else_context, gather_shape, 6, None, None) utils._add_output_to_block(if_context.block, final_non_overrun) utils._add_output_to_block(else_context.block, final_overrun) return if_op # Quantized ops @_onnx_symbolic("quantized::linear") @_beartype.beartype def quantized_linear( g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper( g, bias, input_scale, weight_scale, axis ) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.linear(g, input, weight, bias) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::linear_relu") @_beartype.beartype def quantized_linear_relu( g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper( g, bias, input_scale, weight_scale, axis ) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.linear(g, input, weight, bias) output = opset9.relu(g, output) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv1d_relu") @_beartype.beartype def quantized_conv1d_relu( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper( g, bias, input_scale, weight_scale, axis ) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) output = opset9.relu(g, output) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv2d_relu") @_beartype.beartype def quantized_conv2d_relu( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper( g, bias, input_scale, weight_scale, axis ) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) output = opset9.relu(g, output) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv3d_relu") @_beartype.beartype def quantized_conv3d_relu( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper( g, bias, input_scale, weight_scale, axis ) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) output = opset9.relu(g, output) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv1d") @_beartype.beartype def quantized_conv1d( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper( g, bias, input_scale, weight_scale, axis ) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv2d") @_beartype.beartype def quantized_conv2d( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper( g, bias, input_scale, weight_scale, axis ) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv3d") @_beartype.beartype def quantized_conv3d( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper( g, bias, input_scale, weight_scale, axis ) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv_transpose1d") @_beartype.beartype def quantized_conv_transpose1d( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, output_padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper( g, bias, input_scale, weight_scale, axis ) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv_transpose2d( g, input, weight, bias, stride, padding, output_padding, groups, dilation ) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv_transpose2d") @_beartype.beartype def quantized_conv_transpose2d( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, output_padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper( g, bias, input_scale, weight_scale, axis ) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv_transpose2d( g, input, weight, bias, stride, padding, output_padding, groups, dilation ) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv_transpose3d") @_beartype.beartype def quantized_conv_transpose3d( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, output_padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper( g, bias, input_scale, weight_scale, axis ) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv_transpose3d( g, input, weight, bias, stride, padding, output_padding, groups, dilation ) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)