You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

163 lines
4.7 KiB

class Module:
"""
Modules form a tree that store parameters and other
submodules. They make up the basis of neural network stacks.
Attributes:
_modules (dict of name x :class:`Module`): Storage of the child modules
_parameters (dict of name x :class:`Parameter`): Storage of the module's parameters
training (bool): Whether the module is in training mode or evaluation mode
"""
def __init__(self):
self._modules = {}
self._parameters = {}
self.training = True
def modules(self):
"Return the direct child modules of this module."
return self.__dict__["_modules"].values()
def train(self):
"Set the mode of this module and all descendent modules to `train`."
def _train(module):
module.training = True
for m in module.modules():
_train(m)
_train(self)
def eval(self):
"Set the mode of this module and all descendent modules to `eval`."
def _eval(module):
module.training = False
for m in module.modules():
_eval(m)
_eval(self)
def named_parameters(self):
"""
Collect all the parameters of this module and its descendents.
Returns:
list of pairs: Contains the name and :class:`Parameter` of each ancestor parameter.
"""
def _named_parameters(module, prefix=""):
for name, param in module._parameters.items():
yield prefix + name, param
for name, module in module._modules.items():
yield from _named_parameters(module, prefix + name + ".")
return list(_named_parameters(self))
def parameters(self):
"Enumerate over all the parameters of this module and its descendents."
# def _parameters(module):
# for param in module._parameters.values():
# yield param
# for module in module._modules.values():
# yield from _parameters(module)
return [param for _, param in self.named_parameters()]
def add_parameter(self, k, v):
"""
Manually add a parameter. Useful helper for scalar parameters.
Args:
k (str): Local name of the parameter.
v (value): Value for the parameter.
Returns:
Parameter: Newly created parameter.
"""
val = Parameter(v, k)
self.__dict__["_parameters"][k] = val
return val
def __setattr__(self, key, val):
if isinstance(val, Parameter):
self.__dict__["_parameters"][key] = val
elif isinstance(val, Module):
self.__dict__["_modules"][key] = val
else:
super().__setattr__(key, val)
def __getattr__(self, key):
if key in self.__dict__["_parameters"]:
return self.__dict__["_parameters"][key]
if key in self.__dict__["_modules"]:
return self.__dict__["_modules"][key]
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(self):
assert False, "Not Implemented"
def __repr__(self):
def _addindent(s_, numSpaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
lines = child_lines
main_str = self.__class__.__name__ + "("
if lines:
# simple one-liner info, which most builtin Modules will use
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
class Parameter:
"""
A Parameter is a special container stored in a :class:`Module`.
It is designed to hold a :class:`Variable`, but we allow it to hold
any value for testing.
"""
def __init__(self, x=None, name=None):
self.value = x
self.name = name
if hasattr(x, "requires_grad_"):
self.value.requires_grad_(True)
if self.name:
self.value.name = self.name
def update(self, x):
"Update the parameter value."
self.value = x
if hasattr(x, "requires_grad_"):
self.value.requires_grad_(True)
if self.name:
self.value.name = self.name
def __repr__(self):
return repr(self.value)
def __str__(self):
return str(self.value)