diff --git a/Decom_net_train/high_10_200.png b/Decom_net_train/high_10_200.png new file mode 100644 index 0000000..baa8998 Binary files /dev/null and b/Decom_net_train/high_10_200.png differ diff --git a/Decom_net_train/high_10_400.png b/Decom_net_train/high_10_400.png new file mode 100644 index 0000000..c087bca Binary files /dev/null and b/Decom_net_train/high_10_400.png differ diff --git a/Decom_net_train/high_11_200.png b/Decom_net_train/high_11_200.png new file mode 100644 index 0000000..21eca2b Binary files /dev/null and b/Decom_net_train/high_11_200.png differ diff --git a/Decom_net_train/high_11_400.png b/Decom_net_train/high_11_400.png new file mode 100644 index 0000000..1ad3e5d Binary files /dev/null and b/Decom_net_train/high_11_400.png differ diff --git a/Decom_net_train/high_12_200.png b/Decom_net_train/high_12_200.png new file mode 100644 index 0000000..5f04294 Binary files /dev/null and b/Decom_net_train/high_12_200.png differ diff --git a/Decom_net_train/high_12_400.png b/Decom_net_train/high_12_400.png new file mode 100644 index 0000000..6e7a17b Binary files /dev/null and b/Decom_net_train/high_12_400.png differ diff --git a/Decom_net_train/high_13_200.png b/Decom_net_train/high_13_200.png new file mode 100644 index 0000000..e8bb0a7 Binary files /dev/null and b/Decom_net_train/high_13_200.png differ diff --git a/Decom_net_train/high_13_400.png b/Decom_net_train/high_13_400.png new file mode 100644 index 0000000..e553445 Binary files /dev/null and b/Decom_net_train/high_13_400.png differ diff --git a/Decom_net_train/high_14_200.png b/Decom_net_train/high_14_200.png new file mode 100644 index 0000000..4abaa0c Binary files /dev/null and b/Decom_net_train/high_14_200.png differ diff --git a/Decom_net_train/high_14_400.png b/Decom_net_train/high_14_400.png new file mode 100644 index 0000000..6e5a567 Binary files /dev/null and b/Decom_net_train/high_14_400.png differ diff --git a/Decom_net_train/high_15_200.png b/Decom_net_train/high_15_200.png new file mode 100644 index 0000000..4ee9301 Binary files /dev/null and b/Decom_net_train/high_15_200.png differ diff --git a/Decom_net_train/high_15_400.png b/Decom_net_train/high_15_400.png new file mode 100644 index 0000000..eb7b3fb Binary files /dev/null and b/Decom_net_train/high_15_400.png differ diff --git a/Decom_net_train/high_1_200.png b/Decom_net_train/high_1_200.png new file mode 100644 index 0000000..def1108 Binary files /dev/null and b/Decom_net_train/high_1_200.png differ diff --git a/Decom_net_train/high_1_400.png b/Decom_net_train/high_1_400.png new file mode 100644 index 0000000..86bc1a6 Binary files /dev/null and b/Decom_net_train/high_1_400.png differ diff --git a/Decom_net_train/high_2_200.png b/Decom_net_train/high_2_200.png new file mode 100644 index 0000000..fa43b57 Binary files /dev/null and b/Decom_net_train/high_2_200.png differ diff --git a/Decom_net_train/high_2_400.png b/Decom_net_train/high_2_400.png new file mode 100644 index 0000000..b9e962a Binary files /dev/null and b/Decom_net_train/high_2_400.png differ diff --git a/Decom_net_train/high_3_200.png b/Decom_net_train/high_3_200.png new file mode 100644 index 0000000..27ae7ae Binary files /dev/null and b/Decom_net_train/high_3_200.png differ diff --git a/Decom_net_train/high_3_400.png b/Decom_net_train/high_3_400.png new file mode 100644 index 0000000..127cc73 Binary files /dev/null and b/Decom_net_train/high_3_400.png differ diff --git a/Decom_net_train/high_4_200.png b/Decom_net_train/high_4_200.png new file mode 100644 index 0000000..a08ee7b Binary files /dev/null and b/Decom_net_train/high_4_200.png differ diff --git a/Decom_net_train/high_4_400.png b/Decom_net_train/high_4_400.png new file mode 100644 index 0000000..15ca50c Binary files /dev/null and b/Decom_net_train/high_4_400.png differ diff --git a/Decom_net_train/high_5_200.png b/Decom_net_train/high_5_200.png new file mode 100644 index 0000000..6a01754 Binary files /dev/null and b/Decom_net_train/high_5_200.png differ diff --git a/Decom_net_train/high_5_400.png b/Decom_net_train/high_5_400.png new file mode 100644 index 0000000..fb320ef Binary files /dev/null and b/Decom_net_train/high_5_400.png differ diff --git a/Decom_net_train/high_6_200.png b/Decom_net_train/high_6_200.png new file mode 100644 index 0000000..988eb8f Binary files /dev/null and b/Decom_net_train/high_6_200.png differ diff --git a/Decom_net_train/high_6_400.png b/Decom_net_train/high_6_400.png new file mode 100644 index 0000000..7cb240b Binary files /dev/null and b/Decom_net_train/high_6_400.png differ diff --git a/Decom_net_train/high_7_200.png b/Decom_net_train/high_7_200.png new file mode 100644 index 0000000..e2cca46 Binary files /dev/null and b/Decom_net_train/high_7_200.png differ diff --git a/Decom_net_train/high_7_400.png b/Decom_net_train/high_7_400.png new file mode 100644 index 0000000..59fed1b Binary files /dev/null and b/Decom_net_train/high_7_400.png differ diff --git a/Decom_net_train/high_8_200.png b/Decom_net_train/high_8_200.png new file mode 100644 index 0000000..fb441c4 Binary files /dev/null and b/Decom_net_train/high_8_200.png differ diff --git a/Decom_net_train/high_8_400.png b/Decom_net_train/high_8_400.png new file mode 100644 index 0000000..d33f0ea Binary files /dev/null and b/Decom_net_train/high_8_400.png differ diff --git a/Decom_net_train/high_9_200.png b/Decom_net_train/high_9_200.png new file mode 100644 index 0000000..b104994 Binary files /dev/null and b/Decom_net_train/high_9_200.png differ diff --git a/Decom_net_train/high_9_400.png b/Decom_net_train/high_9_400.png new file mode 100644 index 0000000..c68b600 Binary files /dev/null and b/Decom_net_train/high_9_400.png differ diff --git a/Decom_net_train/low_10_200.png b/Decom_net_train/low_10_200.png new file mode 100644 index 0000000..672eaef Binary files /dev/null and b/Decom_net_train/low_10_200.png differ diff --git a/Decom_net_train/low_10_400.png b/Decom_net_train/low_10_400.png new file mode 100644 index 0000000..5ff8c4e Binary files /dev/null and b/Decom_net_train/low_10_400.png differ diff --git a/Decom_net_train/low_11_200.png b/Decom_net_train/low_11_200.png new file mode 100644 index 0000000..c4b633d Binary files /dev/null and b/Decom_net_train/low_11_200.png differ diff --git a/Decom_net_train/low_11_400.png b/Decom_net_train/low_11_400.png new file mode 100644 index 0000000..ddc3657 Binary files /dev/null and b/Decom_net_train/low_11_400.png differ diff --git a/Decom_net_train/low_12_200.png b/Decom_net_train/low_12_200.png new file mode 100644 index 0000000..287d678 Binary files /dev/null and b/Decom_net_train/low_12_200.png differ diff --git a/Decom_net_train/low_12_400.png b/Decom_net_train/low_12_400.png new file mode 100644 index 0000000..0afe04d Binary files /dev/null and b/Decom_net_train/low_12_400.png differ diff --git a/Decom_net_train/low_13_200.png b/Decom_net_train/low_13_200.png new file mode 100644 index 0000000..45c659e Binary files /dev/null and b/Decom_net_train/low_13_200.png differ diff --git a/Decom_net_train/low_13_400.png b/Decom_net_train/low_13_400.png new file mode 100644 index 0000000..3c8fa15 Binary files /dev/null and b/Decom_net_train/low_13_400.png differ diff --git a/Decom_net_train/low_14_200.png b/Decom_net_train/low_14_200.png new file mode 100644 index 0000000..05e42e6 Binary files /dev/null and b/Decom_net_train/low_14_200.png differ diff --git a/Decom_net_train/low_14_400.png b/Decom_net_train/low_14_400.png new file mode 100644 index 0000000..77cabe8 Binary files /dev/null and b/Decom_net_train/low_14_400.png differ diff --git a/Decom_net_train/low_15_200.png b/Decom_net_train/low_15_200.png new file mode 100644 index 0000000..6b5d89e Binary files /dev/null and b/Decom_net_train/low_15_200.png differ diff --git a/Decom_net_train/low_15_400.png b/Decom_net_train/low_15_400.png new file mode 100644 index 0000000..236c9c0 Binary files /dev/null and b/Decom_net_train/low_15_400.png differ diff --git a/Decom_net_train/low_1_200.png b/Decom_net_train/low_1_200.png new file mode 100644 index 0000000..bef41b7 Binary files /dev/null and b/Decom_net_train/low_1_200.png differ diff --git a/Decom_net_train/low_1_400.png b/Decom_net_train/low_1_400.png new file mode 100644 index 0000000..d43e609 Binary files /dev/null and b/Decom_net_train/low_1_400.png differ diff --git a/Decom_net_train/low_2_200.png b/Decom_net_train/low_2_200.png new file mode 100644 index 0000000..1d144d0 Binary files /dev/null and b/Decom_net_train/low_2_200.png differ diff --git a/Decom_net_train/low_2_400.png b/Decom_net_train/low_2_400.png new file mode 100644 index 0000000..3df85c7 Binary files /dev/null and b/Decom_net_train/low_2_400.png differ diff --git a/Decom_net_train/low_3_200.png b/Decom_net_train/low_3_200.png new file mode 100644 index 0000000..8107588 Binary files /dev/null and b/Decom_net_train/low_3_200.png differ diff --git a/Decom_net_train/low_3_400.png b/Decom_net_train/low_3_400.png new file mode 100644 index 0000000..00795f1 Binary files /dev/null and b/Decom_net_train/low_3_400.png differ diff --git a/Decom_net_train/low_4_200.png b/Decom_net_train/low_4_200.png new file mode 100644 index 0000000..4073eaf Binary files /dev/null and b/Decom_net_train/low_4_200.png differ diff --git a/Decom_net_train/low_4_400.png b/Decom_net_train/low_4_400.png new file mode 100644 index 0000000..0abbddc Binary files /dev/null and b/Decom_net_train/low_4_400.png differ diff --git a/Decom_net_train/low_5_200.png b/Decom_net_train/low_5_200.png new file mode 100644 index 0000000..2ed0adf Binary files /dev/null and b/Decom_net_train/low_5_200.png differ diff --git a/Decom_net_train/low_5_400.png b/Decom_net_train/low_5_400.png new file mode 100644 index 0000000..b42eccf Binary files /dev/null and b/Decom_net_train/low_5_400.png differ diff --git a/Decom_net_train/low_6_200.png b/Decom_net_train/low_6_200.png new file mode 100644 index 0000000..c48142a Binary files /dev/null and b/Decom_net_train/low_6_200.png differ diff --git a/Decom_net_train/low_6_400.png b/Decom_net_train/low_6_400.png new file mode 100644 index 0000000..9c3bbee Binary files /dev/null and b/Decom_net_train/low_6_400.png differ diff --git a/Decom_net_train/low_7_200.png b/Decom_net_train/low_7_200.png new file mode 100644 index 0000000..4e7308d Binary files /dev/null and b/Decom_net_train/low_7_200.png differ diff --git a/Decom_net_train/low_7_400.png b/Decom_net_train/low_7_400.png new file mode 100644 index 0000000..0202d35 Binary files /dev/null and b/Decom_net_train/low_7_400.png differ diff --git a/Decom_net_train/low_8_200.png b/Decom_net_train/low_8_200.png new file mode 100644 index 0000000..f70885b Binary files /dev/null and b/Decom_net_train/low_8_200.png differ diff --git a/Decom_net_train/low_8_400.png b/Decom_net_train/low_8_400.png new file mode 100644 index 0000000..04d90cd Binary files /dev/null and b/Decom_net_train/low_8_400.png differ diff --git a/Decom_net_train/low_9_200.png b/Decom_net_train/low_9_200.png new file mode 100644 index 0000000..21125f2 Binary files /dev/null and b/Decom_net_train/low_9_200.png differ diff --git a/Decom_net_train/low_9_400.png b/Decom_net_train/low_9_400.png new file mode 100644 index 0000000..ecb1731 Binary files /dev/null and b/Decom_net_train/low_9_400.png differ diff --git a/LOLdataset.zip b/LOLdataset.zip new file mode 100644 index 0000000..9cd7823 Binary files /dev/null and b/LOLdataset.zip differ diff --git a/README.md b/README.md index 0d6ef0d..9947386 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,58 @@ -#### 从命令行创建一个新的仓库 +# KinD +This is a Tensorflow implementation of KinD -```bash -touch README.md -git init -git add README.md -git commit -m "first commit" -git remote add origin https://bdgit.educoder.net/ZhengHui/KinD.git -git push -u origin master +### The [KinD++](https://github.com/zhangyhuaee/KinD_plus) is an improved version. ### -``` +Kindling the Darkness: a Practical Low-light Image Enhancer. In ACMMM2019
+Yonghua Zhang, Jiawan Zhang, Xiaojie Guo + +### [Paper](http://doi.acm.org/10.1145/3343031.3350926) + + + + +### Requirements ### +1. Python +2. Tensorflow >= 1.10.0 +3. numpy, PIL -#### 从命令行推送已经创建的仓库 +### Test ### +First download the pre-trained checkpoints from [BaiduNetdisk](https://pan.baidu.com/s/1c4ZLYEIoR-8skNMiAVbl_A) or [google drive](https://drive.google.com/open?id=1-ljWntl7FExf6BSQtl5Mz3rMGWgnXDz4), then just run +```shell +python evaluate.py +``` +Our pre-trained model has changed. Thus, the results have some difference with the report in our paper. However, you can adjust the illumination ratio to get better results. -```bash -git remote add origin https://bdgit.educoder.net/ZhengHui/KinD.git -git push -u origin master +### Train ### +The original LOLdataset can be downloaded from [here](https://daooshee.github.io/BMVC2018website/). We rearrange the original LOLdataset and add several all-zero images to improve the decomposition results and restoration results. The new dataset can be download from [BaiduNetdisk](https://pan.baidu.com/s/1sn3vWJ2I5U2dlVUD7eqIBQ) or [google drive](https://drive.google.com/open?id=1-MaOVG7ylOkmGv1K4HWWcrai01i_FeDK). Save training pairs of LOL dataset under './LOLdataset/our485/' and save evaluating pairs under './LOLdataset/eval15/'. For training, just run +```shell +python decomposition_net_train.py +python adjustment_net_train.py +python reflectance_restoration_net_train.py +``` +You can also evaluate on the LOLdataset, just run +```shell +python evaluate_LOLdataset.py +``` +Our code partly refers to the [code](https://github.com/weichen582/RetinexNet). +### Citation ### +``` +@inproceedings{zhang2019kindling, + author = {Zhang, Yonghua and Zhang, Jiawan and Guo, Xiaojie}, + title = {Kindling the Darkness: A Practical Low-light Image Enhancer}, + booktitle = {Proceedings of the 27th ACM International Conference on Multimedia}, + series = {MM '19}, + year = {2019}, + isbn = {978-1-4503-6889-6}, + location = {Nice, France}, + pages = {1632--1640}, + numpages = {9}, + url = {http://doi.acm.org/10.1145/3343031.3350926}, + doi = {10.1145/3343031.3350926}, + acmid = {3350926}, + publisher = {ACM}, + address = {New York, NY, USA}, + keywords = {image decomposition, image restoration, low light enhancement}, +} ``` diff --git a/adjustment_net_train.py b/adjustment_net_train.py new file mode 100644 index 0000000..59cf95e --- /dev/null +++ b/adjustment_net_train.py @@ -0,0 +1,223 @@ +# coding: utf-8 +from __future__ import print_function +import os +import time +import random +#from skimage import color +from PIL import Image +import tensorflow as tf +import numpy as np +from utils import * +from model import * +from glob import glob + +batch_size = 10 +patch_size = 48 + +sess = tf.Session() +#the input of decomposition net +input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom') +#the input of illumination adjustment net +input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i') +input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio') +input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i') + +[R_decom, I_decom] = DecomNet_simple(input_decom) +#the output of decomposition network +decom_output_R = R_decom +decom_output_I = I_decom +#the output of illumination adjustment net +output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio) + +#define loss + +def grad_loss(input_i_low, input_i_high): + x_loss = tf.square(gradient(input_i_low, 'x') - gradient(input_i_high, 'x')) + y_loss = tf.square(gradient(input_i_low, 'y') - gradient(input_i_high, 'y')) + grad_loss_all = tf.reduce_mean(x_loss + y_loss) + return grad_loss_all + +loss_grad = grad_loss(output_i, input_high_i) +loss_square = tf.reduce_mean(tf.square(output_i - input_high_i))# * ( 1 - input_low_r ))#* (1- input_low_i))) + +loss_adjust = loss_square + loss_grad + +lr = tf.placeholder(tf.float32, name='learning_rate') + +optimizer = tf.train.AdamOptimizer(learning_rate=lr, name='AdamOptimizer') + +var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name] +var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name] + +saver_adjust = tf.train.Saver(var_list=var_adjust) +saver_Decom = tf.train.Saver(var_list = var_Decom) +train_op_adjust = optimizer.minimize(loss_adjust, var_list = var_adjust) +sess.run(tf.global_variables_initializer()) +print("[*] Initialize model successfully...") + +### load data +### Based on the decomposition net, we first get the decomposed reflectance maps +### and illumination maps, then train the adjust net. +###train_data +train_low_data = [] +train_high_data = [] +train_low_data_names = glob('./LOLdataset/our485/low/*.png') +train_low_data_names.sort() +train_high_data_names = glob('./LOLdataset/our485/high/*.png') +train_high_data_names.sort() +assert len(train_low_data_names) == len(train_high_data_names) +print('[*] Number of training data: %d' % len(train_low_data_names)) +for idx in range(len(train_low_data_names)): + low_im = load_images(train_low_data_names[idx]) + train_low_data.append(low_im) + high_im = load_images(train_high_data_names[idx]) + train_high_data.append(high_im) + +pre_decom_checkpoint_dir = './checkpoint/decom_net_train/' +ckpt_pre=tf.train.get_checkpoint_state(pre_decom_checkpoint_dir) +if ckpt_pre: + print('loaded '+ckpt_pre.model_checkpoint_path) + saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path) +else: + print('No pre_decom_net checkpoint!') + +#decomposed_low_r_data_480 = [] +decomposed_low_i_data_480 = [] +#decomposed_high_r_data_480 = [] +decomposed_high_i_data_480 = [] +for idx in range(len(train_low_data)): + input_low = np.expand_dims(train_low_data[idx], axis=0) + RR, II = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low}) + RR0 = np.squeeze(RR) + II0 = np.squeeze(II) + print(RR0.shape, II0.shape) + #decomposed_high_r_data_480.append(result_1_sq) + decomposed_low_i_data_480.append(II0) +for idx in range(len(train_high_data)): + input_high = np.expand_dims(train_high_data[idx], axis=0) + RR2, II2 = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_high}) + RR02 = np.squeeze(RR2) + II02 = np.squeeze(II2) + print(RR02.shape, II02.shape) + #decomposed_high_r_data_480.append(result_1_sq) + decomposed_high_i_data_480.append(II02) + +eval_adjust_low_i_data = decomposed_low_i_data_480[451:480] +eval_adjust_high_i_data = decomposed_high_i_data_480[451:480] + +train_adjust_low_i_data = decomposed_low_i_data_480[0:450] +train_adjust_high_i_data = decomposed_high_i_data_480[0:450] + +print('[*] Number of training data: %d' % len(train_adjust_high_i_data)) + +learning_rate = 0.0001 +epoch = 2000 +eval_every_epoch = 200 +train_phase = 'adjustment' +numBatch = len(train_adjust_low_i_data) // int(batch_size) +train_op = train_op_adjust +train_loss = loss_adjust +saver = saver_adjust + +checkpoint_dir = './checkpoint/illumination_adjust_net_train/' +if not os.path.isdir(checkpoint_dir): + os.makedirs(checkpoint_dir) +ckpt=tf.train.get_checkpoint_state(checkpoint_dir) +if ckpt: + print('loaded '+ckpt.model_checkpoint_path) + saver.restore(sess,ckpt.model_checkpoint_path) +else: + print("No adjustment net pre model!") + +start_step = 0 +start_epoch = 0 +iter_num = 0 +print("[*] Start training for phase %s, with start epoch %d start iter %d : " % (train_phase, start_epoch, iter_num)) + +sample_dir = './illumination_adjust_net_train/' +if not os.path.isdir(sample_dir): + os.makedirs(sample_dir) + +start_time = time.time() +image_id = 0 + +for epoch in range(start_epoch, epoch): + for batch_id in range(start_step, numBatch): + batch_input_low_i_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32") + batch_input_high_i_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32") + batch_input_low_i = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32") + batch_input_high_i = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32") + input_low_i_rand = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32") + input_high_i_rand = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32") + input_low_i_rand_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32") + input_high_i_rand_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32") + + for patch_id in range(batch_size): + i_low_data = train_adjust_low_i_data[image_id] + i_low_expand = np.expand_dims(i_low_data, axis = 2) + i_high_data = train_adjust_high_i_data[image_id] + i_high_expand = np.expand_dims(i_high_data, axis = 2) + + h, w = train_adjust_low_i_data[image_id].shape + x = random.randint(0, h - patch_size) + y = random.randint(0, w - patch_size) + i_low_data_crop = i_low_expand[x : x+patch_size, y : y+patch_size, :] + i_high_data_crop = i_high_expand[x : x+patch_size, y : y+patch_size, :] + + rand_mode = np.random.randint(0, 7) + batch_input_low_i[patch_id, :, :, :] = data_augmentation(i_low_data_crop , rand_mode) + batch_input_high_i[patch_id, :, :, :] = data_augmentation(i_high_data_crop, rand_mode) + + ratio = np.mean(i_low_data_crop/(i_high_data_crop+0.0001)) + #print(ratio) + i_low_data_ratio = np.ones([patch_size,patch_size])*(1/ratio+0.0001) + i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2) + i_high_data_ratio = np.ones([patch_size,patch_size])*(ratio) + i_high_ratio_expand = np.expand_dims(i_high_data_ratio , axis =2) + batch_input_low_i_ratio[patch_id, :, :, :] = i_low_ratio_expand + batch_input_high_i_ratio[patch_id, :, :, :] = i_high_ratio_expand + + rand_mode = np.random.randint(0, 2) + if rand_mode == 1: + input_low_i_rand[patch_id, :, :, :] = batch_input_low_i[patch_id, :, :, :] + input_high_i_rand[patch_id, :, :, :] = batch_input_high_i[patch_id, :, :, :] + input_low_i_rand_ratio[patch_id, :, :, :] = batch_input_low_i_ratio[patch_id, :, :, :] + input_high_i_rand_ratio[patch_id, :, :, :] = batch_input_high_i_ratio[patch_id, :, :, :] + else: + input_low_i_rand[patch_id, :, :, :] = batch_input_high_i[patch_id, :, :, :] + input_high_i_rand[patch_id, :, :, :] = batch_input_low_i[patch_id, :, :, :] + input_low_i_rand_ratio[patch_id, :, :, :] = batch_input_high_i_ratio[patch_id, :, :, :] + input_high_i_rand_ratio[patch_id, :, :, :] = batch_input_low_i_ratio[patch_id, :, :, :] + + image_id = (image_id + 1) % len(train_adjust_low_i_data) + + _, loss = sess.run([train_op, train_loss], feed_dict={input_low_i: input_low_i_rand,input_low_i_ratio: input_low_i_rand_ratio,\ + input_high_i: input_high_i_rand, \ + lr: learning_rate}) + print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \ + % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss)) + iter_num += 1 + if (epoch + 1) % eval_every_epoch == 0: + print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch + 1)) + + for idx in range(10): + rand_idx = idx#np.random.randint(26) + input_uu_i = eval_adjust_low_i_data[rand_idx] + input_low_eval_i = np.expand_dims(input_uu_i, axis=0) + input_low_eval_ii = np.expand_dims(input_low_eval_i, axis=3) + h, w = eval_adjust_low_i_data[idx].shape + rand_ratio = np.random.random(1)*2 + input_uu_i_ratio = np.ones([h,w]) * rand_ratio + input_low_eval_i_ratio = np.expand_dims(input_uu_i_ratio, axis=0) + input_low_eval_ii_ratio = np.expand_dims(input_low_eval_i_ratio, axis=3) + + result_1 = sess.run(output_i, feed_dict={input_low_i: input_low_eval_ii, input_low_i_ratio: input_low_eval_ii_ratio}) + save_images(os.path.join(sample_dir, 'h_eval_%d_%d_%5f.png' % ( epoch + 1 , rand_idx + 1, rand_ratio)), input_uu_i, result_1) + + + saver.save(sess, checkpoint_dir + 'model.ckpt') + +print("[*] Finish training for phase %s." % train_phase) + + + diff --git a/checkpoint/Restoration_net_train/checkpoint b/checkpoint/Restoration_net_train/checkpoint new file mode 100644 index 0000000..093fdf1 --- /dev/null +++ b/checkpoint/Restoration_net_train/checkpoint @@ -0,0 +1 @@ +model_checkpoint_path: "model.ckpt" diff --git a/checkpoint/Restoration_net_train/model.ckpt.data-00000-of-00001 b/checkpoint/Restoration_net_train/model.ckpt.data-00000-of-00001 new file mode 100644 index 0000000..46b2789 Binary files /dev/null and b/checkpoint/Restoration_net_train/model.ckpt.data-00000-of-00001 differ diff --git a/checkpoint/Restoration_net_train/model.ckpt.index b/checkpoint/Restoration_net_train/model.ckpt.index new file mode 100644 index 0000000..996bad3 Binary files /dev/null and b/checkpoint/Restoration_net_train/model.ckpt.index differ diff --git a/checkpoint/Restoration_net_train/model.ckpt.meta b/checkpoint/Restoration_net_train/model.ckpt.meta new file mode 100644 index 0000000..bb96652 Binary files /dev/null and b/checkpoint/Restoration_net_train/model.ckpt.meta differ diff --git a/checkpoint/decom_net_train/checkpoint b/checkpoint/decom_net_train/checkpoint new file mode 100644 index 0000000..febd7d5 --- /dev/null +++ b/checkpoint/decom_net_train/checkpoint @@ -0,0 +1,2 @@ +model_checkpoint_path: "model.ckpt" +all_model_checkpoint_paths: "model.ckpt" diff --git a/checkpoint/decom_net_train/model.ckpt.data-00000-of-00001 b/checkpoint/decom_net_train/model.ckpt.data-00000-of-00001 new file mode 100644 index 0000000..f500187 Binary files /dev/null and b/checkpoint/decom_net_train/model.ckpt.data-00000-of-00001 differ diff --git a/checkpoint/decom_net_train/model.ckpt.index b/checkpoint/decom_net_train/model.ckpt.index new file mode 100644 index 0000000..6da6a7e Binary files /dev/null and b/checkpoint/decom_net_train/model.ckpt.index differ diff --git a/checkpoint/decom_net_train/model.ckpt.meta b/checkpoint/decom_net_train/model.ckpt.meta new file mode 100644 index 0000000..025ea8c Binary files /dev/null and b/checkpoint/decom_net_train/model.ckpt.meta differ diff --git a/checkpoint/illumination_adjust_net_train/checkpoint b/checkpoint/illumination_adjust_net_train/checkpoint new file mode 100644 index 0000000..febd7d5 --- /dev/null +++ b/checkpoint/illumination_adjust_net_train/checkpoint @@ -0,0 +1,2 @@ +model_checkpoint_path: "model.ckpt" +all_model_checkpoint_paths: "model.ckpt" diff --git a/checkpoint/illumination_adjust_net_train/model.ckpt.data-00000-of-00001 b/checkpoint/illumination_adjust_net_train/model.ckpt.data-00000-of-00001 new file mode 100644 index 0000000..88e2dd2 Binary files /dev/null and b/checkpoint/illumination_adjust_net_train/model.ckpt.data-00000-of-00001 differ diff --git a/checkpoint/illumination_adjust_net_train/model.ckpt.index b/checkpoint/illumination_adjust_net_train/model.ckpt.index new file mode 100644 index 0000000..826bfe4 Binary files /dev/null and b/checkpoint/illumination_adjust_net_train/model.ckpt.index differ diff --git a/checkpoint/illumination_adjust_net_train/model.ckpt.meta b/checkpoint/illumination_adjust_net_train/model.ckpt.meta new file mode 100644 index 0000000..9615198 Binary files /dev/null and b/checkpoint/illumination_adjust_net_train/model.ckpt.meta differ diff --git a/decomposition_net_train.py b/decomposition_net_train.py new file mode 100644 index 0000000..2098469 --- /dev/null +++ b/decomposition_net_train.py @@ -0,0 +1,174 @@ +# coding: utf-8 +from __future__ import print_function +import os, time, random +import tensorflow as tf +from PIL import Image +import numpy as np +from utils import * +from model import * +from glob import glob + +batch_size = 10 +patch_size = 48 + +sess = tf.Session() + +input_low = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low') +input_high = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high') + +[R_low, I_low] = DecomNet_simple(input_low) +[R_high, I_high] = DecomNet_simple(input_high) + +I_low_3 = tf.concat([I_low, I_low, I_low], axis=3) +I_high_3 = tf.concat([I_high, I_high, I_high], axis=3) + +#network output +output_R_low = R_low +output_R_high = R_high +output_I_low = I_low_3 +output_I_high = I_high_3 + +# define loss + +def mutual_i_loss(input_I_low, input_I_high): + low_gradient_x = gradient(input_I_low, "x") + high_gradient_x = gradient(input_I_high, "x") + x_loss = (low_gradient_x + high_gradient_x)* tf.exp(-10*(low_gradient_x+high_gradient_x)) + low_gradient_y = gradient(input_I_low, "y") + high_gradient_y = gradient(input_I_high, "y") + y_loss = (low_gradient_y + high_gradient_y) * tf.exp(-10*(low_gradient_y+high_gradient_y)) + mutual_loss = tf.reduce_mean( x_loss + y_loss) + return mutual_loss + +def mutual_i_input_loss(input_I_low, input_im): + input_gray = tf.image.rgb_to_grayscale(input_im) + low_gradient_x = gradient(input_I_low, "x") + input_gradient_x = gradient(input_gray, "x") + x_loss = tf.abs(tf.div(low_gradient_x, tf.maximum(input_gradient_x, 0.01))) + low_gradient_y = gradient(input_I_low, "y") + input_gradient_y = gradient(input_gray, "y") + y_loss = tf.abs(tf.div(low_gradient_y, tf.maximum(input_gradient_y, 0.01))) + mut_loss = tf.reduce_mean(x_loss + y_loss) + return mut_loss + +recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - input_low)) +recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - input_high)) + +equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high)) + +i_mutual_loss = mutual_i_loss(I_low, I_high) + +i_input_mutual_loss_high = mutual_i_input_loss(I_high, input_high) +i_input_mutual_loss_low = mutual_i_input_loss(I_low, input_low) + +loss_Decom = 1*recon_loss_high + 1*recon_loss_low \ + + 0.01 * equal_R_loss + 0.2*i_mutual_loss \ + + 0.15* i_input_mutual_loss_high + 0.15* i_input_mutual_loss_low + +### +lr = tf.placeholder(tf.float32, name='learning_rate') + +optimizer = tf.train.AdamOptimizer(learning_rate=lr, name='AdamOptimizer') +var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name] + +train_op_Decom = optimizer.minimize(loss_Decom, var_list = var_Decom) +sess.run(tf.global_variables_initializer()) + +saver_Decom = tf.train.Saver(var_list = var_Decom) +print("[*] Initialize model successfully...") + +#load data +###train_data +train_low_data = [] +train_high_data = [] +train_low_data_names = glob('./LOLdataset/our485/low/*.png') +train_low_data_names.sort() +train_high_data_names = glob('./LOLdataset/our485/high/*.png') +train_high_data_names.sort() +assert len(train_low_data_names) == len(train_high_data_names) +print('[*] Number of training data: %d' % len(train_low_data_names)) +for idx in range(len(train_low_data_names)): + low_im = load_images(train_low_data_names[idx]) + train_low_data.append(low_im) + high_im = load_images(train_high_data_names[idx]) + train_high_data.append(high_im) +###eval_data +eval_low_data = [] +eval_high_data = [] +eval_low_data_name = glob('./LOLdataset/eval15/low/*.png') +eval_low_data_name.sort() +eval_high_data_name = glob('./LOLdataset/eval15/high/*.png*') +eval_high_data_name.sort() +for idx in range(len(eval_low_data_name)): + eval_low_im = load_images(eval_low_data_name[idx]) + eval_low_data.append(eval_low_im) + eval_high_im = load_images(eval_high_data_name[idx]) + eval_high_data.append(eval_high_im) + + +epoch = 2000 +learning_rate = 0.0001 + +sample_dir = './Decom_net_train/' +if not os.path.isdir(sample_dir): + os.makedirs(sample_dir) + +eval_every_epoch = 200 +train_phase = 'decomposition' +numBatch = len(train_low_data) // int(batch_size) +train_op = train_op_Decom +train_loss = loss_Decom +saver = saver_Decom + +checkpoint_dir = './checkpoint/decom_net_train/' +if not os.path.isdir(checkpoint_dir): + os.makedirs(checkpoint_dir) +ckpt=tf.train.get_checkpoint_state(checkpoint_dir) +if ckpt: + print('loaded '+ckpt.model_checkpoint_path) + saver.restore(sess,ckpt.model_checkpoint_path) + +start_step = 0 +start_epoch = 0 +iter_num = 0 +print("[*] Start training for phase %s, with start epoch %d start iter %d : " % (train_phase, start_epoch, iter_num)) + +start_time = time.time() +image_id = 0 +for epoch in range(start_epoch, epoch): + for batch_id in range(start_step, numBatch): + batch_input_low = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32") + batch_input_high = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32") + for patch_id in range(batch_size): + h, w, _ = train_low_data[image_id].shape + x = random.randint(0, h - patch_size) + y = random.randint(0, w - patch_size) + rand_mode = random.randint(0, 7) + batch_input_low[patch_id, :, :, :] = data_augmentation(train_low_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode) + batch_input_high[patch_id, :, :, :] = data_augmentation(train_high_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode) + image_id = (image_id + 1) % len(train_low_data) + if image_id == 0: + tmp = list(zip(train_low_data, train_high_data)) + random.shuffle(tmp) + train_low_data, train_high_data = zip(*tmp) + + _, loss = sess.run([train_op, train_loss], feed_dict={input_low: batch_input_low, \ + input_high: batch_input_high, \ + lr: learning_rate}) + print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \ + % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss)) + iter_num += 1 + if (epoch + 1) % eval_every_epoch == 0: + print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch + 1)) + for idx in range(len(eval_low_data)): + input_low_eval = np.expand_dims(eval_low_data[idx], axis=0) + result_1, result_2 = sess.run([output_R_low, output_I_low], feed_dict={input_low: input_low_eval}) + save_images(os.path.join(sample_dir, 'low_%d_%d.png' % ( idx + 1, epoch + 1)), result_1, result_2) + for idx in range(len(eval_high_data)): + input_high_eval = np.expand_dims(eval_high_data[idx], axis=0) + result_11, result_22 = sess.run([output_R_high, output_I_high], feed_dict={input_high: input_high_eval}) + save_images(os.path.join(sample_dir, 'high_%d_%d.png' % ( idx + 1, epoch + 1)), result_11, result_22) + + saver.save(sess, checkpoint_dir + 'model.ckpt') + +print("[*] Finish training for phase %s." % train_phase) diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..5d8a633 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,114 @@ +# coding: utf-8 +from __future__ import print_function +import os +import time +import random +from PIL import Image +import tensorflow as tf +import numpy as np +from utils import * +from model import * +from glob import glob +from skimage import color,filters + +sess = tf.Session() + +input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom') +input_low_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low_r') +input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i') +input_high_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high_r') +input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i') +input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio') + +[R_decom, I_decom] = DecomNet_simple(input_decom) +decom_output_R = R_decom +decom_output_I = I_decom +output_r = Restoration_net(input_low_r, input_low_i) +output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio) + +var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name] +var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name] +var_restoration = [var for var in tf.trainable_variables() if 'Restoration_net' in var.name] + +saver_Decom = tf.train.Saver(var_list = var_Decom) +saver_adjust = tf.train.Saver(var_list=var_adjust) +saver_restoration = tf.train.Saver(var_list=var_restoration) + +decom_checkpoint_dir ='./checkpoint/decom_net_train/' +ckpt_pre=tf.train.get_checkpoint_state(decom_checkpoint_dir) +if ckpt_pre: + print('loaded '+ckpt_pre.model_checkpoint_path) + saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path) +else: + print('No decomnet checkpoint!') + +checkpoint_dir_adjust = './checkpoint/illumination_adjust_net_train/' +ckpt_adjust=tf.train.get_checkpoint_state(checkpoint_dir_adjust) +if ckpt_adjust: + print('loaded '+ckpt_adjust.model_checkpoint_path) + saver_adjust.restore(sess,ckpt_adjust.model_checkpoint_path) +else: + print("No adjust pre model!") + +checkpoint_dir_restoration = './checkpoint/Restoration_net_train/' +ckpt=tf.train.get_checkpoint_state(checkpoint_dir_restoration) +if ckpt: + print('loaded '+ckpt.model_checkpoint_path) + saver_restoration.restore(sess,ckpt.model_checkpoint_path) +else: + print("No restoration pre model!") + +###load eval data +eval_low_data = [] +eval_img_name =[] +eval_low_data_name = glob('./test/*') +eval_low_data_name.sort() +for idx in range(len(eval_low_data_name)): + [_, name] = os.path.split(eval_low_data_name[idx]) + suffix = name[name.find('.') + 1:] + name = name[:name.find('.')] + eval_img_name.append(name) + eval_low_im = load_images(eval_low_data_name[idx]) + eval_low_data.append(eval_low_im) + print(eval_low_im.shape) + +sample_dir = './results/test/' +if not os.path.isdir(sample_dir): + os.makedirs(sample_dir) + +print("Start evalating!") +start_time = time.time() +for idx in range(len(eval_low_data)): + print(idx) + name = eval_img_name[idx] + input_low = eval_low_data[idx] + input_low_eval = np.expand_dims(input_low, axis=0) + h, w, _ = input_low.shape + + decom_r_low, decom_i_low = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low_eval}) + + restoration_r = sess.run(output_r, feed_dict={input_low_r: decom_r_low, input_low_i: decom_i_low}) +### change the ratio to get different exposure level, the value can be 0-5.0 + ratio = 5.0 + i_low_data_ratio = np.ones([h, w])*(ratio) + i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2) + i_low_ratio_expand2 = np.expand_dims(i_low_ratio_expand, axis=0) + adjust_i = sess.run(output_i, feed_dict={input_low_i: decom_i_low, input_low_i_ratio: i_low_ratio_expand2}) + + #The restoration result can find more details from very dark regions, however, it will restore the very dark regions +#with gray colors, we use the following operator to alleviate this weakness. + decom_r_sq = np.squeeze(decom_r_low) + r_gray = color.rgb2gray(decom_r_sq) + r_gray_gaussion = filters.gaussian(r_gray, 3) + low_i = np.minimum((r_gray_gaussion*2)**0.5,1) + low_i_expand_0 = np.expand_dims(low_i, axis = 0) + low_i_expand_3 = np.expand_dims(low_i_expand_0, axis = 3) + result_denoise = restoration_r*low_i_expand_3 + fusion4 = result_denoise*adjust_i + + #fusion = restoration_r*adjust_i +# fuse with the original input to avoid over-exposure + fusion2 = decom_i_low*input_low_eval + (1-decom_i_low)*fusion4 + #print(fusion2.shape) + save_images(os.path.join(sample_dir, '%s_kindle.png' % (name)), fusion2) + diff --git a/evaluate_LOLdataset.py b/evaluate_LOLdataset.py new file mode 100644 index 0000000..39d37d2 --- /dev/null +++ b/evaluate_LOLdataset.py @@ -0,0 +1,110 @@ +# coding: utf-8 +from __future__ import print_function +import os +import time +import random +from PIL import Image +import tensorflow as tf +import numpy as np +from utils import * +from model import * +from glob import glob + +sess = tf.Session() + +input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom') +input_low_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low_r') +input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i') +input_high_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high_r') +input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i') +input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio') + +[R_decom, I_decom] = DecomNet_simple(input_decom) +decom_output_R = R_decom +decom_output_I = I_decom +output_r = Restoration_net(input_low_r, input_low_i) +output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio) + +var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name] +var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name] +var_restoration = [var for var in tf.trainable_variables() if 'Restoration_net' in var.name] + +saver_Decom = tf.train.Saver(var_list = var_Decom) +saver_adjust = tf.train.Saver(var_list=var_adjust) +saver_restoration = tf.train.Saver(var_list=var_restoration) + +decom_checkpoint_dir ='./checkpoint/decom_net_train/' +ckpt_pre=tf.train.get_checkpoint_state(decom_checkpoint_dir) +if ckpt_pre: + print('loaded '+ckpt_pre.model_checkpoint_path) + saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path) +else: + print('No decomnet checkpoint!') + +checkpoint_dir_adjust = './checkpoint/illumination_adjust_net_train/' +ckpt_adjust=tf.train.get_checkpoint_state(checkpoint_dir_adjust) +if ckpt_adjust: + print('loaded '+ckpt_adjust.model_checkpoint_path) + saver_adjust.restore(sess,ckpt_adjust.model_checkpoint_path) +else: + print("No adjust pre model!") + +checkpoint_dir_restoration = './checkpoint/Restoration_net_train/' +ckpt=tf.train.get_checkpoint_state(checkpoint_dir_restoration) +if ckpt: + print('loaded '+ckpt.model_checkpoint_path) + saver_restoration.restore(sess,ckpt.model_checkpoint_path) +else: + print("No restoration pre model!") + +###load eval data +eval_low_data = [] +eval_img_name =[] +eval_low_data_name = glob('./LOLdataset/eval15/low/*.png') +eval_low_data_name.sort() +for idx in range(len(eval_low_data_name)): + [_, name] = os.path.split(eval_low_data_name[idx]) + suffix = name[name.find('.') + 1:] + name = name[:name.find('.')] + eval_img_name.append(name) + eval_low_im = load_images(eval_low_data_name[idx]) + eval_low_data.append(eval_low_im) + print(eval_low_im.shape) +# To get better results, the illumination adjustment ratio is computed based on the decom_i_high, so we also need the high data. +eval_high_data = [] +eval_high_data_name = glob('./LOLdataset/eval15/high/*.png') +eval_high_data_name.sort() +for idx in range(len(eval_high_data_name)): + eval_high_im = load_images(eval_high_data_name[idx]) + eval_high_data.append(eval_high_im) + +sample_dir = './results/LOLdataset_eval15/' +if not os.path.isdir(sample_dir): + os.makedirs(sample_dir) + +print("Start evalating!") +start_time = time.time() +for idx in range(len(eval_low_data)): + print(idx) + name = eval_img_name[idx] + input_low = eval_low_data[idx] + input_low_eval = np.expand_dims(input_low, axis=0) + input_high = eval_high_data[idx] + input_high_eval = np.expand_dims(input_high, axis=0) + h, w, _ = input_low.shape + + decom_r_low, decom_i_low = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low_eval}) + decom_r_high, decom_i_high = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_high_eval}) + + restoration_r = sess.run(output_r, feed_dict={input_low_r: decom_r_low, input_low_i: decom_i_low}) + + ratio = np.mean(((decom_i_high))/(decom_i_low+0.0001)) + + i_low_data_ratio = np.ones([h, w])*(ratio) + i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2) + i_low_ratio_expand2 = np.expand_dims(i_low_ratio_expand, axis=0) + + adjust_i = sess.run(output_i, feed_dict={input_low_i: decom_i_low, input_low_i_ratio: i_low_ratio_expand2}) + fusion = restoration_r*adjust_i + save_images(os.path.join(sample_dir, '%s_kindle.png' % (name)), fusion) + diff --git a/figures/network.jpg b/figures/network.jpg new file mode 100644 index 0000000..62001c2 Binary files /dev/null and b/figures/network.jpg differ diff --git a/figures/result.jpg b/figures/result.jpg new file mode 100644 index 0000000..3d89461 Binary files /dev/null and b/figures/result.jpg differ diff --git a/illumination_adjust_net_train/h_eval_200_10_0.243455.png b/illumination_adjust_net_train/h_eval_200_10_0.243455.png new file mode 100644 index 0000000..3fc33a8 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_10_0.243455.png differ diff --git a/illumination_adjust_net_train/h_eval_200_1_0.283560.png b/illumination_adjust_net_train/h_eval_200_1_0.283560.png new file mode 100644 index 0000000..6876f23 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_1_0.283560.png differ diff --git a/illumination_adjust_net_train/h_eval_200_2_1.097869.png b/illumination_adjust_net_train/h_eval_200_2_1.097869.png new file mode 100644 index 0000000..a62ce03 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_2_1.097869.png differ diff --git a/illumination_adjust_net_train/h_eval_200_3_0.790207.png b/illumination_adjust_net_train/h_eval_200_3_0.790207.png new file mode 100644 index 0000000..94091d6 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_3_0.790207.png differ diff --git a/illumination_adjust_net_train/h_eval_200_4_1.368601.png b/illumination_adjust_net_train/h_eval_200_4_1.368601.png new file mode 100644 index 0000000..bcf93ea Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_4_1.368601.png differ diff --git a/illumination_adjust_net_train/h_eval_200_5_0.979668.png b/illumination_adjust_net_train/h_eval_200_5_0.979668.png new file mode 100644 index 0000000..e65bd36 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_5_0.979668.png differ diff --git a/illumination_adjust_net_train/h_eval_200_6_1.648318.png b/illumination_adjust_net_train/h_eval_200_6_1.648318.png new file mode 100644 index 0000000..494f340 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_6_1.648318.png differ diff --git a/illumination_adjust_net_train/h_eval_200_7_0.342475.png b/illumination_adjust_net_train/h_eval_200_7_0.342475.png new file mode 100644 index 0000000..755ec5f Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_7_0.342475.png differ diff --git a/illumination_adjust_net_train/h_eval_200_8_0.597354.png b/illumination_adjust_net_train/h_eval_200_8_0.597354.png new file mode 100644 index 0000000..c2a3280 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_8_0.597354.png differ diff --git a/illumination_adjust_net_train/h_eval_200_9_0.969436.png b/illumination_adjust_net_train/h_eval_200_9_0.969436.png new file mode 100644 index 0000000..fe970a8 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_9_0.969436.png differ diff --git a/illumination_adjust_net_train/h_eval_400_10_1.080474.png b/illumination_adjust_net_train/h_eval_400_10_1.080474.png new file mode 100644 index 0000000..14b27bb Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_10_1.080474.png differ diff --git a/illumination_adjust_net_train/h_eval_400_1_1.829798.png b/illumination_adjust_net_train/h_eval_400_1_1.829798.png new file mode 100644 index 0000000..91aeeb9 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_1_1.829798.png differ diff --git a/illumination_adjust_net_train/h_eval_400_2_0.833631.png b/illumination_adjust_net_train/h_eval_400_2_0.833631.png new file mode 100644 index 0000000..55eea1a Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_2_0.833631.png differ diff --git a/illumination_adjust_net_train/h_eval_400_3_0.639403.png b/illumination_adjust_net_train/h_eval_400_3_0.639403.png new file mode 100644 index 0000000..6efd2ec Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_3_0.639403.png differ diff --git a/illumination_adjust_net_train/h_eval_400_4_0.080941.png b/illumination_adjust_net_train/h_eval_400_4_0.080941.png new file mode 100644 index 0000000..c8797c9 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_4_0.080941.png differ diff --git a/illumination_adjust_net_train/h_eval_400_5_0.118340.png b/illumination_adjust_net_train/h_eval_400_5_0.118340.png new file mode 100644 index 0000000..b8c4f00 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_5_0.118340.png differ diff --git a/illumination_adjust_net_train/h_eval_400_6_1.416026.png b/illumination_adjust_net_train/h_eval_400_6_1.416026.png new file mode 100644 index 0000000..c37bba8 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_6_1.416026.png differ diff --git a/illumination_adjust_net_train/h_eval_400_7_0.024722.png b/illumination_adjust_net_train/h_eval_400_7_0.024722.png new file mode 100644 index 0000000..6b61d59 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_7_0.024722.png differ diff --git a/illumination_adjust_net_train/h_eval_400_8_1.792914.png b/illumination_adjust_net_train/h_eval_400_8_1.792914.png new file mode 100644 index 0000000..29c7696 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_8_1.792914.png differ diff --git a/illumination_adjust_net_train/h_eval_400_9_1.454717.png b/illumination_adjust_net_train/h_eval_400_9_1.454717.png new file mode 100644 index 0000000..b425dda Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_9_1.454717.png differ diff --git a/illumination_adjust_net_train/h_eval_600_10_1.570546.png b/illumination_adjust_net_train/h_eval_600_10_1.570546.png new file mode 100644 index 0000000..48bf314 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_10_1.570546.png differ diff --git a/illumination_adjust_net_train/h_eval_600_1_1.194817.png b/illumination_adjust_net_train/h_eval_600_1_1.194817.png new file mode 100644 index 0000000..567112b Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_1_1.194817.png differ diff --git a/illumination_adjust_net_train/h_eval_600_2_0.413976.png b/illumination_adjust_net_train/h_eval_600_2_0.413976.png new file mode 100644 index 0000000..ed1c71d Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_2_0.413976.png differ diff --git a/illumination_adjust_net_train/h_eval_600_3_0.765548.png b/illumination_adjust_net_train/h_eval_600_3_0.765548.png new file mode 100644 index 0000000..db2f816 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_3_0.765548.png differ diff --git a/illumination_adjust_net_train/h_eval_600_4_0.745525.png b/illumination_adjust_net_train/h_eval_600_4_0.745525.png new file mode 100644 index 0000000..6673647 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_4_0.745525.png differ diff --git a/illumination_adjust_net_train/h_eval_600_5_1.079725.png b/illumination_adjust_net_train/h_eval_600_5_1.079725.png new file mode 100644 index 0000000..1f6dc30 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_5_1.079725.png differ diff --git a/illumination_adjust_net_train/h_eval_600_6_0.365893.png b/illumination_adjust_net_train/h_eval_600_6_0.365893.png new file mode 100644 index 0000000..7eed807 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_6_0.365893.png differ diff --git a/illumination_adjust_net_train/h_eval_600_7_0.716873.png b/illumination_adjust_net_train/h_eval_600_7_0.716873.png new file mode 100644 index 0000000..08369a3 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_7_0.716873.png differ diff --git a/illumination_adjust_net_train/h_eval_600_8_0.492159.png b/illumination_adjust_net_train/h_eval_600_8_0.492159.png new file mode 100644 index 0000000..85df0d4 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_8_0.492159.png differ diff --git a/illumination_adjust_net_train/h_eval_600_9_0.445431.png b/illumination_adjust_net_train/h_eval_600_9_0.445431.png new file mode 100644 index 0000000..b5a8597 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_9_0.445431.png differ diff --git a/illumination_adjust_net_train/h_eval_800_10_0.235056.png b/illumination_adjust_net_train/h_eval_800_10_0.235056.png new file mode 100644 index 0000000..b2217a7 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_10_0.235056.png differ diff --git a/illumination_adjust_net_train/h_eval_800_1_1.459935.png b/illumination_adjust_net_train/h_eval_800_1_1.459935.png new file mode 100644 index 0000000..e9d3700 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_1_1.459935.png differ diff --git a/illumination_adjust_net_train/h_eval_800_2_0.965174.png b/illumination_adjust_net_train/h_eval_800_2_0.965174.png new file mode 100644 index 0000000..afcbda4 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_2_0.965174.png differ diff --git a/illumination_adjust_net_train/h_eval_800_3_0.243200.png b/illumination_adjust_net_train/h_eval_800_3_0.243200.png new file mode 100644 index 0000000..6ee9fa2 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_3_0.243200.png differ diff --git a/illumination_adjust_net_train/h_eval_800_4_1.968109.png b/illumination_adjust_net_train/h_eval_800_4_1.968109.png new file mode 100644 index 0000000..895f40c Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_4_1.968109.png differ diff --git a/illumination_adjust_net_train/h_eval_800_5_1.170218.png b/illumination_adjust_net_train/h_eval_800_5_1.170218.png new file mode 100644 index 0000000..a9994de Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_5_1.170218.png differ diff --git a/illumination_adjust_net_train/h_eval_800_6_0.804922.png b/illumination_adjust_net_train/h_eval_800_6_0.804922.png new file mode 100644 index 0000000..29d807d Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_6_0.804922.png differ diff --git a/illumination_adjust_net_train/h_eval_800_7_1.683641.png b/illumination_adjust_net_train/h_eval_800_7_1.683641.png new file mode 100644 index 0000000..82ffbee Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_7_1.683641.png differ diff --git a/illumination_adjust_net_train/h_eval_800_8_1.448064.png b/illumination_adjust_net_train/h_eval_800_8_1.448064.png new file mode 100644 index 0000000..c924e5c Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_8_1.448064.png differ diff --git a/illumination_adjust_net_train/h_eval_800_9_1.708184.png b/illumination_adjust_net_train/h_eval_800_9_1.708184.png new file mode 100644 index 0000000..a3309b9 Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_9_1.708184.png differ diff --git a/model.py b/model.py new file mode 100644 index 0000000..eae6e85 --- /dev/null +++ b/model.py @@ -0,0 +1,97 @@ +import tensorflow as tf +import tensorflow.contrib.slim as slim +from tensorflow.contrib.layers.python.layers import initializers + +def lrelu(x, trainbable=None): + return tf.maximum(x*0.2,x) + +def upsample_and_concat(x1, x2, output_channels, in_channels, scope_name, trainable=True): + with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope: + pool_size = 2 + deconv_filter = tf.get_variable('weights', [pool_size, pool_size, output_channels, in_channels], trainable= True) + deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2) , strides=[1, pool_size, pool_size, 1], name=scope_name) + + deconv_output = tf.concat([deconv, x2],3) + deconv_output.set_shape([None, None, None, output_channels*2]) + + return deconv_output + +def DecomNet_simple(input): + with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE): + conv1=slim.conv2d(input,32,[3,3], rate=1, activation_fn=lrelu,scope='g_conv1_1') + pool1=slim.max_pool2d(conv1, [2, 2], stride = 2, padding='SAME' ) + conv2=slim.conv2d(pool1,64,[3,3], rate=1, activation_fn=lrelu,scope='g_conv2_1') + pool2=slim.max_pool2d(conv2, [2, 2], stride = 2, padding='SAME' ) + conv3=slim.conv2d(pool2,128,[3,3], rate=1, activation_fn=lrelu,scope='g_conv3_1') + up8 = upsample_and_concat( conv3, conv2, 64, 128 , 'g_up_1') + conv8=slim.conv2d(up8, 64,[3,3], rate=1, activation_fn=lrelu,scope='g_conv8_1') + up9 = upsample_and_concat( conv8, conv1, 32, 64 , 'g_up_2') + conv9=slim.conv2d(up9, 32,[3,3], rate=1, activation_fn=lrelu,scope='g_conv9_1') + # Here, we use 1*1 kernel to replace the 3*3 ones in the paper to get better results. + conv10=slim.conv2d(conv9,3,[1,1], rate=1, activation_fn=None, scope='g_conv10') + R_out = tf.sigmoid(conv10) + + l_conv2=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='l_conv1_2') + l_conv3=tf.concat([l_conv2, conv9],3) + # Here, we use 1*1 kernel to replace the 3*3 ones in the paper to get better results. + l_conv4=slim.conv2d(l_conv3,1,[1,1], rate=1, activation_fn=None,scope='l_conv1_4') + L_out = tf.sigmoid(l_conv4) + + return R_out, L_out + +def Restoration_net(input_r, input_i): + with tf.variable_scope('Restoration_net', reuse=tf.AUTO_REUSE): + input_all = tf.concat([input_r,input_i], 3) + + conv1=slim.conv2d(input_all,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv1_1') + conv1=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv1_2') + pool1=slim.max_pool2d(conv1, [2, 2], padding='SAME' ) + + conv2=slim.conv2d(pool1,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv2_1') + conv2=slim.conv2d(conv2,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv2_2') + pool2=slim.max_pool2d(conv2, [2, 2], padding='SAME' ) + + conv3=slim.conv2d(pool2,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv3_1') + conv3=slim.conv2d(conv3,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv3_2') + pool3=slim.max_pool2d(conv3, [2, 2], padding='SAME' ) + + conv4=slim.conv2d(pool3,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv4_1') + conv4=slim.conv2d(conv4,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv4_2') + pool4=slim.max_pool2d(conv4, [2, 2], padding='SAME' ) + + conv5=slim.conv2d(pool4,512,[3,3], rate=1, activation_fn=lrelu,scope='de_conv5_1') + conv5=slim.conv2d(conv5,512,[3,3], rate=1, activation_fn=lrelu,scope='de_conv5_2') + + up6 = upsample_and_concat( conv5, conv4, 256, 512, 'up_6') + + conv6=slim.conv2d(up6, 256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv6_1') + conv6=slim.conv2d(conv6,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv6_2') + + up7 = upsample_and_concat( conv6, conv3, 128, 256, 'up_7' ) + conv7=slim.conv2d(up7, 128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv7_1') + conv7=slim.conv2d(conv7,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv7_2') + + up8 = upsample_and_concat( conv7, conv2, 64, 128, 'up_8' ) + conv8=slim.conv2d(up8, 64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv8_1') + conv8=slim.conv2d(conv8,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv8_2') + + up9 = upsample_and_concat( conv8, conv1, 32, 64, 'up_9' ) + conv9=slim.conv2d(up9, 32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv9_1') + conv9=slim.conv2d(conv9,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv9_2') + + conv10=slim.conv2d(conv9,3,[3,3], rate=1, activation_fn=None, scope='de_conv10') + + out = tf.sigmoid(conv10) + return out + +def Illumination_adjust_net(input_i, input_ratio): + with tf.variable_scope('Illumination_adjust_net', reuse=tf.AUTO_REUSE): + input_all = tf.concat([input_i, input_ratio], 3) + + conv1=slim.conv2d(input_all,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_1') + conv2=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_2') + conv3=slim.conv2d(conv2,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_3') + conv4=slim.conv2d(conv3,1,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_4') + + L_enhance = tf.sigmoid(conv4) + return L_enhance \ No newline at end of file diff --git a/reflectance_restoration_net_train.py b/reflectance_restoration_net_train.py new file mode 100644 index 0000000..e746420 --- /dev/null +++ b/reflectance_restoration_net_train.py @@ -0,0 +1,244 @@ +# coding: utf-8 +from __future__ import print_function +import os +import time +import random +from PIL import Image +import tensorflow as tf +import numpy as np +from utils import * +from model import * +from glob import glob + +batch_size = 4 +patch_size = 384 + +config = tf.ConfigProto() +config.gpu_options.allow_growth = True +sess=tf.Session(config=config) +#the input of decomposition net +input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom') +#restoration input +input_low_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low_r') +input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i') +input_high_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high_r') + +[R_decom, I_decom] = DecomNet_simple(input_decom) +#the output of decomposition network +decom_output_R = R_decom +decom_output_I = I_decom + +output_r = Restoration_net(input_low_r, input_low_i) + +#define loss +def grad_loss(input_r_low, input_r_high): + input_r_low_gray = tf.image.rgb_to_grayscale(input_r_low) + input_r_high_gray = tf.image.rgb_to_grayscale(input_r_high) + x_loss = tf.square(gradient(input_r_low_gray, 'x') - gradient(input_r_high_gray, 'x')) + y_loss = tf.square(gradient(input_r_low_gray, 'y') - gradient(input_r_high_gray, 'y')) + grad_loss_all = tf.reduce_mean(x_loss + y_loss) + return grad_loss_all + +def ssim_loss(output_r, input_high_r): + output_r_1 = output_r[:,:,:,0:1] + input_high_r_1 = input_high_r[:,:,:,0:1] + ssim_r_1 = tf_ssim(output_r_1, input_high_r_1) + output_r_2 = output_r[:,:,:,1:2] + input_high_r_2 = input_high_r[:,:,:,1:2] + ssim_r_2 = tf_ssim(output_r_2, input_high_r_2) + output_r_3 = output_r[:,:,:,2:3] + input_high_r_3 = input_high_r[:,:,:,2:3] + ssim_r_3 = tf_ssim(output_r_3, input_high_r_3) + ssim_r = (ssim_r_1 + ssim_r_2 + ssim_r_3)/3.0 + loss_ssim1 = 1-ssim_r + return loss_ssim1 + +loss_square = tf.reduce_mean(tf.square(output_r - input_high_r)) +loss_ssim = ssim_loss(output_r, input_high_r) +loss_grad = grad_loss(output_r, input_high_r) + +loss_restoration = loss_square + loss_grad + loss_ssim + +### initialize +lr = tf.placeholder(tf.float32, name='learning_rate') +global_step = tf.get_variable('global_step', [], dtype=tf.int32, initializer=tf.constant_initializer(0), trainable=False) +update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) +optimizer = tf.train.AdamOptimizer(learning_rate=lr, name='AdamOptimizer') +with tf.control_dependencies(update_ops): + grads = optimizer.compute_gradients(loss_restoration) + train_op_restoration = optimizer.apply_gradients(grads, global_step=global_step) + +var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name] +var_restoration = [var for var in tf.trainable_variables() if 'Restoration_net' in var.name] + +saver_restoration = tf.train.Saver(var_list=var_restoration) +saver_Decom = tf.train.Saver(var_list = var_Decom) +sess.run(tf.global_variables_initializer()) +print("[*] Initialize model successfully...") + +### load data +### Based on the decomposition net, we first get the decomposed reflectance maps +### and illumination maps, then train the restoration net. +###train_data +train_low_data = [] +train_high_data = [] +train_low_data_names = glob('./LOLdataset/our485/low/*.png') +train_low_data_names.sort() +train_high_data_names = glob('./LOLdataset/our485/high/*.png') +train_high_data_names.sort() +assert len(train_low_data_names) == len(train_high_data_names) +print('[*] Number of training data: %d' % len(train_low_data_names)) +for idx in range(len(train_low_data_names)): + low_im = load_images(train_low_data_names[idx]) + train_low_data.append(low_im) + high_im = load_images(train_high_data_names[idx]) + train_high_data.append(high_im) + +eval_low_data = [] +eval_low_data_names = glob('./LOLdataset/eval15/low/*.png') +eval_low_data_names.sort() +for idx in range(len(eval_low_data_names)): + eval_low_im = load_images(eval_low_data_names[idx]) + eval_low_data.append(eval_low_im) + +pre_decom_checkpoint_dir = './checkpoint/decom_net_train/' +ckpt_pre=tf.train.get_checkpoint_state(pre_decom_checkpoint_dir) +if ckpt_pre: + print('loaded '+ckpt_pre.model_checkpoint_path) + saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path) +else: + print('No pre_decom_net checkpoint!') + +decomposed_low_r_data_480 = [] +decomposed_low_i_data_480 = [] +decomposed_high_r_data_480 = [] +for idx in range(len(train_low_data)): + input_low = np.expand_dims(train_low_data[idx], axis=0) + RR, II = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low}) + RR0 = np.squeeze(RR) + II0 = np.squeeze(II) + print(idx, RR0.shape, II0.shape) + decomposed_low_r_data_480.append(RR0) + decomposed_low_i_data_480.append(II0) +for idx in range(len(train_high_data)): + input_high = np.expand_dims(train_high_data[idx], axis=0) + RR2, II2 = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_high}) + ### To improve the constrast, we slightly change the decom_r_high by using decom_r_high**1.2 + RR02 = np.squeeze(RR2**1.2) + print(idx, RR02.shape) + decomposed_high_r_data_480.append(RR02) + +decomposed_eval_low_r_data = [] +decomposed_eval_low_i_data = [] +for idx in range(len(eval_low_data)): + input_eval = np.expand_dims(eval_low_data[idx], axis=0) + RR3, II3 = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_eval}) + RR03 = np.squeeze(RR3) + II03 = np.squeeze(II3) + print(idx, RR03.shape, II03.shape) + decomposed_eval_low_r_data.append(RR03) + decomposed_eval_low_i_data.append(II03) + + +eval_restoration_low_r_data = decomposed_low_r_data_480[467:480] + decomposed_eval_low_r_data[0:15] +eval_restoration_low_i_data = decomposed_low_i_data_480[467:480] + decomposed_eval_low_i_data[0:15] + +train_restoration_low_r_data = decomposed_low_r_data_480[0:466] +train_restoration_low_i_data = decomposed_low_i_data_480[0:466] +train_restoration_high_r_data = decomposed_high_r_data_480[0:466] +#train_restoration_high_i_data = train_restoration_high_i_data_480[0:466] +print(len(train_restoration_high_r_data), len(train_restoration_low_r_data),len(train_restoration_low_i_data)) +print(len(eval_restoration_low_r_data),len(eval_restoration_low_i_data)) +assert len(train_restoration_high_r_data) == len(train_restoration_low_r_data) +assert len(train_restoration_low_i_data) == len(train_restoration_low_r_data) +print('[*] Number of training data: %d' % len(train_restoration_high_r_data)) + +learning_rate = 0.0001 +def lr_schedule(epoch): + initial_lr = learning_rate + if epoch<=800: + lr = initial_lr + elif epoch<=1250: + lr = initial_lr/2 + elif epoch<=1500: + lr = initial_lr/4 + else: + lr = initial_lr/10 + return lr + +epoch = 1000 + +sample_dir = './Restoration_net_train/' +if not os.path.isdir(sample_dir): + os.makedirs(sample_dir) + +eval_every_epoch = 50 +train_phase = 'Restoration' +numBatch = len(train_restoration_low_r_data) // int(batch_size) +train_op = train_op_restoration +train_loss = loss_restoration +saver = saver_restoration + +checkpoint_dir = './checkpoint/Restoration_net_train/' +if not os.path.isdir(checkpoint_dir): + os.makedirs(checkpoint_dir) +ckpt=tf.train.get_checkpoint_state(checkpoint_dir) +if ckpt: + print('loaded '+ckpt.model_checkpoint_path) + saver_restoration.restore(sess,ckpt.model_checkpoint_path) +else: + print('No pre_restoration_net checkpoint!') + +start_step = 0 +start_epoch = 0 +iter_num = 0 +print("[*] Start training for phase %s, with start epoch %d start iter %d : " % (train_phase, start_epoch, iter_num)) +start_time = time.time() +image_id = 0 + +for epoch in range(start_epoch, epoch): + for batch_id in range(start_step, numBatch): + batch_input_low_r = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32") + batch_input_low_i = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32") + + batch_input_high_r = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32") + + for patch_id in range(batch_size): + h, w, _ = train_restoration_low_r_data[image_id].shape + x = random.randint(0, h - patch_size) + y = random.randint(0, w - patch_size) + i_low_expand = np.expand_dims(train_restoration_low_i_data[image_id], axis = 2) + rand_mode = random.randint(0, 7) + batch_input_low_r[patch_id, :, :, :] = data_augmentation(train_restoration_low_r_data[image_id][x : x+patch_size, y : y+patch_size, :] , rand_mode)#+ np.random.normal(0, 0.1, (patch_size,patch_size,3)) , rand_mode) + batch_input_low_i[patch_id, :, :, :] = data_augmentation(i_low_expand[x : x+patch_size, y : y+patch_size, :] , rand_mode)#+ np.random.normal(0, 0.1, (patch_size,patch_size,3)) , rand_mode) + + batch_input_high_r[patch_id, :, :, :] = data_augmentation(train_restoration_high_r_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode) + + image_id = (image_id + 1) % len(train_restoration_low_r_data) + if image_id == 0: + tmp = list(zip(train_restoration_low_r_data, train_restoration_low_i_data, train_restoration_high_r_data)) + random.shuffle(tmp) + train_restoration_low_r_data, train_restoration_low_i_data, train_restoration_high_r_data = zip(*tmp) + + _, loss = sess.run([train_op, train_loss], feed_dict={input_low_r: batch_input_low_r,input_low_i: batch_input_low_i,\ + input_high_r: batch_input_high_r, lr: lr_schedule(epoch)}) + print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \ + % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss)) + iter_num += 1 + if (epoch + 1) % eval_every_epoch == 0: + print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch + 1)) + for idx in range(len(eval_restoration_low_r_data)): + input_uu_r = eval_restoration_low_r_data[idx] + input_low_eval_r = np.expand_dims(input_uu_r, axis=0) + input_uu_i = eval_restoration_low_i_data[idx] + input_low_eval_i = np.expand_dims(input_uu_i, axis=0) + input_low_eval_ii = np.expand_dims(input_low_eval_i, axis=3) + result_1 = sess.run(output_r, feed_dict={input_low_r: input_low_eval_r, input_low_i: input_low_eval_ii}) + + save_images(os.path.join(sample_dir, 'eval_%d_%d.png' % ( idx + 1, epoch + 1)), input_uu_r, result_1) + saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=epoch) + +print("[*] Finish training for phase %s." % train_phase) + + + diff --git a/results/test/04_kindle.png b/results/test/04_kindle.png new file mode 100644 index 0000000..7d20726 Binary files /dev/null and b/results/test/04_kindle.png differ diff --git a/results/test/1_kindle.png b/results/test/1_kindle.png new file mode 100644 index 0000000..393127b Binary files /dev/null and b/results/test/1_kindle.png differ diff --git a/results/test/4_kindle.png b/results/test/4_kindle.png new file mode 100644 index 0000000..9589f2c Binary files /dev/null and b/results/test/4_kindle.png differ diff --git a/test/04.JPG b/test/04.JPG new file mode 100644 index 0000000..14cedda Binary files /dev/null and b/test/04.JPG differ diff --git a/test/1.bmp b/test/1.bmp new file mode 100644 index 0000000..4ba6173 Binary files /dev/null and b/test/1.bmp differ diff --git a/test/4.bmp b/test/4.bmp new file mode 100644 index 0000000..59199af Binary files /dev/null and b/test/4.bmp differ diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..ca114d2 --- /dev/null +++ b/utils.py @@ -0,0 +1,140 @@ +import numpy as np +from PIL import Image +import tensorflow as tf +import scipy.stats as st +from skimage import io,data,color +from functools import reduce + +def gauss_kernel(kernlen=21, nsig=3, channels=1): + interval = (2*nsig+1.)/(kernlen) + x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1) + kern1d = np.diff(st.norm.cdf(x)) + kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) + kernel = kernel_raw/kernel_raw.sum() + out_filter = np.array(kernel, dtype = np.float32) + out_filter = out_filter.reshape((kernlen, kernlen, 1, 1)) + out_filter = np.repeat(out_filter, channels, axis = 2) + return out_filter + +def tensor_size(tensor): + from operator import mul + return reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1) + +def blur(x): + kernel_var = gauss_kernel(21, 3, 3) + return tf.nn.depthwise_conv2d(x, kernel_var, [1, 1, 1, 1], padding='SAME') + +def tensor_size(tensor): + from operator import mul + return reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1) + +def data_augmentation(image, mode): + if mode == 0: + # original + return image + elif mode == 1: + # flip up and down + return np.flipud(image) + elif mode == 2: + # rotate counterwise 90 degree + return np.rot90(image) + elif mode == 3: + # rotate 90 degree and flip up and down + image = np.rot90(image) + return np.flipud(image) + elif mode == 4: + # rotate 180 degree + return np.rot90(image, k=2) + elif mode == 5: + # rotate 180 degree and flip + image = np.rot90(image, k=2) + return np.flipud(image) + elif mode == 6: + # rotate 270 degree + return np.rot90(image, k=3) + elif mode == 7: + # rotate 270 degree and flip + image = np.rot90(image, k=3) + return np.flipud(image) + +def load_images(file): + im = Image.open(file) + img = np.array(im, dtype="float32") / 255.0 + img_max = np.max(img) + img_min = np.min(img) + img_norm = np.float32((img - img_min) / np.maximum((img_max - img_min), 0.001)) + return img_norm + +def gradient(input_tensor, direction): + smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1]) + smooth_kernel_y = tf.transpose(smooth_kernel_x, [1, 0, 2, 3]) + if direction == "x": + kernel = smooth_kernel_x + elif direction == "y": + kernel = smooth_kernel_y + gradient_orig = tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME')) + grad_min = tf.reduce_min(gradient_orig) + grad_max = tf.reduce_max(gradient_orig) + grad_norm = tf.div((gradient_orig - grad_min), (grad_max - grad_min + 0.0001)) + return grad_norm + +def _tf_fspecial_gauss(size, sigma): + """Function to mimic the 'fspecial' gaussian MATLAB function + """ + x_data, y_data = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] + + x_data = np.expand_dims(x_data, axis=-1) + x_data = np.expand_dims(x_data, axis=-1) + + y_data = np.expand_dims(y_data, axis=-1) + y_data = np.expand_dims(y_data, axis=-1) + + x = tf.constant(x_data, dtype=tf.float32) + y = tf.constant(y_data, dtype=tf.float32) + + g = tf.exp(-((x**2 + y**2)/(2.0*sigma**2))) + return g / tf.reduce_sum(g) + +def tf_ssim(img1, img2, cs_map=False, mean_metric=True, size=11, sigma=1.5): + window = _tf_fspecial_gauss(size, sigma) # window shape [size, size] + K1 = 0.01 + K2 = 0.03 + L = 1 # depth of image (255 in case the image has a differnt scale) + C1 = (K1*L)**2 + C2 = (K2*L)**2 + mu1 = tf.nn.conv2d(img1, window, strides=[1,1,1,1], padding='VALID') + mu2 = tf.nn.conv2d(img2, window, strides=[1,1,1,1],padding='VALID') + mu1_sq = mu1*mu1 + mu2_sq = mu2*mu2 + mu1_mu2 = mu1*mu2 + sigma1_sq = tf.nn.conv2d(img1*img1, window, strides=[1,1,1,1],padding='VALID') - mu1_sq + sigma2_sq = tf.nn.conv2d(img2*img2, window, strides=[1,1,1,1],padding='VALID') - mu2_sq + sigma12 = tf.nn.conv2d(img1*img2, window, strides=[1,1,1,1],padding='VALID') - mu1_mu2 + if cs_map: + value = (((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)* + (sigma1_sq + sigma2_sq + C2)), + (2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2)) + else: + value = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)* + (sigma1_sq + sigma2_sq + C2)) + + if mean_metric: + value = tf.reduce_mean(value) + return value + +def save_images(filepath, result_1, result_2 = None, result_3 = None): + result_1 = np.squeeze(result_1) + result_2 = np.squeeze(result_2) + result_3 = np.squeeze(result_3) + + if not result_2.any(): + cat_image = result_1 + else: + cat_image = np.concatenate([result_1, result_2], axis = 1) + if not result_3.any(): + cat_image = cat_image + else: + cat_image = np.concatenate([cat_image, result_3], axis = 1) + + im = Image.fromarray(np.clip(cat_image * 255.0, 0, 255.0).astype('uint8')) + im.save(filepath, 'png')