You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
89 lines
2.9 KiB
89 lines
2.9 KiB
"""Functionality for Python <-> C++ frontend inter-op."""
|
|
|
|
from torch import nn
|
|
|
|
|
|
class OrderedDictWrapper:
|
|
"""A wrapper around a C++ OrderedDict.
|
|
|
|
It dynamically evaluates the OrderedDict getter on a bound C++ module, such
|
|
that new changes on the C++ side are picked up. Otherwise accessing e.g.
|
|
``cpp_module._parameters`` just once would get a frozen copy of the parameters
|
|
at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__``
|
|
so using properties does not work.
|
|
"""
|
|
|
|
def __init__(self, cpp_module, attr):
|
|
self.cpp_module = cpp_module
|
|
self.attr = attr
|
|
|
|
@property
|
|
def cpp_dict(self):
|
|
return getattr(self.cpp_module, self.attr)
|
|
|
|
# Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
|
|
# must manually override them.
|
|
|
|
def items(self):
|
|
return self.cpp_dict.items()
|
|
|
|
def keys(self):
|
|
return self.cpp_dict.keys()
|
|
|
|
def values(self):
|
|
return self.cpp_dict.values()
|
|
|
|
def __iter__(self):
|
|
return self.cpp_dict.__iter__()
|
|
|
|
def __len__(self):
|
|
return self.cpp_dict.__len__()
|
|
|
|
def __contains__(self, key):
|
|
return self.cpp_dict.__contains__(key)
|
|
|
|
def __getitem__(self, key):
|
|
return self.cpp_dict.__getitem__(key)
|
|
|
|
|
|
class ModuleWrapper(nn.Module):
|
|
"""A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access."""
|
|
|
|
def __init__(self, cpp_module):
|
|
# Assign before the super class constructor so ``self.training`` can be
|
|
# assigned to in the super class constructor.
|
|
self.cpp_module = cpp_module
|
|
super().__init__()
|
|
self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment]
|
|
self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment]
|
|
self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment]
|
|
for attr in dir(cpp_module):
|
|
# Skip magic methods and the three attributes above.
|
|
if not attr.startswith("_"):
|
|
setattr(self, attr, getattr(self.cpp_module, attr))
|
|
|
|
def _apply(self, fn, recurse=True):
|
|
for param in self.parameters():
|
|
# Tensors stored in modules are graph leaves, and we don't
|
|
# want to create copy nodes, so we have to unpack the data.
|
|
param.data = fn(param.data)
|
|
if param._grad is not None:
|
|
param._grad.data = fn(param._grad.data)
|
|
|
|
for buf in self.buffers():
|
|
buf.data = fn(buf.data)
|
|
|
|
return self
|
|
|
|
# nn.Module defines training as a boolean
|
|
@property # type: ignore[override]
|
|
def training(self):
|
|
return self.cpp_module.training
|
|
|
|
@training.setter
|
|
def training(self, mode):
|
|
self.cpp_module.train(mode)
|
|
|
|
def __repr__(self):
|
|
return self.cpp_module.__repr__()
|