You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
1.6 KiB

3 years ago
import argparse
import cv2
from model import SCNN
from utils.prob2lines import getLane
from utils.transforms import *
test_size = (512, 288)
net = SCNN(input_size=test_size, pretrained=False)
mean = (0.3598, 0.3653, 0.3662) # CULane mean, std
std = (0.2573, 0.2663, 0.2756)
transform_img = Resize(test_size)
transform_to_net = Compose(ToTensor(), Normalize(mean=mean, std=std))
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--img_path", '-i', type=str, default="demo/demo.jpg", help="Path to demo img")
parser.add_argument("--weight_path", '-w', type=str, help="Path to model weights",
default=r"exp0_best.pth")
args = parser.parse_args()
return args
def main():
args = parse_args()
img_path = args.img_path
weight_path = args.weight_path
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = transform_img({'img': img})['img']
x = transform_to_net({'img': img})['img']
x.unsqueeze_(0)
save_dict = torch.load(weight_path, map_location='cpu')
net.load_state_dict(save_dict['net'])
net.eval()
seg_pred, exist_pred = net(x)[:2]
seg_pred = seg_pred.detach().cpu().numpy()
exist_pred = exist_pred.detach().cpu().numpy()
seg_pred = seg_pred[0]
exist = [1 if exist_pred[0, i] > 0.5 else 0 for i in range(4)]
with open("./coordinates.txt", "w", encoding="utf8") as fp:
for coordinates in getLane.prob2lines_CULane(seg_pred, exist):
for x, y in coordinates:
fp.write(str(x) + "," + str(y) + "\t")
if __name__ == "__main__":
main()