From 496ec33a33dd18de0b8dc14de1e571955acf5135 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 25 Jun 2020 19:19:15 -0700 Subject: [PATCH] Update detect.py Added some recent updates that were missing, and updated the filename with an if else. --- detect.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/detect.py b/detect.py index 9c4990e..07631dd 100644 --- a/detect.py +++ b/detect.py @@ -46,7 +46,7 @@ def detect(save_img=False): dataset = LoadImages(source, img_size=imgsz) # Get names and colors - names = model.names if hasattr(model, 'names') else model.modules.names + names = model.module.names if hasattr(model, 'module') else model.names colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))] # Run inference @@ -80,6 +80,7 @@ def detect(save_img=False): p, s, im0 = path, '', im0s save_path = str(Path(out) / Path(p).name) + txt_path = save_path[:save_path.rfind('.')] + ('_%g' % dataset.frame if dataset.mode == 'video' else '') s += '%gx%g ' % img.shape[2:] # print string gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] #  normalization gain whwh if det is not None and len(det): @@ -95,12 +96,8 @@ def detect(save_img=False): for *xyxy, conf, cls in det: if save_txt: # Write to file xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh - if dataset.frame == 0: - with open(save_path[:save_path.rfind('.')] + '.txt', 'a') as f: - f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format - else: - with open(save_path[:save_path.rfind('.')] + '_' + str(dataset.frame) + '.txt', 'a') as f: - f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format + with open(txt_path + '.txt', 'a') as f: + f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format if save_img or view_img: # Add bbox to image label = '%s %.2f' % (names[int(cls)], conf) @@ -160,3 +157,8 @@ if __name__ == '__main__': with torch.no_grad(): detect() + + # Update all models + # for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']: + # detect() + # create_pretrained(opt.weights, opt.weights)