You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

312 lines
6.5 KiB

5 months ago
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
# pointwise operators can go through a faster pathway
tensor_magic_methods = ["add", ""]
pointwise_magic_methods_with_reverse = (
"add",
"sub",
"mul",
"floordiv",
"div",
"truediv",
"mod",
"pow",
"lshift",
"rshift",
"and",
"or",
"xor",
)
pointwise_magic_methods = (
*(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)),
"eq",
"gt",
"le",
"lt",
"ge",
"gt",
"ne",
"neg",
"pos",
"abs",
"invert",
"iadd",
"isub",
"imul",
"ifloordiv",
"idiv",
"itruediv",
"imod",
"ipow",
"ilshift",
"irshift",
"iand",
"ior",
"ixor",
"int",
"long",
"float",
"complex",
)
pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),)
pointwise = (
*(getattr(torch.Tensor, m) for m in pointwise_methods),
torch.nn.functional.dropout,
torch.where,
torch.Tensor.abs,
torch.abs,
torch.Tensor.acos,
torch.acos,
torch.Tensor.acosh,
torch.acosh,
torch.Tensor.add,
torch.add,
torch.Tensor.addcdiv,
torch.addcdiv,
torch.Tensor.addcmul,
torch.addcmul,
torch.Tensor.addr,
torch.addr,
torch.Tensor.angle,
torch.angle,
torch.Tensor.asin,
torch.asin,
torch.Tensor.asinh,
torch.asinh,
torch.Tensor.atan,
torch.atan,
torch.Tensor.atan2,
torch.atan2,
torch.Tensor.atanh,
torch.atanh,
torch.Tensor.bitwise_and,
torch.bitwise_and,
torch.Tensor.bitwise_left_shift,
torch.bitwise_left_shift,
torch.Tensor.bitwise_not,
torch.bitwise_not,
torch.Tensor.bitwise_or,
torch.bitwise_or,
torch.Tensor.bitwise_right_shift,
torch.bitwise_right_shift,
torch.Tensor.bitwise_xor,
torch.bitwise_xor,
torch.Tensor.ceil,
torch.ceil,
torch.celu,
torch.nn.functional.celu,
torch.Tensor.clamp,
torch.clamp,
torch.Tensor.clamp_max,
torch.clamp_max,
torch.Tensor.clamp_min,
torch.clamp_min,
torch.Tensor.copysign,
torch.copysign,
torch.Tensor.cos,
torch.cos,
torch.Tensor.cosh,
torch.cosh,
torch.Tensor.deg2rad,
torch.deg2rad,
torch.Tensor.digamma,
torch.digamma,
torch.Tensor.div,
torch.div,
torch.dropout,
torch.nn.functional.dropout,
torch.nn.functional.elu,
torch.Tensor.eq,
torch.eq,
torch.Tensor.erf,
torch.erf,
torch.Tensor.erfc,
torch.erfc,
torch.Tensor.erfinv,
torch.erfinv,
torch.Tensor.exp,
torch.exp,
torch.Tensor.exp2,
torch.exp2,
torch.Tensor.expm1,
torch.expm1,
torch.feature_dropout,
torch.Tensor.float_power,
torch.float_power,
torch.Tensor.floor,
torch.floor,
torch.Tensor.floor_divide,
torch.floor_divide,
torch.Tensor.fmod,
torch.fmod,
torch.Tensor.frac,
torch.frac,
torch.Tensor.frexp,
torch.frexp,
torch.Tensor.gcd,
torch.gcd,
torch.Tensor.ge,
torch.ge,
torch.nn.functional.gelu,
torch.nn.functional.glu,
torch.Tensor.gt,
torch.gt,
torch.Tensor.hardshrink,
torch.hardshrink,
torch.nn.functional.hardshrink,
torch.nn.functional.hardsigmoid,
torch.nn.functional.hardswish,
torch.nn.functional.hardtanh,
torch.Tensor.heaviside,
torch.heaviside,
torch.Tensor.hypot,
torch.hypot,
torch.Tensor.i0,
torch.i0,
torch.Tensor.igamma,
torch.igamma,
torch.Tensor.igammac,
torch.igammac,
torch.Tensor.isclose,
torch.isclose,
torch.Tensor.isfinite,
torch.isfinite,
torch.Tensor.isinf,
torch.isinf,
torch.Tensor.isnan,
torch.isnan,
torch.Tensor.isneginf,
torch.isneginf,
torch.Tensor.isposinf,
torch.isposinf,
torch.Tensor.isreal,
torch.isreal,
torch.Tensor.kron,
torch.kron,
torch.Tensor.lcm,
torch.lcm,
torch.Tensor.ldexp,
torch.ldexp,
torch.Tensor.le,
torch.le,
torch.nn.functional.leaky_relu,
torch.Tensor.lerp,
torch.lerp,
torch.Tensor.lgamma,
torch.lgamma,
torch.Tensor.log,
torch.log,
torch.Tensor.log10,
torch.log10,
torch.Tensor.log1p,
torch.log1p,
torch.Tensor.log2,
torch.log2,
torch.nn.functional.logsigmoid,
torch.Tensor.logical_and,
torch.logical_and,
torch.Tensor.logical_not,
torch.logical_not,
torch.Tensor.logical_or,
torch.logical_or,
torch.Tensor.logical_xor,
torch.logical_xor,
torch.Tensor.logit,
torch.logit,
torch.Tensor.lt,
torch.lt,
torch.Tensor.maximum,
torch.maximum,
torch.Tensor.minimum,
torch.minimum,
torch.nn.functional.mish,
torch.Tensor.mvlgamma,
torch.mvlgamma,
torch.Tensor.nan_to_num,
torch.nan_to_num,
torch.Tensor.ne,
torch.ne,
torch.Tensor.neg,
torch.neg,
torch.Tensor.nextafter,
torch.nextafter,
torch.Tensor.outer,
torch.outer,
torch.polar,
torch.Tensor.polygamma,
torch.polygamma,
torch.Tensor.positive,
torch.positive,
torch.Tensor.pow,
torch.pow,
torch.Tensor.prelu,
torch.prelu,
torch.nn.functional.prelu,
torch.Tensor.rad2deg,
torch.rad2deg,
torch.Tensor.reciprocal,
torch.reciprocal,
torch.Tensor.relu,
torch.relu,
torch.nn.functional.relu,
torch.nn.functional.relu6,
torch.Tensor.remainder,
torch.remainder,
torch.Tensor.round,
torch.round,
torch.rrelu,
torch.nn.functional.rrelu,
torch.Tensor.rsqrt,
torch.rsqrt,
torch.rsub,
torch.selu,
torch.nn.functional.selu,
torch.Tensor.sgn,
torch.sgn,
torch.Tensor.sigmoid,
torch.sigmoid,
torch.nn.functional.sigmoid,
torch.Tensor.sign,
torch.sign,
torch.Tensor.signbit,
torch.signbit,
torch.nn.functional.silu,
torch.Tensor.sin,
torch.sin,
torch.Tensor.sinc,
torch.sinc,
torch.Tensor.sinh,
torch.sinh,
torch.nn.functional.softplus,
torch.nn.functional.softshrink,
torch.Tensor.sqrt,
torch.sqrt,
torch.Tensor.square,
torch.square,
torch.Tensor.sub,
torch.sub,
torch.Tensor.tan,
torch.tan,
torch.Tensor.tanh,
torch.tanh,
torch.nn.functional.tanh,
torch.threshold,
torch.nn.functional.threshold,
torch.trapz,
torch.Tensor.true_divide,
torch.true_divide,
torch.Tensor.trunc,
torch.trunc,
torch.Tensor.xlogy,
torch.xlogy,
torch.rand_like,
)