You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
44 lines
1.4 KiB
44 lines
1.4 KiB
4 years ago
|
import torch
|
||
|
import os
|
||
|
from torch import nn
|
||
|
import numpy as np
|
||
|
import torch.nn.functional
|
||
|
from termcolor import colored
|
||
|
from .logger import get_logger
|
||
|
|
||
|
def save_model(net, optim, scheduler, recorder, is_best=False):
|
||
|
model_dir = os.path.join(recorder.work_dir, 'output')
|
||
|
os.system('mkdir -p {}'.format(model_dir))
|
||
|
epoch = recorder.epoch
|
||
|
ckpt_name = 'best' if is_best else epoch
|
||
|
torch.save({
|
||
|
'net': net.state_dict(),
|
||
|
'optim': optim.state_dict(),
|
||
|
'scheduler': scheduler.state_dict(),
|
||
|
'recorder': recorder.state_dict(),
|
||
|
'epoch': epoch
|
||
|
}, os.path.join(model_dir, '{}.pth'.format(ckpt_name)))
|
||
|
|
||
|
|
||
|
def load_network_specified(net, model_dir, logger=None):
|
||
|
pretrained_net = torch.load(model_dir)['net']
|
||
|
net_state = net.state_dict()
|
||
|
state = {}
|
||
|
for k, v in pretrained_net.items():
|
||
|
if k not in net_state.keys() or v.size() != net_state[k].size():
|
||
|
if logger:
|
||
|
logger.info('skip weights: ' + k)
|
||
|
continue
|
||
|
state[k] = v
|
||
|
net.load_state_dict(state, strict=False)
|
||
|
|
||
|
|
||
|
def load_network(net, model_dir, finetune_from=None, logger=None):
|
||
|
if finetune_from:
|
||
|
if logger:
|
||
|
logger.info('Finetune model from: ' + finetune_from)
|
||
|
load_network_specified(net, finetune_from, logger)
|
||
|
return
|
||
|
pretrained_model = torch.load(model_dir)
|
||
|
net.load_state_dict(pretrained_model['net'], strict=True)
|