You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
167 lines
6.6 KiB
167 lines
6.6 KiB
import os, ntpath
|
|
import torch
|
|
from collections import OrderedDict
|
|
from util import util
|
|
from . import base_function
|
|
|
|
|
|
class BaseModel():
|
|
def __init__(self, opt):
|
|
self.opt = opt
|
|
self.gpu_ids = opt.gpu_ids
|
|
self.isTrain = opt.isTrain
|
|
self.pd_isTrain = opt.pd_isTrain
|
|
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
|
self.loss_names = []
|
|
self.model_names = []
|
|
self.visual_names = []
|
|
self.value_names = []
|
|
self.image_paths = []
|
|
self.optimizers = []
|
|
self.schedulers = []
|
|
|
|
def name(self):
|
|
return 'BaseModel'
|
|
|
|
@staticmethod
|
|
def modify_options(parser, is_train):
|
|
"""Add new options and rewrite default values for existing options"""
|
|
return parser
|
|
|
|
def set_input(self, input):
|
|
"""Unpack input data from the dataloader and perform necessary pre-processing steps"""
|
|
pass
|
|
|
|
def setup(self, opt):
|
|
"""Load networks, create schedulers"""
|
|
if self.isTrain:
|
|
self.schedulers = [base_function.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
|
if not self.isTrain or opt.continue_train:
|
|
self.load_networks(opt.which_iter)
|
|
|
|
def pd_setup(self, opt):
|
|
"""Load networks, create schedulers"""
|
|
if self.pd_isTrain:
|
|
self.schedulers = [base_function.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
|
if not self.pd_isTrain or opt.continue_train:
|
|
self.load_networks(opt.which_iter)
|
|
|
|
|
|
def eval(self):
|
|
"""Make models eval mode during test time"""
|
|
for name in self.model_names:
|
|
if isinstance(name, str):
|
|
net = getattr(self, 'net_' + name)
|
|
net.eval()
|
|
|
|
def get_image_paths(self):
|
|
"""Return image paths that are used to load current data"""
|
|
return self.image_paths
|
|
|
|
def update_learning_rate(self):
|
|
"""Update learning rate"""
|
|
for scheduler in self.schedulers:
|
|
scheduler.step()
|
|
lr = self.optimizers[0].param_groups[0]['lr']
|
|
print('learning rate=%.7f' % lr)
|
|
|
|
def get_current_errors(self):
|
|
"""Return training loss"""
|
|
errors_ret = OrderedDict()
|
|
for name in self.loss_names:
|
|
if isinstance(name, str):
|
|
errors_ret[name] = getattr(self, 'loss_' + name).item()
|
|
return errors_ret
|
|
|
|
def get_current_visuals(self):
|
|
"""Return visualization images"""
|
|
visual_ret = OrderedDict()
|
|
for name in self.visual_names:
|
|
if isinstance(name, str):
|
|
value = getattr(self, name)
|
|
if isinstance(value, list):
|
|
# visual multi-scale ouputs
|
|
# for i in range(len(value)):
|
|
# visual_ret[name + str(i)] = util.tensor2im(value[i].data)
|
|
visual_ret[name] = util.tensor2im(value[-1].data)
|
|
else:
|
|
visual_ret[name] = util.tensor2im(value.data)
|
|
return visual_ret
|
|
|
|
def get_current_dis(self):
|
|
"""Return the distribution of encoder features"""
|
|
dis_ret = OrderedDict()
|
|
value = getattr(self, 'distribution')
|
|
for i in range(1):
|
|
for j, name in enumerate(self.value_names):
|
|
if isinstance(name, str):
|
|
dis_ret[name+str(i)] =util.tensor2array(value[i][j].data)
|
|
|
|
return dis_ret
|
|
|
|
# save model
|
|
def save_networks(self, which_epoch):
|
|
"""Save all the networks to the disk"""
|
|
for name in self.model_names:
|
|
if isinstance(name, str):
|
|
save_filename = '%s_net_%s.pth' % (which_epoch, name)
|
|
save_path = os.path.join(self.save_dir, save_filename)
|
|
net = getattr(self, 'net_' + name)
|
|
torch.save(net.cpu().state_dict(), save_path)
|
|
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
|
net.cuda()
|
|
|
|
# load models
|
|
def load_networks(self, which_epoch):
|
|
"""Load all the networks from the disk"""
|
|
for name in self.model_names:
|
|
if isinstance(name, str):
|
|
filename = '%s_net_%s.pth' % (which_epoch, name)
|
|
path = os.path.join(self.save_dir, filename)
|
|
net = getattr(self, 'net_' + name)
|
|
try:
|
|
net.load_state_dict(torch.load(path))
|
|
except:
|
|
pretrained_dict = torch.load(path)
|
|
model_dict = net.state_dict()
|
|
try:
|
|
pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict}
|
|
net.load_state_dict(pretrained_dict)
|
|
print('Pretrained network %s has excessive layers; Only loading layers that are used' % name)
|
|
except:
|
|
print('Pretrained network %s has fewer layers; The following are not initialized:' % name)
|
|
not_initialized = set()
|
|
for k, v in pretrained_dict.items():
|
|
if v.size() == model_dict[k].size():
|
|
model_dict[k] = v
|
|
|
|
for k, v in model_dict.items():
|
|
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
|
|
not_initialized.add(k.split('.')[0])
|
|
print(sorted(not_initialized))
|
|
net.load_state_dict(model_dict)
|
|
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
|
net.cuda()
|
|
if not self.isTrain:
|
|
net.eval()
|
|
|
|
def save_results(self, save_data, score=None, data_name='none'):
|
|
"""Save the training or testing results to disk"""
|
|
img_paths = self.get_image_paths()
|
|
|
|
for i in range(save_data.size(0)):
|
|
print('process image ...... %s' % img_paths[i])
|
|
short_path = ntpath.basename(img_paths[i]) # get image path
|
|
name = os.path.splitext(short_path)[0]
|
|
if type(score) == type(None):
|
|
img_name = '%s_%s.png' % (name, data_name)
|
|
else:
|
|
# d_score = score[i].mean()
|
|
# img_name = '%s_%s_%s.png' % (name, data_name, str(round(d_score.item(), 3)))
|
|
img_name = '%s_%s_%s.png' % (name, data_name, str(score))
|
|
# save predicted image with discriminator score
|
|
util.mkdir(self.opt.results_dir)
|
|
img_path = os.path.join(self.opt.results_dir, img_name)
|
|
img_numpy = util.tensor2im(save_data[i].data)
|
|
util.save_image(img_numpy, img_path)
|