You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
2.6 KiB
66 lines
2.6 KiB
import torch, os, cv2
|
|
from model.model import parsingNet
|
|
from utils.common import merge_config
|
|
from utils.dist_utils import dist_print
|
|
import torch
|
|
import scipy.special, tqdm
|
|
import numpy as np
|
|
import torchvision.transforms as transforms
|
|
from data.dataset import LaneTestDataset
|
|
from data.constant import culane_row_anchor, tusimple_row_anchor
|
|
|
|
if __name__ == "__main__":
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
args, cfg = merge_config()
|
|
|
|
dist_print('start testing...')
|
|
assert cfg.backbone in ['18', '34', '50', '101', '152', '50next', '101next', '50wide', '101wide']
|
|
|
|
if cfg.dataset == 'CULane':
|
|
cls_num_per_lane = 18
|
|
elif cfg.dataset == 'Tusimple':
|
|
cls_num_per_lane = 56
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
net = parsingNet(pretrained=False, backbone=cfg.backbone, cls_dim=(cfg.griding_num + 1, cls_num_per_lane, 4),
|
|
use_aux=False).cuda() # we dont need auxiliary segmentation in testing
|
|
|
|
state_dict = torch.load(cfg.test_model, map_location='cpu')['model']
|
|
compatible_state_dict = {}
|
|
for k, v in state_dict.items():
|
|
if 'module.' in k:
|
|
compatible_state_dict[k[7:]] = v
|
|
else:
|
|
compatible_state_dict[k] = v
|
|
|
|
net.load_state_dict(compatible_state_dict, strict=False)
|
|
net.eval()
|
|
|
|
img_transforms = transforms.Compose([
|
|
transforms.Resize((288, 800)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
|
])
|
|
if cfg.dataset == 'CULane':
|
|
splits = ['test0_normal.txt', 'test1_crowd.txt', 'test2_hlight.txt', 'test3_shadow.txt', 'test4_noline.txt',
|
|
'test5_arrow.txt', 'test6_curve.txt', 'test7_cross.txt', 'test8_night.txt']
|
|
datasets = [LaneTestDataset(cfg.data_root, os.path.join(cfg.data_root, 'list/test_split/' + split),
|
|
img_transform=img_transforms) for split in splits]
|
|
img_w, img_h = 1640, 590
|
|
row_anchor = culane_row_anchor
|
|
elif cfg.dataset == 'Tusimple':
|
|
splits = ['predict.txt']
|
|
datasets = [LaneTestDataset(cfg.data_root, os.path.join(cfg.data_root, split), img_transform=img_transforms) for
|
|
split in splits]
|
|
img_w, img_h = 1280, 720
|
|
row_anchor = tusimple_row_anchor
|
|
else:
|
|
raise NotImplementedError
|
|
for split, dataset in zip(splits, datasets):
|
|
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
|
for i, data in enumerate(tqdm.tqdm(loader)):
|
|
########## Begin ##########
|
|
|
|
########## End ########## |