You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

83 lines
2.9 KiB

5 months ago
"""This file exports ONNX ops for opset 15.
Note [ONNX operators that are added/updated in opset 15]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set
New operators:
Bernoulli
CastLike
Optional
OptionalGetElement
OptionalHasElement
Updated operators:
BatchNormalization https://github.com/onnx/onnx/pull/3545
Backwards compatible
TODO: test coverage for mixed types inputs.
Pow https://github.com/onnx/onnx/pull/3412
Backwards compatible
TODO: bfloat16 support.
Shape https://github.com/onnx/onnx/pull/3580
Backwards compatible
TODO: optional start/end attribute.
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
import functools
import torch
from torch import _C
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import _beartype, jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15)
@_onnx_symbolic("aten::__is_")
@_beartype.beartype
def aten__is_(g: jit_utils.GraphContext, self, other):
if symbolic_helper._is_none(other):
if isinstance(self.type(), _C.OptionalType):
none = g.op("OptionalHasElement", self)
return g.op("Not", none)
else:
return g.op("Constant", value_t=torch.BoolTensor([0]))
return opset9.eq(g, self, other)
@_onnx_symbolic("aten::__isnot_")
@opset9.wrap_logical_op_with_negation # type: ignore[has-type]
@_beartype.beartype
def aten__isnot_(g: jit_utils.GraphContext, self, other):
return aten__is_(g, self, other)
@_onnx_symbolic("aten::bernoulli")
@_beartype.beartype
def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None):
if out is not None and not symbolic_helper._is_none(out):
symbolic_helper._unimplemented(
"Bernoulli", "out parameter is not supported for bernoulli", input
)
if generator is not None and not symbolic_helper._is_none(generator):
symbolic_helper._unimplemented(
"Bernoulli", "generator is not supported for bernoulli", input
)
if p is None or symbolic_helper._is_none(p):
return g.op("Bernoulli", input)
return opset9.bernoulli(g, input, p, generator, out)
@_onnx_symbolic("prim::unchecked_cast")
@_beartype.beartype
def prim_unchecked_cast(g: jit_utils.GraphContext, self):
# exists to refine the type of the Value
# if x is Optional[Tensor], unchecked_cast will cast
# x to Tensor, so the rest of the graph knows that x is a Tensor.
if isinstance(self.type(), _C.OptionalType):
return g.op("OptionalGetElement", self)
return self