import argparse import cv2 from model import SCNN from utils.prob2lines import getLane from utils.transforms import * test_size = (512, 288) net = SCNN(input_size=test_size, pretrained=False) mean = (0.3598, 0.3653, 0.3662) # CULane mean, std std = (0.2573, 0.2663, 0.2756) transform_img = Resize(test_size) transform_to_net = Compose(ToTensor(), Normalize(mean=mean, std=std)) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--img_path", '-i', type=str, default="demo/demo.jpg", help="Path to demo img") parser.add_argument("--weight_path", '-w', type=str, help="Path to model weights", default=r"exp0_best.pth") args = parser.parse_args() return args def main(): args = parse_args() img_path = args.img_path weight_path = args.weight_path img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = transform_img({'img': img})['img'] x = transform_to_net({'img': img})['img'] x.unsqueeze_(0) save_dict = torch.load(weight_path, map_location='cpu') net.load_state_dict(save_dict['net']) net.eval() seg_pred, exist_pred = net(x)[:2] seg_pred = seg_pred.detach().cpu().numpy() exist_pred = exist_pred.detach().cpu().numpy() seg_pred = seg_pred[0] exist = [1 if exist_pred[0, i] > 0.5 else 0 for i in range(4)] with open("./coordinates.txt", "w", encoding="utf8") as fp: for coordinates in getLane.prob2lines_CULane(seg_pred, exist): for x, y in coordinates: fp.write(str(x) + "," + str(y) + "\t") if __name__ == "__main__": main()