You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
95 lines
4.0 KiB
95 lines
4.0 KiB
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
import argparse
|
|
import os
|
|
import os.path as osp
|
|
import torch
|
|
import torch.distributed as dist
|
|
import sys
|
|
|
|
ROOT = os.getcwd()
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.append(str(ROOT))
|
|
|
|
from yolov6.core.engine import Trainer
|
|
from yolov6.utils.config import Config
|
|
from yolov6.utils.events import LOGGER, save_yaml
|
|
from yolov6.utils.envs import get_envs, select_device, set_random_seed
|
|
from yolov6.utils.general import increment_name
|
|
|
|
|
|
def get_args_parser(add_help=True):
|
|
parser = argparse.ArgumentParser(description='YOLOv6 PyTorch Training', add_help=add_help)
|
|
parser.add_argument('--data-path', default='./data/coco.yaml', type=str, help='path of dataset')
|
|
parser.add_argument('--conf-file', default='./configs/yolov6s.py', type=str, help='experiments description file')
|
|
parser.add_argument('--img-size', default=640, type=int, help='train, val image size (pixels)')
|
|
parser.add_argument('--batch-size', default=32, type=int, help='total batch size for all GPUs')
|
|
parser.add_argument('--epochs', default=400, type=int, help='number of total epochs to run')
|
|
parser.add_argument('--workers', default=8, type=int, help='number of data loading workers (default: 8)')
|
|
parser.add_argument('--device', default='0', type=str, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
|
parser.add_argument('--eval-interval', default=20, type=int, help='evaluate at every interval epochs')
|
|
parser.add_argument('--eval-final-only', action='store_true', help='only evaluate at the final epoch')
|
|
parser.add_argument('--heavy-eval-range', default=50, type=int,
|
|
help='evaluating every epoch for last such epochs (can be jointly used with --eval-interval)')
|
|
parser.add_argument('--check-images', action='store_true', help='check images when initializing datasets')
|
|
parser.add_argument('--check-labels', action='store_true', help='check label files when initializing datasets')
|
|
parser.add_argument('--output-dir', default='./runs/train', type=str, help='path to save outputs')
|
|
parser.add_argument('--name', default='exp', type=str, help='experiment name, saved to output_dir/name')
|
|
parser.add_argument('--dist_url', default='env://', type=str, help='url used to set up distributed training')
|
|
parser.add_argument('--gpu_count', type=int, default=0)
|
|
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter')
|
|
parser.add_argument('--resume', type=str, default=None, help='resume the corresponding ckpt')
|
|
|
|
return parser
|
|
|
|
|
|
def check_and_init(args):
|
|
'''check config files and device, and initialize '''
|
|
|
|
# check files
|
|
master_process = args.rank == 0 if args.world_size > 1 else args.rank == -1
|
|
args.save_dir = str(increment_name(osp.join(args.output_dir, args.name), master_process))
|
|
cfg = Config.fromfile(args.conf_file)
|
|
|
|
# check device
|
|
device = select_device(args.device)
|
|
|
|
# set random seed
|
|
set_random_seed(1+args.rank, deterministic=(args.rank == -1))
|
|
|
|
# save args
|
|
if master_process:
|
|
os.makedirs(args.save_dir)
|
|
save_yaml(vars(args), osp.join(args.save_dir, 'args.yaml'))
|
|
|
|
return cfg, device
|
|
|
|
|
|
def main(args):
|
|
'''main function of training'''
|
|
# Setup
|
|
args.rank, args.local_rank, args.world_size = get_envs()
|
|
LOGGER.info(f'training args are: {args}\n')
|
|
cfg, device = check_and_init(args)
|
|
|
|
if args.local_rank != -1: # if DDP mode
|
|
torch.cuda.set_device(args.local_rank)
|
|
device = torch.device('cuda', args.local_rank)
|
|
LOGGER.info('Initializing process group... ')
|
|
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", \
|
|
init_method=args.dist_url, rank=args.local_rank, world_size=args.world_size)
|
|
|
|
# Start
|
|
trainer = Trainer(args, cfg, device)
|
|
trainer.train()
|
|
|
|
# End
|
|
if args.world_size > 1 and args.rank == 0:
|
|
LOGGER.info('Destroying process group... ')
|
|
dist.destroy_process_group()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = get_args_parser().parse_args()
|
|
main(args)
|