You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
1.6 KiB
54 lines
1.6 KiB
3 years ago
|
import argparse
|
||
|
import cv2
|
||
|
from model import SCNN
|
||
|
from utils.prob2lines import getLane
|
||
|
from utils.transforms import *
|
||
|
|
||
|
test_size = (512, 288)
|
||
|
net = SCNN(input_size=test_size, pretrained=False)
|
||
|
mean = (0.3598, 0.3653, 0.3662) # CULane mean, std
|
||
|
std = (0.2573, 0.2663, 0.2756)
|
||
|
transform_img = Resize(test_size)
|
||
|
transform_to_net = Compose(ToTensor(), Normalize(mean=mean, std=std))
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--img_path", '-i', type=str, default="demo/demo.jpg", help="Path to demo img")
|
||
|
parser.add_argument("--weight_path", '-w', type=str, help="Path to model weights",
|
||
|
default=r"exp0_best.pth")
|
||
|
args = parser.parse_args()
|
||
|
return args
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args = parse_args()
|
||
|
img_path = args.img_path
|
||
|
weight_path = args.weight_path
|
||
|
|
||
|
img = cv2.imread(img_path)
|
||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
|
img = transform_img({'img': img})['img']
|
||
|
x = transform_to_net({'img': img})['img']
|
||
|
x.unsqueeze_(0)
|
||
|
|
||
|
save_dict = torch.load(weight_path, map_location='cpu')
|
||
|
net.load_state_dict(save_dict['net'])
|
||
|
net.eval()
|
||
|
|
||
|
seg_pred, exist_pred = net(x)[:2]
|
||
|
seg_pred = seg_pred.detach().cpu().numpy()
|
||
|
exist_pred = exist_pred.detach().cpu().numpy()
|
||
|
|
||
|
seg_pred = seg_pred[0]
|
||
|
exist = [1 if exist_pred[0, i] > 0.5 else 0 for i in range(4)]
|
||
|
|
||
|
with open("./coordinates.txt", "w", encoding="utf8") as fp:
|
||
|
for coordinates in getLane.prob2lines_CULane(seg_pred, exist):
|
||
|
for x, y in coordinates:
|
||
|
fp.write(str(x) + "," + str(y) + "\t")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|