|
|
@ -1,11 +1,13 @@
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.optim as optim
|
|
|
|
import torch.optim as optim
|
|
|
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
|
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
|
|
import torch.utils.data
|
|
|
|
import torch.utils.data
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
|
|
|
|
|
|
|
|
import test # import test.py to get mAP after each epoch
|
|
|
|
import test # import test.py to get mAP after each epoch
|
|
|
|
from models.yolo import Model
|
|
|
|
from models.yolo import Model
|
|
|
@ -42,7 +44,7 @@ hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
|
|
|
|
'shear': 0.0} # image shear (+/- deg)
|
|
|
|
'shear': 0.0} # image shear (+/- deg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(hyp):
|
|
|
|
def train(hyp, tb_writer, opt, device):
|
|
|
|
print(f'Hyperparameters {hyp}')
|
|
|
|
print(f'Hyperparameters {hyp}')
|
|
|
|
log_dir = tb_writer.log_dir if tb_writer else 'runs/evolution' # run directory
|
|
|
|
log_dir = tb_writer.log_dir if tb_writer else 'runs/evolution' # run directory
|
|
|
|
wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory
|
|
|
|
wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory
|
|
|
@ -59,11 +61,16 @@ def train(hyp):
|
|
|
|
yaml.dump(vars(opt), f, sort_keys=False)
|
|
|
|
yaml.dump(vars(opt), f, sort_keys=False)
|
|
|
|
|
|
|
|
|
|
|
|
epochs = opt.epochs # 300
|
|
|
|
epochs = opt.epochs # 300
|
|
|
|
batch_size = opt.batch_size # 64
|
|
|
|
batch_size = opt.batch_size # batch size per process.
|
|
|
|
|
|
|
|
total_batch_size = opt.total_batch_size
|
|
|
|
weights = opt.weights # initial training weights
|
|
|
|
weights = opt.weights # initial training weights
|
|
|
|
|
|
|
|
local_rank = opt.local_rank
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: Init DDP logging. Only the first process is allowed to log.
|
|
|
|
|
|
|
|
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.
|
|
|
|
|
|
|
|
|
|
|
|
# Configure
|
|
|
|
# Configure
|
|
|
|
init_seeds(1)
|
|
|
|
init_seeds(2+local_rank)
|
|
|
|
with open(opt.data) as f:
|
|
|
|
with open(opt.data) as f:
|
|
|
|
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
|
|
|
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
|
|
|
train_path = data_dict['train']
|
|
|
|
train_path = data_dict['train']
|
|
|
@ -72,6 +79,7 @@ def train(hyp):
|
|
|
|
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
|
|
|
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
|
|
|
|
|
|
|
|
|
|
|
# Remove previous results
|
|
|
|
# Remove previous results
|
|
|
|
|
|
|
|
if local_rank in [-1, 0]:
|
|
|
|
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
|
|
|
|
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
|
|
|
|
os.remove(f)
|
|
|
|
os.remove(f)
|
|
|
|
|
|
|
|
|
|
|
@ -84,8 +92,15 @@ def train(hyp):
|
|
|
|
|
|
|
|
|
|
|
|
# Optimizer
|
|
|
|
# Optimizer
|
|
|
|
nbs = 64 # nominal batch size
|
|
|
|
nbs = 64 # nominal batch size
|
|
|
|
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
|
|
|
|
# the default DDP implementation is slow for accumulation according to: https://pytorch.org/docs/stable/notes/ddp.html
|
|
|
|
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
|
|
|
|
# all-reduce operation is carried out during loss.backward().
|
|
|
|
|
|
|
|
# Thus, there would be redundant all-reduce communications in a accumulation procedure,
|
|
|
|
|
|
|
|
# which means, the result is still right but the training speed gets slower.
|
|
|
|
|
|
|
|
# TODO: If acceleration is needed, there is an implementation of allreduce_post_accumulation
|
|
|
|
|
|
|
|
# in https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/run_pretraining.py
|
|
|
|
|
|
|
|
accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
|
|
|
|
|
|
|
|
hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
|
|
|
|
|
|
|
|
|
|
|
|
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
|
|
|
|
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
|
|
|
|
for k, v in model.named_parameters():
|
|
|
|
for k, v in model.named_parameters():
|
|
|
|
if v.requires_grad:
|
|
|
|
if v.requires_grad:
|
|
|
@ -106,12 +121,9 @@ def train(hyp):
|
|
|
|
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
|
|
|
|
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
|
|
|
|
del pg0, pg1, pg2
|
|
|
|
del pg0, pg1, pg2
|
|
|
|
|
|
|
|
|
|
|
|
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
|
|
|
|
|
|
|
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
|
|
|
|
|
|
|
|
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
|
|
|
|
|
|
|
# plot_lr_scheduler(optimizer, scheduler, epochs, save_dir=log_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Load Model
|
|
|
|
# Load Model
|
|
|
|
|
|
|
|
# Avoid multiple downloads.
|
|
|
|
|
|
|
|
with torch_distributed_zero_first(local_rank):
|
|
|
|
google_utils.attempt_download(weights)
|
|
|
|
google_utils.attempt_download(weights)
|
|
|
|
start_epoch, best_fitness = 0, 0.0
|
|
|
|
start_epoch, best_fitness = 0, 0.0
|
|
|
|
if weights.endswith('.pt'): # pytorch format
|
|
|
|
if weights.endswith('.pt'): # pytorch format
|
|
|
@ -124,7 +136,7 @@ def train(hyp):
|
|
|
|
except KeyError as e:
|
|
|
|
except KeyError as e:
|
|
|
|
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
|
|
|
|
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
|
|
|
|
"Please delete or update %s and try again, or use --weights '' to train from scratch." \
|
|
|
|
"Please delete or update %s and try again, or use --weights '' to train from scratch." \
|
|
|
|
% (opt.weights, opt.cfg, opt.weights, opt.weights)
|
|
|
|
% (weights, opt.cfg, weights, weights)
|
|
|
|
raise KeyError(s) from e
|
|
|
|
raise KeyError(s) from e
|
|
|
|
|
|
|
|
|
|
|
|
# load optimizer
|
|
|
|
# load optimizer
|
|
|
@ -141,7 +153,7 @@ def train(hyp):
|
|
|
|
start_epoch = ckpt['epoch'] + 1
|
|
|
|
start_epoch = ckpt['epoch'] + 1
|
|
|
|
if epochs < start_epoch:
|
|
|
|
if epochs < start_epoch:
|
|
|
|
print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
|
|
|
|
print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
|
|
|
|
(opt.weights, ckpt['epoch'], epochs))
|
|
|
|
(weights, ckpt['epoch'], epochs))
|
|
|
|
epochs += ckpt['epoch'] # finetune additional epochs
|
|
|
|
epochs += ckpt['epoch'] # finetune additional epochs
|
|
|
|
|
|
|
|
|
|
|
|
del ckpt
|
|
|
|
del ckpt
|
|
|
@ -150,25 +162,41 @@ def train(hyp):
|
|
|
|
if mixed_precision:
|
|
|
|
if mixed_precision:
|
|
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
|
|
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
|
|
|
|
|
|
|
|
|
|
|
|
# Distributed training
|
|
|
|
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
|
|
|
if device.type != 'cpu' and torch.cuda.device_count() > 1 and dist.is_available():
|
|
|
|
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
|
|
|
|
dist.init_process_group(backend='nccl', # distributed backend
|
|
|
|
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
|
|
|
init_method='tcp://127.0.0.1:9999', # init method
|
|
|
|
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
|
|
|
|
world_size=1, # number of nodes
|
|
|
|
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
|
|
|
rank=0) # node rank
|
|
|
|
|
|
|
|
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) # requires world_size > 1
|
|
|
|
# DP mode
|
|
|
|
model = torch.nn.parallel.DistributedDataParallel(model)
|
|
|
|
if device.type != 'cpu' and local_rank == -1 and torch.cuda.device_count() > 1:
|
|
|
|
|
|
|
|
model = torch.nn.DataParallel(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Exponential moving average
|
|
|
|
|
|
|
|
# From https://github.com/rwightman/pytorch-image-models/blob/master/train.py:
|
|
|
|
|
|
|
|
# "Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper"
|
|
|
|
|
|
|
|
# chenyzsjtu: ema should be placed before after SyncBN. As SyncBN introduces new modules.
|
|
|
|
|
|
|
|
if opt.sync_bn and device.type != 'cpu' and local_rank != -1:
|
|
|
|
|
|
|
|
print("SyncBN activated!")
|
|
|
|
|
|
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
|
|
|
|
|
|
|
|
ema = torch_utils.ModelEMA(model) if local_rank in [-1, 0] else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# DDP mode
|
|
|
|
|
|
|
|
if device.type != 'cpu' and local_rank != -1:
|
|
|
|
|
|
|
|
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
|
|
|
|
|
|
|
|
|
|
|
# Trainloader
|
|
|
|
# Trainloader
|
|
|
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
|
|
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
|
|
|
|
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
|
|
|
|
cache=opt.cache_images, rect=opt.rect, local_rank=local_rank, world_size=opt.world_size)
|
|
|
|
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
|
|
|
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
|
|
|
nb = len(dataloader) # number of batches
|
|
|
|
nb = len(dataloader) # number of batches
|
|
|
|
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
|
|
|
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
|
|
|
|
|
|
|
|
|
|
|
# Testloader
|
|
|
|
# Testloader
|
|
|
|
testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt,
|
|
|
|
if local_rank in [-1, 0]:
|
|
|
|
hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0]
|
|
|
|
# local_rank is set to -1. Because only the first process is expected to do evaluation.
|
|
|
|
|
|
|
|
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
|
|
|
|
|
|
|
|
cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0]
|
|
|
|
|
|
|
|
|
|
|
|
# Model parameters
|
|
|
|
# Model parameters
|
|
|
|
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
|
|
|
|
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
|
|
|
@ -179,6 +207,8 @@ def train(hyp):
|
|
|
|
model.names = names
|
|
|
|
model.names = names
|
|
|
|
|
|
|
|
|
|
|
|
# Class frequency
|
|
|
|
# Class frequency
|
|
|
|
|
|
|
|
# Only one check and log is needed.
|
|
|
|
|
|
|
|
if local_rank in [-1, 0]:
|
|
|
|
labels = np.concatenate(dataset.labels, 0)
|
|
|
|
labels = np.concatenate(dataset.labels, 0)
|
|
|
|
c = torch.tensor(labels[:, 0]) # classes
|
|
|
|
c = torch.tensor(labels[:, 0]) # classes
|
|
|
|
# cf = torch.bincount(c.long(), minlength=nc) + 1.
|
|
|
|
# cf = torch.bincount(c.long(), minlength=nc) + 1.
|
|
|
@ -191,16 +221,13 @@ def train(hyp):
|
|
|
|
# Check anchors
|
|
|
|
# Check anchors
|
|
|
|
if not opt.noautoanchor:
|
|
|
|
if not opt.noautoanchor:
|
|
|
|
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
|
|
|
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
|
|
|
|
|
|
|
|
|
|
|
# Exponential moving average
|
|
|
|
|
|
|
|
ema = torch_utils.ModelEMA(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Start training
|
|
|
|
# Start training
|
|
|
|
t0 = time.time()
|
|
|
|
t0 = time.time()
|
|
|
|
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
|
|
|
|
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
|
|
|
|
maps = np.zeros(nc) # mAP per class
|
|
|
|
maps = np.zeros(nc) # mAP per class
|
|
|
|
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
|
|
|
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
|
|
|
scheduler.last_epoch = start_epoch - 1 # do not move
|
|
|
|
scheduler.last_epoch = start_epoch - 1 # do not move
|
|
|
|
|
|
|
|
if local_rank in [0, -1]:
|
|
|
|
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
|
|
|
|
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
|
|
|
|
print('Using %g dataloader workers' % dataloader.num_workers)
|
|
|
|
print('Using %g dataloader workers' % dataloader.num_workers)
|
|
|
|
print('Starting training for %g epochs...' % epochs)
|
|
|
|
print('Starting training for %g epochs...' % epochs)
|
|
|
@ -209,18 +236,34 @@ def train(hyp):
|
|
|
|
model.train()
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
|
|
# Update image weights (optional)
|
|
|
|
# Update image weights (optional)
|
|
|
|
|
|
|
|
# When in DDP mode, the generated indices will be broadcasted to synchronize dataset.
|
|
|
|
if dataset.image_weights:
|
|
|
|
if dataset.image_weights:
|
|
|
|
|
|
|
|
# Generate indices.
|
|
|
|
|
|
|
|
if local_rank in [-1, 0]:
|
|
|
|
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
|
|
|
|
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
|
|
|
|
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
|
|
|
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
|
|
|
dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
|
|
|
|
dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
|
|
|
|
|
|
|
|
# Broadcast.
|
|
|
|
|
|
|
|
if local_rank != -1:
|
|
|
|
|
|
|
|
indices = torch.zeros([dataset.n], dtype=torch.int)
|
|
|
|
|
|
|
|
if local_rank == 0:
|
|
|
|
|
|
|
|
indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int)
|
|
|
|
|
|
|
|
dist.broadcast(indices, 0)
|
|
|
|
|
|
|
|
if local_rank != 0:
|
|
|
|
|
|
|
|
dataset.indices = indices.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
# Update mosaic border
|
|
|
|
# Update mosaic border
|
|
|
|
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
|
|
|
|
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
|
|
|
|
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders
|
|
|
|
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders
|
|
|
|
|
|
|
|
|
|
|
|
mloss = torch.zeros(4, device=device) # mean losses
|
|
|
|
mloss = torch.zeros(4, device=device) # mean losses
|
|
|
|
|
|
|
|
if local_rank != -1:
|
|
|
|
|
|
|
|
dataloader.sampler.set_epoch(epoch)
|
|
|
|
|
|
|
|
pbar = enumerate(dataloader)
|
|
|
|
|
|
|
|
if local_rank in [-1, 0]:
|
|
|
|
print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
|
|
|
|
print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
|
|
|
|
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
|
|
|
|
pbar = tqdm(pbar, total=nb) # progress bar
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
|
|
|
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
|
|
|
ni = i + nb * epoch # number integrated batches (since train start)
|
|
|
|
ni = i + nb * epoch # number integrated batches (since train start)
|
|
|
|
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
|
|
|
|
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
|
|
|
@ -229,7 +272,7 @@ def train(hyp):
|
|
|
|
if ni <= nw:
|
|
|
|
if ni <= nw:
|
|
|
|
xi = [0, nw] # x interp
|
|
|
|
xi = [0, nw] # x interp
|
|
|
|
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
|
|
|
|
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
|
|
|
|
accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
|
|
|
|
accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
|
|
|
|
for j, x in enumerate(optimizer.param_groups):
|
|
|
|
for j, x in enumerate(optimizer.param_groups):
|
|
|
|
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
|
|
|
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
|
|
|
x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
|
|
|
|
x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
|
|
|
@ -249,6 +292,9 @@ def train(hyp):
|
|
|
|
|
|
|
|
|
|
|
|
# Loss
|
|
|
|
# Loss
|
|
|
|
loss, loss_items = compute_loss(pred, targets.to(device), model)
|
|
|
|
loss, loss_items = compute_loss(pred, targets.to(device), model)
|
|
|
|
|
|
|
|
# loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices.
|
|
|
|
|
|
|
|
if local_rank != -1:
|
|
|
|
|
|
|
|
loss *= opt.world_size
|
|
|
|
if not torch.isfinite(loss):
|
|
|
|
if not torch.isfinite(loss):
|
|
|
|
print('WARNING: non-finite loss, ending training ', loss_items)
|
|
|
|
print('WARNING: non-finite loss, ending training ', loss_items)
|
|
|
|
return results
|
|
|
|
return results
|
|
|
@ -264,9 +310,11 @@ def train(hyp):
|
|
|
|
if ni % accumulate == 0:
|
|
|
|
if ni % accumulate == 0:
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
if ema is not None:
|
|
|
|
ema.update(model)
|
|
|
|
ema.update(model)
|
|
|
|
|
|
|
|
|
|
|
|
# Print
|
|
|
|
# Print
|
|
|
|
|
|
|
|
if local_rank in [-1, 0]:
|
|
|
|
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
|
|
|
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
|
|
|
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
|
|
|
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
|
|
|
s = ('%10s' * 2 + '%10.4g' * 6) % (
|
|
|
|
s = ('%10s' * 2 + '%10.4g' * 6) % (
|
|
|
@ -286,29 +334,32 @@ def train(hyp):
|
|
|
|
# Scheduler
|
|
|
|
# Scheduler
|
|
|
|
scheduler.step()
|
|
|
|
scheduler.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Only the first process in DDP mode is allowed to log or save checkpoints.
|
|
|
|
|
|
|
|
if local_rank in [-1, 0]:
|
|
|
|
# mAP
|
|
|
|
# mAP
|
|
|
|
|
|
|
|
if ema is not None:
|
|
|
|
ema.update_attr(model, include=['md', 'nc', 'hyp', 'gr', 'names', 'stride'])
|
|
|
|
ema.update_attr(model, include=['md', 'nc', 'hyp', 'gr', 'names', 'stride'])
|
|
|
|
final_epoch = epoch + 1 == epochs
|
|
|
|
final_epoch = epoch + 1 == epochs
|
|
|
|
if not opt.notest or final_epoch: # Calculate mAP
|
|
|
|
if not opt.notest or final_epoch: # Calculate mAP
|
|
|
|
results, maps, times = test.test(opt.data,
|
|
|
|
results, maps, times = test.test(opt.data,
|
|
|
|
batch_size=batch_size,
|
|
|
|
batch_size=total_batch_size,
|
|
|
|
imgsz=imgsz_test,
|
|
|
|
imgsz=imgsz_test,
|
|
|
|
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
|
|
|
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
|
|
|
model=ema.ema,
|
|
|
|
model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema,
|
|
|
|
single_cls=opt.single_cls,
|
|
|
|
single_cls=opt.single_cls,
|
|
|
|
dataloader=testloader,
|
|
|
|
dataloader=testloader,
|
|
|
|
save_dir=log_dir)
|
|
|
|
save_dir=log_dir)
|
|
|
|
|
|
|
|
# Explicitly keep the shape.
|
|
|
|
# Write
|
|
|
|
# Write
|
|
|
|
with open(results_file, 'a') as f:
|
|
|
|
with open(results_file, 'a') as f:
|
|
|
|
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
|
|
|
|
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
|
|
|
|
if len(opt.name) and opt.bucket:
|
|
|
|
if len(opt.name) and opt.bucket:
|
|
|
|
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
|
|
|
|
os.system('gsutil cp results.txt gs://%s/results/results%s.txt' % (opt.bucket, opt.name))
|
|
|
|
|
|
|
|
|
|
|
|
# Tensorboard
|
|
|
|
# Tensorboard
|
|
|
|
if tb_writer:
|
|
|
|
if tb_writer:
|
|
|
|
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
|
|
|
|
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
|
|
|
|
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
|
|
|
|
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/F1',
|
|
|
|
'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
|
|
|
|
'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
|
|
|
|
for x, tag in zip(list(mloss[:-1]) + list(results), tags):
|
|
|
|
for x, tag in zip(list(mloss[:-1]) + list(results), tags):
|
|
|
|
tb_writer.add_scalar(tag, x, epoch)
|
|
|
|
tb_writer.add_scalar(tag, x, epoch)
|
|
|
@ -325,7 +376,7 @@ def train(hyp):
|
|
|
|
ckpt = {'epoch': epoch,
|
|
|
|
ckpt = {'epoch': epoch,
|
|
|
|
'best_fitness': best_fitness,
|
|
|
|
'best_fitness': best_fitness,
|
|
|
|
'training_results': f.read(),
|
|
|
|
'training_results': f.read(),
|
|
|
|
'model': ema.ema,
|
|
|
|
'model': ema.ema.module if hasattr(ema, 'module') else ema.ema,
|
|
|
|
'optimizer': None if final_epoch else optimizer.state_dict()}
|
|
|
|
'optimizer': None if final_epoch else optimizer.state_dict()}
|
|
|
|
|
|
|
|
|
|
|
|
# Save last, best and delete
|
|
|
|
# Save last, best and delete
|
|
|
@ -333,10 +384,10 @@ def train(hyp):
|
|
|
|
if (best_fitness == fi) and not final_epoch:
|
|
|
|
if (best_fitness == fi) and not final_epoch:
|
|
|
|
torch.save(ckpt, best)
|
|
|
|
torch.save(ckpt, best)
|
|
|
|
del ckpt
|
|
|
|
del ckpt
|
|
|
|
|
|
|
|
|
|
|
|
# end epoch ----------------------------------------------------------------------------------------------------
|
|
|
|
# end epoch ----------------------------------------------------------------------------------------------------
|
|
|
|
# end training
|
|
|
|
# end training
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if local_rank in [-1, 0]:
|
|
|
|
# Strip optimizers
|
|
|
|
# Strip optimizers
|
|
|
|
n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
|
|
|
|
n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
|
|
|
|
fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
|
|
|
|
fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
|
|
|
@ -346,24 +397,23 @@ def train(hyp):
|
|
|
|
ispt = f2.endswith('.pt') # is *.pt
|
|
|
|
ispt = f2.endswith('.pt') # is *.pt
|
|
|
|
strip_optimizer(f2) if ispt else None # strip optimizer
|
|
|
|
strip_optimizer(f2) if ispt else None # strip optimizer
|
|
|
|
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
|
|
|
|
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
|
|
|
|
|
|
|
|
|
|
|
|
# Finish
|
|
|
|
# Finish
|
|
|
|
if not opt.evolve:
|
|
|
|
if not opt.evolve:
|
|
|
|
plot_results(save_dir=log_dir) # save as results.png
|
|
|
|
plot_results() # save as results.png
|
|
|
|
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
|
|
|
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
|
|
|
dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None
|
|
|
|
|
|
|
|
|
|
|
|
dist.destroy_process_group() if local_rank not in [-1,0] else None
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
return results
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if __name__ == '__main__':
|
|
|
|
check_git_status()
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='model.yaml path')
|
|
|
|
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='model.yaml path')
|
|
|
|
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
|
|
|
|
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
|
|
|
|
parser.add_argument('--hyp', type=str, default='', help='hyp.yaml path (optional)')
|
|
|
|
parser.add_argument('--hyp', type=str, default='', help='hyp.yaml path (optional)')
|
|
|
|
parser.add_argument('--epochs', type=int, default=300)
|
|
|
|
parser.add_argument('--epochs', type=int, default=300)
|
|
|
|
parser.add_argument('--batch-size', type=int, default=16)
|
|
|
|
parser.add_argument('--batch-size', type=int, default=16, help="Total batch size for all gpus.")
|
|
|
|
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
|
|
|
|
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
|
|
|
|
parser.add_argument('--rect', action='store_true', help='rectangular training')
|
|
|
|
parser.add_argument('--rect', action='store_true', help='rectangular training')
|
|
|
|
parser.add_argument('--resume', nargs='?', const='get_last', default=False,
|
|
|
|
parser.add_argument('--resume', nargs='?', const='get_last', default=False,
|
|
|
@ -379,32 +429,54 @@ if __name__ == '__main__':
|
|
|
|
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
|
|
|
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
|
|
|
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
|
|
|
|
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
|
|
|
|
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
|
|
|
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
|
|
|
|
|
|
|
parser.add_argument("--sync-bn", action="store_true", help="Use sync-bn, only avaible in DDP mode.")
|
|
|
|
|
|
|
|
# Parameter For DDP.
|
|
|
|
|
|
|
|
parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.")
|
|
|
|
opt = parser.parse_args()
|
|
|
|
opt = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
|
|
|
|
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
|
|
|
|
if last and not opt.weights:
|
|
|
|
if last and not opt.weights:
|
|
|
|
print(f'Resuming training from {last}')
|
|
|
|
print(f'Resuming training from {last}')
|
|
|
|
opt.weights = last if opt.resume and not opt.weights else opt.weights
|
|
|
|
opt.weights = last if opt.resume and not opt.weights else opt.weights
|
|
|
|
|
|
|
|
if opt.local_rank in [-1, 0]:
|
|
|
|
|
|
|
|
check_git_status()
|
|
|
|
opt.cfg = check_file(opt.cfg) # check file
|
|
|
|
opt.cfg = check_file(opt.cfg) # check file
|
|
|
|
opt.data = check_file(opt.data) # check file
|
|
|
|
opt.data = check_file(opt.data) # check file
|
|
|
|
if opt.hyp: # update hyps
|
|
|
|
if opt.hyp: # update hyps
|
|
|
|
opt.hyp = check_file(opt.hyp) # check file
|
|
|
|
opt.hyp = check_file(opt.hyp) # check file
|
|
|
|
with open(opt.hyp) as f:
|
|
|
|
with open(opt.hyp) as f:
|
|
|
|
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
|
|
|
|
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
|
|
|
|
print(opt)
|
|
|
|
|
|
|
|
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
|
|
|
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
|
|
|
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
|
|
|
|
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
|
|
|
|
|
|
|
|
opt.total_batch_size = opt.batch_size
|
|
|
|
|
|
|
|
opt.world_size = 1
|
|
|
|
if device.type == 'cpu':
|
|
|
|
if device.type == 'cpu':
|
|
|
|
mixed_precision = False
|
|
|
|
mixed_precision = False
|
|
|
|
|
|
|
|
elif opt.local_rank != -1:
|
|
|
|
|
|
|
|
# DDP mode
|
|
|
|
|
|
|
|
assert torch.cuda.device_count() > opt.local_rank
|
|
|
|
|
|
|
|
torch.cuda.set_device(opt.local_rank)
|
|
|
|
|
|
|
|
device = torch.device("cuda", opt.local_rank)
|
|
|
|
|
|
|
|
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
opt.world_size = dist.get_world_size()
|
|
|
|
|
|
|
|
assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"
|
|
|
|
|
|
|
|
opt.batch_size = opt.total_batch_size // opt.world_size
|
|
|
|
|
|
|
|
print(opt)
|
|
|
|
|
|
|
|
|
|
|
|
# Train
|
|
|
|
# Train
|
|
|
|
if not opt.evolve:
|
|
|
|
if not opt.evolve:
|
|
|
|
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
|
|
|
|
if opt.local_rank in [-1, 0]:
|
|
|
|
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
|
|
|
|
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
|
|
|
|
train(hyp)
|
|
|
|
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
tb_writer = None
|
|
|
|
|
|
|
|
train(hyp, tb_writer, opt, device)
|
|
|
|
|
|
|
|
|
|
|
|
# Evolve hyperparameters (optional)
|
|
|
|
# Evolve hyperparameters (optional)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
assert opt.local_rank == -1, "DDP mode currently not implemented for Evolve!"
|
|
|
|
|
|
|
|
|
|
|
|
tb_writer = None
|
|
|
|
tb_writer = None
|
|
|
|
opt.notest, opt.nosave = True, True # only test/save final epoch
|
|
|
|
opt.notest, opt.nosave = True, True # only test/save final epoch
|
|
|
|
if opt.bucket:
|
|
|
|
if opt.bucket:
|
|
|
@ -443,7 +515,7 @@ if __name__ == '__main__':
|
|
|
|
hyp[k] = np.clip(hyp[k], v[0], v[1])
|
|
|
|
hyp[k] = np.clip(hyp[k], v[0], v[1])
|
|
|
|
|
|
|
|
|
|
|
|
# Train mutation
|
|
|
|
# Train mutation
|
|
|
|
results = train(hyp.copy())
|
|
|
|
results = train(hyp.copy(), tb_writer, opt, device)
|
|
|
|
|
|
|
|
|
|
|
|
# Write mutation results
|
|
|
|
# Write mutation results
|
|
|
|
print_mutation(hyp, results, opt.bucket)
|
|
|
|
print_mutation(hyp, results, opt.bucket)
|
|
|
|