@ -20,15 +20,10 @@ except:
print ( ' Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex ' )
mixed_precision = False # not installed
wdir = ' weights ' + os . sep # weights dir
os . makedirs ( wdir , exist_ok = True )
last = wdir + ' last.pt '
best = wdir + ' best.pt '
results_file = ' results.txt '
# Hyperparameters
hyp = { ' lr0 ' : 0.01 , # initial learning rate (SGD=1E-2, Adam=1E-3)
' momentum ' : 0.937 , # SGD momentum
hyp = { ' optimizer ' : ' SGD ' , # ['adam', 'SGD', None] if none, default is SGD
' lr0 ' : 0.01 , # initial learning rate (SGD=1E-2, Adam=1E-3)
' momentum ' : 0.937 , # SGD momentum/Adam beta1
' weight_decay ' : 5e-4 , # optimizer weight decay
' giou ' : 0.05 , # giou loss gain
' cls ' : 0.58 , # cls loss gain
@ -45,21 +40,24 @@ hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
' translate ' : 0.0 , # image translation (+/- fraction)
' scale ' : 0.5 , # image scale (+/- gain)
' shear ' : 0.0 } # image shear (+/- deg)
print ( hyp )
# Overwrite hyp with hyp*.txt (optional)
f = glob . glob ( ' hyp*.txt ' )
if f :
print ( ' Using %s ' % f [ 0 ] )
for k , v in zip ( hyp . keys ( ) , np . loadtxt ( f [ 0 ] ) ) :
hyp [ k ] = v
# Print focal loss if gamma > 0
if hyp [ ' fl_gamma ' ] :
print ( ' Using FocalLoss(gamma= %g ) ' % hyp [ ' fl_gamma ' ] )
def train ( hyp ) :
print ( f ' Hyperparameters { hyp } ' )
log_dir = tb_writer . log_dir # run directory
wdir = str ( Path ( log_dir ) / ' weights ' ) + os . sep # weights directory
os . makedirs ( wdir , exist_ok = True )
last = wdir + ' last.pt '
best = wdir + ' best.pt '
results_file = log_dir + os . sep + ' results.txt '
# Save run settings
with open ( Path ( log_dir ) / ' hyp.yaml ' , ' w ' ) as f :
yaml . dump ( hyp , f , sort_keys = False )
with open ( Path ( log_dir ) / ' opt.yaml ' , ' w ' ) as f :
yaml . dump ( vars ( opt ) , f , sort_keys = False )
def train ( hyp ) :
epochs = opt . epochs # 300
batch_size = opt . batch_size # 64
weights = opt . weights # initial training weights
@ -70,14 +68,15 @@ def train(hyp):
data_dict = yaml . load ( f , Loader = yaml . FullLoader ) # model dict
train_path = data_dict [ ' train ' ]
test_path = data_dict [ ' val ' ]
nc = 1 if opt . single_cls else int ( data_dict [ ' nc ' ] ) # number of classes
nc , names = ( 1 , [ ' item ' ] ) if opt . single_cls else ( int ( data_dict [ ' nc ' ] ) , data_dict [ ' names ' ] ) # number classes, names
assert len ( names ) == nc , ' %g names found for nc= %g dataset in %s ' % ( len ( names ) , nc , opt . data ) # check
# Remove previous results
for f in glob . glob ( ' *_batch*.jpg ' ) + glob . glob ( results_file ) :
os . remove ( f )
# Create model
model = Model ( opt . cfg , nc = data_dict[ ' nc' ] ) . to ( device )
model = Model ( opt . cfg , nc = nc) . to ( device )
# Image sizes
gs = int ( max ( model . stride ) ) # grid size (max stride)
@ -97,15 +96,20 @@ def train(hyp):
else :
pg0 . append ( v ) # all else
optimizer = optim . Adam ( pg0 , lr = hyp [ ' lr0 ' ] ) if opt . adam else \
optim . SGD ( pg0 , lr = hyp [ ' lr0 ' ] , momentum = hyp [ ' momentum ' ] , nesterov = True )
if hyp [ ' optimizer ' ] == ' adam ' : # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
optimizer = optim . Adam ( pg0 , lr = hyp [ ' lr0 ' ] , betas = ( hyp [ ' momentum ' ] , 0.999 ) ) # adjust beta1 to momentum
else :
optimizer = optim . SGD ( pg0 , lr = hyp [ ' lr0 ' ] , momentum = hyp [ ' momentum ' ] , nesterov = True )
optimizer . add_param_group ( { ' params ' : pg1 , ' weight_decay ' : hyp [ ' weight_decay ' ] } ) # add pg1 with weight_decay
optimizer . add_param_group ( { ' params ' : pg2 } ) # add pg2 (biases)
print ( ' Optimizer groups: %g .bias, %g conv.weight, %g other ' % ( len ( pg2 ) , len ( pg1 ) , len ( pg0 ) ) )
del pg0 , pg1 , pg2
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x : ( ( ( 1 + math . cos ( x * math . pi / epochs ) ) / 2 ) * * 1.0 ) * 0.9 + 0.1 # cosine
scheduler = lr_scheduler . LambdaLR ( optimizer , lr_lambda = lf )
print ( ' Optimizer groups: %g .bias, %g conv.weight, %g other ' % ( len ( pg2 ) , len ( pg1 ) , len ( pg0 ) ) )
del pg0 , pg1 , pg2
# plot_lr_scheduler(optimizer, scheduler, epochs, save_dir=log_dir)
# Load Model
google_utils . attempt_download ( weights )
@ -147,12 +151,7 @@ def train(hyp):
if mixed_precision :
model , optimizer = amp . initialize ( model , optimizer , opt_level = ' O1 ' , verbosity = 0 )
scheduler . last_epoch = start_epoch - 1 # do not move
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
# plot_lr_scheduler(optimizer, scheduler, epochs)
# Initialize distributed training
# Distributed training
if device . type != ' cpu ' and torch . cuda . device_count ( ) > 1 and torch . distributed . is_available ( ) :
dist . init_process_group ( backend = ' nccl ' , # distributed backend
init_method = ' tcp://127.0.0.1:9999 ' , # init method
@ -165,6 +164,7 @@ def train(hyp):
dataloader , dataset = create_dataloader ( train_path , imgsz , batch_size , gs , opt ,
hyp = hyp , augment = True , cache = opt . cache_images , rect = opt . rect )
mlc = np . concatenate ( dataset . labels , 0 ) [ : , 0 ] . max ( ) # max label class
nb = len ( dataloader ) # number of batches
assert mlc < nc , ' Label class %g exceeds nc= %g in %s . Correct your labels or your model. ' % ( mlc , nc , opt . cfg )
# Testloader
@ -177,15 +177,15 @@ def train(hyp):
model . hyp = hyp # attach hyperparameters to model
model . gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
model . class_weights = labels_to_class_weights ( dataset . labels , nc ) . to ( device ) # attach class weights
model . names = data_dict[ ' names' ]
model . names = names
# Class frequency
labels = np . concatenate ( dataset . labels , 0 )
c = torch . tensor ( labels [ : , 0 ] ) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1.
# model._initialize_biases(cf.to(device))
plot_labels ( labels , save_dir = log_dir )
if tb_writer :
plot_labels ( labels )
tb_writer . add_histogram ( ' classes ' , c , 0 )
# Check anchors
@ -193,14 +193,14 @@ def train(hyp):
check_anchors ( dataset , model = model , thr = hyp [ ' anchor_t ' ] , imgsz = imgsz )
# Exponential moving average
ema = torch_utils . ModelEMA ( model )
ema = torch_utils . ModelEMA ( model , updates = start_epoch * nb / accumulate )
# Start training
t0 = time . time ( )
nb = len ( dataloader ) # number of batches
n_burn = max ( 3 * nb , 1e3 ) # burn-in iterations, max(3 epochs, 1k iterations)
nw = max ( 3 * nb , 1e3 ) # number of warmup iterations, max(3 epochs, 1k iterations)
maps = np . zeros ( nc ) # mAP per class
results = ( 0 , 0 , 0 , 0 , 0 , 0 , 0 ) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
scheduler . last_epoch = start_epoch - 1 # do not move
print ( ' Image sizes %g train, %g test ' % ( imgsz , imgsz_test ) )
print ( ' Using %g dataloader workers ' % dataloader . num_workers )
print ( ' Starting training for %g epochs... ' % epochs )
@ -225,9 +225,9 @@ def train(hyp):
ni = i + nb * epoch # number integrated batches (since train start)
imgs = imgs . to ( device ) . float ( ) / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
# Burn-in
if ni < = n _burn :
xi = [ 0 , n _burn ] # x interp
# Warmup
if ni < = n w :
xi = [ 0 , n w ] # x interp
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
accumulate = max ( 1 , np . interp ( ni , xi , [ 1 , nbs / batch_size ] ) . round ( ) )
for j , x in enumerate ( optimizer . param_groups ) :
@ -275,7 +275,7 @@ def train(hyp):
# Plot
if ni < 3 :
f = ' train_batch %g .jpg ' % ni # filename
f = str ( Path ( log_dir ) / ( ' train_batch %g .jpg ' % ni ) ) # filename
result = plot_images ( images = imgs , targets = targets , paths = paths , fname = f )
if tb_writer and result is not None :
tb_writer . add_image ( f , result , dataformats = ' HWC ' , global_step = epoch )
@ -296,7 +296,8 @@ def train(hyp):
save_json = final_epoch and opt . data . endswith ( os . sep + ' coco.yaml ' ) ,
model = ema . ema ,
single_cls = opt . single_cls ,
dataloader = testloader )
dataloader = testloader ,
save_dir = log_dir )
# Write
with open ( results_file , ' a ' ) as f :
@ -348,7 +349,7 @@ def train(hyp):
# Finish
if not opt . evolve :
plot_results ( ) # save as results.png
plot_results ( save_dir = log_dir ) # save as results.png
print ( ' %g epochs completed in %.3f hours. \n ' % ( epoch - start_epoch + 1 , ( time . time ( ) - t0 ) / 3600 ) )
dist . destroy_process_group ( ) if device . type != ' cpu ' and torch . cuda . device_count ( ) > 1 else None
torch . cuda . empty_cache ( )
@ -358,13 +359,15 @@ def train(hyp):
if __name__ == ' __main__ ' :
check_git_status ( )
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' --cfg ' , type = str , default = ' models/yolov5s.yaml ' , help = ' model.yaml path ' )
parser . add_argument ( ' --data ' , type = str , default = ' data/coco128.yaml ' , help = ' data.yaml path ' )
parser . add_argument ( ' --hyp ' , type = str , default = ' ' , help = ' hyp.yaml path (optional) ' )
parser . add_argument ( ' --epochs ' , type = int , default = 300 )
parser . add_argument ( ' --batch-size ' , type = int , default = 16 )
parser . add_argument ( ' --cfg ' , type = str , default = ' models/yolov5s.yaml ' , help = ' *.cfg path ' )
parser . add_argument ( ' --data ' , type = str , default = ' data/coco128.yaml ' , help = ' *.data path ' )
parser . add_argument ( ' --img-size ' , nargs = ' + ' , type = int , default = [ 640 , 640 ] , help = ' train,test sizes ' )
parser . add_argument ( ' --rect ' , action = ' store_true ' , help = ' rectangular training ' )
parser . add_argument ( ' --resume ' , action = ' store_true ' , help = ' resume training from last.pt ' )
parser . add_argument ( ' --resume ' , nargs = ' ? ' , const = ' get_last ' , default = False ,
help = ' resume from given path/to/last.pt, or most recent run if blank. ' )
parser . add_argument ( ' --nosave ' , action = ' store_true ' , help = ' only save final checkpoint ' )
parser . add_argument ( ' --notest ' , action = ' store_true ' , help = ' only test final epoch ' )
parser . add_argument ( ' --noautoanchor ' , action = ' store_true ' , help = ' disable autoanchor check ' )
@ -374,13 +377,17 @@ if __name__ == '__main__':
parser . add_argument ( ' --weights ' , type = str , default = ' ' , help = ' initial weights path ' )
parser . add_argument ( ' --name ' , default = ' ' , help = ' renames results.txt to results_name.txt if supplied ' )
parser . add_argument ( ' --device ' , default = ' ' , help = ' cuda device, i.e. 0 or 0,1,2,3 or cpu ' )
parser . add_argument ( ' --adam ' , action = ' store_true ' , help = ' use adam optimizer ' )
parser . add_argument ( ' --multi-scale ' , action = ' store_true ' , help = ' vary img-size +/- 50 % ' )
parser . add_argument ( ' --multi-scale ' , action = ' store_true ' , help = ' vary img-size +/- 50 %% ' )
parser . add_argument ( ' --single-cls ' , action = ' store_true ' , help = ' train as single-class dataset ' )
opt = parser . parse_args ( )
last = get_latest_run ( ) if opt . resume == ' get_last ' else opt . resume # resume from most recent run
if last and not opt . weights :
print ( f ' Resuming training from { last } ' )
opt . weights = last if opt . resume and not opt . weights else opt . weights
opt . cfg = check_file ( opt . cfg ) # check file
opt . data = check_file ( opt . data ) # check file
opt . hyp = check_file ( opt . hyp ) if opt . hyp else ' ' # check file
print ( opt )
opt . img_size . extend ( [ opt . img_size [ - 1 ] ] * ( 2 - len ( opt . img_size ) ) ) # extend to 2 sizes (train, test)
device = torch_utils . select_device ( opt . device , apex = mixed_precision , batch_size = opt . batch_size )
@ -389,8 +396,12 @@ if __name__ == '__main__':
# Train
if not opt . evolve :
tb_writer = SummaryWriter ( comment = opt . name )
print ( ' Start Tensorboard with " tensorboard --logdir=runs " , view at http://localhost:6006/ ' )
tb_writer = SummaryWriter ( log_dir = increment_dir ( ' runs/exp ' , opt . name ) )
if opt . hyp : # update hyps
with open ( opt . hyp ) as f :
hyp . update ( yaml . load ( f , Loader = yaml . FullLoader ) )
train ( hyp )
# Evolve hyperparameters (optional)