|
After Width: | Height: | Size: 226 KiB |
|
After Width: | Height: | Size: 246 KiB |
|
After Width: | Height: | Size: 182 KiB |
|
After Width: | Height: | Size: 161 KiB |
|
After Width: | Height: | Size: 142 KiB |
@ -0,0 +1,83 @@
|
||||
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
import lsrgan_config
|
||||
import model_LSRGAN
|
||||
from utils import make_directory
|
||||
|
||||
def preprocess_image(image_path, device):
|
||||
# Load image using OpenCV
|
||||
image = cv2.imread(image_path)
|
||||
|
||||
# Check if the image was successfully loaded
|
||||
if image is None:
|
||||
raise ValueError(f"Failed to load image: {image_path}")
|
||||
|
||||
# Convert BGR image to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Convert image to a PyTorch tensor
|
||||
transform = transforms.ToTensor()
|
||||
tensor = transform(image)
|
||||
|
||||
# Add batch dimension and move to specified device
|
||||
tensor = tensor.unsqueeze(0).to(device)
|
||||
|
||||
return tensor
|
||||
|
||||
def main() -> None:
|
||||
# Initialize the super-resolution model
|
||||
msrn_model = model_LSRGAN.__dict__[lsrgan_config.g_arch_name](in_channels=lsrgan_config.in_channels,
|
||||
out_channels=lsrgan_config.out_channels,
|
||||
channels=lsrgan_config.channels,
|
||||
growth_channels=lsrgan_config.growth_channels,
|
||||
num_blocks=lsrgan_config.num_blocks)
|
||||
msrn_model = msrn_model.to(device=lsrgan_config.device)
|
||||
print(f"Build `{lsrgan_config.g_arch_name}` model successfully.")
|
||||
|
||||
# Load the super-resolution model weights
|
||||
checkpoint = torch.load(lsrgan_config.g_model_weights_path, map_location=lambda storage, loc: storage)
|
||||
msrn_model.load_state_dict(checkpoint["state_dict"])
|
||||
print(f"Load `{lsrgan_config.g_arch_name}` model weights "
|
||||
f"`{os.path.abspath(lsrgan_config.g_model_weights_path)}` successfully.")
|
||||
|
||||
# Create a folder of super-resolution experiment results
|
||||
make_directory(lsrgan_config.sr_dir)
|
||||
|
||||
# Start the verification mode of the model.
|
||||
msrn_model.eval()
|
||||
|
||||
# Process the specified image
|
||||
lr_image_path = "eye_lr.jpg"
|
||||
sr_image_path = os.path.join(lsrgan_config.sr_dir, "eye_LSRGAN_x2.jpg")
|
||||
|
||||
print(f"Processing `{os.path.abspath(lr_image_path)}`...")
|
||||
lr_tensor = preprocess_image(lr_image_path, lsrgan_config.device)
|
||||
|
||||
# Only reconstruct the Y channel image data.
|
||||
with torch.no_grad():
|
||||
sr_tensor = msrn_model(lr_tensor)
|
||||
|
||||
# Save image
|
||||
sr_image = sr_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
||||
sr_image = (sr_image * 255.0).clip(0, 255).astype("uint8")
|
||||
sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite(sr_image_path, sr_image)
|
||||
print(f"Super-resolved image saved to `{os.path.abspath(sr_image_path)}`")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,56 @@
|
||||
import argparse
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import numpy as np
|
||||
import PIL.Image as pil_image
|
||||
|
||||
from model_SRCNN import SRCNN
|
||||
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr
|
||||
|
||||
def super_resolution(weights,image,in_scale=2):
|
||||
weights_file = weights
|
||||
image_file = image
|
||||
scale = in_scale
|
||||
|
||||
cudnn.benchmark = True
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
model = SRCNN().to(device)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
for n, p in torch.load(weights_file, map_location=lambda storage, loc: storage).items():
|
||||
if n in state_dict.keys():
|
||||
state_dict[n].copy_(p)
|
||||
else:
|
||||
raise KeyError(n)
|
||||
|
||||
model.eval()
|
||||
|
||||
image = pil_image.open(image_file).convert('RGB')
|
||||
|
||||
image_width = (image.width // scale) * scale
|
||||
image_height = (image.height // scale) * scale
|
||||
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
|
||||
image = image.resize((image.width * scale, image.height * scale), resample=pil_image.BICUBIC)
|
||||
|
||||
image = np.array(image).astype(np.float32)
|
||||
ycbcr = convert_rgb_to_ycbcr(image)
|
||||
|
||||
y = ycbcr[..., 0]
|
||||
y /= 255.
|
||||
y = torch.from_numpy(y).to(device)
|
||||
y = y.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
preds = model(y).clamp(0.0, 1.0)
|
||||
|
||||
psnr = calc_psnr(y, preds)
|
||||
print('PSNR: {:.2f}'.format(psnr))
|
||||
|
||||
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
|
||||
|
||||
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
|
||||
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
|
||||
output = pil_image.fromarray(output)
|
||||
# output.save(image_file.replace('.', '_srcnn_x{}.'.format(scale)))
|
||||
return output
|
||||
|
After Width: | Height: | Size: 436 KiB |
|
After Width: | Height: | Size: 389 KiB |
|
After Width: | Height: | Size: 338 KiB |
|
After Width: | Height: | Size: 728 KiB |
|
After Width: | Height: | Size: 28 KiB |
@ -0,0 +1,99 @@
|
||||
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.backends import cudnn
|
||||
|
||||
# Random seed to maintain reproducible results
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
# Use GPU for training by default
|
||||
device = torch.device("cuda", 0)
|
||||
# Turning on when the image size does not change during training can speed up training
|
||||
cudnn.benchmark = True
|
||||
# When evaluating the performance of the SR model, whether to verify only the Y channel image data
|
||||
only_test_y_channel = True
|
||||
# Model architecture name
|
||||
d_arch_name = "discriminator"
|
||||
g_arch_name = "lsrgan_x2"
|
||||
# Model arch config
|
||||
in_channels = 3
|
||||
out_channels = 3
|
||||
channels = 64
|
||||
growth_channels = 32
|
||||
num_blocks = 23
|
||||
upscale_factor = 2
|
||||
# Current configuration parameter method
|
||||
mode = "test"
|
||||
# Experiment name, easy to save weights and log files
|
||||
exp_name = "LSRGAN_x2"
|
||||
|
||||
if mode == "train":
|
||||
# Dataset address
|
||||
train_gt_images_dir = f"./data/DIV2K/LSRGAN/train"
|
||||
|
||||
test_gt_images_dir = f"./data/Set5/GTmod12"
|
||||
test_lr_images_dir = f"./data/Set5/LRbicx{upscale_factor}"
|
||||
|
||||
gt_image_size = 128
|
||||
batch_size = 16
|
||||
num_workers = 4
|
||||
|
||||
# Load the address of the pretrained model
|
||||
pretrained_d_model_weights_path = ""
|
||||
pretrained_g_model_weights_path = ""
|
||||
|
||||
# Incremental training and migration training
|
||||
resume_d = ""
|
||||
resume_g = ""
|
||||
|
||||
# Total num epochs (500,000 iters)
|
||||
epochs = 512
|
||||
|
||||
# Feature extraction layer parameter configuration
|
||||
feature_model_extractor_node = "features.34"
|
||||
feature_model_normalize_mean = [0.485, 0.456, 0.406]
|
||||
feature_model_normalize_std = [0.229, 0.224, 0.225]
|
||||
|
||||
# Loss function weight
|
||||
pixel_weight = 0.01
|
||||
content_weight = 1.0
|
||||
adversarial_weight = 0.005
|
||||
|
||||
# Optimizer parameter
|
||||
model_lr = 1e-4
|
||||
model_betas = (0.9, 0.999)
|
||||
model_eps = 1e-8
|
||||
|
||||
# EMA parameter
|
||||
model_ema_decay = 0.99998
|
||||
|
||||
# Dynamically adjust the learning rate policy (200,000 iters)
|
||||
lr_scheduler_milestones = [int(epochs * 0.1), int(epochs * 0.2), int(epochs * 0.4), int(epochs * 0.6)]
|
||||
lr_scheduler_gamma = 0.5
|
||||
|
||||
# How many iterations to print the training result
|
||||
train_print_frequency = 100
|
||||
valid_print_frequency = 1
|
||||
|
||||
if mode == "test":
|
||||
# Test data address
|
||||
test_gt_images_dir = "data\\hr"
|
||||
test_lr_images_dir = f"data\\lr"
|
||||
sr_dir = f"data\\{exp_name}"
|
||||
|
||||
g_model_weights_path = "LSRGAN_x2.pth.tar"
|
||||
@ -0,0 +1,439 @@
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog, messagebox, simpledialog, Toplevel, Label, Button, Radiobutton, IntVar
|
||||
from PIL import Image, ImageTk, ImageEnhance, ImageFilter
|
||||
import os
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import numpy as np
|
||||
import PIL.Image as pil_image
|
||||
from model_SRCNN import SRCNN
|
||||
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr
|
||||
import cv2
|
||||
from restinal import *
|
||||
import random
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
import lsrgan_config
|
||||
import model_LSRGAN
|
||||
from utils import make_directory
|
||||
|
||||
class ImageProcessorApp:
|
||||
def __init__(self, root):
|
||||
self.root = root
|
||||
self.root.title("Image Processor")
|
||||
self.root.geometry("1000x800")
|
||||
|
||||
# 初始化图像显示区域
|
||||
self.original_image = None
|
||||
self.processed_image = None
|
||||
self.original_image_label = tk.Label(self.root)
|
||||
self.processed_image_label = tk.Label(self.root)
|
||||
|
||||
# 创建按钮框架
|
||||
button_frame = tk.Frame(self.root)
|
||||
button_frame.pack(side=tk.BOTTOM, pady=20)
|
||||
|
||||
# 创建按钮
|
||||
self.btn_open = tk.Button(button_frame, text="Open Image", command=self.open_image)
|
||||
self.btn_reset = tk.Button(button_frame, text="Reset", command=self.reset_images)
|
||||
self.btn_save = tk.Button(button_frame, text="Save Image", command=self.save_image)
|
||||
self.btn_enhance = tk.Button(button_frame, text="Enhance", command=self.show_custom_options2)
|
||||
self.btn_add_noise = tk.Button(button_frame, text="Add Noise", command=self.show_custom_options3)
|
||||
self.btn_edge_detection = tk.Button(button_frame, text="Edge Detection", command=self.show_custom_options)
|
||||
self.btn_sr = tk.Button(button_frame, text="Super Resolution", command=self.show_super_resolution_options)
|
||||
self.btn_restinal = tk.Button(button_frame, text="Detect Blood Vessels", command=self.restinal)
|
||||
self.btn_noise_reduction = tk.Button(button_frame, text="Noise Reduction", command=self.show_noise_reduction_options)
|
||||
|
||||
|
||||
# 布局按钮
|
||||
self.btn_open.grid(row=0, column=0, padx=10, pady=5)
|
||||
self.btn_reset.grid(row=0, column=1, padx=10, pady=5)
|
||||
self.btn_save.grid(row=0, column=2, padx=10, pady=5)
|
||||
self.btn_enhance.grid(row=0, column=3, padx=10, pady=5)
|
||||
self.btn_add_noise.grid(row=0, column=4, padx=10, pady=5)
|
||||
self.btn_edge_detection.grid(row=0, column=5, padx=10, pady=5)
|
||||
self.btn_sr.grid(row=0, column=6, padx=10, pady=5)
|
||||
self.btn_restinal.grid(row=0, column=7, padx=10, pady=5)
|
||||
self.btn_noise_reduction.grid(row=0, column=8, padx=10, pady=5)
|
||||
|
||||
|
||||
# 布局图像标签
|
||||
self.original_image_label.pack(side=tk.LEFT, padx=20, pady=20)
|
||||
self.processed_image_label.pack(side=tk.RIGHT, padx=20, pady=20)
|
||||
|
||||
def show_super_resolution_options(self):
|
||||
sr_window = Toplevel(self.root)
|
||||
sr_window.title("Super Resolution Options")
|
||||
|
||||
Label(sr_window, text="Choose a super resolution method:").pack(pady=10)
|
||||
|
||||
self.sr_method = IntVar()
|
||||
self.sr_method.set(1)
|
||||
|
||||
Radiobutton(sr_window, text="Bicubic", variable=self.sr_method, value=1).pack(anchor=tk.W)
|
||||
Radiobutton(sr_window, text="SRCNN", variable=self.sr_method, value=2).pack(anchor=tk.W)
|
||||
Radiobutton(sr_window, text="LSRGAN", variable=self.sr_method, value=3).pack(anchor=tk.W)
|
||||
|
||||
Button(sr_window, text="Apply", command=self.apply_super_resolution).pack(pady=20)
|
||||
|
||||
def apply_super_resolution(self):
|
||||
method = self.sr_method.get()
|
||||
if method == 1:
|
||||
self.upscale_image_with_bicubic()
|
||||
elif method == 2:
|
||||
self.super_resolution()
|
||||
elif method == 3:
|
||||
self.LSRGAN_SR()
|
||||
|
||||
def upscale_image_with_bicubic(self, scale_factor=2):
|
||||
if self.original_image:
|
||||
upscaled_image = self.processed_image.resize(
|
||||
(self.processed_image.width * scale_factor, self.processed_image.height * scale_factor),
|
||||
resample=Image.BICUBIC)
|
||||
self.processed_image = upscaled_image
|
||||
self.update_images()
|
||||
|
||||
def super_resolution(self, weights='SRCNN_x2.pth'):
|
||||
try:
|
||||
weights_file = weights
|
||||
image = self.processed_image
|
||||
scale = 2
|
||||
|
||||
cudnn.benchmark = True
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
model = SRCNN().to(device)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
for n, p in torch.load(weights_file, map_location=lambda storage, loc: storage).items():
|
||||
if n in state_dict.keys():
|
||||
state_dict[n].copy_(p)
|
||||
else:
|
||||
raise KeyError(n)
|
||||
|
||||
model.eval()
|
||||
|
||||
if not isinstance(image, pil_image.Image):
|
||||
raise ValueError("self.processed_image is not a valid PIL Image")
|
||||
|
||||
image_width = (image.width // scale) * scale
|
||||
image_height = (image.height // scale) * scale
|
||||
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
|
||||
image = image.resize((image.width * scale, image.height * scale), resample=pil_image.BICUBIC)
|
||||
|
||||
image = np.array(image).astype(np.float32)
|
||||
ycbcr = convert_rgb_to_ycbcr(image)
|
||||
|
||||
y = ycbcr[..., 0]
|
||||
y /= 255.
|
||||
y = torch.from_numpy(y).to(device)
|
||||
y = y.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
preds = model(y).clamp(0.0, 1.0)
|
||||
|
||||
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
|
||||
|
||||
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
|
||||
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
|
||||
output = pil_image.fromarray(output)
|
||||
|
||||
self.processed_image = output
|
||||
self.update_images()
|
||||
except Exception as e:
|
||||
messagebox.showerror("Error", f"Failed to process image: {str(e)}")
|
||||
|
||||
def open_image(self):
|
||||
filename = filedialog.askopenfilename(initialdir=os.getcwd(), title="Select Image File",
|
||||
filetypes=(("JPEG files", "*.jpg"), ("PNG files", "*.png"),
|
||||
("All files", "*.*")))
|
||||
if filename:
|
||||
try:
|
||||
self.original_image = Image.open(filename)
|
||||
self.processed_image = self.original_image.copy()
|
||||
self.update_images()
|
||||
except Exception as e:
|
||||
messagebox.showerror("Error", f"Failed to open image: {str(e)}")
|
||||
|
||||
def show_custom_options(self):
|
||||
edge_detection_window = Toplevel(self.root)
|
||||
edge_detection_window.title("Edge Detection Options")
|
||||
|
||||
Label(edge_detection_window, text="Choose an edge detection method:").pack(pady=10)
|
||||
|
||||
self.edge_detection_method = IntVar()
|
||||
self.edge_detection_method.set(1)
|
||||
|
||||
Radiobutton(edge_detection_window, text="Sobel", variable=self.edge_detection_method, value=1).pack(anchor=tk.W)
|
||||
Radiobutton(edge_detection_window, text="Robert", variable=self.edge_detection_method, value=2).pack(anchor=tk.W)
|
||||
Radiobutton(edge_detection_window, text="Laplacian", variable=self.edge_detection_method, value=3).pack(anchor=tk.W)
|
||||
|
||||
Button(edge_detection_window, text="Apply", command=self.apply_edge_detection).pack(pady=20)
|
||||
|
||||
def apply_edge_detection(self):
|
||||
method = self.edge_detection_method.get()
|
||||
if self.original_image:
|
||||
if method == 1:
|
||||
self.sobel_operator()
|
||||
elif method == 2:
|
||||
self.ropert_operator()
|
||||
elif method == 3:
|
||||
self.laplacian_operator()
|
||||
|
||||
def sobel_operator(self):
|
||||
if self.original_image:
|
||||
img_np = np.array(self.original_image)
|
||||
img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
||||
sobelx = cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3)
|
||||
sobely = cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)
|
||||
sobel_edges = cv2.magnitude(sobelx, sobely)
|
||||
sobel_edges = np.uint8(255 * sobel_edges / np.max(sobel_edges))
|
||||
self.processed_image = Image.fromarray(cv2.cvtColor(sobel_edges, cv2.COLOR_GRAY2RGB))
|
||||
self.update_images()
|
||||
|
||||
def ropert_operator(self):
|
||||
if self.original_image:
|
||||
img_np = np.array(self.original_image)
|
||||
img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
||||
kernelx = np.array([[-1, 0], [0, 1]], dtype=int)
|
||||
kernely = np.array([[0, -1], [1, 0]], dtype=int)
|
||||
x = cv2.filter2D(img_gray, cv2.CV_16S, kernelx)
|
||||
y = cv2.filter2D(img_gray, cv2.CV_16S, kernely)
|
||||
absX = cv2.convertScaleAbs(x)
|
||||
absY = cv2.convertScaleAbs(y)
|
||||
Roberts = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
|
||||
self.processed_image = Image.fromarray(cv2.cvtColor(Roberts, cv2.COLOR_GRAY2RGB))
|
||||
self.update_images()
|
||||
|
||||
def laplacian_operator(self):
|
||||
if self.original_image:
|
||||
img_np = np.array(self.original_image)
|
||||
img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
||||
grayImage = cv2.GaussianBlur(img_gray, (5, 5), 0, 0)
|
||||
dst = cv2.Laplacian(grayImage, cv2.CV_16S, ksize=3)
|
||||
Laplacian = cv2.convertScaleAbs(dst)
|
||||
self.processed_image = Image.fromarray(cv2.cvtColor(Laplacian, cv2.COLOR_GRAY2RGB))
|
||||
self.update_images()
|
||||
|
||||
def update_images(self):
|
||||
if self.original_image:
|
||||
self.original_image_tk = ImageTk.PhotoImage(self.original_image)
|
||||
self.processed_image_tk = ImageTk.PhotoImage(self.processed_image)
|
||||
self.original_image_label.config(image=self.original_image_tk)
|
||||
self.original_image_label.image = self.original_image_tk
|
||||
self.processed_image_label.config(image=self.processed_image_tk)
|
||||
self.processed_image_label.image = self.processed_image_tk
|
||||
|
||||
def show_custom_options2(self):
|
||||
enhance_window = Toplevel(self.root)
|
||||
enhance_window.title("Enhance Options")
|
||||
|
||||
Label(enhance_window, text="Choose an enhancement method:").pack(pady=10)
|
||||
|
||||
self.enhance_method = IntVar()
|
||||
self.enhance_method.set(1)
|
||||
|
||||
Radiobutton(enhance_window, text="Brightness Enhancement", variable=self.enhance_method, value=1).pack(anchor=tk.W)
|
||||
Radiobutton(enhance_window, text="Histogram Equalization", variable=self.enhance_method, value=2).pack(anchor=tk.W)
|
||||
|
||||
Button(enhance_window, text="Apply", command=self.apply_enhancement).pack(pady=20)
|
||||
|
||||
def apply_enhancement(self):
|
||||
method = self.enhance_method.get()
|
||||
if self.original_image:
|
||||
if method == 1:
|
||||
self.enhance_image()
|
||||
elif method == 2:
|
||||
self.balance_image()
|
||||
|
||||
def balance_image(self): # 直方图均衡化
|
||||
if self.original_image:
|
||||
img_cv = cv2.cvtColor(np.array(self.original_image), cv2.COLOR_RGB2BGR)
|
||||
img_gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
|
||||
equ = cv2.equalizeHist(img_gray)
|
||||
self.processed_image = Image.fromarray(cv2.cvtColor(equ, cv2.COLOR_BGR2RGB))
|
||||
self.update_images()
|
||||
|
||||
def enhance_image(self):
|
||||
if self.original_image:
|
||||
enhancer = ImageEnhance.Contrast(self.processed_image)
|
||||
self.processed_image = enhancer.enhance(1.5)
|
||||
self.update_images()
|
||||
|
||||
def show_custom_options3(self):
|
||||
noise_window = Toplevel(self.root)
|
||||
noise_window.title("Add Noise Options")
|
||||
|
||||
Label(noise_window, text="Choose a noise type:").pack(pady=10)
|
||||
|
||||
self.noise_type = IntVar()
|
||||
self.noise_type.set(1)
|
||||
|
||||
Radiobutton(noise_window, text="Gaussian Noise", variable=self.noise_type, value=1).pack(anchor=tk.W)
|
||||
Radiobutton(noise_window, text="Salt and Pepper Noise", variable=self.noise_type, value=2).pack(anchor=tk.W)
|
||||
|
||||
Button(noise_window, text="Apply", command=self.apply_noise).pack(pady=20)
|
||||
|
||||
def apply_noise(self):
|
||||
noise_type = self.noise_type.get()
|
||||
if noise_type == 1:
|
||||
self.add_Gaussian_noise()
|
||||
elif noise_type == 2:
|
||||
self.add_salt_pepper_noise()
|
||||
|
||||
def add_Gaussian_noise(self):
|
||||
if self.original_image:
|
||||
img_np = np.array(self.original_image)
|
||||
mean = 0
|
||||
var = 0.01
|
||||
sigma = var**0.5
|
||||
gauss = np.random.normal(mean, sigma, img_np.shape)
|
||||
noisy = img_np + gauss
|
||||
self.processed_image = Image.fromarray(np.clip(noisy, 0, 255).astype(np.uint8))
|
||||
self.update_images()
|
||||
|
||||
def add_salt_pepper_noise(self):
|
||||
if self.original_image:
|
||||
img_np = np.array(self.original_image)
|
||||
prob = 0.05
|
||||
noisy = img_np.copy()
|
||||
black = np.random.rand(*img_np.shape[:2]) < prob
|
||||
white = np.random.rand(*img_np.shape[:2]) < prob
|
||||
noisy[black] = 0
|
||||
noisy[white] = 255
|
||||
self.processed_image = Image.fromarray(noisy)
|
||||
self.update_images()
|
||||
|
||||
def show_noise_reduction_options(self):
|
||||
noise_reduction_window = Toplevel(self.root)
|
||||
noise_reduction_window.title("Noise Reduction Options")
|
||||
|
||||
Label(noise_reduction_window, text="Choose a noise reduction method:").pack(pady=10)
|
||||
|
||||
self.noise_reduction_method = IntVar()
|
||||
self.noise_reduction_method.set(1)
|
||||
|
||||
Radiobutton(noise_reduction_window, text="Mean Filter", variable=self.noise_reduction_method, value=1).pack(anchor=tk.W)
|
||||
Radiobutton(noise_reduction_window, text="Median Filter", variable=self.noise_reduction_method, value=2).pack(anchor=tk.W)
|
||||
Radiobutton(noise_reduction_window, text="Bilateral Filter", variable=self.noise_reduction_method, value=3).pack(anchor=tk.W)
|
||||
|
||||
Button(noise_reduction_window, text="Apply", command=self.apply_noise_reduction).pack(pady=20)
|
||||
|
||||
def apply_noise_reduction(self):
|
||||
method = self.noise_reduction_method.get()
|
||||
if self.original_image:
|
||||
img_np = np.array(self.original_image)
|
||||
if method == 1:
|
||||
result = cv2.blur(img_np, (5, 5))
|
||||
elif method == 2:
|
||||
result = cv2.medianBlur(img_np, 5)
|
||||
elif method == 3:
|
||||
result = cv2.bilateralFilter(img_np, 9, 75, 75)
|
||||
self.processed_image = Image.fromarray(result)
|
||||
self.update_images()
|
||||
|
||||
def edge_detection(self):
|
||||
if self.original_image:
|
||||
self.processed_image = self.processed_image.filter(ImageFilter.FIND_EDGES)
|
||||
self.update_images()
|
||||
|
||||
def reset_images(self):
|
||||
if self.original_image:
|
||||
self.processed_image = self.original_image.copy()
|
||||
self.update_images()
|
||||
|
||||
def save_image(self):
|
||||
if self.processed_image:
|
||||
try:
|
||||
filetypes = (("PNG files", "*.png"), ("JPEG files", "*.jpg"), ("All files", "*.*"))
|
||||
filepath = filedialog.asksaveasfilename(defaultextension=".png", filetypes=filetypes)
|
||||
if filepath:
|
||||
self.processed_image.save(filepath)
|
||||
messagebox.showinfo("Image Saved", f"Image saved successfully at {filepath}")
|
||||
except Exception as e:
|
||||
messagebox.showerror("Error", f"Failed to save image: {str(e)}")
|
||||
|
||||
def restinal(self):
|
||||
try:
|
||||
srcImg = np.array(self.processed_image)
|
||||
grayImg = cv2.cvtColor(srcImg, cv2.COLOR_RGB2GRAY)
|
||||
ret0, th0 = cv2.threshold(grayImg, 30, 255, cv2.THRESH_BINARY)
|
||||
mask = cv2.erode(th0, np.ones((7, 7), np.uint8))
|
||||
blurImg = cv2.GaussianBlur(grayImg, (5, 5), 0)
|
||||
heImg = cv2.equalizeHist(blurImg)
|
||||
clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(10, 10))
|
||||
claheImg = clahe.apply(blurImg)
|
||||
homoImg = homofilter(blurImg)
|
||||
preMFImg = adjust_gamma(claheImg, gamma=1.5)
|
||||
filters = build_filters2()
|
||||
gaussMFImg = process(preMFImg, filters)
|
||||
gaussMFImg_mask = pass_mask(mask, gaussMFImg)
|
||||
grayStretchImg = grayStretch(gaussMFImg_mask, m=30.0 / 255, e=8)
|
||||
ret1, th1 = cv2.threshold(grayStretchImg, 30, 255, cv2.THRESH_OTSU)
|
||||
predictImg = th1.copy()
|
||||
wtf = np.hstack([srcImg, cv2.cvtColor(predictImg, cv2.COLOR_GRAY2BGR)])
|
||||
wtf_pil = Image.fromarray(wtf)
|
||||
self.processed_image = wtf_pil
|
||||
self.update_images()
|
||||
except Exception as e:
|
||||
messagebox.showerror("Error", f"Failed to process image: {str(e)}")
|
||||
|
||||
def LSRGAN_SR(self):
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
msrn_model = model_LSRGAN.__dict__[lsrgan_config.g_arch_name](in_channels=lsrgan_config.in_channels,
|
||||
out_channels=lsrgan_config.out_channels,
|
||||
channels=lsrgan_config.channels,
|
||||
growth_channels=lsrgan_config.growth_channels,
|
||||
num_blocks=lsrgan_config.num_blocks)
|
||||
msrn_model = msrn_model.to(device=lsrgan_config.device)
|
||||
# print(f"Build `{lsrgan_config.g_arch_name}` model successfully.")
|
||||
|
||||
# Load the super-resolution model weights
|
||||
checkpoint = torch.load(lsrgan_config.g_model_weights_path, map_location=lambda storage, loc: storage)
|
||||
msrn_model.load_state_dict(checkpoint["state_dict"])
|
||||
#print(f"Load `{lsrgan_config.g_arch_name}` model weights "
|
||||
# f"`{os.path.abspath(lsrgan_config.g_model_weights_path)}` successfully.")
|
||||
|
||||
# Create a folder of super-resolution experiment results
|
||||
# make_directory(lsrgan_config.sr_dir)
|
||||
|
||||
# Start the verification mode of the model.
|
||||
msrn_model.eval()
|
||||
|
||||
# Process the specified image
|
||||
lr_image = self.processed_image
|
||||
lr_image = np.array(lr_image)
|
||||
#print(f"Processing `{os.path.abspath(lr_image_path)}`...")
|
||||
# Convert BGR image to RGB
|
||||
image = cv2.cvtColor(lr_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Convert image to a PyTorch tensor
|
||||
transform = transforms.ToTensor()
|
||||
tensor = transform(image)
|
||||
|
||||
# Add batch dimension and move to specified device
|
||||
tensor = tensor.unsqueeze(0).to(device)
|
||||
|
||||
lr_tensor = tensor
|
||||
|
||||
# Only reconstruct the Y channel image data.
|
||||
with torch.no_grad():
|
||||
sr_tensor = msrn_model(lr_tensor)
|
||||
|
||||
# Save image
|
||||
sr_image = sr_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
||||
sr_image = (sr_image * 255.0).clip(0, 255).astype("uint8")
|
||||
sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 将 numpy 数组转换为 PIL 图像
|
||||
pil_image = Image.fromarray(sr_image)
|
||||
self.processed_image = pil_image
|
||||
self.update_images()
|
||||
#print(f"Super-resolved image saved to `{os.path.abspath(sr_image_path)}`")
|
||||
|
||||
if __name__ == "__main__":
|
||||
root = tk.Tk()
|
||||
app = ImageProcessorApp(root)
|
||||
root.mainloop()
|
||||
@ -0,0 +1,32 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
|
||||
|
||||
def calculate_psnr_ssim_y_channel(original_image_path, reconstructed_image_path):
|
||||
# 读取原始图像和重建图像
|
||||
original_image = cv2.imread(original_image_path)
|
||||
reconstructed_image = cv2.imread(reconstructed_image_path)
|
||||
|
||||
# 将图像从BGR转换为YCbCr
|
||||
original_ycbcr = cv2.cvtColor(original_image, cv2.COLOR_BGR2YCR_CB)
|
||||
reconstructed_ycbcr = cv2.cvtColor(reconstructed_image, cv2.COLOR_BGR2YCR_CB)
|
||||
|
||||
# 提取Y通道
|
||||
original_y = original_ycbcr[:, :, 0]
|
||||
reconstructed_y = reconstructed_ycbcr[:, :, 0]
|
||||
|
||||
# 计算PSNR
|
||||
psnr = peak_signal_noise_ratio(original_y, reconstructed_y)
|
||||
|
||||
# 计算SSIM
|
||||
ssim = structural_similarity(original_y, reconstructed_y)
|
||||
|
||||
return psnr, ssim
|
||||
|
||||
# 示例路径
|
||||
original_image_path = 'eye_hr.png'
|
||||
reconstructed_image_path = 'eye_LSRGAN_x2.png'
|
||||
|
||||
psnr, ssim = calculate_psnr_ssim_y_channel(original_image_path, reconstructed_image_path)
|
||||
print(f'PSNR (Y channel): {psnr:.2f} dB')
|
||||
print(f'SSIM (Y channel): {ssim:.4f}')
|
||||
@ -0,0 +1,302 @@
|
||||
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
from torchvision.models.feature_extraction import create_feature_extractor
|
||||
|
||||
__all__ = [
|
||||
"Discriminator", "LSRGAN", "ContentLoss",
|
||||
"discriminator", "lsrgan_x2", "lsrgan_x4", "content_loss", "content_loss_for_vgg19_34","lsrgan_x8"
|
||||
|
||||
]
|
||||
|
||||
|
||||
class _ResidualDenseBlock(nn.Module):
|
||||
"""Achieves densely connected convolutional layers.
|
||||
`Densely Connected Convolutional Networks <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels in the input image.
|
||||
growth_channels (int): The number of channels that increase in each layer of convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, growth_channels: int) -> None:
|
||||
super(_ResidualDenseBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(channels + growth_channels * 0, growth_channels, (3, 3), (1, 1), (1, 1))
|
||||
self.conv2 = nn.Conv2d(channels + growth_channels * 1, growth_channels, (3, 3), (1, 1), (1, 1))
|
||||
self.conv3 = nn.Conv2d(channels + growth_channels * 2, growth_channels, (3, 3), (1, 1), (1, 1))
|
||||
self.conv4 = nn.Conv2d(channels + growth_channels * 3, growth_channels, (3, 3), (1, 1), (1, 1))
|
||||
self.conv5 = nn.Conv2d(channels + growth_channels * 4, channels, (3, 3), (1, 1), (1, 1))
|
||||
|
||||
self.leaky_relu = nn.LeakyReLU(0.2, True)
|
||||
self.identity = nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
identity = x
|
||||
|
||||
out1 = self.leaky_relu(self.conv1(x))
|
||||
out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
|
||||
out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
|
||||
out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
|
||||
out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
|
||||
out = torch.mul(out5, 0.2)
|
||||
out = torch.add(out, identity)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class _ResidualResidualDenseBlock(nn.Module):
|
||||
"""Multi-layer residual dense convolution block.
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels in the input image.
|
||||
growth_channels (int): The number of channels that increase in each layer of convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, growth_channels: int) -> None:
|
||||
super(_ResidualResidualDenseBlock, self).__init__()
|
||||
self.rdb1 = _ResidualDenseBlock(channels, growth_channels)
|
||||
self.rdb2 = _ResidualDenseBlock(channels, growth_channels)
|
||||
self.rdb3 = _ResidualDenseBlock(channels, growth_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.rdb1(x)
|
||||
out = self.rdb2(out)
|
||||
out = self.rdb3(out)
|
||||
out = torch.mul(out, 0.2)
|
||||
out = torch.add(out, identity)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(Discriminator, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
# input size. (3) x 128 x 128
|
||||
nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=True),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# state size. (64) x 64 x 64
|
||||
nn.Conv2d(64, 64, (4, 4), (2, 2), (1, 1), bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# state size. (128) x 32 x 32
|
||||
nn.Conv2d(128, 128, (4, 4), (2, 2), (1, 1), bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# state size. (256) x 16 x 16
|
||||
nn.Conv2d(256, 256, (4, 4), (2, 2), (1, 1), bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# state size. (512) x 8 x 8
|
||||
nn.Conv2d(512, 512, (4, 4), (2, 2), (1, 1), bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# state size. (512) x 4 x 4
|
||||
nn.Conv2d(512, 512, (4, 4), (2, 2), (1, 1), bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
)
|
||||
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 4 * 4, 100),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Linear(100, 1)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = self.features(x)
|
||||
out = torch.flatten(out, 1)
|
||||
out = self.classifier(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class LSRGAN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
channels: int = 64,
|
||||
growth_channels: int = 32,
|
||||
num_blocks: int = 23,
|
||||
upscale_factor: int = 4,
|
||||
) -> None:
|
||||
super(LSRGAN, self).__init__()
|
||||
# The first layer of convolutional layer.
|
||||
self.conv1 = nn.Conv2d(in_channels, channels, (3, 3), (1, 1), (1, 1))
|
||||
|
||||
# Feature extraction backbone network.
|
||||
trunk = []
|
||||
for _ in range(num_blocks):
|
||||
trunk.append(_ResidualResidualDenseBlock(channels, growth_channels))
|
||||
self.trunk = nn.Sequential(*trunk)
|
||||
|
||||
# After the feature extraction network, reconnect a layer of convolutional blocks.
|
||||
self.conv2 = nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1))
|
||||
|
||||
# Upsampling convolutional layer.
|
||||
upsampling = []
|
||||
for _ in range(int(math.log(upscale_factor, 2))):
|
||||
upsampling.append(nn.Upsample(scale_factor=2))
|
||||
upsampling.append(nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)))
|
||||
upsampling.append(nn.LeakyReLU(0.2, True))
|
||||
self.upsampling = nn.Sequential(*upsampling)
|
||||
|
||||
# Reconnect a layer of convolution block after upsampling.
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
)
|
||||
|
||||
# Output layer.
|
||||
self.conv4 = nn.Conv2d(channels, out_channels, (3, 3), (1, 1), (1, 1))
|
||||
|
||||
# Initialize model parameters
|
||||
self._initialize_weights()
|
||||
|
||||
# The model should be defined in the Torch.script method.
|
||||
def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out1 = self.conv1(x)
|
||||
out = self.trunk(out1)
|
||||
out2 = self.conv2(out)
|
||||
out = torch.add(out1, out2)
|
||||
out = self.upsampling(out)
|
||||
out = self.conv3(out)
|
||||
out = self.conv4(out)
|
||||
|
||||
out = torch.clamp_(out, 0.0, 1.0)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
def _initialize_weights(self) -> None:
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight)
|
||||
module.weight.data *= 0.1
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
|
||||
class ContentLoss(nn.Module):
|
||||
"""Constructs a content loss function based on the VGG19 network.
|
||||
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
|
||||
|
||||
Paper reference list:
|
||||
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
|
||||
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
|
||||
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_model_extractor_node: str,
|
||||
feature_model_normalize_mean: list,
|
||||
feature_model_normalize_std: list
|
||||
) -> None:
|
||||
super(ContentLoss, self).__init__()
|
||||
# Get the name of the specified feature extraction node
|
||||
self.feature_model_extractor_node = feature_model_extractor_node
|
||||
# Load the VGG19 model trained on the ImageNet dataset.
|
||||
model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)
|
||||
# Extract the thirty-fifth layer output in the VGG19 model as the content loss.
|
||||
self.feature_extractor = create_feature_extractor(model, [feature_model_extractor_node])
|
||||
# set to validation mode
|
||||
self.feature_extractor.eval()
|
||||
|
||||
# The preprocessing method of the input data.
|
||||
# This is the VGG model preprocessing method of the ImageNet dataset
|
||||
self.normalize = transforms.Normalize(feature_model_normalize_mean, feature_model_normalize_std)
|
||||
|
||||
# Freeze model parameters.
|
||||
for model_parameters in self.feature_extractor.parameters():
|
||||
model_parameters.requires_grad = False
|
||||
|
||||
def forward(self, sr_tensor: torch.Tensor, gt_tensor: torch.Tensor) -> torch.Tensor:
|
||||
# Standardized operations
|
||||
sr_tensor = self.normalize(sr_tensor)
|
||||
gt_tensor = self.normalize(gt_tensor)
|
||||
|
||||
sr_feature = self.feature_extractor(sr_tensor)[self.feature_model_extractor_node]
|
||||
gt_feature = self.feature_extractor(gt_tensor)[self.feature_model_extractor_node]
|
||||
|
||||
# Find the feature map difference between the two images
|
||||
content_loss = F.l1_loss(sr_feature, gt_feature)
|
||||
|
||||
return content_loss
|
||||
|
||||
|
||||
def discriminator() -> Discriminator:
|
||||
model = Discriminator()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def lsrgan_x2(**kwargs: Any) -> LSRGAN:
|
||||
model = LSRGAN(upscale_factor=2, **kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def lsrgan_x4(**kwargs: Any) -> LSRGAN:
|
||||
model = LSRGAN(upscale_factor=4, **kwargs)
|
||||
|
||||
return model
|
||||
|
||||
def lsrgan_x8(**kwargs: Any) -> LSRGAN:
|
||||
model = LSRGAN(upscale_factor=8, **kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def content_loss(feature_model_extractor_node,
|
||||
feature_model_normalize_mean,
|
||||
feature_model_normalize_std) -> ContentLoss:
|
||||
content_loss = ContentLoss(feature_model_extractor_node,
|
||||
feature_model_normalize_mean,
|
||||
feature_model_normalize_std)
|
||||
|
||||
return content_loss
|
||||
|
||||
|
||||
def content_loss_for_vgg19_34() -> ContentLoss:
|
||||
content_loss = ContentLoss(feature_model_extractor_node="features.34",
|
||||
feature_model_normalize_mean=[0.485, 0.456, 0.406],
|
||||
feature_model_normalize_std=[0.229, 0.224, 0.225])
|
||||
|
||||
return content_loss
|
||||
@ -0,0 +1,16 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
class SRCNN(nn.Module):
|
||||
def __init__(self, num_channels=1):
|
||||
super(SRCNN, self).__init__()
|
||||
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
|
||||
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
|
||||
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.conv1(x))
|
||||
x = self.relu(self.conv2(x))
|
||||
x = self.conv3(x)
|
||||
return x
|
||||
@ -0,0 +1,200 @@
|
||||
Epoch 0. Training loss: 0.14875313472002744 Average PSNR: 13.378548447840299 dB.
|
||||
Epoch 1. Training loss: 0.015275207967497408 Average PSNR: 17.42924177677117 dB.
|
||||
Epoch 2. Training loss: 0.011266281926073134 Average PSNR: 18.757091258443367 dB.
|
||||
Epoch 3. Training loss: 0.007268386427313089 Average PSNR: 21.45991297849006 dB.
|
||||
Epoch 4. Training loss: 0.003995528084924444 Average PSNR: 23.848708920906056 dB.
|
||||
Epoch 5. Training loss: 0.00243228601408191 Average PSNR: 25.814329549022027 dB.
|
||||
Epoch 6. Training loss: 0.0015943490510107949 Average PSNR: 27.529330365919293 dB.
|
||||
Epoch 7. Training loss: 0.00112663698149845 Average PSNR: 28.973019031043023 dB.
|
||||
Epoch 8. Training loss: 0.0008293315611081198 Average PSNR: 30.121271320226853 dB.
|
||||
Epoch 9. Training loss: 0.0006575708842137828 Average PSNR: 31.04464052247794 dB.
|
||||
Epoch 10. Training loss: 0.0005367766652489081 Average PSNR: 31.820446047979445 dB.
|
||||
Epoch 11. Training loss: 0.0004539210665097926 Average PSNR: 32.491218057269165 dB.
|
||||
Epoch 12. Training loss: 0.00039154114405391737 Average PSNR: 33.05704847426454 dB.
|
||||
Epoch 13. Training loss: 0.00034473506471840667 Average PSNR: 33.58429206508774 dB.
|
||||
Epoch 14. Training loss: 0.00030716530316567516 Average PSNR: 34.06440795826225 dB.
|
||||
Epoch 15. Training loss: 0.0002766982484899927 Average PSNR: 34.51910808772584 dB.
|
||||
Epoch 16. Training loss: 0.00024886662838980556 Average PSNR: 34.9183669618796 dB.
|
||||
Epoch 17. Training loss: 0.0002273252428130945 Average PSNR: 35.32857831324897 dB.
|
||||
Epoch 18. Training loss: 0.00020853087546129245 Average PSNR: 35.67967789504088 dB.
|
||||
Epoch 19. Training loss: 0.00019246405652666 Average PSNR: 36.02033897287565 dB.
|
||||
Epoch 20. Training loss: 0.00017955698385776486 Average PSNR: 36.30282652504606 dB.
|
||||
Epoch 21. Training loss: 0.00016742635118134784 Average PSNR: 36.61986081553303 dB.
|
||||
Epoch 22. Training loss: 0.00015529726297245361 Average PSNR: 36.932148671635645 dB.
|
||||
Epoch 23. Training loss: 0.00014399774467165115 Average PSNR: 37.365006814093846 dB.
|
||||
Epoch 24. Training loss: 0.00013206048992287834 Average PSNR: 37.65861852072448 dB.
|
||||
Epoch 25. Training loss: 0.00012246826609043637 Average PSNR: 37.97759224613426 dB.
|
||||
Epoch 26. Training loss: 0.00011503802044899203 Average PSNR: 38.234452682907644 dB.
|
||||
Epoch 27. Training loss: 0.00010774759506603004 Average PSNR: 38.52051534464395 dB.
|
||||
Epoch 28. Training loss: 0.00010174096278205979 Average PSNR: 38.76806032187628 dB.
|
||||
Epoch 29. Training loss: 9.698616515379398e-05 Average PSNR: 38.96585890558862 dB.
|
||||
Epoch 30. Training loss: 9.265545588277746e-05 Average PSNR: 39.132399583419286 dB.
|
||||
Epoch 31. Training loss: 8.88960807787953e-05 Average PSNR: 39.32665802812593 dB.
|
||||
Epoch 32. Training loss: 8.514736950019142e-05 Average PSNR: 39.485947129970526 dB.
|
||||
Epoch 33. Training loss: 8.228369857533835e-05 Average PSNR: 39.657517115527824 dB.
|
||||
Epoch 34. Training loss: 7.935761332191759e-05 Average PSNR: 39.78166850360353 dB.
|
||||
Epoch 35. Training loss: 7.67016623285599e-05 Average PSNR: 39.955637582206506 dB.
|
||||
Epoch 36. Training loss: 7.411128972307779e-05 Average PSNR: 40.03724905997241 dB.
|
||||
Epoch 37. Training loss: 7.20307761366712e-05 Average PSNR: 40.27968283854483 dB.
|
||||
Epoch 38. Training loss: 6.979531452088849e-05 Average PSNR: 40.453226841911 dB.
|
||||
Epoch 39. Training loss: 6.788882783439476e-05 Average PSNR: 40.4245090632741 dB. 32.6363/0.6468
|
||||
Epoch 40. Training loss: 6.613676512642996e-05 Average PSNR: 40.41326497885555 dB.
|
||||
Epoch 41. Training loss: 6.460600769059966e-05 Average PSNR: 40.807618644183165 dB.
|
||||
Epoch 42. Training loss: 6.28032401073142e-05 Average PSNR: 40.85120959811469 dB.
|
||||
Epoch 43. Training loss: 6.187801533087622e-05 Average PSNR: 41.01150926102096 dB.
|
||||
Epoch 44. Training loss: 6.058884808226139e-05 Average PSNR: 41.05180893277702 dB.
|
||||
Epoch 45. Training loss: 5.893255285627674e-05 Average PSNR: 41.168570742704055 dB.
|
||||
Epoch 46. Training loss: 5.732719488150906e-05 Average PSNR: 41.240308580465154 dB.
|
||||
Epoch 47. Training loss: 5.702058719180059e-05 Average PSNR: 41.3511633017136 dB.
|
||||
Epoch 48. Training loss: 5.561599014981766e-05 Average PSNR: 41.45872854749355 dB.
|
||||
Epoch 49. Training loss: 5.5232235354196746e-05 Average PSNR: 41.40003123922509 dB.
|
||||
Epoch 50. Training loss: 5.44620567234233e-05 Average PSNR: 41.14656813110207 dB.
|
||||
Epoch 51. Training loss: 5.3858329320064516e-05 Average PSNR: 41.412219824049416 dB.
|
||||
Epoch 52. Training loss: 5.249960475339322e-05 Average PSNR: 41.06763311036325 dB.
|
||||
Epoch 53. Training loss: 5.251374377621687e-05 Average PSNR: 41.837420965513424 dB.
|
||||
Epoch 54. Training loss: 5.06843776020105e-05 Average PSNR: 41.77355106160621 dB.
|
||||
Epoch 55. Training loss: 4.994360175260226e-05 Average PSNR: 41.98269945660057 dB.
|
||||
Epoch 56. Training loss: 5.027707884437404e-05 Average PSNR: 41.9793556863657 dB.
|
||||
Epoch 57. Training loss: 4.9384301437385144e-05 Average PSNR: 42.03275814844621 dB.
|
||||
Epoch 58. Training loss: 4.956375089022913e-05 Average PSNR: 42.06607378817673 dB.
|
||||
Epoch 59. Training loss: 4.793416825123131e-05 Average PSNR: 42.188537812402394 dB.
|
||||
Epoch 60. Training loss: 4.764747898661881e-05 Average PSNR: 42.144209742139125 dB.
|
||||
Epoch 61. Training loss: 4.712960180768278e-05 Average PSNR: 42.25111711841055 dB.
|
||||
Epoch 62. Training loss: 4.628681665053591e-05 Average PSNR: 41.783143717871646 dB.
|
||||
Epoch 63. Training loss: 4.6740010584471745e-05 Average PSNR: 42.2285704715098 dB.
|
||||
Epoch 64. Training loss: 4.535250043772976e-05 Average PSNR: 42.27144939622808 dB.
|
||||
Epoch 65. Training loss: 4.527481007244205e-05 Average PSNR: 42.311776678022156 dB.
|
||||
Epoch 66. Training loss: 4.5228196504467633e-05 Average PSNR: 41.49500222336806 dB.
|
||||
Epoch 67. Training loss: 4.637343330614385e-05 Average PSNR: 42.45478651917092 dB.
|
||||
Epoch 68. Training loss: 4.4330210948828606e-05 Average PSNR: 42.34659372619247 dB.
|
||||
Epoch 69. Training loss: 4.575668157485779e-05 Average PSNR: 42.50239584981128 dB.
|
||||
Epoch 70. Training loss: 4.334861048846506e-05 Average PSNR: 42.51811162030329 dB.
|
||||
Epoch 71. Training loss: 4.316709642807837e-05 Average PSNR: 42.56836168218139 dB.
|
||||
Epoch 72. Training loss: 4.290127848435077e-05 Average PSNR: 42.59131327932355 dB.
|
||||
Epoch 73. Training loss: 4.279395383491647e-05 Average PSNR: 42.505901953472375 dB.
|
||||
Epoch 74. Training loss: 4.227319686833653e-05 Average PSNR: 42.62832676739763 dB.
|
||||
Epoch 75. Training loss: 4.167593468082487e-05 Average PSNR: 42.62986386448593 dB.
|
||||
Epoch 76. Training loss: 4.219136577376048e-05 Average PSNR: 42.44606031537528 dB.
|
||||
Epoch 77. Training loss: 4.272524030966452e-05 Average PSNR: 42.67717208214105 dB.
|
||||
Epoch 78. Training loss: 4.352791798737599e-05 Average PSNR: 42.62172339083419 dB.
|
||||
Epoch 79. Training loss: 4.0833448019839123e-05 Average PSNR: 42.710664224920905 dB.
|
||||
Epoch 80. Training loss: 4.063137839693809e-05 Average PSNR: 42.815601984767554 dB.
|
||||
Epoch 81. Training loss: 4.323298475355841e-05 Average PSNR: 42.40471426652512 dB.
|
||||
Epoch 82. Training loss: 4.1421410132898016e-05 Average PSNR: 42.458958440721815 dB.
|
||||
Epoch 83. Training loss: 4.1137531261483674e-05 Average PSNR: 42.80103678818227 dB.
|
||||
Epoch 84. Training loss: 4.023500352559495e-05 Average PSNR: 42.680222548077765 dB.
|
||||
Epoch 85. Training loss: 3.937705409043702e-05 Average PSNR: 42.92232871453005 dB.
|
||||
Epoch 86. Training loss: 3.9499145996160225e-05 Average PSNR: 42.77000655125947 dB.
|
||||
Epoch 87. Training loss: 3.99408045632299e-05 Average PSNR: 42.262120154845974 dB.
|
||||
Epoch 88. Training loss: 4.2969265323336e-05 Average PSNR: 42.89088169577039 dB.
|
||||
Epoch 89. Training loss: 3.853916996376938e-05 Average PSNR: 43.00806079611334 dB.
|
||||
Epoch 90. Training loss: 3.812810891759e-05 Average PSNR: 42.98660421435421 dB.
|
||||
Epoch 91. Training loss: 3.8021667151042494e-05 Average PSNR: 42.95568073408279 dB.
|
||||
Epoch 92. Training loss: 3.910218174496549e-05 Average PSNR: 43.07431405920869 dB.
|
||||
Epoch 93. Training loss: 3.732714743819088e-05 Average PSNR: 43.00797843180268 dB.
|
||||
Epoch 94. Training loss: 3.779505837883334e-05 Average PSNR: 43.07809379231574 dB.
|
||||
Epoch 95. Training loss: 3.9144075235526546e-05 Average PSNR: 43.07315375774023 dB.
|
||||
Epoch 96. Training loss: 3.886822338245111e-05 Average PSNR: 43.069418450353936 dB.
|
||||
Epoch 97. Training loss: 3.6940460213372716e-05 Average PSNR: 43.07666543842636 dB.
|
||||
Epoch 98. Training loss: 3.6657753798863266e-05 Average PSNR: 43.1537746333162 dB.
|
||||
Epoch 99. Training loss: 3.5986074690299576e-05 Average PSNR: 43.15389075403043 dB.
|
||||
Epoch 100. Training loss: 3.5906440225517144e-05 Average PSNR: 43.20493074892996 dB. 37.9652/0.7040
|
||||
Epoch 101. Training loss: 3.989933359662245e-05 Average PSNR: 39.93472076434047 dB. 36.9117/0.7040
|
||||
Epoch 102. Training loss: 3.733712854227633e-05 Average PSNR: 43.29688828021639 dB.
|
||||
Epoch 103. Training loss: 3.6041745934198845e-05 Average PSNR: 43.306737995998866 dB.
|
||||
Epoch 104. Training loss: 3.5658771166708904e-05 Average PSNR: 43.26810466586173 dB.
|
||||
Epoch 105. Training loss: 3.524670941260411e-05 Average PSNR: 43.40545911645734 dB.
|
||||
Epoch 106. Training loss: 3.731644037543447e-05 Average PSNR: 43.29618589002626 dB.
|
||||
Epoch 107. Training loss: 3.515424312354298e-05 Average PSNR: 42.57789014273901 dB.
|
||||
Epoch 108. Training loss: 3.533762910592486e-05 Average PSNR: 43.420234193644454 dB.
|
||||
Epoch 109. Training loss: 3.596072063373867e-05 Average PSNR: 43.39850092295053 dB.
|
||||
Epoch 110. Training loss: 3.533482920829556e-05 Average PSNR: 43.392093948907174 dB.
|
||||
Epoch 111. Training loss: 3.410994330806716e-05 Average PSNR: 43.463910796398245 dB.
|
||||
Epoch 112. Training loss: 3.4490836678742195e-05 Average PSNR: 43.54414430731717 dB.
|
||||
Epoch 113. Training loss: 3.498832632431004e-05 Average PSNR: 43.54282871335326 dB.
|
||||
Epoch 114. Training loss: 3.4872444011853076e-05 Average PSNR: 43.569616480742184 dB.
|
||||
Epoch 115. Training loss: 3.339561070788477e-05 Average PSNR: 43.5617513391657 dB.
|
||||
Epoch 116. Training loss: 3.365256839970243e-05 Average PSNR: 43.36659571686377 dB.
|
||||
Epoch 117. Training loss: 3.525727189298777e-05 Average PSNR: 38.92335442332546 dB. 38.3984/0.7774
|
||||
Epoch 118. Training loss: 4.9161803908646106e-05 Average PSNR: 43.489483815400796 dB.
|
||||
Epoch 119. Training loss: 3.2969348540063946e-05 Average PSNR: 43.4802004816044 dB.
|
||||
Epoch 120. Training loss: 3.229604599255254e-05 Average PSNR: 43.635621539803566 dB.
|
||||
Epoch 121. Training loss: 3.2639079699947614e-05 Average PSNR: 43.637332431503616 dB.
|
||||
Epoch 122. Training loss: 3.275221746662282e-05 Average PSNR: 43.691427829855876 dB.
|
||||
Epoch 123. Training loss: 3.207607833246584e-05 Average PSNR: 43.58842822259747 dB.
|
||||
Epoch 124. Training loss: 3.205541729585093e-05 Average PSNR: 43.33105584418493 dB.
|
||||
Epoch 125. Training loss: 3.2615698537483694e-05 Average PSNR: 43.69760001820072 dB.
|
||||
Epoch 126. Training loss: 3.211692940567445e-05 Average PSNR: 43.74373051946786 dB.
|
||||
Epoch 127. Training loss: 3.152981980747427e-05 Average PSNR: 43.57205078205611 dB.
|
||||
Epoch 128. Training loss: 4.316570361879712e-05 Average PSNR: 42.18685314747573 dB.
|
||||
Epoch 129. Training loss: 3.181462780048605e-05 Average PSNR: 43.776277789506906 dB.
|
||||
Epoch 130. Training loss: 3.13720618669322e-05 Average PSNR: 43.7757626679266 dB.
|
||||
Epoch 131. Training loss: 3.09608046154608e-05 Average PSNR: 43.52274171700122 dB.
|
||||
Epoch 132. Training loss: 3.1449040807274284e-05 Average PSNR: 43.817116080109216 dB.
|
||||
Epoch 133. Training loss: 3.110551078862045e-05 Average PSNR: 43.770872311405626 dB.
|
||||
Epoch 134. Training loss: 3.084682341977896e-05 Average PSNR: 43.855956409067424 dB.
|
||||
Epoch 135. Training loss: 3.1018329782455115e-05 Average PSNR: 43.81357529648038 dB.
|
||||
Epoch 136. Training loss: 3.064724745854619e-05 Average PSNR: 43.680037600006685 dB.
|
||||
Epoch 137. Training loss: 3.6394121634657496e-05 Average PSNR: 43.60290275220899 dB.
|
||||
Epoch 138. Training loss: 3.05222990209586e-05 Average PSNR: 43.840158709280004 dB.
|
||||
Epoch 139. Training loss: 3.10103852825705e-05 Average PSNR: 39.05086275015894 dB. 40.0427/0.7782
|
||||
Epoch 140. Training loss: 3.857239870740159e-05 Average PSNR: 43.72159945394756 dB.
|
||||
Epoch 141. Training loss: 3.032407699720352e-05 Average PSNR: 43.91590803530713 dB.
|
||||
Epoch 142. Training loss: 3.0461475425909157e-05 Average PSNR: 43.59204624965716 dB.
|
||||
Epoch 143. Training loss: 3.1531663016721725e-05 Average PSNR: 43.92504877190417 dB.
|
||||
Epoch 144. Training loss: 3.007462703862984e-05 Average PSNR: 43.934042165438214 dB.
|
||||
Epoch 145. Training loss: 3.0099126925051677e-05 Average PSNR: 43.93920749295095 dB.
|
||||
Epoch 146. Training loss: 3.079762317611312e-05 Average PSNR: 43.033225114146504 dB.
|
||||
Epoch 147. Training loss: 3.0168946523190243e-05 Average PSNR: 43.98033678417097 dB.
|
||||
Epoch 148. Training loss: 3.091180842602625e-05 Average PSNR: 43.96775819891362 dB.
|
||||
Epoch 149. Training loss: 2.9445438831317005e-05 Average PSNR: 44.00493881286475 dB.
|
||||
Epoch 150. Training loss: 2.9688153917959427e-05 Average PSNR: 43.98941263333619 dB.
|
||||
Epoch 151. Training loss: 2.938612261459639e-05 Average PSNR: 43.75649797927036 dB.
|
||||
Epoch 152. Training loss: 3.6253114376449955e-05 Average PSNR: 43.770361945515845 dB.
|
||||
Epoch 153. Training loss: 3.077322558965534e-05 Average PSNR: 44.04419670964887 dB.
|
||||
Epoch 154. Training loss: 2.9240441353977075e-05 Average PSNR: 44.04990209284622 dB.
|
||||
Epoch 155. Training loss: 2.9035014831606533e-05 Average PSNR: 44.02112733800398 dB.
|
||||
Epoch 156. Training loss: 2.8992041661695113e-05 Average PSNR: 43.97957428947031 dB.
|
||||
Epoch 157. Training loss: 2.9251676714920904e-05 Average PSNR: 43.9443508048078 dB.
|
||||
Epoch 158. Training loss: 3.05789163121517e-05 Average PSNR: 43.63132674102104 dB.
|
||||
Epoch 159. Training loss: 2.977376703711343e-05 Average PSNR: 44.10396993940835 dB. 40.9453/0.7787
|
||||
Epoch 160. Training loss: 2.883360203668417e-05 Average PSNR: 43.93621610363439 dB.
|
||||
Epoch 161. Training loss: 2.8674156428678543e-05 Average PSNR: 44.10925236769381 dB.
|
||||
Epoch 162. Training loss: 2.8620581260838663e-05 Average PSNR: 44.11039775102721 dB.
|
||||
Epoch 163. Training loss: 2.8966812988073797e-05 Average PSNR: 44.109695515726564 dB.
|
||||
Epoch 164. Training loss: 3.071662084948912e-05 Average PSNR: 44.091161834586956 dB.
|
||||
Epoch 165. Training loss: 2.9427598883557947e-05 Average PSNR: 43.93181032368704 dB.
|
||||
Epoch 166. Training loss: 2.9850875907868613e-05 Average PSNR: 44.07792097912315 dB.
|
||||
Epoch 167. Training loss: 5.5166150741570166e-05 Average PSNR: 44.13405225123569 dB.
|
||||
Epoch 168. Training loss: 2.8183579524920786e-05 Average PSNR: 44.14559047230213 dB.
|
||||
Epoch 169. Training loss: 2.817333066559513e-05 Average PSNR: 44.156953682913084 dB. 42.7420/0.8448
|
||||
Epoch 170. Training loss: 2.806818698445568e-05 Average PSNR: 44.15158185239579 dB.
|
||||
Epoch 171. Training loss: 2.8043716583852073e-05 Average PSNR: 44.14570809594094 dB.
|
||||
Epoch 172. Training loss: 2.8252245147086797e-05 Average PSNR: 44.16260957123795 dB.
|
||||
Epoch 173. Training loss: 2.8405940138327424e-05 Average PSNR: 44.16192530002021 dB.
|
||||
Epoch 174. Training loss: 2.809159637763514e-05 Average PSNR: 44.15552258345899 dB.
|
||||
Epoch 175. Training loss: 2.790903268760303e-05 Average PSNR: 44.165708173777 dB.
|
||||
Epoch 176. Training loss: 2.7898249418285558e-05 Average PSNR: 44.20985241672864 dB.
|
||||
Epoch 177. Training loss: 2.8110662278777453e-05 Average PSNR: 44.1875459155592 dB.
|
||||
Epoch 178. Training loss: 2.7957651727774646e-05 Average PSNR: 44.214501714516025 dB.
|
||||
Epoch 179. Training loss: 2.8742814974975773e-05 Average PSNR: 44.17747672651946 dB.
|
||||
Epoch 180. Training loss: 2.903205723669089e-05 Average PSNR: 43.798061653532415 dB. 41.3036/0.7786
|
||||
Epoch 181. Training loss: 2.9952536524433527e-05 Average PSNR: 44.25979137488734 dB.
|
||||
Epoch 182. Training loss: 2.783904421448824e-05 Average PSNR: 43.9045814046101 dB.
|
||||
Epoch 183. Training loss: 2.80613460472523e-05 Average PSNR: 44.27714496274151 dB.
|
||||
Epoch 184. Training loss: 2.757695238869928e-05 Average PSNR: 44.19456243836852 dB.
|
||||
Epoch 185. Training loss: 2.7383821561670628e-05 Average PSNR: 44.267053477129664 dB.
|
||||
Epoch 186. Training loss: 2.783937245112611e-05 Average PSNR: 44.00551626475267 dB.
|
||||
Epoch 187. Training loss: 3.04154180776095e-05 Average PSNR: 44.22232211221754 dB.
|
||||
Epoch 188. Training loss: 3.2381991049987844e-05 Average PSNR: 44.13192002677785 dB.
|
||||
Epoch 189. Training loss: 2.7306975744068042e-05 Average PSNR: 44.296146733430916 dB.
|
||||
Epoch 190. Training loss: 2.7956372150583775e-05 Average PSNR: 44.30142511718106 dB. 42.7548/0.8448
|
||||
Epoch 191. Training loss: 2.707152742004837e-05 Average PSNR: 44.29096410515267 dB.
|
||||
Epoch 192. Training loss: 2.9267182644616695e-05 Average PSNR: 44.2269224074133 dB.
|
||||
Epoch 193. Training loss: 2.7651238588077832e-05 Average PSNR: 44.236247742259124 dB.
|
||||
Epoch 194. Training loss: 2.7692641624526003e-05 Average PSNR: 44.20613613915918 dB.
|
||||
Epoch 195. Training loss: 2.7414094383857445e-05 Average PSNR: 44.267585435124985 dB.
|
||||
Epoch 196. Training loss: 2.7187735149709624e-05 Average PSNR: 44.29366917495008 dB.
|
||||
Epoch 197. Training loss: 2.735268495598575e-05 Average PSNR: 43.435957249824526 dB.
|
||||
Epoch 198. Training loss: 2.8096141177229585e-05 Average PSNR: 44.31250727582331 dB.
|
||||
Epoch 199. Training loss: 2.695965677048662e-05 Average PSNR: 39.01822184344504 dB. 39.6769/0.8454
|
||||
|
After Width: | Height: | Size: 265 KiB |
|
After Width: | Height: | Size: 1.0 MiB |
|
After Width: | Height: | Size: 254 KiB |
@ -0,0 +1,222 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
from enum import Enum
|
||||
from typing import Any, Tuple, Union, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from typing import Union, Optional, Any, Tuple
|
||||
|
||||
def convert_rgb_to_y(img):
|
||||
if type(img) == np.ndarray:
|
||||
return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
|
||||
elif type(img) == torch.Tensor:
|
||||
if len(img.shape) == 4:
|
||||
img = img.squeeze(0)
|
||||
return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
|
||||
else:
|
||||
raise Exception('Unknown Type', type(img))
|
||||
|
||||
|
||||
def convert_rgb_to_ycbcr(img):
|
||||
if type(img) == np.ndarray:
|
||||
y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
|
||||
cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
|
||||
cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
|
||||
return np.array([y, cb, cr]).transpose([1, 2, 0])
|
||||
elif type(img) == torch.Tensor:
|
||||
if len(img.shape) == 4:
|
||||
img = img.squeeze(0)
|
||||
y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
|
||||
cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
|
||||
cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
|
||||
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
|
||||
else:
|
||||
raise Exception('Unknown Type', type(img))
|
||||
|
||||
|
||||
def convert_ycbcr_to_rgb(img):
|
||||
if type(img) == np.ndarray:
|
||||
r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
|
||||
g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
|
||||
b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
|
||||
return np.array([r, g, b]).transpose([1, 2, 0])
|
||||
elif type(img) == torch.Tensor:
|
||||
if len(img.shape) == 4:
|
||||
img = img.squeeze(0)
|
||||
r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
|
||||
g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
|
||||
b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
|
||||
return torch.cat([r, g, b], 0).permute(1, 2, 0)
|
||||
else:
|
||||
raise Exception('Unknown Type', type(img))
|
||||
|
||||
|
||||
def calc_psnr(img1, img2):
|
||||
return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
|
||||
def load_state_dict(
|
||||
model: nn.Module,
|
||||
model_weights_path: str,
|
||||
ema_model: Optional[nn.Module] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
load_mode: Optional[str] = None,
|
||||
) -> Union[Tuple[nn.Module, Optional[nn.Module], Any, Any, Any, Optional[Optimizer], Optional[_LRScheduler]],
|
||||
Tuple[nn.Module, Any, Any, Any, Optional[Optimizer], Optional[_LRScheduler]],
|
||||
nn.Module]:
|
||||
# Load model weights
|
||||
checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage)
|
||||
|
||||
if load_mode == "resume":
|
||||
# Restore the parameters in the training node to this point
|
||||
start_epoch = checkpoint["epoch"]
|
||||
best_psnr = checkpoint["best_psnr"]
|
||||
best_ssim = checkpoint["best_ssim"]
|
||||
# Load model state dict. Extract the fitted model weights
|
||||
model_state_dict = model.state_dict()
|
||||
state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict.keys()}
|
||||
# Overwrite the model weights to the current model (base model)
|
||||
model_state_dict.update(state_dict)
|
||||
model.load_state_dict(model_state_dict)
|
||||
# Load the optimizer model
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
# Load the scheduler model
|
||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
|
||||
if ema_model is not None:
|
||||
# Load ema model state dict. Extract the fitted model weights
|
||||
ema_model_state_dict = ema_model.state_dict()
|
||||
ema_state_dict = {k: v for k, v in checkpoint["ema_state_dict"].items() if k in ema_model_state_dict.keys()}
|
||||
# Overwrite the model weights to the current model (ema model)
|
||||
ema_model_state_dict.update(ema_state_dict)
|
||||
ema_model.load_state_dict(ema_model_state_dict)
|
||||
return model, ema_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler
|
||||
|
||||
return model, start_epoch, best_psnr, best_ssim, optimizer, scheduler
|
||||
else:
|
||||
# Load model state dict. Extract the fitted model weights
|
||||
model_state_dict = model.state_dict()
|
||||
state_dict = {k: v for k, v in checkpoint["state_dict"].items() if
|
||||
k in model_state_dict.keys() and v.size() == model_state_dict[k].size()}
|
||||
# Overwrite the model weights to the current model
|
||||
model_state_dict.update(state_dict)
|
||||
model.load_state_dict(model_state_dict)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def make_directory(dir_path: str) -> None:
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
state_dict: dict,
|
||||
file_name: str,
|
||||
samples_dir: str,
|
||||
results_dir: str,
|
||||
is_best: bool = False,
|
||||
is_last: bool = False,
|
||||
) -> None:
|
||||
checkpoint_path = os.path.join(samples_dir, file_name)
|
||||
torch.save(state_dict, checkpoint_path)
|
||||
|
||||
if is_best:
|
||||
shutil.copyfile(checkpoint_path, os.path.join(results_dir, "LSRGAN_x2.pth.tar"))
|
||||
if is_last:
|
||||
shutil.copyfile(checkpoint_path, os.path.join(results_dir, "last.pth.tar"))
|
||||
|
||||
|
||||
class Summary(Enum):
|
||||
NONE = 0
|
||||
AVERAGE = 1
|
||||
SUM = 2
|
||||
COUNT = 3
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.summary_type = summary_type
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
def summary(self):
|
||||
if self.summary_type is Summary.NONE:
|
||||
fmtstr = ""
|
||||
elif self.summary_type is Summary.AVERAGE:
|
||||
fmtstr = "{name} {avg:.2f}"
|
||||
elif self.summary_type is Summary.SUM:
|
||||
fmtstr = "{name} {sum:.2f}"
|
||||
elif self.summary_type is Summary.COUNT:
|
||||
fmtstr = "{name} {count:.2f}"
|
||||
else:
|
||||
raise ValueError(f"Invalid summary type {self.summary_type}")
|
||||
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class ProgressMeter(object):
|
||||
def __init__(self, num_batches, meters, prefix=""):
|
||||
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
||||
self.meters = meters
|
||||
self.prefix = prefix
|
||||
|
||||
def display(self, batch):
|
||||
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
||||
entries += [str(meter) for meter in self.meters]
|
||||
print("\t".join(entries))
|
||||
|
||||
def display_summary(self):
|
||||
entries = [" *"]
|
||||
entries += [meter.summary() for meter in self.meters]
|
||||
print(" ".join(entries))
|
||||
|
||||
def _get_batch_fmtstr(self, num_batches):
|
||||
num_digits = len(str(num_batches // 1))
|
||||
fmt = "{:" + str(num_digits) + "d}"
|
||||
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
|
||||
|
After Width: | Height: | Size: 448 KiB |