You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

84 lines
3.3 KiB

# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import cv2
import torch
from torchvision import transforms
import lsrgan_config
import model_LSRGAN
from utils import make_directory
def preprocess_image(image_path, device):
# Load image using OpenCV
image = cv2.imread(image_path)
# Check if the image was successfully loaded
if image is None:
raise ValueError(f"Failed to load image: {image_path}")
# Convert BGR image to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Convert image to a PyTorch tensor
transform = transforms.ToTensor()
tensor = transform(image)
# Add batch dimension and move to specified device
tensor = tensor.unsqueeze(0).to(device)
return tensor
def main() -> None:
# Initialize the super-resolution model
msrn_model = model_LSRGAN.__dict__[lsrgan_config.g_arch_name](in_channels=lsrgan_config.in_channels,
out_channels=lsrgan_config.out_channels,
channels=lsrgan_config.channels,
growth_channels=lsrgan_config.growth_channels,
num_blocks=lsrgan_config.num_blocks)
msrn_model = msrn_model.to(device=lsrgan_config.device)
print(f"Build `{lsrgan_config.g_arch_name}` model successfully.")
# Load the super-resolution model weights
checkpoint = torch.load(lsrgan_config.g_model_weights_path, map_location=lambda storage, loc: storage)
msrn_model.load_state_dict(checkpoint["state_dict"])
print(f"Load `{lsrgan_config.g_arch_name}` model weights "
f"`{os.path.abspath(lsrgan_config.g_model_weights_path)}` successfully.")
# Create a folder of super-resolution experiment results
make_directory(lsrgan_config.sr_dir)
# Start the verification mode of the model.
msrn_model.eval()
# Process the specified image
lr_image_path = "eye_lr.jpg"
sr_image_path = os.path.join(lsrgan_config.sr_dir, "eye_LSRGAN_x2.jpg")
print(f"Processing `{os.path.abspath(lr_image_path)}`...")
lr_tensor = preprocess_image(lr_image_path, lsrgan_config.device)
# Only reconstruct the Y channel image data.
with torch.no_grad():
sr_tensor = msrn_model(lr_tensor)
# Save image
sr_image = sr_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
sr_image = (sr_image * 255.0).clip(0, 255).astype("uint8")
sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(sr_image_path, sr_image)
print(f"Super-resolved image saved to `{os.path.abspath(sr_image_path)}`")
if __name__ == "__main__":
main()