You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
223 lines
8.0 KiB
223 lines
8.0 KiB
import torch
|
|
import numpy as np
|
|
import os
|
|
import shutil
|
|
from enum import Enum
|
|
from typing import Any, Tuple, Union, Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import Module
|
|
from torch.optim import Optimizer
|
|
import torch
|
|
from torch import nn
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from typing import Union, Optional, Any, Tuple
|
|
|
|
def convert_rgb_to_y(img):
|
|
if type(img) == np.ndarray:
|
|
return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
|
|
elif type(img) == torch.Tensor:
|
|
if len(img.shape) == 4:
|
|
img = img.squeeze(0)
|
|
return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
|
|
else:
|
|
raise Exception('Unknown Type', type(img))
|
|
|
|
|
|
def convert_rgb_to_ycbcr(img):
|
|
if type(img) == np.ndarray:
|
|
y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
|
|
cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
|
|
cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
|
|
return np.array([y, cb, cr]).transpose([1, 2, 0])
|
|
elif type(img) == torch.Tensor:
|
|
if len(img.shape) == 4:
|
|
img = img.squeeze(0)
|
|
y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
|
|
cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
|
|
cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
|
|
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
|
|
else:
|
|
raise Exception('Unknown Type', type(img))
|
|
|
|
|
|
def convert_ycbcr_to_rgb(img):
|
|
if type(img) == np.ndarray:
|
|
r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
|
|
g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
|
|
b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
|
|
return np.array([r, g, b]).transpose([1, 2, 0])
|
|
elif type(img) == torch.Tensor:
|
|
if len(img.shape) == 4:
|
|
img = img.squeeze(0)
|
|
r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
|
|
g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
|
|
b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
|
|
return torch.cat([r, g, b], 0).permute(1, 2, 0)
|
|
else:
|
|
raise Exception('Unknown Type', type(img))
|
|
|
|
|
|
def calc_psnr(img1, img2):
|
|
return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
|
|
|
|
|
|
class AverageMeter(object):
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
|
|
def load_state_dict(
|
|
model: nn.Module,
|
|
model_weights_path: str,
|
|
ema_model: Optional[nn.Module] = None,
|
|
optimizer: Optional[Optimizer] = None,
|
|
scheduler: Optional[_LRScheduler] = None,
|
|
load_mode: Optional[str] = None,
|
|
) -> Union[Tuple[nn.Module, Optional[nn.Module], Any, Any, Any, Optional[Optimizer], Optional[_LRScheduler]],
|
|
Tuple[nn.Module, Any, Any, Any, Optional[Optimizer], Optional[_LRScheduler]],
|
|
nn.Module]:
|
|
# Load model weights
|
|
checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage)
|
|
|
|
if load_mode == "resume":
|
|
# Restore the parameters in the training node to this point
|
|
start_epoch = checkpoint["epoch"]
|
|
best_psnr = checkpoint["best_psnr"]
|
|
best_ssim = checkpoint["best_ssim"]
|
|
# Load model state dict. Extract the fitted model weights
|
|
model_state_dict = model.state_dict()
|
|
state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict.keys()}
|
|
# Overwrite the model weights to the current model (base model)
|
|
model_state_dict.update(state_dict)
|
|
model.load_state_dict(model_state_dict)
|
|
# Load the optimizer model
|
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
|
# Load the scheduler model
|
|
scheduler.load_state_dict(checkpoint["scheduler"])
|
|
|
|
if ema_model is not None:
|
|
# Load ema model state dict. Extract the fitted model weights
|
|
ema_model_state_dict = ema_model.state_dict()
|
|
ema_state_dict = {k: v for k, v in checkpoint["ema_state_dict"].items() if k in ema_model_state_dict.keys()}
|
|
# Overwrite the model weights to the current model (ema model)
|
|
ema_model_state_dict.update(ema_state_dict)
|
|
ema_model.load_state_dict(ema_model_state_dict)
|
|
return model, ema_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler
|
|
|
|
return model, start_epoch, best_psnr, best_ssim, optimizer, scheduler
|
|
else:
|
|
# Load model state dict. Extract the fitted model weights
|
|
model_state_dict = model.state_dict()
|
|
state_dict = {k: v for k, v in checkpoint["state_dict"].items() if
|
|
k in model_state_dict.keys() and v.size() == model_state_dict[k].size()}
|
|
# Overwrite the model weights to the current model
|
|
model_state_dict.update(state_dict)
|
|
model.load_state_dict(model_state_dict)
|
|
|
|
return model
|
|
|
|
|
|
def make_directory(dir_path: str) -> None:
|
|
if not os.path.exists(dir_path):
|
|
os.makedirs(dir_path)
|
|
|
|
|
|
def save_checkpoint(
|
|
state_dict: dict,
|
|
file_name: str,
|
|
samples_dir: str,
|
|
results_dir: str,
|
|
is_best: bool = False,
|
|
is_last: bool = False,
|
|
) -> None:
|
|
checkpoint_path = os.path.join(samples_dir, file_name)
|
|
torch.save(state_dict, checkpoint_path)
|
|
|
|
if is_best:
|
|
shutil.copyfile(checkpoint_path, os.path.join(results_dir, "LSRGAN_x2.pth.tar"))
|
|
if is_last:
|
|
shutil.copyfile(checkpoint_path, os.path.join(results_dir, "last.pth.tar"))
|
|
|
|
|
|
class Summary(Enum):
|
|
NONE = 0
|
|
AVERAGE = 1
|
|
SUM = 2
|
|
COUNT = 3
|
|
|
|
|
|
class AverageMeter(object):
|
|
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
|
|
self.name = name
|
|
self.fmt = fmt
|
|
self.summary_type = summary_type
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
def __str__(self):
|
|
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
|
|
return fmtstr.format(**self.__dict__)
|
|
|
|
def summary(self):
|
|
if self.summary_type is Summary.NONE:
|
|
fmtstr = ""
|
|
elif self.summary_type is Summary.AVERAGE:
|
|
fmtstr = "{name} {avg:.2f}"
|
|
elif self.summary_type is Summary.SUM:
|
|
fmtstr = "{name} {sum:.2f}"
|
|
elif self.summary_type is Summary.COUNT:
|
|
fmtstr = "{name} {count:.2f}"
|
|
else:
|
|
raise ValueError(f"Invalid summary type {self.summary_type}")
|
|
|
|
return fmtstr.format(**self.__dict__)
|
|
|
|
|
|
class ProgressMeter(object):
|
|
def __init__(self, num_batches, meters, prefix=""):
|
|
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
|
self.meters = meters
|
|
self.prefix = prefix
|
|
|
|
def display(self, batch):
|
|
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
|
entries += [str(meter) for meter in self.meters]
|
|
print("\t".join(entries))
|
|
|
|
def display_summary(self):
|
|
entries = [" *"]
|
|
entries += [meter.summary() for meter in self.meters]
|
|
print(" ".join(entries))
|
|
|
|
def _get_batch_fmtstr(self, num_batches):
|
|
num_digits = len(str(num_batches // 1))
|
|
fmt = "{:" + str(num_digits) + "d}"
|
|
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
|