From 099e6f5ebd31416f33d047249382624ad5489550 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 12 Jun 2020 22:10:46 -0700 Subject: [PATCH] --img-size stride-multiple verification --- detect.py | 1 + test.py | 1 + train.py | 4 +--- utils/utils.py | 9 ++++++++- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/detect.py b/detect.py index 132d162..66f1522 100644 --- a/detect.py +++ b/detect.py @@ -156,6 +156,7 @@ if __name__ == '__main__': parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') parser.add_argument('--augment', action='store_true', help='augmented inference') opt = parser.parse_args() + opt.img_size = check_img_size(opt.img_size) print(opt) with torch.no_grad(): diff --git a/test.py b/test.py index 72cbff6..3b52be2 100644 --- a/test.py +++ b/test.py @@ -245,6 +245,7 @@ if __name__ == '__main__': parser.add_argument('--augment', action='store_true', help='augmented inference') parser.add_argument('--verbose', action='store_true', help='report mAP by class') opt = parser.parse_args() + opt.img_size = check_img_size(opt.img_size) opt.save_json = opt.save_json or opt.data.endswith('coco.yaml') opt.data = glob.glob('./**/' + opt.data, recursive=True)[0] # find file print(opt) diff --git a/train.py b/train.py index d89e011..8748677 100644 --- a/train.py +++ b/train.py @@ -80,9 +80,7 @@ def train(hyp): # Image sizes gs = int(max(model.stride)) # grid size (max stride) - if any(x % gs != 0 for x in opt.img_size): - print('WARNING: --img-size %g,%g must be multiple of %s max stride %g' % (*opt.img_size, opt.cfg, gs)) - imgsz, imgsz_test = [make_divisible(x, gs) for x in opt.img_size] # image sizes (train, test) + imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples # Optimizer nbs = 64 # nominal batch size diff --git a/utils/utils.py b/utils/utils.py index 967c5b5..624d06b 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -37,13 +37,20 @@ def init_seeds(seed=0): def check_git_status(): + # Suggest 'git pull' if repo is out of date if platform in ['linux', 'darwin']: - # Suggest 'git pull' if repo is out of date s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8') if 'Your branch is behind' in s: print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n') +def check_img_size(img_size, s=32): + # Verify img_size is a multiple of stride s + if img_size % s != 0: + print('WARNING: --img-size %g must be multiple of max stride %g' % (img_size, s)) + return make_divisible(img_size, s) # nearest gs-multiple + + def make_divisible(x, divisor): # Returns x evenly divisble by divisor return math.ceil(x / divisor) * divisor