diff --git a/utils/general.py b/utils/general.py index 239683d..8f29d1a 100755 --- a/utils/general.py +++ b/utils/general.py @@ -9,7 +9,7 @@ import logging from contextlib import contextmanager from copy import copy from pathlib import Path -from sys import platform +import platform import cv2 import matplotlib @@ -66,7 +66,7 @@ def get_latest_run(search_dir='./runs'): def check_git_status(): # Suggest 'git pull' if repo is out of date - if platform in ['linux', 'darwin'] and not os.path.isfile('/.dockerenv'): + if platform.system() in ['Linux', 'Darwin'] and not os.path.isfile('/.dockerenv'): s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8') if 'Your branch is behind' in s: print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n') @@ -126,7 +126,7 @@ def check_anchor_order(m): def check_file(file): - # Searches for file if not found locally + # Search for file if not found if os.path.isfile(file) or file == '': return file else: @@ -137,21 +137,25 @@ def check_file(file): def check_dataset(dict): # Download dataset if not found - train, val = os.path.abspath(dict['train']), os.path.abspath(dict['val']) # data paths - if not (os.path.exists(train) and os.path.exists(val)): - print('\nWARNING: Dataset not found, nonexistant paths: %s' % [train, val]) - if 'download' in dict: - s = dict['download'] - print('Attempting autodownload from: %s' % s) - if s.startswith('http') and s.endswith('.zip'): # URL - f = Path(s).name # filename - torch.hub.download_url_to_file(s, f) - r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) - else: # bash script - r = os.system(s) - print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value - else: - Exception('Dataset autodownload unavailable.') + val, s = dict.get('val'), dict.get('download') + if val and len(val): + val = [os.path.abspath(x) for x in (val if isinstance(val, list) else [val])] # val path + if not all(os.path.exists(x) for x in val): + print('\nWARNING: Dataset not found, nonexistant paths: %s' % [*val]) + if s and len(s): # download script + print('Attempting autodownload from: %s' % s) + if s.startswith('http') and s.endswith('.zip'): # URL + f = Path(s).name # filename + if platform.system() == 'Darwin': # avoid MacOS python requests certificate error + os.system('curl -L %s -o %s' % (s, f)) + else: + torch.hub.download_url_to_file(s, f) + r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip + else: # bash script + r = os.system(s) + print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value + else: + raise Exception('Dataset not found.') def make_divisible(x, divisor):