import cmath import math import warnings from collections import OrderedDict from typing import Dict, Optional import torch import torch.backends.cudnn as cudnn from ..nn.modules.utils import _list_with_default, _pair, _quadruple, _single, _triple _builtin_table: Optional[Dict[int, str]] = None _modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950 _builtin_ops = [ # Pairs of (function, op_name) (_pair, "aten::_pair"), (_quadruple, "aten::_quadruple"), (_single, "aten::_single"), (_triple, "aten::_triple"), (_list_with_default, "aten::list_with_default"), (OrderedDict, "aten::dict"), (dict, "aten::dict"), (cudnn.is_acceptable, "aten::cudnn_is_acceptable"), (math.ceil, "aten::ceil"), (math.copysign, "aten::copysign"), (math.erf, "aten::erf"), (math.erfc, "aten::erfc"), (math.exp, "aten::exp"), (math.expm1, "aten::expm1"), (math.fabs, "aten::fabs"), (math.floor, "aten::floor"), (math.gamma, "aten::gamma"), (math.lgamma, "aten::lgamma"), (math.log, "aten::log"), (math.log10, "aten::log10"), (math.log1p, "aten::log1p"), (math.pow, "aten::pow"), (math.sqrt, "aten::sqrt"), (math.isnan, "aten::isnan"), (math.asinh, "aten::asinh"), (math.atanh, "aten::atanh"), (math.cosh, "aten::cosh"), (math.sinh, "aten::sinh"), (math.tanh, "aten::tanh"), (math.acos, "aten::acos"), (math.asin, "aten::asin"), (math.atan, "aten::atan"), (math.atan2, "aten::atan2"), (math.cos, "aten::cos"), (math.sin, "aten::sin"), (math.tan, "aten::tan"), (math.asinh, "aten::asinh"), (math.atanh, "aten::atanh"), (math.acosh, "aten::acosh"), (math.fmod, "aten::fmod"), (math.modf, "aten::modf"), (math.factorial, "aten::factorial"), (math.frexp, "aten::frexp"), (math.isinf, "aten::isinf"), (math.degrees, "aten::degrees"), (math.radians, "aten::radians"), (cmath.isnan, "aten::isnan"), (cmath.isfinite, "aten::isfinite"), (cmath.isinf, "aten::isinf"), (cmath.phase, "aten::angle"), (cmath.rect, "aten::polar"), (cmath.log, "aten::log"), (cmath.log10, "aten::log10"), (cmath.sqrt, "aten::sqrt"), (cmath.exp, "aten::exp"), (cmath.sin, "aten::sin"), (cmath.tan, "aten::tan"), (cmath.cos, "aten::cos"), (cmath.asin, "aten::asin"), (cmath.acos, "aten::acos"), (cmath.atan, "aten::atan"), (cmath.sinh, "aten::sinh"), (cmath.cosh, "aten::cosh"), (cmath.tanh, "aten::tanh"), (cmath.asinh, "aten::asinh"), (cmath.acosh, "aten::acosh"), (cmath.atanh, "aten::atanh"), (math.ldexp, "aten::ldexp"), (torch._assert, "aten::_assert"), (torch.autograd.grad, "aten::grad"), (torch.autograd.backward, "aten::backward"), (torch._C._infer_size, "aten::_infer_size"), (torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"), # type: ignore[attr-defined] (torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"), (torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"), (torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"), (torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"), (torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"), (torch._C._get_tracing_state, "aten::_get_tracing_state"), (torch._C._get_cpu_capability, "aten::_get_cpu_capability"), (warnings.warn, "aten::warn"), (torch._VF.stft, "aten::stft"), # type: ignore[attr-defined] (torch._VF.istft, "aten::istft"), # type: ignore[attr-defined] (torch._VF.cdist, "aten::cdist"), # type: ignore[attr-defined] (torch._VF.norm, "aten::norm"), # type: ignore[attr-defined] (torch._VF.unique_dim, "aten::unique_dim"), (torch._VF.unique_consecutive, "aten::unique_consecutive"), # type: ignore[attr-defined] (torch._VF.nuclear_norm, "aten::nuclear_norm"), (torch._VF.frobenius_norm, "aten::frobenius_norm"), (torch._VF.tensordot, "aten::tensordot"), # type: ignore[attr-defined] ] # ops in torch.functional are bound to torch # in these cases, we want to resolve the function to their python implementation # instead looking up a builtin "aten::" schema def _gen_torch_functional_registered_ops(): # eventually ops should encompass all of torch/functional.py, (torch.functional.__all__) # but we are currently only able to compile some of the functions. additionally, # some functions directly map to their aten:: implementations. # TODO: add support for more ops ops = [ "stft", "istft", "lu", "cdist", "norm", "unique", "unique_consecutive", "tensordot", ] return {getattr(torch.functional, name) for name in ops} _functional_registered_ops = _gen_torch_functional_registered_ops() def _is_special_functional_bound_op(fn): return fn in _functional_registered_ops # lazily built to ensure the correct initialization order def _get_builtin_table(): global _builtin_table if _builtin_table is not None: return _builtin_table _builtin_table = {} def register_all(mod): for name in dir(mod): v = getattr(mod, name) if ( callable(v) and not _is_special_functional_bound_op(v) and v is not torch.no_grad and v is not torch.autocast ): # Fixup inconsistency in segment_reduce if name == "_segment_reduce": name = name[1:] _builtin_ops.append((v, "aten::" + name)) for mod in _modules_containing_builtins: register_all(mod) _builtin_ops.append((math.gcd, "aten::gcd")) _builtin_ops.append((math.isfinite, "aten::isfinite")) _builtin_ops.append((math.remainder, "aten::mathremainder")) # type: ignore[attr-defined] import torch.distributed.autograd as dist_autograd if dist_autograd.is_available(): _builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients")) _builtin_ops.append((dist_autograd.backward, "aten::dist_backward")) # populate the _builtin_table from _builtin_ops for builtin, aten_op in _builtin_ops: _builtin_table[id(builtin)] = aten_op return _builtin_table def _register_builtin(fn, op): _get_builtin_table()[id(fn)] = op def _find_builtin(fn): return _get_builtin_table().get(id(fn))