You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

166 lines
5.5 KiB

import torch
from PIL import Image
import os
import pdb
import numpy as np
import cv2
from data.mytransforms import find_start_pos
def loader_func(path):
return Image.open(path)
class LaneTestDataset(torch.utils.data.Dataset):
def __init__(self, path, list_path, img_transform=None):
super(LaneTestDataset, self).__init__()
self.path = path
self.img_transform = img_transform
with open(list_path, 'r') as f:
self.list = f.readlines()
self.list = [l[1:] if l[0] == '/' else l for l in self.list] # exclude the incorrect path prefix '/' of CULane
def __getitem__(self, index):
name = self.list[index].split()[0]
img_path = os.path.join(self.path, name)
img = loader_func(img_path)
if self.img_transform is not None:
img = self.img_transform(img)
return img, name
def __len__(self):
return len(self.list)
class LaneClsDataset(torch.utils.data.Dataset):
def __init__(self, path, list_path, img_transform = None,target_transform = None,simu_transform = None, griding_num=50, load_name = False,
row_anchor = None,use_aux=False,segment_transform=None, num_lanes = 4):
super(LaneClsDataset, self).__init__()
self.img_transform = img_transform
self.target_transform = target_transform
self.segment_transform = segment_transform
self.simu_transform = simu_transform
self.path = path
self.griding_num = griding_num
self.load_name = load_name
self.use_aux = use_aux
self.num_lanes = num_lanes
with open(list_path, 'r') as f:
self.list = f.readlines()
self.row_anchor = row_anchor
self.row_anchor.sort()
def __getitem__(self, index):
l = self.list[index]
l_info = l.split()
img_name, label_name = l_info[0], l_info[1]
if img_name[0] == '/':
img_name = img_name[1:]
label_name = label_name[1:]
label_path = os.path.join(self.path, label_name)
label = loader_func(label_path)
img_path = os.path.join(self.path, img_name)
img = loader_func(img_path)
if self.simu_transform is not None:
img, label = self.simu_transform(img, label)
lane_pts = self._get_index(label)
# get the coordinates of lanes at row anchors
w, h = img.size
cls_label = self._grid_pts(lane_pts, self.griding_num, w)
# make the coordinates to classification label
if self.use_aux:
assert self.segment_transform is not None
seg_label = self.segment_transform(label)
if self.img_transform is not None:
img = self.img_transform(img)
if self.use_aux:
return img, cls_label, seg_label
if self.load_name:
return img, cls_label, img_name
return img, cls_label
def __len__(self):
return len(self.list)
def _grid_pts(self, pts, num_cols, w):
# pts : numlane,n,2
num_lane, n, n2 = pts.shape
col_sample = np.linspace(0, w - 1, num_cols)
assert n2 == 2
to_pts = np.zeros((n, num_lane))
for i in range(num_lane):
pti = pts[i, :, 1]
to_pts[:, i] = np.asarray(
[int(pt // (col_sample[1] - col_sample[0])) if pt != -1 else num_cols for pt in pti])
return to_pts.astype(int)
def _get_index(self, label):
w, h = label.size
if h != 288:
scale_f = lambda x : int((x * 1.0/288) * h)
sample_tmp = list(map(scale_f,self.row_anchor))
all_idx = np.zeros((self.num_lanes,len(sample_tmp),2))
for i,r in enumerate(sample_tmp):
label_r = np.asarray(label)[int(round(r))]
for lane_idx in range(1, self.num_lanes + 1):
pos = np.where(label_r == lane_idx)[0]
if len(pos) == 0:
all_idx[lane_idx - 1, i, 0] = r
all_idx[lane_idx - 1, i, 1] = -1
continue
pos = np.mean(pos)
all_idx[lane_idx - 1, i, 0] = r
all_idx[lane_idx - 1, i, 1] = pos
# data augmentation: extend the lane to the boundary of image
all_idx_cp = all_idx.copy()
for i in range(self.num_lanes):
if np.all(all_idx_cp[i,:,1] == -1):
continue
# if there is no lane
valid = all_idx_cp[i,:,1] != -1
# get all valid lane points' index
valid_idx = all_idx_cp[i,valid,:]
# get all valid lane points
if valid_idx[-1,0] == all_idx_cp[0,-1,0]:
# if the last valid lane point's y-coordinate is already the last y-coordinate of all rows
# this means this lane has reached the bottom boundary of the image
# so we skip
continue
if len(valid_idx) < 6:
continue
# if the lane is too short to extend
valid_idx_half = valid_idx[len(valid_idx) // 2:,:]
p = np.polyfit(valid_idx_half[:,0], valid_idx_half[:,1],deg = 1)
start_line = valid_idx_half[-1,0]
pos = find_start_pos(all_idx_cp[i,:,0],start_line) + 1
fitted = np.polyval(p,all_idx_cp[i,pos:,0])
fitted = np.array([-1 if y < 0 or y > w-1 else y for y in fitted])
assert np.all(all_idx_cp[i,pos:,1] == -1)
all_idx_cp[i,pos:,1] = fitted
if -1 in all_idx[:, :, 0]:
pdb.set_trace()
return all_idx_cp