Fix list paths (#721)

* Add list paths on check_dataset

* missing raise statement

* Update general.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/1/head
Hatovix 5 years ago committed by GitHub
parent 0892c44bc4
commit 56c2c344ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,7 +9,7 @@ import logging
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path
from sys import platform import platform
import cv2 import cv2
import matplotlib import matplotlib
@ -66,7 +66,7 @@ def get_latest_run(search_dir='./runs'):
def check_git_status(): def check_git_status():
# Suggest 'git pull' if repo is out of date # Suggest 'git pull' if repo is out of date
if platform in ['linux', 'darwin'] and not os.path.isfile('/.dockerenv'): if platform.system() in ['Linux', 'Darwin'] and not os.path.isfile('/.dockerenv'):
s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8') s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
if 'Your branch is behind' in s: if 'Your branch is behind' in s:
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n') print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
@ -126,7 +126,7 @@ def check_anchor_order(m):
def check_file(file): def check_file(file):
# Searches for file if not found locally # Search for file if not found
if os.path.isfile(file) or file == '': if os.path.isfile(file) or file == '':
return file return file
else: else:
@ -137,21 +137,25 @@ def check_file(file):
def check_dataset(dict): def check_dataset(dict):
# Download dataset if not found # Download dataset if not found
train, val = os.path.abspath(dict['train']), os.path.abspath(dict['val']) # data paths val, s = dict.get('val'), dict.get('download')
if not (os.path.exists(train) and os.path.exists(val)): if val and len(val):
print('\nWARNING: Dataset not found, nonexistant paths: %s' % [train, val]) val = [os.path.abspath(x) for x in (val if isinstance(val, list) else [val])] # val path
if 'download' in dict: if not all(os.path.exists(x) for x in val):
s = dict['download'] print('\nWARNING: Dataset not found, nonexistant paths: %s' % [*val])
if s and len(s): # download script
print('Attempting autodownload from: %s' % s) print('Attempting autodownload from: %s' % s)
if s.startswith('http') and s.endswith('.zip'): # URL if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename f = Path(s).name # filename
if platform.system() == 'Darwin': # avoid MacOS python requests certificate error
os.system('curl -L %s -o %s' % (s, f))
else:
torch.hub.download_url_to_file(s, f) torch.hub.download_url_to_file(s, f)
r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
else: # bash script else: # bash script
r = os.system(s) r = os.system(s)
print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
else: else:
Exception('Dataset autodownload unavailable.') raise Exception('Dataset not found.')
def make_divisible(x, divisor): def make_divisible(x, divisor):

Loading…
Cancel
Save