from __future__ import annotations import functools import sys from typing import Optional, Tuple import torch from torch._C import _onnx as _C_onnx from torch.onnx import ( _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9, utils, ) from torch.onnx._internal import _beartype, jit_utils, registration # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md # This file exports ONNX ops for opset 12 __all__ = [ "argmax", "argmin", "binary_cross_entropy_with_logits", "celu", "cross_entropy_loss", "dropout", "einsum", "ge", "le", "native_dropout", "nll_loss", "nll_loss2d", "nll_loss_nd", "outer", "pow", "tensordot", "unfold", ] _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12) @_beartype.beartype def _einsum_helper(g: jit_utils.GraphContext, equation, tensors): if not tensors: raise RuntimeError("Einsum inputs are empty.") # ONNX does not support bool for Einsum inputs. if symbolic_helper._is_bool(tensors[0]): tensors = [ g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64) for tensor in tensors ] return g.op( "Cast", g.op("Einsum", *tensors, equation_s=equation), to_i=_C_onnx.TensorProtoDataType.BOOL, ) else: return g.op("Einsum", *tensors, equation_s=equation) @_onnx_symbolic("aten::einsum") @symbolic_helper.parse_args("s", "v", "is") @_beartype.beartype def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None): tensors = symbolic_helper._unpack_list(tensor_list) return _einsum_helper(g, equation, tensors) @_onnx_symbolic("aten::outer") @symbolic_helper.parse_args("v", "v") @_beartype.beartype def outer(g: jit_utils.GraphContext, input, other): # make sure to cast other to self's type if _type_utils.JitScalarType.from_value( other, _type_utils.JitScalarType.UNDEFINED ) != _type_utils.JitScalarType.from_value(input): other = g.op( "Cast", other, to_i=_type_utils.JitScalarType.from_value(input).onnx_type(), ) return _einsum_helper(g, "i,j->ij", [input, other]) @_beartype.beartype def _dropout_returns_masked_input_and_mask( g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool ) -> Tuple[torch._C.Value, Optional[torch._C.Value]]: symbolic_helper.check_training_mode(train, "dropout") # In eval mode, dropout is non-op. That is, if the node's # train param is set to False, dropout just returns its inputs. if not train: return input, None p = g.op("Constant", value_t=torch.tensor(p)) t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool)) r, mask = g.op("Dropout", input, p, t, outputs=2) return r, mask @_onnx_symbolic("aten::dropout") @symbolic_helper.parse_args("v", "f", "b") @_beartype.beartype def dropout(g: jit_utils.GraphContext, input, p, train): masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train) return masked @_onnx_symbolic("aten::native_dropout") @symbolic_helper.parse_args("v", "f", "b") @_beartype.beartype def native_dropout(g: jit_utils.GraphContext, input, p, train): return _dropout_returns_masked_input_and_mask(g, input, p, train) @_onnx_symbolic("aten::nll_loss") @_beartype.beartype def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index): # none reduction : onnx::Constant[value={0}] # mean reduction : onnx::Constant[value={1}] # sum reduction : onnx::Constant[value={2}] reduction = symbolic_helper._maybe_get_const(reduction, "i") reduction_vals = ["none", "mean", "sum"] reduction = reduction_vals[reduction] # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") if weight.node().mustBeNone(): nllloss = g.op( "NegativeLogLikelihoodLoss", self, target, reduction_s=reduction, ignore_index_i=ignore_index, ) else: nllloss = g.op( "NegativeLogLikelihoodLoss", self, target, weight, reduction_s=reduction, ignore_index_i=ignore_index, ) return nllloss @_onnx_symbolic("aten::nll_loss2d") @_beartype.beartype def nll_loss2d( g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index ): return nll_loss(g, self, target, weight, reduction, ignore_index) @_onnx_symbolic("aten::nll_loss_nd") @_beartype.beartype def nll_loss_nd( g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index ): return nll_loss(g, self, target, weight, reduction, ignore_index) @_onnx_symbolic("aten::cross_entropy_loss") @_beartype.beartype def cross_entropy_loss( g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index, label_smoothing, ): # none reduction : onnx::Constant[value={0}] # mean reduction : onnx::Constant[value={1}] # sum reduction : onnx::Constant[value={2}] reduction = symbolic_helper._maybe_get_const(reduction, "i") reduction_vals = ["none", "mean", "sum"] reduction = reduction_vals[reduction] label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f") if label_smoothing is not None and label_smoothing > 0.0: raise errors.SymbolicValueError( "Unsupported: ONNX does not support label_smoothing", self ) # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value. # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") if weight.node().mustBeNone(): celoss = g.op( "SoftmaxCrossEntropyLoss", self, target, reduction_s=reduction, ignore_index_i=ignore_index, ) else: celoss = g.op( "SoftmaxCrossEntropyLoss", self, target, weight, reduction_s=reduction, ignore_index_i=ignore_index, ) return celoss @_onnx_symbolic("aten::binary_cross_entropy_with_logits") @symbolic_helper.parse_args("v", "v", "v", "v", "i") @_beartype.beartype def binary_cross_entropy_with_logits( g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction ): p = g.op("Constant", value_t=torch.tensor([1])) sig_x = opset9.sigmoid(g, input) log_sig_x = opset9.log(g, sig_x) sub_1_x = opset9.sub(g, p, sig_x) sub_1_y = opset9.sub(g, p, target) log_1_x = opset9.log(g, sub_1_x) if pos_weight is None or symbolic_helper._is_none(pos_weight): output = opset9.neg( g, opset9.add( g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x) ), ) else: output = opset9.neg( g, opset9.add( g, opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight), opset9.mul(g, sub_1_y, log_1_x), ), ) if weight is not None and not symbolic_helper._is_none(weight): output = opset9.mul(g, weight, output) reduction = symbolic_helper._maybe_get_const(reduction, "i") if reduction == 0: return output elif reduction == 1: return g.op("ReduceMean", output, keepdims_i=0) elif reduction == 2: return g.op("ReduceSum", output, keepdims_i=0) else: return symbolic_helper._onnx_unsupported( "binary_cross_entropy_with_logits with reduction other than none, mean, or sum", input, ) @_onnx_symbolic("aten::celu") @_beartype.beartype def celu(g: jit_utils.GraphContext, self, alpha): alpha = symbolic_helper._maybe_get_const(alpha, "f") # if the input is of type double cast it to float if ( _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) == _type_utils.JitScalarType.DOUBLE ): self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) out = g.op("Celu", self, alpha_f=alpha) return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE) return g.op("Celu", self, alpha_f=alpha) @_onnx_symbolic("aten::argmax") @symbolic_helper.parse_args("v", "v", "b") @_beartype.beartype def argmax( g: jit_utils.GraphContext, input: torch._C.Value, dim: torch._C.Value, keepdim: bool, ): return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") @_onnx_symbolic("aten::argmin") @symbolic_helper.parse_args("v", "v", "b") @_beartype.beartype def argmin( g: jit_utils.GraphContext, input: torch._C.Value, dim: torch._C.Value, keepdim: bool, ): return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") @_onnx_symbolic("aten::pow") @_beartype.beartype def pow(g: jit_utils.GraphContext, self, exponent): return g.op("Pow", self, exponent) @_onnx_symbolic("aten::ge") @_beartype.beartype def ge(g: jit_utils.GraphContext, input, other): return g.op("GreaterOrEqual", input, other) @_onnx_symbolic("aten::le") @_beartype.beartype def le(g: jit_utils.GraphContext, input, other): return g.op("LessOrEqual", input, other) @_onnx_symbolic("aten::unfold") @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def unfold(g: jit_utils.GraphContext, input, dimension, size, step): const_size = symbolic_helper._maybe_get_const(size, "i") const_step = symbolic_helper._maybe_get_const(step, "i") if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value( const_step ): return opset9.unfold(g, input, dimension, const_size, const_step) if symbolic_helper.is_caffe2_aten_fallback(): return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) if sizedim is not None: low_start = g.op("Constant", value_t=torch.tensor(0)) low_end = g.op("Constant", value_t=torch.tensor(sizedim)) hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) low_indices = g.op("Range", low_start, low_end, step) hi_indices = g.op("Range", size, hi_end, step) low_size = symbolic_helper._size_helper( g, low_indices, g.op("Constant", value_t=torch.tensor(0)) ) hi_size = symbolic_helper._size_helper( g, hi_indices, g.op("Constant", value_t=torch.tensor(0)) ) ndim = symbolic_helper._get_tensor_rank(input) assert ndim is not None perm = list(range(0, ndim)) perm.append(perm.pop(dimension)) unsqueeze_list = [] loop_condition = g.op("Constant", value_t=torch.tensor(1)) loop_condition = g.op( "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL ) loop_len = g.op("Min", low_size, hi_size) loop, (loop_context,), _ = jit_utils.add_op_with_blocks( g, "Loop", loop_len, loop_condition, n_blocks=1 ) loop_block = loop_context.block block_input_iter = utils._add_input_to_block(loop_block) # FIXME(justinchuby): cond is unused? cond = utils._add_input_to_block(loop_block) starts = loop_context.op("Gather", low_indices, block_input_iter) ends = loop_context.op("Gather", hi_indices, block_input_iter) axes = loop_context.op("Constant", value_t=torch.tensor([2])) starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0]) ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0]) stack = loop_context.op("Slice", input, starts, ends, axes) unsqueeze = symbolic_helper._unsqueeze_helper( loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension] ) unsqueeze_list.append(unsqueeze) concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0) cond_out = loop_context.op( "Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL ) utils._add_output_to_block(loop_block, cond_out) utils._add_output_to_block(loop_block, concat) loop_output = loop.node().output() perm = [0, 1, 2, 3, 4] perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] transpose = g.op("Transpose", loop_output, perm_i=perm) squeeze = symbolic_helper._squeeze_helper(g, transpose, [0]) return squeeze return symbolic_helper._unimplemented("Unfold", "input size not accessible") @_onnx_symbolic("aten::tensordot") @symbolic_helper.parse_args("v", "v", "is", "is", "v") @_beartype.beartype def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None): if out is not None: symbolic_helper._unimplemented( "Tensordot", "Out parameter is not supported for tensordot." ) dim_count_a = symbolic_helper._get_tensor_rank(input_a) if dim_count_a is None: raise errors.SymbolicValueError( "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.", input_a, ) dim_count_b = symbolic_helper._get_tensor_rank(input_b) if dim_count_b is None: raise errors.SymbolicValueError( "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.", input_b, ) dims_a = [ (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] for i in range(len(dims_a)) ] dims_b = [ (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] for i in range(len(dims_b)) ] left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a) new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b) input_shape = g.op("Shape", new_input_a) left_sizes_a = symbolic_helper._slice_helper( g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)] ) shape_sizes = [ left_sizes_a, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), ] output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) input_shape = g.op("Shape", output_a) slices = symbolic_helper._slice_helper( g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] ) shape_sizes = [ g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices, ] output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) input_shape = g.op("Shape", new_input_b) left_sizes_b = symbolic_helper._slice_helper( g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize] ) slices = symbolic_helper._slice_helper( g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)] ) shape_sizes = [ slices, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), ] output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) input_shape = g.op("Shape", output_b) slices = symbolic_helper._slice_helper( g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] ) shape_sizes = [ g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices, ] output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) shape_sizes = [left_sizes_a, left_sizes_b] return opset9._reshape_from_tensor(g, output, shape_sizes)