You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

166 lines
5.5 KiB

3 years ago
import torch
from PIL import Image
import os
import pdb
import numpy as np
import cv2
from data.mytransforms import find_start_pos
def loader_func(path):
return Image.open(path)
class LaneTestDataset(torch.utils.data.Dataset):
def __init__(self, path, list_path, img_transform=None):
super(LaneTestDataset, self).__init__()
self.path = path
self.img_transform = img_transform
with open(list_path, 'r') as f:
self.list = f.readlines()
self.list = [l[1:] if l[0] == '/' else l for l in self.list] # exclude the incorrect path prefix '/' of CULane
def __getitem__(self, index):
name = self.list[index].split()[0]
img_path = os.path.join(self.path, name)
img = loader_func(img_path)
if self.img_transform is not None:
img = self.img_transform(img)
return img, name
def __len__(self):
return len(self.list)
class LaneClsDataset(torch.utils.data.Dataset):
def __init__(self, path, list_path, img_transform = None,target_transform = None,simu_transform = None, griding_num=50, load_name = False,
row_anchor = None,use_aux=False,segment_transform=None, num_lanes = 4):
super(LaneClsDataset, self).__init__()
self.img_transform = img_transform
self.target_transform = target_transform
self.segment_transform = segment_transform
self.simu_transform = simu_transform
self.path = path
self.griding_num = griding_num
self.load_name = load_name
self.use_aux = use_aux
self.num_lanes = num_lanes
with open(list_path, 'r') as f:
self.list = f.readlines()
self.row_anchor = row_anchor
self.row_anchor.sort()
def __getitem__(self, index):
l = self.list[index]
l_info = l.split()
img_name, label_name = l_info[0], l_info[1]
if img_name[0] == '/':
img_name = img_name[1:]
label_name = label_name[1:]
label_path = os.path.join(self.path, label_name)
label = loader_func(label_path)
img_path = os.path.join(self.path, img_name)
img = loader_func(img_path)
if self.simu_transform is not None:
img, label = self.simu_transform(img, label)
lane_pts = self._get_index(label)
# get the coordinates of lanes at row anchors
w, h = img.size
cls_label = self._grid_pts(lane_pts, self.griding_num, w)
# make the coordinates to classification label
if self.use_aux:
assert self.segment_transform is not None
seg_label = self.segment_transform(label)
if self.img_transform is not None:
img = self.img_transform(img)
if self.use_aux:
return img, cls_label, seg_label
if self.load_name:
return img, cls_label, img_name
return img, cls_label
def __len__(self):
return len(self.list)
def _grid_pts(self, pts, num_cols, w):
# pts : numlane,n,2
num_lane, n, n2 = pts.shape
col_sample = np.linspace(0, w - 1, num_cols)
assert n2 == 2
to_pts = np.zeros((n, num_lane))
for i in range(num_lane):
pti = pts[i, :, 1]
to_pts[:, i] = np.asarray(
[int(pt // (col_sample[1] - col_sample[0])) if pt != -1 else num_cols for pt in pti])
return to_pts.astype(int)
def _get_index(self, label):
w, h = label.size
if h != 288:
scale_f = lambda x : int((x * 1.0/288) * h)
sample_tmp = list(map(scale_f,self.row_anchor))
all_idx = np.zeros((self.num_lanes,len(sample_tmp),2))
for i,r in enumerate(sample_tmp):
label_r = np.asarray(label)[int(round(r))]
for lane_idx in range(1, self.num_lanes + 1):
pos = np.where(label_r == lane_idx)[0]
if len(pos) == 0:
all_idx[lane_idx - 1, i, 0] = r
all_idx[lane_idx - 1, i, 1] = -1
continue
pos = np.mean(pos)
all_idx[lane_idx - 1, i, 0] = r
all_idx[lane_idx - 1, i, 1] = pos
# data augmentation: extend the lane to the boundary of image
all_idx_cp = all_idx.copy()
for i in range(self.num_lanes):
if np.all(all_idx_cp[i,:,1] == -1):
continue
# if there is no lane
valid = all_idx_cp[i,:,1] != -1
# get all valid lane points' index
valid_idx = all_idx_cp[i,valid,:]
# get all valid lane points
if valid_idx[-1,0] == all_idx_cp[0,-1,0]:
# if the last valid lane point's y-coordinate is already the last y-coordinate of all rows
# this means this lane has reached the bottom boundary of the image
# so we skip
continue
if len(valid_idx) < 6:
continue
# if the lane is too short to extend
valid_idx_half = valid_idx[len(valid_idx) // 2:,:]
p = np.polyfit(valid_idx_half[:,0], valid_idx_half[:,1],deg = 1)
start_line = valid_idx_half[-1,0]
pos = find_start_pos(all_idx_cp[i,:,0],start_line) + 1
fitted = np.polyval(p,all_idx_cp[i,pos:,0])
fitted = np.array([-1 if y < 0 or y > w-1 else y for y in fitted])
assert np.all(all_idx_cp[i,pos:,1] == -1)
all_idx_cp[i,pos:,1] = fitted
if -1 in all_idx[:, :, 0]:
pdb.set_trace()
return all_idx_cp