import os, ntpath import torch from collections import OrderedDict from util import util from . import base_function class BaseModel(): def __init__(self, opt): self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.pd_isTrain = opt.pd_isTrain self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) self.loss_names = [] self.model_names = [] self.visual_names = [] self.value_names = [] self.image_paths = [] self.optimizers = [] self.schedulers = [] def name(self): return 'BaseModel' @staticmethod def modify_options(parser, is_train): """Add new options and rewrite default values for existing options""" return parser def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps""" pass def setup(self, opt): """Load networks, create schedulers""" if self.isTrain: self.schedulers = [base_function.get_scheduler(optimizer, opt) for optimizer in self.optimizers] if not self.isTrain or opt.continue_train: self.load_networks(opt.which_iter) def pd_setup(self, opt): """Load networks, create schedulers""" if self.pd_isTrain: self.schedulers = [base_function.get_scheduler(optimizer, opt) for optimizer in self.optimizers] if not self.pd_isTrain or opt.continue_train: self.load_networks(opt.which_iter) def eval(self): """Make models eval mode during test time""" for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net_' + name) net.eval() def get_image_paths(self): """Return image paths that are used to load current data""" return self.image_paths def update_learning_rate(self): """Update learning rate""" for scheduler in self.schedulers: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] print('learning rate=%.7f' % lr) def get_current_errors(self): """Return training loss""" errors_ret = OrderedDict() for name in self.loss_names: if isinstance(name, str): errors_ret[name] = getattr(self, 'loss_' + name).item() return errors_ret def get_current_visuals(self): """Return visualization images""" visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str): value = getattr(self, name) if isinstance(value, list): # visual multi-scale ouputs # for i in range(len(value)): # visual_ret[name + str(i)] = util.tensor2im(value[i].data) visual_ret[name] = util.tensor2im(value[-1].data) else: visual_ret[name] = util.tensor2im(value.data) return visual_ret def get_current_dis(self): """Return the distribution of encoder features""" dis_ret = OrderedDict() value = getattr(self, 'distribution') for i in range(1): for j, name in enumerate(self.value_names): if isinstance(name, str): dis_ret[name+str(i)] =util.tensor2array(value[i][j].data) return dis_ret # save model def save_networks(self, which_epoch): """Save all the networks to the disk""" for name in self.model_names: if isinstance(name, str): save_filename = '%s_net_%s.pth' % (which_epoch, name) save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, 'net_' + name) torch.save(net.cpu().state_dict(), save_path) if len(self.gpu_ids) > 0 and torch.cuda.is_available(): net.cuda() # load models def load_networks(self, which_epoch): """Load all the networks from the disk""" for name in self.model_names: if isinstance(name, str): filename = '%s_net_%s.pth' % (which_epoch, name) path = os.path.join(self.save_dir, filename) net = getattr(self, 'net_' + name) try: net.load_state_dict(torch.load(path)) except: pretrained_dict = torch.load(path) model_dict = net.state_dict() try: pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict} net.load_state_dict(pretrained_dict) print('Pretrained network %s has excessive layers; Only loading layers that are used' % name) except: print('Pretrained network %s has fewer layers; The following are not initialized:' % name) not_initialized = set() for k, v in pretrained_dict.items(): if v.size() == model_dict[k].size(): model_dict[k] = v for k, v in model_dict.items(): if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): not_initialized.add(k.split('.')[0]) print(sorted(not_initialized)) net.load_state_dict(model_dict) if len(self.gpu_ids) > 0 and torch.cuda.is_available(): net.cuda() if not self.isTrain: net.eval() def save_results(self, save_data, score=None, data_name='none'): """Save the training or testing results to disk""" img_paths = self.get_image_paths() for i in range(save_data.size(0)): print('process image ...... %s' % img_paths[i]) short_path = ntpath.basename(img_paths[i]) # get image path name = os.path.splitext(short_path)[0] if type(score) == type(None): img_name = '%s_%s.png' % (name, data_name) else: # d_score = score[i].mean() # img_name = '%s_%s_%s.png' % (name, data_name, str(round(d_score.item(), 3))) img_name = '%s_%s_%s.png' % (name, data_name, str(score)) # save predicted image with discriminator score util.mkdir(self.opt.results_dir) img_path = os.path.join(self.opt.results_dir, img_name) img_numpy = util.tensor2im(save_data[i].data) util.save_image(img_numpy, img_path)