You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
56 lines
1.7 KiB
56 lines
1.7 KiB
import argparse
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
import numpy as np
|
|
import PIL.Image as pil_image
|
|
|
|
from model_SRCNN import SRCNN
|
|
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr
|
|
|
|
def super_resolution(weights,image,in_scale=2):
|
|
weights_file = weights
|
|
image_file = image
|
|
scale = in_scale
|
|
|
|
cudnn.benchmark = True
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
|
|
model = SRCNN().to(device)
|
|
|
|
state_dict = model.state_dict()
|
|
for n, p in torch.load(weights_file, map_location=lambda storage, loc: storage).items():
|
|
if n in state_dict.keys():
|
|
state_dict[n].copy_(p)
|
|
else:
|
|
raise KeyError(n)
|
|
|
|
model.eval()
|
|
|
|
image = pil_image.open(image_file).convert('RGB')
|
|
|
|
image_width = (image.width // scale) * scale
|
|
image_height = (image.height // scale) * scale
|
|
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
|
|
image = image.resize((image.width * scale, image.height * scale), resample=pil_image.BICUBIC)
|
|
|
|
image = np.array(image).astype(np.float32)
|
|
ycbcr = convert_rgb_to_ycbcr(image)
|
|
|
|
y = ycbcr[..., 0]
|
|
y /= 255.
|
|
y = torch.from_numpy(y).to(device)
|
|
y = y.unsqueeze(0).unsqueeze(0)
|
|
|
|
with torch.no_grad():
|
|
preds = model(y).clamp(0.0, 1.0)
|
|
|
|
psnr = calc_psnr(y, preds)
|
|
print('PSNR: {:.2f}'.format(psnr))
|
|
|
|
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
|
|
|
|
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
|
|
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
|
|
output = pil_image.fromarray(output)
|
|
# output.save(image_file.replace('.', '_srcnn_x{}.'.format(scale)))
|
|
return output |