import torch.nn as nn import torch import torch.nn.functional as F from runner.logger import get_logger from runner.registry import EVALUATOR import json import os import cv2 from .lane import LaneEval def split_path(path): """split path tree into list""" folders = [] while True: path, folder = os.path.split(path) if folder != "": folders.insert(0, folder) else: if path != "": folders.insert(0, path) break return folders @EVALUATOR.register_module class Tusimple(nn.Module): def __init__(self, cfg): super(Tusimple, self).__init__() self.cfg = cfg exp_dir = os.path.join(self.cfg.work_dir, "output") if not os.path.exists(exp_dir): os.mkdir(exp_dir) self.out_path = os.path.join(exp_dir, "coord_output") if not os.path.exists(self.out_path): os.mkdir(self.out_path) self.dump_to_json = [] self.thresh = cfg.evaluator.thresh self.logger = get_logger('resa') if cfg.view: self.view_dir = os.path.join(self.cfg.work_dir, 'vis') if not os.path.exists(self.view_dir): os.mkdir(self.view_dir) def evaluate_pred(self, dataset, seg_pred, exist_pred, batch): img_name = batch['meta']['img_name'] img_path = batch['meta']['full_img_path'] for b in range(len(seg_pred)): seg = seg_pred[b] exist = [1 if exist_pred[b, i] > 0.5 else 0 for i in range(self.cfg.num_classes-1)] lane_coords = dataset.probmap2lane(seg, exist, thresh = self.thresh) for i in range(len(lane_coords)): lane_coords[i] = sorted( lane_coords[i], key=lambda pair: pair[1]) path_tree = split_path(img_name[b]) save_dir, save_name = path_tree[-3:-1], path_tree[-1] save_dir = os.path.join(self.out_path, *save_dir) save_name = save_name[:-3] + "lines.txt" save_name = os.path.join(save_dir, save_name) if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) with open(save_name, "w") as f: for l in lane_coords: for (x, y) in l: print("{} {}".format(x, y), end=" ", file=f) print(file=f) json_dict = {} json_dict['lanes'] = [] json_dict['h_sample'] = [] json_dict['raw_file'] = os.path.join(*path_tree[-4:]) json_dict['run_time'] = 0 for l in lane_coords: if len(l) == 0: continue json_dict['lanes'].append([]) for (x, y) in l: json_dict['lanes'][-1].append(int(x)) for (x, y) in lane_coords[0]: json_dict['h_sample'].append(y) self.dump_to_json.append(json.dumps(json_dict)) if self.cfg.view: with open("{}/coordinates.txt".format("E:/pythonProject/ENet_SAD"), "w", encoding="utf8") as fp: for coord in lane_coords: for x, y in coord: if x >= 0 and y >= 0: fp.write(str(x) + "," + str(y) + "\t") img = cv2.imread(img_path[b]) new_img_name = img_name[b].replace('/', '_') save_dir = os.path.join(self.view_dir, new_img_name) dataset.view(img, lane_coords, save_dir) def evaluate(self, dataset, output, batch): seg_pred, exist_pred = output['seg'], output['exist'] seg_pred = F.softmax(seg_pred, dim=1) seg_pred = seg_pred.detach().cpu().numpy() exist_pred = exist_pred.detach().cpu().numpy() self.evaluate_pred(dataset, seg_pred, exist_pred, batch) def summarize(self): best_acc = 0 output_file = os.path.join(self.out_path, 'predict_test.json') with open(output_file, "w+") as f: for line in self.dump_to_json: print(line, end="\n", file=f) eval_result, acc = LaneEval.bench_one_submit(output_file, self.cfg.test_json_file) self.logger.info(eval_result) self.dump_to_json = [] best_acc = max(acc, best_acc) return best_acc