You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

44 lines
1.4 KiB

import torch
import os
from torch import nn
import numpy as np
import torch.nn.functional
from termcolor import colored
from .logger import get_logger
def save_model(net, optim, scheduler, recorder, is_best=False):
model_dir = os.path.join(recorder.work_dir, 'output')
os.system('mkdir -p {}'.format(model_dir))
epoch = recorder.epoch
ckpt_name = 'best' if is_best else epoch
torch.save({
'net': net.state_dict(),
'optim': optim.state_dict(),
'scheduler': scheduler.state_dict(),
'recorder': recorder.state_dict(),
'epoch': epoch
}, os.path.join(model_dir, '{}.pth'.format(ckpt_name)))
def load_network_specified(net, model_dir, logger=None):
pretrained_net = torch.load(model_dir)['net']
net_state = net.state_dict()
state = {}
for k, v in pretrained_net.items():
if k not in net_state.keys() or v.size() != net_state[k].size():
if logger:
logger.info('skip weights: ' + k)
continue
state[k] = v
net.load_state_dict(state, strict=False)
def load_network(net, model_dir, finetune_from=None, logger=None):
if finetune_from:
if logger:
logger.info('Finetune model from: ' + finetune_from)
load_network_specified(net, finetune_from, logger)
return
pretrained_model = torch.load(model_dir)
net.load_state_dict(pretrained_model['net'], strict=True)