diff --git a/test.py b/test.py index 1f937a7..aa5cde3 100644 --- a/test.py +++ b/test.py @@ -1,7 +1,6 @@ import argparse import json -import yaml from torch.utils.data import DataLoader from utils.datasets import * @@ -40,8 +39,9 @@ def test(data, if half: model.half() # to FP16 - if device.type != 'cpu' and torch.cuda.device_count() > 1: - model = nn.DataParallel(model) + # Multi-GPU disabled, incompatible with .half() + # if device.type != 'cpu' and torch.cuda.device_count() > 1: + # model = nn.DataParallel(model) else: # called by train.py training = True