You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
2.2 KiB
61 lines
2.2 KiB
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
import os
|
|
import shutil
|
|
import torch
|
|
import os.path as osp
|
|
from yolov6.utils.events import LOGGER
|
|
from yolov6.utils.torch_utils import fuse_model
|
|
|
|
|
|
def load_state_dict(weights, model, map_location=None):
|
|
"""Load weights from checkpoint file, only assign weights those layers' name and shape are match."""
|
|
ckpt = torch.load(weights, map_location=map_location)
|
|
state_dict = ckpt['model'].float().state_dict()
|
|
model_state_dict = model.state_dict()
|
|
state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
|
|
model.load_state_dict(state_dict, strict=False)
|
|
del ckpt, state_dict, model_state_dict
|
|
return model
|
|
|
|
|
|
def load_checkpoint(weights, map_location=None, inplace=True, fuse=True):
|
|
"""Load model from checkpoint file."""
|
|
LOGGER.info("Loading checkpoint from {}".format(weights))
|
|
ckpt = torch.load(weights, map_location=map_location) # load
|
|
model = ckpt['ema' if ckpt.get('ema') else 'model'].float()
|
|
if fuse:
|
|
LOGGER.info("\nFusing model...")
|
|
model = fuse_model(model).eval()
|
|
else:
|
|
model = model.eval()
|
|
return model
|
|
|
|
|
|
def save_checkpoint(ckpt, is_best, save_dir, model_name=""):
|
|
""" Save checkpoint to the disk."""
|
|
if not osp.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
filename = osp.join(save_dir, model_name + '.pt')
|
|
torch.save(ckpt, filename)
|
|
if is_best:
|
|
best_filename = osp.join(save_dir, 'best_ckpt.pt')
|
|
shutil.copyfile(filename, best_filename)
|
|
|
|
|
|
def strip_optimizer(ckpt_dir, epoch):
|
|
for s in ['best', 'last']:
|
|
ckpt_path = osp.join(ckpt_dir, '{}_ckpt.pt'.format(s))
|
|
if not osp.exists(ckpt_path):
|
|
continue
|
|
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
|
|
if ckpt.get('ema'):
|
|
ckpt['model'] = ckpt['ema'] # replace model with ema
|
|
for k in ['optimizer', 'ema', 'updates']: # keys
|
|
ckpt[k] = None
|
|
ckpt['epoch'] = epoch
|
|
ckpt['model'].half() # to FP16
|
|
for p in ckpt['model'].parameters():
|
|
p.requires_grad = False
|
|
torch.save(ckpt, ckpt_path)
|