You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							187 lines
						
					
					
						
							6.6 KiB
						
					
					
				
			
		
		
	
	
							187 lines
						
					
					
						
							6.6 KiB
						
					
					
				| import torch
 | |
| from ..modules import Module
 | |
| from . import comm
 | |
| from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union, cast
 | |
| from torch._utils import _get_device_index
 | |
| 
 | |
| from collections import OrderedDict
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     import torch.jit
 | |
|     import torch.jit._state
 | |
| 
 | |
| __all__ = ['replicate']
 | |
| 
 | |
| def _is_script_module(module: Module) -> bool:
 | |
|     import torch.jit
 | |
|     return isinstance(module, torch.jit.ScriptModule)
 | |
| 
 | |
| 
 | |
| def _is_script_method(module: Module) -> bool:
 | |
|     import torch.jit
 | |
|     return isinstance(module, torch._C.ScriptMethod)
 | |
| 
 | |
| 
 | |
| def _init_script_module() -> "torch.jit.ScriptModule":
 | |
|     import torch.jit
 | |
|     return torch.jit.ScriptModule()
 | |
| 
 | |
| 
 | |
| def _is_jit_enabled() -> "torch.jit._state.EnabledProxy":
 | |
|     import torch.jit._state
 | |
|     return torch.jit._state._enabled
 | |
| 
 | |
| 
 | |
| # Check if we can safely replicate the module.
 | |
| # there are two types of module:
 | |
| # 1. python modules
 | |
| # 2. ScriptModule
 | |
| #
 | |
| # currently a module cannot be replicated properly if the descendants of
 | |
| # any ScriptModule contains python module (type 1 above)
 | |
| def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool:
 | |
| 
 | |
|     # module.modules() contains module itself as the first element
 | |
|     def descendant_modules(module: Module) -> Iterator[Module]:
 | |
|         gen = module.modules()
 | |
|         next(gen)
 | |
|         return gen
 | |
| 
 | |
|     if not _is_jit_enabled():
 | |
|         return True
 | |
|     if memo is None:
 | |
|         memo = set()
 | |
| 
 | |
|     # memoize visited modules
 | |
|     memo.add(module)
 | |
|     if _is_script_module(module):
 | |
|         memo.update(descendant_modules(module))
 | |
|         return all(_is_script_module(descendant) for
 | |
|                    descendant in descendant_modules(module))
 | |
| 
 | |
|     for child in module.children():
 | |
|         # since any unreplicatable module will cause the check to return
 | |
|         # False early, visited modules here can be safely ignored.
 | |
|         if child in memo:
 | |
|             continue
 | |
|         if not _replicatable_module(child, memo):
 | |
|             return False
 | |
| 
 | |
|     return True
 | |
| 
 | |
| def _broadcast_coalesced_reshape(
 | |
|     tensors: Sequence[torch.Tensor],
 | |
|     devices: Sequence[Union[int, torch.device]],
 | |
|     detach: bool = False,
 | |
| ) -> List[List[torch.Tensor]]:
 | |
|     from ._functions import Broadcast
 | |
|     if detach:
 | |
|         return comm.broadcast_coalesced(tensors, devices)
 | |
|     else:
 | |
|         # Use the autograd function to broadcast if not detach
 | |
|         if len(tensors) > 0:
 | |
|             tensor_copies = Broadcast.apply(devices, *tensors)
 | |
|             return [tensor_copies[i:i + len(tensors)]
 | |
|                     for i in range(0, len(tensor_copies), len(tensors))]
 | |
|         else:
 | |
|             return []
 | |
| 
 | |
| 
 | |
| T = TypeVar("T", bound=Module)
 | |
| 
 | |
| 
 | |
| def replicate(
 | |
|     network: T,
 | |
|     devices: Sequence[Union[int, torch.device]],
 | |
|     detach: bool = False,
 | |
| ) -> List[T]:
 | |
|     if not _replicatable_module(network):
 | |
|         raise RuntimeError("Cannot replicate network where python modules are "
 | |
|                            "childrens of ScriptModule")
 | |
| 
 | |
