|
|
@ -1,13 +1,12 @@
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.optim as optim
|
|
|
|
import torch.optim as optim
|
|
|
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
|
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
|
|
import torch.utils.data
|
|
|
|
import torch.utils.data
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
|
|
|
|
|
|
import test # import test.py to get mAP after each epoch
|
|
|
|
import test # import test.py to get mAP after each epoch
|
|
|
|
from models.yolo import Model
|
|
|
|
from models.yolo import Model
|
|
|
@ -131,7 +130,8 @@ def train(hyp, tb_writer, opt, device):
|
|
|
|
|
|
|
|
|
|
|
|
# load model
|
|
|
|
# load model
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items() if k in model.state_dict()}
|
|
|
|
ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
|
|
|
|
|
|
|
|
if k in model.state_dict() and model.state_dict()[k].shape == v.shape}
|
|
|
|
model.load_state_dict(ckpt['model'], strict=False)
|
|
|
|
model.load_state_dict(ckpt['model'], strict=False)
|
|
|
|
except KeyError as e:
|
|
|
|
except KeyError as e:
|
|
|
|
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
|
|
|
|
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
|
|
|
@ -187,7 +187,8 @@ def train(hyp, tb_writer, opt, device):
|
|
|
|
|
|
|
|
|
|
|
|
# Trainloader
|
|
|
|
# Trainloader
|
|
|
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
|
|
|
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
|
|
|
|
cache=opt.cache_images, rect=opt.rect, local_rank=local_rank, world_size=opt.world_size)
|
|
|
|
cache=opt.cache_images, rect=opt.rect, local_rank=local_rank,
|
|
|
|
|
|
|
|
world_size=opt.world_size)
|
|
|
|
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
|
|
|
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
|
|
|
nb = len(dataloader) # number of batches
|
|
|
|
nb = len(dataloader) # number of batches
|
|
|
|
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
|
|
|
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
|
|
@ -242,7 +243,8 @@ def train(hyp, tb_writer, opt, device):
|
|
|
|
if local_rank in [-1, 0]:
|
|
|
|
if local_rank in [-1, 0]:
|
|
|
|
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
|
|
|
|
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
|
|
|
|
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
|
|
|
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
|
|
|
dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
|
|
|
|
dataset.indices = random.choices(range(dataset.n), weights=image_weights,
|
|
|
|
|
|
|
|
k=dataset.n) # rand weighted idx
|
|
|
|
# Broadcast.
|
|
|
|
# Broadcast.
|
|
|
|
if local_rank != -1:
|
|
|
|
if local_rank != -1:
|
|
|
|
indices = torch.zeros([dataset.n], dtype=torch.int)
|
|
|
|
indices = torch.zeros([dataset.n], dtype=torch.int)
|
|
|
@ -431,7 +433,8 @@ if __name__ == '__main__':
|
|
|
|
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
|
|
|
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
|
|
|
parser.add_argument("--sync-bn", action="store_true", help="Use sync-bn, only avaible in DDP mode.")
|
|
|
|
parser.add_argument("--sync-bn", action="store_true", help="Use sync-bn, only avaible in DDP mode.")
|
|
|
|
# Parameter For DDP.
|
|
|
|
# Parameter For DDP.
|
|
|
|
parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.")
|
|
|
|
parser.add_argument('--local_rank', type=int, default=-1,
|
|
|
|
|
|
|
|
help="Extra parameter for DDP implementation. Don't use it manually.")
|
|
|
|
opt = parser.parse_args()
|
|
|
|
opt = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
|
|
|
|
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
|
|
|
|