You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
566 lines
22 KiB
566 lines
22 KiB
5 months ago
|
# mypy: ignore-errors
|
||
|
|
||
|
import enum
|
||
|
import dis
|
||
|
import copy
|
||
|
import sys
|
||
|
import torch
|
||
|
import inspect
|
||
|
import operator
|
||
|
import traceback
|
||
|
import collections
|
||
|
|
||
|
from dataclasses import is_dataclass, fields
|
||
|
|
||
|
|
||
|
from .graph import magic_methods, reflectable_magic_methods, Graph
|
||
|
from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable
|
||
|
from .node import Target, Node, Argument, base_types, map_aggregate
|
||
|
from ._compatibility import compatibility
|
||
|
from .operator_schemas import check_for_mutable_operation
|
||
|
import torch.fx.traceback as fx_traceback
|
||
|
|
||
|
__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
|
||
|
'Proxy', 'Attribute', 'ParameterProxy', 'Scope',
|
||
|
'ScopeContextManager']
|
||
|
|
||
|
|
||
|
@compatibility(is_backward_compatible=False)
|
||
|
class Scope:
|
||
|
""" Scope object that records the module path and the module type
|
||
|
of a module. Scope is used to track the information of the module
|
||
|
that contains a Node in a Graph of GraphModule. For example::
|
||
|
|
||
|
class Sub(torch.nn.Module):
|
||
|
def forward(self, x):
|
||
|
# This will be a call_method Node in GraphModule,
|
||
|
# scope for this would be (module_path="sub", module_type=Sub)
|
||
|
return x.transpose(1, 2)
|
||
|
|
||
|
class M(torch.nn.Module):
|
||
|
def __init__(self):
|
||
|
self.sub = Sub()
|
||
|
|
||
|
def forward(self, x):
|
||
|
# This will be a call_method Node as well,
|
||
|
# scope for this would be (module_path="", None)
|
||
|
x = x.transpose(1, 2)
|
||
|
x = self.sub(x)
|
||
|
return x
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, module_path: str, module_type: Any):
|
||
|
super().__init__()
|
||
|
self.module_path = module_path
|
||
|
self.module_type = module_type
|
||
|
|
||
|
|
||
|
@compatibility(is_backward_compatible=False)
|
||
|
class ScopeContextManager:
|
||
|
""" A context manager to track the Scope of Node during symbolic tracing.
|
||
|
When entering a forward function of a Module, we'll update the scope information of
|
||
|
the current module, and when we exit, we'll restore the previous scope information.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
scope: Scope,
|
||
|
current_scope: Scope,
|
||
|
):
|
||
|
super().__init__()
|
||
|
# Keep a copy of prev scope to restore on exit
|
||
|
self._prev_scope = copy.copy(scope)
|
||
|
# Update scope to current scope
|
||
|
scope.module_path = current_scope.module_path
|
||
|
scope.module_type = current_scope.module_type
|
||
|
# Save a reference so we can restore it
|
||
|
self._scope = scope
|
||
|
|
||
|
def __enter__(self):
|
||
|
return self._scope
|
||
|
|
||
|
def __exit__(self, *args):
|
||
|
self._scope.module_path = self._prev_scope.module_path
|
||
|
self._scope.module_type = self._prev_scope.module_type
|
||
|
return
|
||
|
|
||
|
|
||
|
_COPY_META_FIELDS = ["nn_module_stack", "source_fn_stack", "original_aten", "recompute", "from_node", "quantization_tag"]
|
||
|
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
class TracerBase:
|
||
|
graph: Graph
|
||
|
record_stack_traces : bool = False
|
||
|
# Feature flag for mutable schema checking
|
||
|
# Enableby default in 1.12
|
||
|
check_mutable_operations : bool = False
|
||
|
# Feature flag for assert tracing
|
||
|
trace_asserts : bool = False
|
||
|
# Feature flag for proxying accesses to buffer values
|
||
|
proxy_buffer_attributes : bool = False
|
||
|
|
||
|
# Name of the function to be traced. It will only be used when
|
||
|
# ``root`` is an instance of ``nn.Module``
|
||
|
traced_func_name: str = "forward"
|
||
|
|
||
|
# Maps the containing module's name to the operator name
|
||
|
scope : Scope
|
||
|
|
||
|
# Records the module call stack
|
||
|
module_stack: OrderedDict[str, Tuple[str, Any]]
|
||
|
|
||
|
# Mapping of node name to module scope
|
||
|
node_name_to_scope: Dict[str, Tuple[str, type]]
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def create_node(self, kind : str, target : Target,
|
||
|
args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
|
||
|
type_expr : Optional[Any] = None) -> Node:
|
||
|
"""
|
||
|
Inserts a graph node given target, args, kwargs, and name.
|
||
|
|
||
|
This method can be overridden to do extra checking, validation, or
|
||
|
modification of values used in node creation. For example, one might
|
||
|
want to disallow in-place operations from being recorded.
|
||
|
"""
|
||
|
if kind == 'call_function' and self.check_mutable_operations:
|
||
|
check_for_mutable_operation(target, args, kwargs)
|
||
|
|
||
|
node = self.graph.create_node(kind, target, args, kwargs, name, type_expr)
|
||
|
# TODO node_name_to_scope will be depreciated in favor of
|
||
|
# node.meta['nn_module_stack']
|
||
|
self.node_name_to_scope[node.name] = (
|
||
|
self.scope.module_path,
|
||
|
self.scope.module_type,
|
||
|
)
|
||
|
# Optionally set stack trace on the created Node for debugging purposes
|
||
|
if fx_traceback.has_preserved_node_meta():
|
||
|
current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
|
||
|
|
||
|
stack_trace = current_meta.get("stack_trace")
|
||
|
if stack_trace:
|
||
|
node.stack_trace = stack_trace
|
||
|
# Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta
|
||
|
# If other meta fields are needed, they can be added here
|
||
|
for field in _COPY_META_FIELDS:
|
||
|
if field in current_meta:
|
||
|
node.meta[field] = copy.copy(current_meta[field])
|
||
|
|
||
|
# Here we decrement to account for the sequence_nr having
|
||
|
# just been incremented while tracing this lowered aten op.
|
||
|
new_seq_nr = torch.autograd._get_sequence_nr() - 1
|
||
|
# The sequence_nr increments every time a new autograd Node
|
||
|
# is created. During the FWD pass we store the sequence_nr
|
||
|
# corresponding to the last autograd Node created on this fx
|
||
|
# node's meta. A single aten op can create multiple autograd
|
||
|
# nodes as is the case with in-place foreach ops. During the
|
||
|
# BWD pass we retrieve the sequence_nr stored on the current
|
||
|
# executing autograd Node. See NOTE [ Sequence Number ].
|
||
|
if current_meta.get("in_grad_fn", 0) > 0:
|
||
|
new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
|
||
|
node.meta["seq_nr"] = new_seq_nr
|
||
|
|
||
|
elif self.module_stack:
|
||
|
node.meta['nn_module_stack'] = copy.copy(self.module_stack)
|
||
|
return node
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def proxy(self, node: Node) -> 'Proxy':
|
||
|
return Proxy(node, self)
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
|
||
|
name: Optional[str] = None, type_expr : Optional[Any] = None,
|
||
|
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
|
||
|
'''
|
||
|
Create a Node from the given arguments, then return the Node
|
||
|
wrapped in a Proxy object.
|
||
|
|
||
|
If kind = 'placeholder', then we're creating a Node that
|
||
|
represents the parameter of a function. If we need to encode
|
||
|
a default parameter, we use the ``args`` tuple. ``args`` is
|
||
|
otherwise empty for ``placeholder`` Nodes.
|
||
|
'''
|
||
|
|
||
|
args_ = self.create_arg(args)
|
||
|
kwargs_ = self.create_arg(kwargs)
|
||
|
assert isinstance(args_, tuple)
|
||
|
assert isinstance(kwargs_, dict)
|
||
|
|
||
|
node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
|
||
|
|
||
|
if not proxy_factory_fn:
|
||
|
proxy = self.proxy(node)
|
||
|
else:
|
||
|
proxy = proxy_factory_fn(node)
|
||
|
|
||
|
if self.record_stack_traces and not proxy.node.stack_trace:
|
||
|
user_frame = self._find_user_frame()
|
||
|
if user_frame:
|
||
|
summary = traceback.extract_stack(user_frame)
|
||
|
tb_lines = summary.format()
|
||
|
# stack_trace would have innermost frame at the bottom
|
||
|
proxy.node.stack_trace = ''.join(tb_lines)
|
||
|
|
||
|
return proxy
|
||
|
|
||
|
def _find_user_frame(self):
|
||
|
"""
|
||
|
Find the Python stack frame executing the user code during
|
||
|
symbolic tracing.
|
||
|
"""
|
||
|
# We have to do a little dance here. Basically, walk up the callstack and
|
||
|
# record the first frame not in the pytorch source. This is the frame executing
|
||
|
# the user code during tracing.
|
||
|
frame = inspect.currentframe()
|
||
|
|
||
|
pt_files = ['torch/fx/proxy.py',
|
||
|
'torch/fx/_symbolic_trace.py',
|
||
|
'torch/fx/experimental/proxy_tensor.py',
|
||
|
'torch/_ops.py',
|
||
|
'torch/_tensor.py',
|
||
|
'torch/utils/_python_dispatch.py',
|
||
|
'torch/_prims_common/wrappers.py',
|
||
|
'torch/_refs/__init__.py',
|
||
|
'torch/_refs/nn/functional/__init__.py',
|
||
|
'torch/utils/_stats.py',
|
||
|
]
|
||
|
while frame:
|
||
|
frame = frame.f_back
|
||
|
if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files):
|
||
|
break
|
||
|
|
||
|
if not frame:
|
||
|
return None
|
||
|
|
||
|
return frame
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def create_arg(self, a: Any) -> Argument:
|
||
|
"""
|
||
|
A method that lowers the objects seen as arguments during symbolic evaluation
|
||
|
into Argument types that can be stored in IR.
|
||
|
|
||
|
Can be override to support more trace-specific types.
|
||
|
"""
|
||
|
if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
|
||
|
return a.__fx_create_arg__(self)
|
||
|
# aggregates
|
||
|
elif isinstance(a, tuple) and hasattr(a, '_fields'):
|
||
|
# NamedTuple constructors don't seem to like getting a generator
|
||
|
# expression as an argument to their constructor, so build this
|
||
|
# intermediate tuple and unpack it into the NamedTuple constructor
|
||
|
args = tuple(self.create_arg(elem) for elem in a)
|
||
|
return type(a)(*args) # type: ignore[arg-type]
|
||
|
elif isinstance(a, (tuple, list)):
|
||
|
return type(a)(self.create_arg(elem) for elem in a)
|
||
|
elif isinstance(a, dict):
|
||
|
r = {}
|
||
|
for k, v in a.items():
|
||
|
# Check for invalid dict keys. We do not want a Proxy to appear
|
||
|
# anywhere within the key. Since keys can be collection types,
|
||
|
# we iterate through the key with map_aggregate
|
||
|
k = self.create_arg(k)
|
||
|
|
||
|
def no_node(arg):
|
||
|
if isinstance(arg, Node):
|
||
|
raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
|
||
|
f"Node. Got key: {k}")
|
||
|
map_aggregate(k, no_node)
|
||
|
|
||
|
r[k] = self.create_arg(v)
|
||
|
return r
|
||
|
elif isinstance(a, slice):
|
||
|
return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
|
||
|
|
||
|
elif isinstance(a, range):
|
||
|
return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
|
||
|
|
||
|
elif isinstance(a, torch._ops.OpOverload):
|
||
|
return a
|
||
|
|
||
|
if isinstance(a, Proxy):
|
||
|
# base case: we unwrap the Proxy object
|
||
|
return a.node
|
||
|
|
||
|
if is_dataclass(a):
|
||
|
kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)}
|
||
|
return self.create_node("call_function", a.__class__, (), kwargs)
|
||
|
|
||
|
elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...:
|
||
|
return a
|
||
|
raise NotImplementedError(f"argument of type: {type(a)}")
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def to_bool(self, obj: 'Proxy') -> bool:
|
||
|
"""Called when a proxy object is being converted to a boolean, such as
|
||
|
when used in control flow. Normally we don't know what to do because
|
||
|
we don't know the value of the proxy, but a custom tracer can attach more
|
||
|
information to the graph node using create_node and can choose to return a value.
|
||
|
"""
|
||
|
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def iter(self, obj: 'Proxy') -> Iterator:
|
||
|
"""Called when a proxy object is being iterated over, such as
|
||
|
when used in control flow. Normally we don't know what to do because
|
||
|
we don't know the value of the proxy, but a custom tracer can attach more
|
||
|
information to the graph node using create_node and can choose to return an iterator.
|
||
|
"""
|
||
|
raise TraceError('Proxy object cannot be iterated. This can be '
|
||
|
'attempted when the Proxy is used in a loop or'
|
||
|
' as a *args or **kwargs function argument. '
|
||
|
'See the torch.fx docs on pytorch.org for a '
|
||
|
'more detailed explanation of what types of '
|
||
|
'control flow can be traced, and check out the'
|
||
|
' Proxy docstring for help troubleshooting '
|
||
|
'Proxy iteration errors')
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def keys(self, obj: 'Proxy') -> Any:
|
||
|
"""Called when a proxy object is has the keys() method called.
|
||
|
This is what happens when ** is called on a proxy. This should return an
|
||
|
iterator it ** is suppose to work in your custom tracer.
|
||
|
"""
|
||
|
return Attribute(obj, 'keys')()
|
||
|
|
||
|
|
||
|
# used in Proxy object when just appending to the graph while not tracing.
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
class GraphAppendingTracer(TracerBase):
|
||
|
def __init__(self, graph: Graph):
|
||
|
super().__init__()
|
||
|
self.graph = graph
|
||
|
self.scope = Scope("", None)
|
||
|
self.module_stack = collections.OrderedDict()
|
||
|
self.node_name_to_scope = {}
|
||
|
|
||
|
@compatibility(is_backward_compatible=False)
|
||
|
def assert_fn(x):
|
||
|
assert x
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
class TraceError(ValueError):
|
||
|
pass
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
class Proxy:
|
||
|
"""
|
||
|
``Proxy`` objects are ``Node`` wrappers that flow through the
|
||
|
program during symbolic tracing and record all the operations
|
||
|
(``torch`` function calls, method calls, operators) that they touch
|
||
|
into the growing FX Graph.
|
||
|
|
||
|
If you're doing graph transforms, you can wrap your own ``Proxy``
|
||
|
method around a raw ``Node`` so that you can use the overloaded
|
||
|
operators to add additional things to a ``Graph``.
|
||
|
|
||
|
``Proxy`` objects cannot be iterated. In other words, the symbolic
|
||
|
tracer will throw an error if a ``Proxy`` is used in a loop or as
|
||
|
an ``*args``/``**kwargs`` function argument.
|
||
|
|
||
|
There are two main ways around this:
|
||
|
1. Factor out the untraceable logic into a top-level function and
|
||
|
use ``fx.wrap`` on it.
|
||
|
2. If the control flow is static (i.e. the loop trip count is
|
||
|
based on some hyperparameter), the code can be kept in its original
|
||
|
position and refactored into something like::
|
||
|
|
||
|
for i in range(self.some_hyperparameter):
|
||
|
indexed_item = proxied_value[i]
|
||
|
|
||
|
For a more detailed description into the Proxy internals, check out
|
||
|
the "Proxy" section in `torch/fx/OVERVIEW.md`
|
||
|
"""
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
|
||
|
if tracer is None:
|
||
|
# This allows you to create a Proxy object around a raw Node
|
||
|
tracer = GraphAppendingTracer(node.graph)
|
||
|
self.tracer = tracer
|
||
|
self.node = node
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return f'Proxy({self.node.name})'
|
||
|
|
||
|
def __getattr__(self, k) -> 'Attribute':
|
||
|
# note: not added to the graph yet, if this is a method call
|
||
|
# we peephole optimize to the method invocation
|
||
|
return Attribute(self, k)
|
||
|
|
||
|
def __call__(self, *args, **kwargs) -> 'Proxy':
|
||
|
return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs)
|
||
|
|
||
|
def __iter__(self) -> Iterator['Proxy']:
|
||
|
frame = inspect.currentframe()
|
||
|
assert frame is not None
|
||
|
calling_frame = frame.f_back
|
||
|
assert calling_frame is not None
|
||
|
inst_list = list(dis.get_instructions(calling_frame.f_code))
|
||
|
if sys.version_info >= (3, 11):
|
||
|
from bisect import bisect_left
|
||
|
inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset)
|
||
|
else:
|
||
|
inst_idx = calling_frame.f_lasti // 2
|
||
|
inst = inst_list[inst_idx]
|
||
|
if inst.opname == 'UNPACK_SEQUENCE':
|
||
|
return (self[i] for i in range(inst.argval)) # type: ignore[index]
|
||
|
|
||
|
return self.tracer.iter(self)
|
||
|
|
||
|
def __abs__(self):
|
||
|
return self.tracer.create_proxy('call_function', operator.abs, (self,), {})
|
||
|
|
||
|
def __bool__(self) -> bool:
|
||
|
if self.tracer.trace_asserts:
|
||
|
# check if this boolean is used in an assertion, bytecode pattern for assertions
|
||
|
# is pretty stable for Python 3.7--3.9
|
||
|
frame = inspect.currentframe()
|
||
|
assert frame is not None
|
||
|
calling_frame = frame.f_back
|
||
|
assert calling_frame is not None
|
||
|
insts = list(dis.get_instructions(calling_frame.f_code))
|
||
|
if sys.version_info >= (3, 11):
|
||
|
from bisect import bisect_left
|
||
|
cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset)
|
||
|
else:
|
||
|
cur = calling_frame.f_lasti // 2
|
||
|
inst = insts[cur]
|
||
|
|
||
|
if inst.opname == 'POP_JUMP_IF_TRUE':
|
||
|
first = insts[cur + 1]
|
||
|
assert inst.arg is not None
|
||
|
last = insts[inst.arg // 2 - 1]
|
||
|
starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError'
|
||
|
or first.opname == 'LOAD_ASSERTION_ERROR')
|
||
|
if starts_with_assert and last.opname == 'RAISE_VARARGS':
|
||
|
self.tracer.create_proxy('call_function', assert_fn, (self,), {})
|
||
|
return True
|
||
|
|
||
|
return self.tracer.to_bool(self)
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def keys(self):
|
||
|
return self.tracer.keys(self)
|
||
|
|
||
|
def __len__(self):
|
||
|
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
|
||
|
"this call to be recorded, please call torch.fx.wrap('len') at "
|
||
|
"module scope")
|
||
|
|
||
|
@classmethod
|
||
|
def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
|
||
|
args = args if args else ()
|
||
|
kwargs = kwargs if kwargs else {}
|
||
|
|
||
|
tracers : Dict[Any, None] = {}
|
||
|
|
||
|
def find_tracer(a):
|
||
|
if isinstance(a, cls):
|
||
|
tracers[a.tracer] = None
|
||
|
torch.fx.node.map_aggregate(args, find_tracer)
|
||
|
torch.fx.node.map_aggregate(kwargs, find_tracer)
|
||
|
|
||
|
if len(tracers) > 1:
|
||
|
raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while '
|
||
|
f'trying to trace operations {orig_method}')
|
||
|
tracer = next(iter(tracers.keys()))
|
||
|
|
||
|
if isinstance(orig_method, torch._C.ScriptMethod):
|
||
|
args = (orig_method.owner,) + args
|
||
|
return tracer.create_proxy('call_method', orig_method.name, args, kwargs)
|
||
|
if torch.overrides.is_tensor_method_or_property(orig_method):
|
||
|
return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs)
|
||
|
else:
|
||
|
if isinstance(orig_method, torch._ops.HigherOrderOperator):
|
||
|
# TODO: Define how to symbolically trace HigherOrderOperators
|
||
|
raise RuntimeError("Unable to symbolically trace HigherOrderOperators")
|
||
|
return tracer.create_proxy('call_function', orig_method, args, kwargs,
|
||
|
name=tracer.graph._target_to_str(orig_method.__name__))
|
||
|
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
class Attribute(Proxy):
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def __init__(self, root: Proxy, attr: str):
|
||
|
self.root = root
|
||
|
self.attr = attr
|
||
|
self.tracer = root.tracer
|
||
|
self._node: Optional[Node] = None
|
||
|
|
||
|
@property
|
||
|
def node(self):
|
||
|
# the node for attributes is added lazily, since most will just be method calls
|
||
|
# which do not rely on the getitem call
|
||
|
if self._node is None:
|
||
|
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
||
|
return self._node
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
||
|
|
||
|
|
||
|
@compatibility(is_backward_compatible=False)
|
||
|
class ParameterProxy(Proxy):
|
||
|
"""
|
||
|
A special proxy which lets "shape", "size", "dim", and a few other
|
||
|
attribute accesses pass through to the underlying module parameter object,
|
||
|
so that conditional tests on these attributes will not throw exception during tracing
|
||
|
"""
|
||
|
def __init__(self, tracer: TracerBase, node: Node, name, param):
|
||
|
super().__init__(node, tracer)
|
||
|
assert isinstance(param, torch.nn.Parameter)
|
||
|
self.param = param
|
||
|
self.name = name
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return f'ParameterProxy({self.name})'
|
||
|
|
||
|
@property
|
||
|
def shape(self):
|
||
|
return self.param.shape
|
||
|
|
||
|
def size(self):
|
||
|
return self.param.size()
|
||
|
|
||
|
def dim(self):
|
||
|
return self.param.dim()
|
||
|
|
||
|
@property
|
||
|
def ndim(self):
|
||
|
return self.param.ndim
|
||
|
|
||
|
def numel(self):
|
||
|
return self.param.numel()
|
||
|
|
||
|
def nelement(self):
|
||
|
return self.param.nelement()
|
||
|
|
||
|
|
||
|
for method in magic_methods:
|
||
|
def _scope(method):
|
||
|
def impl(*args, **kwargs):
|
||
|
tracer = args[0].tracer
|
||
|
target = getattr(operator, method)
|
||
|
return tracer.create_proxy('call_function', target, args, kwargs)
|
||
|
impl.__name__ = method
|
||
|
as_magic = f'__{method.strip("_")}__'
|
||
|
setattr(Proxy, as_magic, impl)
|
||
|
_scope(method)
|
||
|
|
||
|
def _define_reflectable(orig_method_name):
|
||
|
method_name = f'__r{orig_method_name.strip("_")}__'
|
||
|
|
||
|
def impl(self, rhs):
|
||
|
target = getattr(operator, orig_method_name)
|
||
|
return self.tracer.create_proxy('call_function', target, (rhs, self), {})
|
||
|
impl.__name__ = method_name
|
||
|
impl.__qualname__ = method_name
|
||
|
setattr(Proxy, method_name, impl)
|
||
|
|
||
|
for orig_method_name in reflectable_magic_methods:
|
||
|
_define_reflectable(orig_method_name)
|