You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

109 lines
4.7 KiB

3 years ago
import torch, os
import numpy as np
import torchvision.transforms as transforms
import data.mytransforms as mytransforms
from data.constant import tusimple_row_anchor, culane_row_anchor
from data.dataset import LaneClsDataset, LaneTestDataset
def get_train_loader(batch_size, data_root, griding_num, dataset, use_aux, distributed, num_lanes):
target_transform = transforms.Compose([
mytransforms.FreeScaleMask((288, 800)),
mytransforms.MaskToTensor(),
])
segment_transform = transforms.Compose([
mytransforms.FreeScaleMask((36, 100)),
mytransforms.MaskToTensor(),
])
img_transform = transforms.Compose([
transforms.Resize((288, 800)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
simu_transform = mytransforms.Compose2([
mytransforms.RandomRotate(6),
mytransforms.RandomUDoffsetLABEL(100),
mytransforms.RandomLROffsetLABEL(200)
])
if dataset == 'CULane':
train_dataset = LaneClsDataset(data_root,
os.path.join(data_root, 'list/train_gt.txt'),
img_transform=img_transform, target_transform=target_transform,
simu_transform = simu_transform,
segment_transform=segment_transform,
row_anchor = culane_row_anchor,
griding_num=griding_num, use_aux=use_aux, num_lanes = num_lanes)
cls_num_per_lane = 18
elif dataset == 'Tusimple':
train_dataset = LaneClsDataset(data_root,
os.path.join(data_root, 'train_gt.txt'),
img_transform=img_transform, target_transform=target_transform,
simu_transform = simu_transform,
griding_num=griding_num,
row_anchor = tusimple_row_anchor,
segment_transform=segment_transform,use_aux=use_aux, num_lanes = num_lanes)
cls_num_per_lane = 56
else:
raise NotImplementedError
if distributed:
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
sampler = torch.utils.data.RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler = sampler, num_workers=4)
return train_loader, cls_num_per_lane
def get_test_loader(batch_size, data_root,dataset, distributed):
img_transforms = transforms.Compose([
transforms.Resize((288, 800)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
if dataset == 'CULane':
test_dataset = LaneTestDataset(data_root,os.path.join(data_root, 'list/test.txt'),img_transform = img_transforms)
cls_num_per_lane = 18
elif dataset == 'Tusimple':
test_dataset = LaneTestDataset(data_root,os.path.join(data_root, 'test.txt'), img_transform = img_transforms)
cls_num_per_lane = 56
if distributed:
sampler = SeqDistributedSampler(test_dataset, shuffle = False)
else:
sampler = torch.utils.data.SequentialSampler(test_dataset)
loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, sampler = sampler, num_workers=4)
return loader
class SeqDistributedSampler(torch.utils.data.distributed.DistributedSampler):
'''
Change the behavior of DistributedSampler to sequential distributed sampling.
The sequential sampling helps the stability of multi-thread testing, which needs multi-thread file io.
Without sequentially sampling, the file io on thread may interfere other threads.
'''
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
super().__init__(dataset, num_replicas, rank, shuffle)
def __iter__(self):
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
num_per_rank = int(self.total_size // self.num_replicas)
# sequential sampling
indices = indices[num_per_rank * self.rank : num_per_rank * (self.rank + 1)]
assert len(indices) == self.num_samples
return iter(indices)