You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
3.0 KiB
100 lines
3.0 KiB
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.backends import cudnn
|
|
|
|
# Random seed to maintain reproducible results
|
|
random.seed(0)
|
|
torch.manual_seed(0)
|
|
np.random.seed(0)
|
|
# Use GPU for training by default
|
|
device = torch.device("cuda", 0)
|
|
# Turning on when the image size does not change during training can speed up training
|
|
cudnn.benchmark = True
|
|
# When evaluating the performance of the SR model, whether to verify only the Y channel image data
|
|
only_test_y_channel = True
|
|
# Model architecture name
|
|
d_arch_name = "discriminator"
|
|
g_arch_name = "lsrgan_x2"
|
|
# Model arch config
|
|
in_channels = 3
|
|
out_channels = 3
|
|
channels = 64
|
|
growth_channels = 32
|
|
num_blocks = 23
|
|
upscale_factor = 2
|
|
# Current configuration parameter method
|
|
mode = "test"
|
|
# Experiment name, easy to save weights and log files
|
|
exp_name = "LSRGAN_x2"
|
|
|
|
if mode == "train":
|
|
# Dataset address
|
|
train_gt_images_dir = f"./data/DIV2K/LSRGAN/train"
|
|
|
|
test_gt_images_dir = f"./data/Set5/GTmod12"
|
|
test_lr_images_dir = f"./data/Set5/LRbicx{upscale_factor}"
|
|
|
|
gt_image_size = 128
|
|
batch_size = 16
|
|
num_workers = 4
|
|
|
|
# Load the address of the pretrained model
|
|
pretrained_d_model_weights_path = ""
|
|
pretrained_g_model_weights_path = ""
|
|
|
|
# Incremental training and migration training
|
|
resume_d = ""
|
|
resume_g = ""
|
|
|
|
# Total num epochs (500,000 iters)
|
|
epochs = 512
|
|
|
|
# Feature extraction layer parameter configuration
|
|
feature_model_extractor_node = "features.34"
|
|
feature_model_normalize_mean = [0.485, 0.456, 0.406]
|
|
feature_model_normalize_std = [0.229, 0.224, 0.225]
|
|
|
|
# Loss function weight
|
|
pixel_weight = 0.01
|
|
content_weight = 1.0
|
|
adversarial_weight = 0.005
|
|
|
|
# Optimizer parameter
|
|
model_lr = 1e-4
|
|
model_betas = (0.9, 0.999)
|
|
model_eps = 1e-8
|
|
|
|
# EMA parameter
|
|
model_ema_decay = 0.99998
|
|
|
|
# Dynamically adjust the learning rate policy (200,000 iters)
|
|
lr_scheduler_milestones = [int(epochs * 0.1), int(epochs * 0.2), int(epochs * 0.4), int(epochs * 0.6)]
|
|
lr_scheduler_gamma = 0.5
|
|
|
|
# How many iterations to print the training result
|
|
train_print_frequency = 100
|
|
valid_print_frequency = 1
|
|
|
|
if mode == "test":
|
|
# Test data address
|
|
test_gt_images_dir = "data\\hr"
|
|
test_lr_images_dir = f"data\\lr"
|
|
sr_dir = f"data\\{exp_name}"
|
|
|
|
g_model_weights_path = "LSRGAN_x2.pth.tar"
|