You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
179 lines
6.2 KiB
179 lines
6.2 KiB
import itertools
|
|
import unittest.mock
|
|
from contextlib import contextmanager
|
|
from typing import Iterator
|
|
|
|
import torch
|
|
import torch._C
|
|
import torch._ops
|
|
import torch.utils._python_dispatch
|
|
import torch.utils._pytree as pytree
|
|
|
|
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
|
|
|
|
no_python_dispatcher = torch._C._DisablePythonDispatcher
|
|
enable_python_dispatcher = torch._C._EnablePythonDispatcher
|
|
enable_pre_dispatch = torch._C._EnablePreDispatch
|
|
|
|
CROSSREF_FUNCTIONALIZE = False
|
|
|
|
|
|
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
|
|
"""
|
|
Warning: the set of overloads this will report is very subtle. It is precisely
|
|
the set of torch.ops functions that have actually been accessed from Python
|
|
(e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
|
|
from the set of registered operators, which will in general be a larger set,
|
|
as this would include all operators which we ran C++ static initializers or
|
|
Python operator registration on. This does not eagerly populate the list on
|
|
torch.ops.aten; this list is lazy!
|
|
|
|
In other words, this is good for traversing over everything that has an
|
|
OpOverload object allocated in Python. We use it for cache invalidation, but
|
|
don't rely on this list being complete.
|
|
|
|
Note that even if we did report all C++ registered overloads, this isn't guaranteed
|
|
to be complete either, as a subsequent lazy load of a library which triggers more
|
|
registrations could add more things to the set.
|
|
"""
|
|
for ns in torch.ops:
|
|
packets = getattr(torch.ops, ns)
|
|
for op_name in packets:
|
|
packet = getattr(packets, op_name)
|
|
for overload in packet:
|
|
yield getattr(packet, overload)
|
|
|
|
|
|
@contextmanager
|
|
def suspend_functionalization():
|
|
f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
|
|
torch._C.DispatchKey.Functionalize
|
|
)
|
|
f_rv = torch._C._functionalization_reapply_views_tls()
|
|
if f_tls:
|
|
torch._disable_functionalization()
|
|
try:
|
|
yield
|
|
finally:
|
|
if f_tls:
|
|
torch._enable_functionalization(reapply_views=f_rv)
|
|
|
|
|
|
def check_tensor_metadata_matches(nv, rv, desc):
|
|
assert callable(desc)
|
|
assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
|
|
assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
|
|
same_strides, idx = torch._prims_common.check_significant_strides(
|
|
nv, rv, only_cuda=False
|
|
)
|
|
assert (
|
|
same_strides
|
|
), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
|
|
|
|
|
|
def check_metadata_matches(n, r, desc):
|
|
assert callable(desc)
|
|
n_vals, n_spec = pytree.tree_flatten(n)
|
|
r_vals, r_spec = pytree.tree_flatten(r)
|
|
# TODO: test the specs match; empirically sometimes we have a tuple
|
|
# on one side and a list on the other
|
|
assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
|
|
for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
|
|
if not isinstance(rv, torch.Tensor):
|
|
continue
|
|
check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
|
|
|
|
|
|
class Lit:
|
|
def __init__(self, s):
|
|
self.s = s
|
|
|
|
def __repr__(self):
|
|
return self.s
|
|
|
|
|
|
def _fmt(a: object) -> object:
|
|
if isinstance(a, torch.Tensor):
|
|
return Lit(
|
|
f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
|
|
)
|
|
else:
|
|
return a
|
|
|
|
|
|
def make_crossref_functionalize(op, final_key):
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
|
|
# This case is pretty weird, suppress it for now
|
|
if op == torch.ops.aten.lift_fresh.default:
|
|
return final_key
|
|
|
|
def handler(*args, **kwargs):
|
|
fake_mode = FakeTensorMode()
|
|
|
|
def fakeify_defun(t):
|
|
if isinstance(t, torch.Tensor):
|
|
if torch._is_functional_tensor(t):
|
|
r = torch._from_functional_tensor(t)
|
|
# NB: This assumes that the inner tensor sizes/strides match
|
|
# the outer tensor sizes/strides. This doesn't necessarily have to
|
|
# be the case, see discussion at
|
|
# https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
|
|
assert t.size() == r.size()
|
|
assert t.stride() == r.stride()
|
|
else:
|
|
r = t
|
|
# TODO: suppress guards
|
|
return fake_mode.from_tensor(r)
|
|
return t
|
|
|
|
def maybe_detach(t):
|
|
if isinstance(t, torch.Tensor):
|
|
return t.detach()
|
|
else:
|
|
return t
|
|
|
|
# TODO: This probably does the wrong thing if you're running other
|
|
# substantive modes with the normal op outside here
|
|
with torch.utils._python_dispatch._disable_current_modes(), suspend_functionalization():
|
|
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
|
|
orig_f_args, orig_f_kwargs = pytree.tree_map(
|
|
maybe_detach, (f_args, f_kwargs)
|
|
)
|
|
with fake_mode:
|
|
f_r = op(*f_args, **f_kwargs)
|
|
r = op._op_dk(final_key, *args, **kwargs)
|
|
|
|
def desc():
|
|
fmt_args = ", ".join(
|
|
itertools.chain(
|
|
(repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
|
|
(
|
|
f"{k}={pytree.tree_map(_fmt, v)}"
|
|
for k, v in orig_f_kwargs.items()
|
|
),
|
|
)
|
|
)
|
|
return f"{op}({fmt_args})"
|
|
|
|
check_metadata_matches(f_r, r, desc)
|
|
return r
|
|
|
|
return handler
|
|
|
|
|
|
# NB: enabling this is slow, don't do it in a hot loop. This is purely
|
|
# for debugging purposes.
|
|
@contextmanager
|
|
def enable_crossref_functionalize():
|
|
for op in all_py_loaded_overloads():
|
|
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|
|
try:
|
|
with enable_python_dispatcher(), unittest.mock.patch(
|
|
"torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True
|
|
):
|
|
yield
|
|
finally:
|
|
for op in all_py_loaded_overloads():
|
|
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|