You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
KinD/adjustment_net_train.py

224 lines
9.8 KiB

# coding: utf-8
from __future__ import print_function
import os
import time
import random
#from skimage import color
from PIL import Image
import tensorflow as tf
import numpy as np
from utils import *
from model import *
from glob import glob
batch_size = 10
patch_size = 48
sess = tf.Session()
#the input of decomposition net
input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom')
#the input of illumination adjustment net
input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i')
input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio')
input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i')
[R_decom, I_decom] = DecomNet_simple(input_decom)
#the output of decomposition network
decom_output_R = R_decom
decom_output_I = I_decom
#the output of illumination adjustment net
output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio)
#define loss
def grad_loss(input_i_low, input_i_high):
x_loss = tf.square(gradient(input_i_low, 'x') - gradient(input_i_high, 'x'))
y_loss = tf.square(gradient(input_i_low, 'y') - gradient(input_i_high, 'y'))
grad_loss_all = tf.reduce_mean(x_loss + y_loss)
return grad_loss_all
loss_grad = grad_loss(output_i, input_high_i)
loss_square = tf.reduce_mean(tf.square(output_i - input_high_i))# * ( 1 - input_low_r ))#* (1- input_low_i)))
loss_adjust = loss_square + loss_grad
lr = tf.placeholder(tf.float32, name='learning_rate')
optimizer = tf.train.AdamOptimizer(learning_rate=lr, name='AdamOptimizer')
var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name]
saver_adjust = tf.train.Saver(var_list=var_adjust)
saver_Decom = tf.train.Saver(var_list = var_Decom)
train_op_adjust = optimizer.minimize(loss_adjust, var_list = var_adjust)
sess.run(tf.global_variables_initializer())
print("[*] Initialize model successfully...")
### load data
### Based on the decomposition net, we first get the decomposed reflectance maps
### and illumination maps, then train the adjust net.
###train_data
train_low_data = []
train_high_data = []
train_low_data_names = glob('./LOLdataset/our485/low/*.png')
train_low_data_names.sort()
train_high_data_names = glob('./LOLdataset/our485/high/*.png')
train_high_data_names.sort()
assert len(train_low_data_names) == len(train_high_data_names)
print('[*] Number of training data: %d' % len(train_low_data_names))
for idx in range(len(train_low_data_names)):
low_im = load_images(train_low_data_names[idx])
train_low_data.append(low_im)
high_im = load_images(train_high_data_names[idx])
train_high_data.append(high_im)
pre_decom_checkpoint_dir = './checkpoint/decom_net_train/'
ckpt_pre=tf.train.get_checkpoint_state(pre_decom_checkpoint_dir)
if ckpt_pre:
print('loaded '+ckpt_pre.model_checkpoint_path)
saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path)
else:
print('No pre_decom_net checkpoint!')
#decomposed_low_r_data_480 = []
decomposed_low_i_data_480 = []
#decomposed_high_r_data_480 = []
decomposed_high_i_data_480 = []
for idx in range(len(train_low_data)):
input_low = np.expand_dims(train_low_data[idx], axis=0)
RR, II = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low})
RR0 = np.squeeze(RR)
II0 = np.squeeze(II)
print(RR0.shape, II0.shape)
#decomposed_high_r_data_480.append(result_1_sq)
decomposed_low_i_data_480.append(II0)
for idx in range(len(train_high_data)):
input_high = np.expand_dims(train_high_data[idx], axis=0)
RR2, II2 = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_high})
RR02 = np.squeeze(RR2)
II02 = np.squeeze(II2)
print(RR02.shape, II02.shape)
#decomposed_high_r_data_480.append(result_1_sq)
decomposed_high_i_data_480.append(II02)
eval_adjust_low_i_data = decomposed_low_i_data_480[451:480]
eval_adjust_high_i_data = decomposed_high_i_data_480[451:480]
train_adjust_low_i_data = decomposed_low_i_data_480[0:450]
train_adjust_high_i_data = decomposed_high_i_data_480[0:450]
print('[*] Number of training data: %d' % len(train_adjust_high_i_data))
learning_rate = 0.0001
epoch = 2000
eval_every_epoch = 200
train_phase = 'adjustment'
numBatch = len(train_adjust_low_i_data) // int(batch_size)
train_op = train_op_adjust
train_loss = loss_adjust
saver = saver_adjust
checkpoint_dir = './checkpoint/illumination_adjust_net_train/'
if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir)
ckpt=tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
print('loaded '+ckpt.model_checkpoint_path)
saver.restore(sess,ckpt.model_checkpoint_path)
else:
print("No adjustment net pre model!")
start_step = 0
start_epoch = 0
iter_num = 0
print("[*] Start training for phase %s, with start epoch %d start iter %d : " % (train_phase, start_epoch, iter_num))
sample_dir = './illumination_adjust_net_train/'
if not os.path.isdir(sample_dir):
os.makedirs(sample_dir)
start_time = time.time()
image_id = 0
for epoch in range(start_epoch, epoch):
for batch_id in range(start_step, numBatch):
batch_input_low_i_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
batch_input_high_i_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
batch_input_low_i = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
batch_input_high_i = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
input_low_i_rand = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
input_high_i_rand = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
input_low_i_rand_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
input_high_i_rand_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
for patch_id in range(batch_size):
i_low_data = train_adjust_low_i_data[image_id]
i_low_expand = np.expand_dims(i_low_data, axis = 2)
i_high_data = train_adjust_high_i_data[image_id]
i_high_expand = np.expand_dims(i_high_data, axis = 2)
h, w = train_adjust_low_i_data[image_id].shape
x = random.randint(0, h - patch_size)
y = random.randint(0, w - patch_size)
i_low_data_crop = i_low_expand[x : x+patch_size, y : y+patch_size, :]
i_high_data_crop = i_high_expand[x : x+patch_size, y : y+patch_size, :]
rand_mode = np.random.randint(0, 7)
batch_input_low_i[patch_id, :, :, :] = data_augmentation(i_low_data_crop , rand_mode)
batch_input_high_i[patch_id, :, :, :] = data_augmentation(i_high_data_crop, rand_mode)
ratio = np.mean(i_low_data_crop/(i_high_data_crop+0.0001))
#print(ratio)
i_low_data_ratio = np.ones([patch_size,patch_size])*(1/ratio+0.0001)
i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2)
i_high_data_ratio = np.ones([patch_size,patch_size])*(ratio)
i_high_ratio_expand = np.expand_dims(i_high_data_ratio , axis =2)
batch_input_low_i_ratio[patch_id, :, :, :] = i_low_ratio_expand
batch_input_high_i_ratio[patch_id, :, :, :] = i_high_ratio_expand
rand_mode = np.random.randint(0, 2)
if rand_mode == 1:
input_low_i_rand[patch_id, :, :, :] = batch_input_low_i[patch_id, :, :, :]
input_high_i_rand[patch_id, :, :, :] = batch_input_high_i[patch_id, :, :, :]
input_low_i_rand_ratio[patch_id, :, :, :] = batch_input_low_i_ratio[patch_id, :, :, :]
input_high_i_rand_ratio[patch_id, :, :, :] = batch_input_high_i_ratio[patch_id, :, :, :]
else:
input_low_i_rand[patch_id, :, :, :] = batch_input_high_i[patch_id, :, :, :]
input_high_i_rand[patch_id, :, :, :] = batch_input_low_i[patch_id, :, :, :]
input_low_i_rand_ratio[patch_id, :, :, :] = batch_input_high_i_ratio[patch_id, :, :, :]
input_high_i_rand_ratio[patch_id, :, :, :] = batch_input_low_i_ratio[patch_id, :, :, :]
image_id = (image_id + 1) % len(train_adjust_low_i_data)
_, loss = sess.run([train_op, train_loss], feed_dict={input_low_i: input_low_i_rand,input_low_i_ratio: input_low_i_rand_ratio,\
input_high_i: input_high_i_rand, \
lr: learning_rate})
print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
% (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
iter_num += 1
if (epoch + 1) % eval_every_epoch == 0:
print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch + 1))
for idx in range(10):
rand_idx = idx#np.random.randint(26)
input_uu_i = eval_adjust_low_i_data[rand_idx]
input_low_eval_i = np.expand_dims(input_uu_i, axis=0)
input_low_eval_ii = np.expand_dims(input_low_eval_i, axis=3)
h, w = eval_adjust_low_i_data[idx].shape
rand_ratio = np.random.random(1)*2
input_uu_i_ratio = np.ones([h,w]) * rand_ratio
input_low_eval_i_ratio = np.expand_dims(input_uu_i_ratio, axis=0)
input_low_eval_ii_ratio = np.expand_dims(input_low_eval_i_ratio, axis=3)
result_1 = sess.run(output_i, feed_dict={input_low_i: input_low_eval_ii, input_low_i_ratio: input_low_eval_ii_ratio})
save_images(os.path.join(sample_dir, 'h_eval_%d_%d_%5f.png' % ( epoch + 1 , rand_idx + 1, rand_ratio)), input_uu_i, result_1)
saver.save(sess, checkpoint_dir + 'model.ckpt')
print("[*] Finish training for phase %s." % train_phase)