diff --git a/detect.py b/detect.py index a698429..3333d24 100644 --- a/detect.py +++ b/detect.py @@ -21,6 +21,8 @@ def detect(save_img=False): google_utils.attempt_download(weights) model = torch.load(weights, map_location=device)['model'] # torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning + # model.fuse() + model.to(device).eval() # Second-stage classifier classify = False @@ -29,12 +31,6 @@ def detect(save_img=False): modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights modelc.to(device).eval() - # Eval mode - model.to(device).eval() - - # Fuse Conv2d + BatchNorm2d layers - # model.fuse() - # Half precision half = half and device.type != 'cpu' # half precision only supported on CUDA if half: