You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
166 lines
5.5 KiB
166 lines
5.5 KiB
3 years ago
|
import torch
|
||
|
from PIL import Image
|
||
|
import os
|
||
|
import pdb
|
||
|
import numpy as np
|
||
|
import cv2
|
||
|
from data.mytransforms import find_start_pos
|
||
|
|
||
|
|
||
|
def loader_func(path):
|
||
|
return Image.open(path)
|
||
|
|
||
|
|
||
|
class LaneTestDataset(torch.utils.data.Dataset):
|
||
|
def __init__(self, path, list_path, img_transform=None):
|
||
|
super(LaneTestDataset, self).__init__()
|
||
|
self.path = path
|
||
|
self.img_transform = img_transform
|
||
|
with open(list_path, 'r') as f:
|
||
|
self.list = f.readlines()
|
||
|
self.list = [l[1:] if l[0] == '/' else l for l in self.list] # exclude the incorrect path prefix '/' of CULane
|
||
|
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
name = self.list[index].split()[0]
|
||
|
img_path = os.path.join(self.path, name)
|
||
|
img = loader_func(img_path)
|
||
|
|
||
|
if self.img_transform is not None:
|
||
|
img = self.img_transform(img)
|
||
|
|
||
|
return img, name
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.list)
|
||
|
|
||
|
|
||
|
class LaneClsDataset(torch.utils.data.Dataset):
|
||
|
def __init__(self, path, list_path, img_transform = None,target_transform = None,simu_transform = None, griding_num=50, load_name = False,
|
||
|
row_anchor = None,use_aux=False,segment_transform=None, num_lanes = 4):
|
||
|
super(LaneClsDataset, self).__init__()
|
||
|
self.img_transform = img_transform
|
||
|
self.target_transform = target_transform
|
||
|
self.segment_transform = segment_transform
|
||
|
self.simu_transform = simu_transform
|
||
|
self.path = path
|
||
|
self.griding_num = griding_num
|
||
|
self.load_name = load_name
|
||
|
self.use_aux = use_aux
|
||
|
self.num_lanes = num_lanes
|
||
|
|
||
|
with open(list_path, 'r') as f:
|
||
|
self.list = f.readlines()
|
||
|
|
||
|
self.row_anchor = row_anchor
|
||
|
self.row_anchor.sort()
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
l = self.list[index]
|
||
|
l_info = l.split()
|
||
|
img_name, label_name = l_info[0], l_info[1]
|
||
|
if img_name[0] == '/':
|
||
|
img_name = img_name[1:]
|
||
|
label_name = label_name[1:]
|
||
|
|
||
|
label_path = os.path.join(self.path, label_name)
|
||
|
label = loader_func(label_path)
|
||
|
|
||
|
img_path = os.path.join(self.path, img_name)
|
||
|
img = loader_func(img_path)
|
||
|
|
||
|
|
||
|
if self.simu_transform is not None:
|
||
|
img, label = self.simu_transform(img, label)
|
||
|
lane_pts = self._get_index(label)
|
||
|
# get the coordinates of lanes at row anchors
|
||
|
|
||
|
|
||
|
|
||
|
w, h = img.size
|
||
|
cls_label = self._grid_pts(lane_pts, self.griding_num, w)
|
||
|
# make the coordinates to classification label
|
||
|
if self.use_aux:
|
||
|
assert self.segment_transform is not None
|
||
|
seg_label = self.segment_transform(label)
|
||
|
|
||
|
if self.img_transform is not None:
|
||
|
img = self.img_transform(img)
|
||
|
|
||
|
if self.use_aux:
|
||
|
return img, cls_label, seg_label
|
||
|
if self.load_name:
|
||
|
return img, cls_label, img_name
|
||
|
return img, cls_label
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.list)
|
||
|
|
||
|
def _grid_pts(self, pts, num_cols, w):
|
||
|
# pts : numlane,n,2
|
||
|
num_lane, n, n2 = pts.shape
|
||
|
col_sample = np.linspace(0, w - 1, num_cols)
|
||
|
|
||
|
assert n2 == 2
|
||
|
to_pts = np.zeros((n, num_lane))
|
||
|
for i in range(num_lane):
|
||
|
pti = pts[i, :, 1]
|
||
|
to_pts[:, i] = np.asarray(
|
||
|
[int(pt // (col_sample[1] - col_sample[0])) if pt != -1 else num_cols for pt in pti])
|
||
|
return to_pts.astype(int)
|
||
|
|
||
|
def _get_index(self, label):
|
||
|
w, h = label.size
|
||
|
|
||
|
if h != 288:
|
||
|
scale_f = lambda x : int((x * 1.0/288) * h)
|
||
|
sample_tmp = list(map(scale_f,self.row_anchor))
|
||
|
|
||
|
all_idx = np.zeros((self.num_lanes,len(sample_tmp),2))
|
||
|
for i,r in enumerate(sample_tmp):
|
||
|
label_r = np.asarray(label)[int(round(r))]
|
||
|
for lane_idx in range(1, self.num_lanes + 1):
|
||
|
pos = np.where(label_r == lane_idx)[0]
|
||
|
if len(pos) == 0:
|
||
|
all_idx[lane_idx - 1, i, 0] = r
|
||
|
all_idx[lane_idx - 1, i, 1] = -1
|
||
|
continue
|
||
|
pos = np.mean(pos)
|
||
|
all_idx[lane_idx - 1, i, 0] = r
|
||
|
all_idx[lane_idx - 1, i, 1] = pos
|
||
|
|
||
|
# data augmentation: extend the lane to the boundary of image
|
||
|
|
||
|
all_idx_cp = all_idx.copy()
|
||
|
for i in range(self.num_lanes):
|
||
|
if np.all(all_idx_cp[i,:,1] == -1):
|
||
|
continue
|
||
|
# if there is no lane
|
||
|
|
||
|
valid = all_idx_cp[i,:,1] != -1
|
||
|
# get all valid lane points' index
|
||
|
valid_idx = all_idx_cp[i,valid,:]
|
||
|
# get all valid lane points
|
||
|
if valid_idx[-1,0] == all_idx_cp[0,-1,0]:
|
||
|
# if the last valid lane point's y-coordinate is already the last y-coordinate of all rows
|
||
|
# this means this lane has reached the bottom boundary of the image
|
||
|
# so we skip
|
||
|
continue
|
||
|
if len(valid_idx) < 6:
|
||
|
continue
|
||
|
# if the lane is too short to extend
|
||
|
|
||
|
valid_idx_half = valid_idx[len(valid_idx) // 2:,:]
|
||
|
p = np.polyfit(valid_idx_half[:,0], valid_idx_half[:,1],deg = 1)
|
||
|
start_line = valid_idx_half[-1,0]
|
||
|
pos = find_start_pos(all_idx_cp[i,:,0],start_line) + 1
|
||
|
|
||
|
fitted = np.polyval(p,all_idx_cp[i,pos:,0])
|
||
|
fitted = np.array([-1 if y < 0 or y > w-1 else y for y in fitted])
|
||
|
|
||
|
assert np.all(all_idx_cp[i,pos:,1] == -1)
|
||
|
all_idx_cp[i,pos:,1] = fitted
|
||
|
if -1 in all_idx[:, :, 0]:
|
||
|
pdb.set_trace()
|
||
|
return all_idx_cp
|