You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
95 lines
3.2 KiB
95 lines
3.2 KiB
# -*- coding: utf-8 -*-
|
|
# File : replicate.py
|
|
# Author : Jiayuan Mao
|
|
# Email : maojiayuan@gmail.com
|
|
# Date : 27/01/2018
|
|
#
|
|
# This file is part of Synchronized-BatchNorm-PyTorch.
|
|
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
|
# Distributed under MIT License.
|
|
|
|
import functools
|
|
|
|
from torch.nn.parallel.data_parallel import DataParallel
|
|
|
|
__all__ = [
|
|
'CallbackContext',
|
|
'execute_replication_callbacks',
|
|
'DataParallelWithCallback',
|
|
'patch_replication_callback'
|
|
]
|
|
|
|
|
|
class CallbackContext(object):
|
|
pass
|
|
|
|
|
|
def execute_replication_callbacks(modules):
|
|
"""
|
|
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
|
|
|
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
|
|
|
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
|
(shared among multiple copies of this module on different devices).
|
|
Through this context, different copies can share some information.
|
|
|
|
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
|
of any slave copies.
|
|
"""
|
|
master_copy = modules[0]
|
|
nr_modules = len(list(master_copy.modules()))
|
|
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
|
|
|
for i, module in enumerate(modules):
|
|
for j, m in enumerate(module.modules()):
|
|
if hasattr(m, '__data_parallel_replicate__'):
|
|
m.__data_parallel_replicate__(ctxs[j], i)
|
|
|
|
|
|
class DataParallelWithCallback(DataParallel):
|
|
"""
|
|
Data Parallel with a replication callback.
|
|
|
|
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
|
original `replicate` function.
|
|
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
|
|
|
Examples:
|
|
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
|
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
|
# sync_bn.__data_parallel_replicate__ will be invoked.
|
|
"""
|
|
|
|
def replicate(self, module, device_ids):
|
|
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
|
execute_replication_callbacks(modules)
|
|
return modules
|
|
|
|
|
|
def patch_replication_callback(data_parallel):
|
|
"""
|
|
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
|
Useful when you have customized `DataParallel` implementation.
|
|
|
|
Examples:
|
|
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
|
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
|
> patch_replication_callback(sync_bn)
|
|
# this is equivalent to
|
|
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
|
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
|
"""
|
|
|
|
assert isinstance(data_parallel, DataParallel)
|
|
|
|
old_replicate = data_parallel.replicate
|
|
|
|
@functools.wraps(old_replicate)
|
|
def new_replicate(module, device_ids):
|
|
modules = old_replicate(module, device_ids)
|
|
execute_replication_callbacks(modules)
|
|
return modules
|
|
|
|
data_parallel.replicate = new_replicate
|