You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

223 lines
8.0 KiB

import torch
import numpy as np
import os
import shutil
from enum import Enum
from typing import Any, Tuple, Union, Optional
import torch
from torch import nn
from torch.nn import Module
from torch.optim import Optimizer
import torch
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from typing import Union, Optional, Any, Tuple
def convert_rgb_to_y(img):
if type(img) == np.ndarray:
return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
elif type(img) == torch.Tensor:
if len(img.shape) == 4:
img = img.squeeze(0)
return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
else:
raise Exception('Unknown Type', type(img))
def convert_rgb_to_ycbcr(img):
if type(img) == np.ndarray:
y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
return np.array([y, cb, cr]).transpose([1, 2, 0])
elif type(img) == torch.Tensor:
if len(img.shape) == 4:
img = img.squeeze(0)
y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
else:
raise Exception('Unknown Type', type(img))
def convert_ycbcr_to_rgb(img):
if type(img) == np.ndarray:
r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
return np.array([r, g, b]).transpose([1, 2, 0])
elif type(img) == torch.Tensor:
if len(img.shape) == 4:
img = img.squeeze(0)
r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
return torch.cat([r, g, b], 0).permute(1, 2, 0)
else:
raise Exception('Unknown Type', type(img))
def calc_psnr(img1, img2):
return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def load_state_dict(
model: nn.Module,
model_weights_path: str,
ema_model: Optional[nn.Module] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
load_mode: Optional[str] = None,
) -> Union[Tuple[nn.Module, Optional[nn.Module], Any, Any, Any, Optional[Optimizer], Optional[_LRScheduler]],
Tuple[nn.Module, Any, Any, Any, Optional[Optimizer], Optional[_LRScheduler]],
nn.Module]:
# Load model weights
checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage)
if load_mode == "resume":
# Restore the parameters in the training node to this point
start_epoch = checkpoint["epoch"]
best_psnr = checkpoint["best_psnr"]
best_ssim = checkpoint["best_ssim"]
# Load model state dict. Extract the fitted model weights
model_state_dict = model.state_dict()
state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict.keys()}
# Overwrite the model weights to the current model (base model)
model_state_dict.update(state_dict)
model.load_state_dict(model_state_dict)
# Load the optimizer model
optimizer.load_state_dict(checkpoint["optimizer"])
# Load the scheduler model
scheduler.load_state_dict(checkpoint["scheduler"])
if ema_model is not None:
# Load ema model state dict. Extract the fitted model weights
ema_model_state_dict = ema_model.state_dict()
ema_state_dict = {k: v for k, v in checkpoint["ema_state_dict"].items() if k in ema_model_state_dict.keys()}
# Overwrite the model weights to the current model (ema model)
ema_model_state_dict.update(ema_state_dict)
ema_model.load_state_dict(ema_model_state_dict)
return model, ema_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler
return model, start_epoch, best_psnr, best_ssim, optimizer, scheduler
else:
# Load model state dict. Extract the fitted model weights
model_state_dict = model.state_dict()
state_dict = {k: v for k, v in checkpoint["state_dict"].items() if
k in model_state_dict.keys() and v.size() == model_state_dict[k].size()}
# Overwrite the model weights to the current model
model_state_dict.update(state_dict)
model.load_state_dict(model_state_dict)
return model
def make_directory(dir_path: str) -> None:
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def save_checkpoint(
state_dict: dict,
file_name: str,
samples_dir: str,
results_dir: str,
is_best: bool = False,
is_last: bool = False,
) -> None:
checkpoint_path = os.path.join(samples_dir, file_name)
torch.save(state_dict, checkpoint_path)
if is_best:
shutil.copyfile(checkpoint_path, os.path.join(results_dir, "LSRGAN_x2.pth.tar"))
if is_last:
shutil.copyfile(checkpoint_path, os.path.join(results_dir, "last.pth.tar"))
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class AverageMeter(object):
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
def summary(self):
if self.summary_type is Summary.NONE:
fmtstr = ""
elif self.summary_type is Summary.AVERAGE:
fmtstr = "{name} {avg:.2f}"
elif self.summary_type is Summary.SUM:
fmtstr = "{name} {sum:.2f}"
elif self.summary_type is Summary.COUNT:
fmtstr = "{name} {count:.2f}"
else:
raise ValueError(f"Invalid summary type {self.summary_type}")
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print("\t".join(entries))
def display_summary(self):
entries = [" *"]
entries += [meter.summary() for meter in self.meters]
print(" ".join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"