|     if not devices:
 | |
|         return []
 | |
| 
 | |
|     devices = [_get_device_index(x, True) for x in devices]
 | |
|     num_replicas = len(devices)
 | |
| 
 | |
|     params = list(network.parameters())
 | |
|     param_indices = {param: idx for idx, param in enumerate(params)}
 | |
|     param_copies = _broadcast_coalesced_reshape(params, devices, detach)
 | |
| 
 | |
|     buffers = list(network.buffers())
 | |
|     buffers_rg: List[torch.Tensor] = []
 | |
|     buffers_not_rg: List[torch.Tensor] = []
 | |
|     for buf in buffers:
 | |
|         if buf.requires_grad and not detach:
 | |
|             buffers_rg.append(buf)
 | |
|         else:
 | |
|             buffers_not_rg.append(buf)
 | |
| 
 | |
|     buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
 | |
|     buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}
 | |
| 
 | |
|     buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
 | |
|     buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True)
 | |
| 
 | |
|     modules = list(network.modules())
 | |
|     module_copies: List[List[Module]] = [[] for _ in devices]
 | |
|     module_indices: Dict[Module, int] = {}
 | |
| 
 | |
|     for i, module in enumerate(modules):
 | |
|         module_indices[module] = i
 | |
|         for j in range(num_replicas):
 | |
|             replica = module._replicate_for_data_parallel()
 | |
|             # This is a temporary fix for DDP. DDP needs to access the
 | |
|             # replicated model parameters. It used to do so through
 | |
|             # `mode.parameters()`. The fix added in #33907 for DP stops the
 | |
|             # `parameters()` API from exposing the replicated parameters.
 | |
|             # Hence, we add a `_former_parameters` dict here to support DDP.
 | |
|             replica._former_parameters = OrderedDict()
 | |
| 
 | |
|             module_copies[j].append(replica)
 | |
| 
 | |
|     for i, module in enumerate(modules):
 | |
|         for key, child in module._modules.items():
 | |
|             if child is None:
 | |
|                 for j in range(num_replicas):
 | |
|                     replica = module_copies[j][i]
 | |
|                     replica._modules[key] = None
 | |
|             else:
 | |
|                 module_idx = module_indices[child]
 | |
|                 for j in range(num_replicas):
 | |
|                     replica = module_copies[j][i]
 | |
|                     setattr(replica, key, module_copies[j][module_idx])
 | |
|         for key, param in module._parameters.items():
 | |
|             if param is None:
 | |
|                 for j in range(num_replicas):
 | |
|                     replica = module_copies[j][i]
 | |
|                     replica._parameters[key] = None
 | |
|             else:
 | |
|                 param_idx = param_indices[param]
 | |
|                 for j in range(num_replicas):
 | |
|                     replica = module_copies[j][i]
 | |
|                     param_copy = param_copies[j][param_idx]
 | |
|                     # parameters in replicas are no longer leaves,
 | |
|                     # so setattr them as non-parameter attributes
 | |
|                     setattr(replica, key, param_copy)
 | |
|                     # expose the parameter for DDP
 | |
|                     replica._former_parameters[key] = param_copy
 | |
|         for key, buf in module._buffers.items():  # type: ignore[assignment]
 | |
|             if buf is None:
 | |
|                 for j in range(num_replicas):
 | |
|                     replica = module_copies[j][i]
 | |
|                     replica._buffers[key] = None
 | |
|             else:
 | |
|                 if buf.requires_grad and not detach:
 | |
|                     buffer_copies = buffer_copies_rg
 | |
|                     buffer_idx = buffer_indices_rg[buf]
 | |
|                 else:
 | |
|                     buffer_copies = buffer_copies_not_rg
 | |
|                     buffer_idx = buffer_indices_not_rg[buf]
 | |
|                 for j in range(num_replicas):
 | |
|                     replica = module_copies[j][i]
 | |
|                     setattr(replica, key, buffer_copies[j][buffer_idx])
 | |
| 
 | |
|     return [cast(T, module_copies[j][0]) for j in range(num_replicas)]
 |