You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
174 lines
4.5 KiB
174 lines
4.5 KiB
import torch
|
|
import torch.distributed as dist
|
|
import pickle
|
|
|
|
|
|
def get_world_size():
|
|
if not dist.is_available():
|
|
return 1
|
|
if not dist.is_initialized():
|
|
return 1
|
|
return dist.get_world_size()
|
|
|
|
|
|
def to_python_float(t):
|
|
if hasattr(t, 'item'):
|
|
return t.item()
|
|
else:
|
|
return t[0]
|
|
|
|
|
|
def get_rank():
|
|
if not dist.is_available():
|
|
return 0
|
|
if not dist.is_initialized():
|
|
return 0
|
|
return dist.get_rank()
|
|
|
|
|
|
def is_main_process():
|
|
return get_rank() == 0
|
|
|
|
|
|
def can_log():
|
|
return is_main_process()
|
|
|
|
|
|
def dist_print(*args, **kwargs):
|
|
if can_log():
|
|
print(*args, **kwargs)
|
|
|
|
|
|
def synchronize():
|
|
"""
|
|
Helper function to synchronize (barrier) among all processes when
|
|
using distributed training
|
|
"""
|
|
if not dist.is_available():
|
|
return
|
|
if not dist.is_initialized():
|
|
return
|
|
world_size = dist.get_world_size()
|
|
if world_size == 1:
|
|
return
|
|
dist.barrier()
|
|
|
|
def dist_cat_reduce_tensor(tensor):
|
|
if not dist.is_available():
|
|
return tensor
|
|
if not dist.is_initialized():
|
|
return tensor
|
|
# dist_print(tensor)
|
|
rt = tensor.clone()
|
|
all_list = [torch.zeros_like(tensor) for _ in range(get_world_size())]
|
|
dist.all_gather(all_list,rt)
|
|
# dist_print(all_list[0][1],all_list[1][1],all_list[2][1],all_list[3][1])
|
|
# dist_print(all_list[0][2],all_list[1][2],all_list[2][2],all_list[3][2])
|
|
# dist_print(all_list[0][3],all_list[1][3],all_list[2][3],all_list[3][3])
|
|
# dist_print(all_list[0].shape)
|
|
return torch.cat(all_list,dim = 0)
|
|
|
|
def dist_sum_reduce_tensor(tensor):
|
|
if not dist.is_available():
|
|
return tensor
|
|
if not dist.is_initialized():
|
|
return tensor
|
|
if not isinstance(tensor, torch.Tensor):
|
|
return tensor
|
|
rt = tensor.clone()
|
|
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
|
return rt
|
|
|
|
|
|
def dist_mean_reduce_tensor(tensor):
|
|
rt = dist_sum_reduce_tensor(tensor)
|
|
rt /= get_world_size()
|
|
return rt
|
|
|
|
|
|
def all_gather(data):
|
|
"""
|
|
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
|
Args:
|
|
data: any picklable object
|
|
Returns:
|
|
list[data]: list of data gathered from each rank
|
|
"""
|
|
world_size = get_world_size()
|
|
if world_size == 1:
|
|
return [data]
|
|
|
|
# serialized to a Tensor
|
|
buffer = pickle.dumps(data)
|
|
storage = torch.ByteStorage.from_buffer(buffer)
|
|
tensor = torch.ByteTensor(storage).to("cuda")
|
|
|
|
# obtain Tensor size of each rank
|
|
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
|
|
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
|
|
dist.all_gather(size_list, local_size)
|
|
size_list = [int(size.item()) for size in size_list]
|
|
max_size = max(size_list)
|
|
|
|
# receiving Tensor from all ranks
|
|
# we pad the tensor because torch all_gather does not support
|
|
# gathering tensors of different shapes
|
|
tensor_list = []
|
|
for _ in size_list:
|
|
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
|
|
if local_size != max_size:
|
|
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
|
|
tensor = torch.cat((tensor, padding), dim=0)
|
|
dist.all_gather(tensor_list, tensor)
|
|
|
|
data_list = []
|
|
for size, tensor in zip(size_list, tensor_list):
|
|
buffer = tensor.cpu().numpy().tobytes()[:size]
|
|
data_list.append(pickle.loads(buffer))
|
|
|
|
return data_list
|
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
class DistSummaryWriter(SummaryWriter):
|
|
def __init__(self, *args, **kwargs):
|
|
if can_log():
|
|
super(DistSummaryWriter, self).__init__(*args, **kwargs)
|
|
|
|
def add_scalar(self, *args, **kwargs):
|
|
if can_log():
|
|
super(DistSummaryWriter, self).add_scalar(*args, **kwargs)
|
|
|
|
def add_figure(self, *args, **kwargs):
|
|
if can_log():
|
|
super(DistSummaryWriter, self).add_figure(*args, **kwargs)
|
|
|
|
def add_graph(self, *args, **kwargs):
|
|
if can_log():
|
|
super(DistSummaryWriter, self).add_graph(*args, **kwargs)
|
|
|
|
def add_histogram(self, *args, **kwargs):
|
|
if can_log():
|
|
super(DistSummaryWriter, self).add_histogram(*args, **kwargs)
|
|
|
|
def add_image(self, *args, **kwargs):
|
|
if can_log():
|
|
super(DistSummaryWriter, self).add_image(*args, **kwargs)
|
|
|
|
def close(self):
|
|
if can_log():
|
|
super(DistSummaryWriter, self).close()
|
|
|
|
|
|
import tqdm
|
|
|
|
|
|
def dist_tqdm(obj, *args, **kwargs):
|
|
if can_log():
|
|
return tqdm.tqdm(obj, *args, **kwargs)
|
|
else:
|
|
return obj
|
|
|