refactor dataloader

pull/1/head
Glenn Jocher 5 years ago
parent 97b5186fa0
commit 22fb2b0c25

@ -1,8 +1,6 @@
import argparse import argparse
import json import json
from torch.utils.data import DataLoader
from utils import google_utils from utils import google_utils
from utils.datasets import * from utils.datasets import *
from utils.utils import * from utils.utils import *
@ -56,30 +54,16 @@ def test(data,
data = yaml.load(f, Loader=yaml.FullLoader) # model dict data = yaml.load(f, Loader=yaml.FullLoader) # model dict
nc = 1 if single_cls else int(data['nc']) # number of classes nc = 1 if single_cls else int(data['nc']) # number of classes
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95 iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
# iouv = iouv[0].view(1) # comment for mAP@0.5:0.95
niou = iouv.numel() niou = iouv.numel()
# Dataloader # Dataloader
if dataloader is None: # not training if dataloader is None: # not training
merge = opt.merge # use Merge NMS
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
merge = opt.merge # use Merge NMS
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
dataset = LoadImagesAndLabels(path, dataloader = create_dataloader(path, imgsz, batch_size, int(max(model.stride)), opt,
imgsz, hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0]
batch_size,
rect=True, # rectangular inference
single_cls=opt.single_cls, # single class mode
stride=int(max(model.stride)), # model stride
pad=0.5) # padding
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
dataloader = DataLoader(dataset,
batch_size=batch_size,
num_workers=nw,
pin_memory=True,
collate_fn=dataset.collate_fn)
seen = 0 seen = 0
names = model.names if hasattr(model, 'names') else model.module.names names = model.names if hasattr(model, 'names') else model.module.names

@ -155,38 +155,15 @@ def train(hyp):
model = torch.nn.parallel.DistributedDataParallel(model) model = torch.nn.parallel.DistributedDataParallel(model)
# pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html # pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
# Dataset # Trainloader
dataset = LoadImagesAndLabels(train_path, imgsz, batch_size, dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
augment=True, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
hyp=hyp, # augmentation hyperparameters
rect=opt.rect, # rectangular training
cache_images=opt.cache_images,
single_cls=opt.single_cls,
stride=gs)
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg) assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
# Dataloader
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
num_workers=nw,
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
pin_memory=True,
collate_fn=dataset.collate_fn)
# Testloader # Testloader
testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, imgsz_test, batch_size, testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt,
hyp=hyp, hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0]
rect=True,
cache_images=opt.cache_images,
single_cls=opt.single_cls,
stride=gs),
batch_size=batch_size,
num_workers=nw,
pin_memory=True,
collate_fn=dataset.collate_fn)
# Model parameters # Model parameters
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
@ -218,7 +195,7 @@ def train(hyp):
maps = np.zeros(nc) # mAP per class maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
print('Image sizes %g train, %g test' % (imgsz, imgsz_test)) print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
print('Using %g dataloader workers' % nw) print('Using %g dataloader workers' % dataloader.num_workers)
print('Starting training for %g epochs...' % epochs) print('Starting training for %g epochs...' % epochs)
# torch.autograd.set_detect_anomaly(True) # torch.autograd.set_detect_anomaly(True)
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------

@ -41,6 +41,26 @@ def exif_size(img):
return s return s
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
single_cls=opt.single_cls,
stride=stride,
pad=pad)
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 0]) # number of workers
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
num_workers=nw,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn)
return dataloader, dataset
class LoadImages: # for inference class LoadImages: # for inference
def __init__(self, path, img_size=416): def __init__(self, path, img_size=416):
path = str(Path(path)) # os-agnostic path = str(Path(path)) # os-agnostic
@ -712,7 +732,7 @@ def random_affine(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10,
area = w * h area = w * h
area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2]) area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio
i = (w > 4) & (h > 4) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 10) i = (w > 2) & (h > 2) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 20)
targets = targets[i] targets = targets[i]
targets[:, 1:5] = xy[i] targets[:, 1:5] = xy[i]

Loading…
Cancel
Save