You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

119 lines
4.3 KiB

4 years ago
import torch.nn as nn
import torch
import torch.nn.functional as F
from runner.logger import get_logger
from runner.registry import EVALUATOR
import json
import os
import cv2
from .lane import LaneEval
def split_path(path):
"""split path tree into list"""
folders = []
while True:
path, folder = os.path.split(path)
if folder != "":
folders.insert(0, folder)
else:
if path != "":
folders.insert(0, path)
break
return folders
@EVALUATOR.register_module
class Tusimple(nn.Module):
def __init__(self, cfg):
super(Tusimple, self).__init__()
self.cfg = cfg
exp_dir = os.path.join(self.cfg.work_dir, "output")
if not os.path.exists(exp_dir):
os.mkdir(exp_dir)
self.out_path = os.path.join(exp_dir, "coord_output")
if not os.path.exists(self.out_path):
os.mkdir(self.out_path)
self.dump_to_json = []
self.thresh = cfg.evaluator.thresh
self.logger = get_logger('resa')
if cfg.view:
self.view_dir = os.path.join(self.cfg.work_dir, 'vis')
if not os.path.exists(self.view_dir):
os.mkdir(self.view_dir)
def evaluate_pred(self, dataset, seg_pred, exist_pred, batch):
img_name = batch['meta']['img_name']
img_path = batch['meta']['full_img_path']
for b in range(len(seg_pred)):
seg = seg_pred[b]
exist = [1 if exist_pred[b, i] >
0.5 else 0 for i in range(self.cfg.num_classes-1)]
lane_coords = dataset.probmap2lane(seg, exist, thresh = self.thresh)
for i in range(len(lane_coords)):
lane_coords[i] = sorted(
lane_coords[i], key=lambda pair: pair[1])
path_tree = split_path(img_name[b])
save_dir, save_name = path_tree[-3:-1], path_tree[-1]
save_dir = os.path.join(self.out_path, *save_dir)
save_name = save_name[:-3] + "lines.txt"
save_name = os.path.join(save_dir, save_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
with open(save_name, "w") as f:
for l in lane_coords:
for (x, y) in l:
print("{} {}".format(x, y), end=" ", file=f)
print(file=f)
json_dict = {}
json_dict['lanes'] = []
json_dict['h_sample'] = []
json_dict['raw_file'] = os.path.join(*path_tree[-4:])
json_dict['run_time'] = 0
for l in lane_coords:
if len(l) == 0:
continue
json_dict['lanes'].append([])
for (x, y) in l:
json_dict['lanes'][-1].append(int(x))
for (x, y) in lane_coords[0]:
json_dict['h_sample'].append(y)
self.dump_to_json.append(json.dumps(json_dict))
if self.cfg.view:
with open("{}/coordinates.txt".format("E:/pythonProject/ENet_SAD"), "w", encoding="utf8") as fp:
for coord in lane_coords:
for x, y in coord:
if x >= 0 and y >= 0:
fp.write(str(x) + "," + str(y) + "\t")
img = cv2.imread(img_path[b])
new_img_name = img_name[b].replace('/', '_')
save_dir = os.path.join(self.view_dir, new_img_name)
dataset.view(img, lane_coords, save_dir)
def evaluate(self, dataset, output, batch):
seg_pred, exist_pred = output['seg'], output['exist']
seg_pred = F.softmax(seg_pred, dim=1)
seg_pred = seg_pred.detach().cpu().numpy()
exist_pred = exist_pred.detach().cpu().numpy()
self.evaluate_pred(dataset, seg_pred, exist_pred, batch)
def summarize(self):
best_acc = 0
output_file = os.path.join(self.out_path, 'predict_test.json')
with open(output_file, "w+") as f:
for line in self.dump_to_json:
print(line, end="\n", file=f)
eval_result, acc = LaneEval.bench_one_submit(output_file,
self.cfg.test_json_file)
self.logger.info(eval_result)
self.dump_to_json = []
best_acc = max(acc, best_acc)
return best_acc