diff --git a/Decom_net_train/high_10_200.png b/Decom_net_train/high_10_200.png
new file mode 100644
index 0000000..baa8998
Binary files /dev/null and b/Decom_net_train/high_10_200.png differ
diff --git a/Decom_net_train/high_10_400.png b/Decom_net_train/high_10_400.png
new file mode 100644
index 0000000..c087bca
Binary files /dev/null and b/Decom_net_train/high_10_400.png differ
diff --git a/Decom_net_train/high_11_200.png b/Decom_net_train/high_11_200.png
new file mode 100644
index 0000000..21eca2b
Binary files /dev/null and b/Decom_net_train/high_11_200.png differ
diff --git a/Decom_net_train/high_11_400.png b/Decom_net_train/high_11_400.png
new file mode 100644
index 0000000..1ad3e5d
Binary files /dev/null and b/Decom_net_train/high_11_400.png differ
diff --git a/Decom_net_train/high_12_200.png b/Decom_net_train/high_12_200.png
new file mode 100644
index 0000000..5f04294
Binary files /dev/null and b/Decom_net_train/high_12_200.png differ
diff --git a/Decom_net_train/high_12_400.png b/Decom_net_train/high_12_400.png
new file mode 100644
index 0000000..6e7a17b
Binary files /dev/null and b/Decom_net_train/high_12_400.png differ
diff --git a/Decom_net_train/high_13_200.png b/Decom_net_train/high_13_200.png
new file mode 100644
index 0000000..e8bb0a7
Binary files /dev/null and b/Decom_net_train/high_13_200.png differ
diff --git a/Decom_net_train/high_13_400.png b/Decom_net_train/high_13_400.png
new file mode 100644
index 0000000..e553445
Binary files /dev/null and b/Decom_net_train/high_13_400.png differ
diff --git a/Decom_net_train/high_14_200.png b/Decom_net_train/high_14_200.png
new file mode 100644
index 0000000..4abaa0c
Binary files /dev/null and b/Decom_net_train/high_14_200.png differ
diff --git a/Decom_net_train/high_14_400.png b/Decom_net_train/high_14_400.png
new file mode 100644
index 0000000..6e5a567
Binary files /dev/null and b/Decom_net_train/high_14_400.png differ
diff --git a/Decom_net_train/high_15_200.png b/Decom_net_train/high_15_200.png
new file mode 100644
index 0000000..4ee9301
Binary files /dev/null and b/Decom_net_train/high_15_200.png differ
diff --git a/Decom_net_train/high_15_400.png b/Decom_net_train/high_15_400.png
new file mode 100644
index 0000000..eb7b3fb
Binary files /dev/null and b/Decom_net_train/high_15_400.png differ
diff --git a/Decom_net_train/high_1_200.png b/Decom_net_train/high_1_200.png
new file mode 100644
index 0000000..def1108
Binary files /dev/null and b/Decom_net_train/high_1_200.png differ
diff --git a/Decom_net_train/high_1_400.png b/Decom_net_train/high_1_400.png
new file mode 100644
index 0000000..86bc1a6
Binary files /dev/null and b/Decom_net_train/high_1_400.png differ
diff --git a/Decom_net_train/high_2_200.png b/Decom_net_train/high_2_200.png
new file mode 100644
index 0000000..fa43b57
Binary files /dev/null and b/Decom_net_train/high_2_200.png differ
diff --git a/Decom_net_train/high_2_400.png b/Decom_net_train/high_2_400.png
new file mode 100644
index 0000000..b9e962a
Binary files /dev/null and b/Decom_net_train/high_2_400.png differ
diff --git a/Decom_net_train/high_3_200.png b/Decom_net_train/high_3_200.png
new file mode 100644
index 0000000..27ae7ae
Binary files /dev/null and b/Decom_net_train/high_3_200.png differ
diff --git a/Decom_net_train/high_3_400.png b/Decom_net_train/high_3_400.png
new file mode 100644
index 0000000..127cc73
Binary files /dev/null and b/Decom_net_train/high_3_400.png differ
diff --git a/Decom_net_train/high_4_200.png b/Decom_net_train/high_4_200.png
new file mode 100644
index 0000000..a08ee7b
Binary files /dev/null and b/Decom_net_train/high_4_200.png differ
diff --git a/Decom_net_train/high_4_400.png b/Decom_net_train/high_4_400.png
new file mode 100644
index 0000000..15ca50c
Binary files /dev/null and b/Decom_net_train/high_4_400.png differ
diff --git a/Decom_net_train/high_5_200.png b/Decom_net_train/high_5_200.png
new file mode 100644
index 0000000..6a01754
Binary files /dev/null and b/Decom_net_train/high_5_200.png differ
diff --git a/Decom_net_train/high_5_400.png b/Decom_net_train/high_5_400.png
new file mode 100644
index 0000000..fb320ef
Binary files /dev/null and b/Decom_net_train/high_5_400.png differ
diff --git a/Decom_net_train/high_6_200.png b/Decom_net_train/high_6_200.png
new file mode 100644
index 0000000..988eb8f
Binary files /dev/null and b/Decom_net_train/high_6_200.png differ
diff --git a/Decom_net_train/high_6_400.png b/Decom_net_train/high_6_400.png
new file mode 100644
index 0000000..7cb240b
Binary files /dev/null and b/Decom_net_train/high_6_400.png differ
diff --git a/Decom_net_train/high_7_200.png b/Decom_net_train/high_7_200.png
new file mode 100644
index 0000000..e2cca46
Binary files /dev/null and b/Decom_net_train/high_7_200.png differ
diff --git a/Decom_net_train/high_7_400.png b/Decom_net_train/high_7_400.png
new file mode 100644
index 0000000..59fed1b
Binary files /dev/null and b/Decom_net_train/high_7_400.png differ
diff --git a/Decom_net_train/high_8_200.png b/Decom_net_train/high_8_200.png
new file mode 100644
index 0000000..fb441c4
Binary files /dev/null and b/Decom_net_train/high_8_200.png differ
diff --git a/Decom_net_train/high_8_400.png b/Decom_net_train/high_8_400.png
new file mode 100644
index 0000000..d33f0ea
Binary files /dev/null and b/Decom_net_train/high_8_400.png differ
diff --git a/Decom_net_train/high_9_200.png b/Decom_net_train/high_9_200.png
new file mode 100644
index 0000000..b104994
Binary files /dev/null and b/Decom_net_train/high_9_200.png differ
diff --git a/Decom_net_train/high_9_400.png b/Decom_net_train/high_9_400.png
new file mode 100644
index 0000000..c68b600
Binary files /dev/null and b/Decom_net_train/high_9_400.png differ
diff --git a/Decom_net_train/low_10_200.png b/Decom_net_train/low_10_200.png
new file mode 100644
index 0000000..672eaef
Binary files /dev/null and b/Decom_net_train/low_10_200.png differ
diff --git a/Decom_net_train/low_10_400.png b/Decom_net_train/low_10_400.png
new file mode 100644
index 0000000..5ff8c4e
Binary files /dev/null and b/Decom_net_train/low_10_400.png differ
diff --git a/Decom_net_train/low_11_200.png b/Decom_net_train/low_11_200.png
new file mode 100644
index 0000000..c4b633d
Binary files /dev/null and b/Decom_net_train/low_11_200.png differ
diff --git a/Decom_net_train/low_11_400.png b/Decom_net_train/low_11_400.png
new file mode 100644
index 0000000..ddc3657
Binary files /dev/null and b/Decom_net_train/low_11_400.png differ
diff --git a/Decom_net_train/low_12_200.png b/Decom_net_train/low_12_200.png
new file mode 100644
index 0000000..287d678
Binary files /dev/null and b/Decom_net_train/low_12_200.png differ
diff --git a/Decom_net_train/low_12_400.png b/Decom_net_train/low_12_400.png
new file mode 100644
index 0000000..0afe04d
Binary files /dev/null and b/Decom_net_train/low_12_400.png differ
diff --git a/Decom_net_train/low_13_200.png b/Decom_net_train/low_13_200.png
new file mode 100644
index 0000000..45c659e
Binary files /dev/null and b/Decom_net_train/low_13_200.png differ
diff --git a/Decom_net_train/low_13_400.png b/Decom_net_train/low_13_400.png
new file mode 100644
index 0000000..3c8fa15
Binary files /dev/null and b/Decom_net_train/low_13_400.png differ
diff --git a/Decom_net_train/low_14_200.png b/Decom_net_train/low_14_200.png
new file mode 100644
index 0000000..05e42e6
Binary files /dev/null and b/Decom_net_train/low_14_200.png differ
diff --git a/Decom_net_train/low_14_400.png b/Decom_net_train/low_14_400.png
new file mode 100644
index 0000000..77cabe8
Binary files /dev/null and b/Decom_net_train/low_14_400.png differ
diff --git a/Decom_net_train/low_15_200.png b/Decom_net_train/low_15_200.png
new file mode 100644
index 0000000..6b5d89e
Binary files /dev/null and b/Decom_net_train/low_15_200.png differ
diff --git a/Decom_net_train/low_15_400.png b/Decom_net_train/low_15_400.png
new file mode 100644
index 0000000..236c9c0
Binary files /dev/null and b/Decom_net_train/low_15_400.png differ
diff --git a/Decom_net_train/low_1_200.png b/Decom_net_train/low_1_200.png
new file mode 100644
index 0000000..bef41b7
Binary files /dev/null and b/Decom_net_train/low_1_200.png differ
diff --git a/Decom_net_train/low_1_400.png b/Decom_net_train/low_1_400.png
new file mode 100644
index 0000000..d43e609
Binary files /dev/null and b/Decom_net_train/low_1_400.png differ
diff --git a/Decom_net_train/low_2_200.png b/Decom_net_train/low_2_200.png
new file mode 100644
index 0000000..1d144d0
Binary files /dev/null and b/Decom_net_train/low_2_200.png differ
diff --git a/Decom_net_train/low_2_400.png b/Decom_net_train/low_2_400.png
new file mode 100644
index 0000000..3df85c7
Binary files /dev/null and b/Decom_net_train/low_2_400.png differ
diff --git a/Decom_net_train/low_3_200.png b/Decom_net_train/low_3_200.png
new file mode 100644
index 0000000..8107588
Binary files /dev/null and b/Decom_net_train/low_3_200.png differ
diff --git a/Decom_net_train/low_3_400.png b/Decom_net_train/low_3_400.png
new file mode 100644
index 0000000..00795f1
Binary files /dev/null and b/Decom_net_train/low_3_400.png differ
diff --git a/Decom_net_train/low_4_200.png b/Decom_net_train/low_4_200.png
new file mode 100644
index 0000000..4073eaf
Binary files /dev/null and b/Decom_net_train/low_4_200.png differ
diff --git a/Decom_net_train/low_4_400.png b/Decom_net_train/low_4_400.png
new file mode 100644
index 0000000..0abbddc
Binary files /dev/null and b/Decom_net_train/low_4_400.png differ
diff --git a/Decom_net_train/low_5_200.png b/Decom_net_train/low_5_200.png
new file mode 100644
index 0000000..2ed0adf
Binary files /dev/null and b/Decom_net_train/low_5_200.png differ
diff --git a/Decom_net_train/low_5_400.png b/Decom_net_train/low_5_400.png
new file mode 100644
index 0000000..b42eccf
Binary files /dev/null and b/Decom_net_train/low_5_400.png differ
diff --git a/Decom_net_train/low_6_200.png b/Decom_net_train/low_6_200.png
new file mode 100644
index 0000000..c48142a
Binary files /dev/null and b/Decom_net_train/low_6_200.png differ
diff --git a/Decom_net_train/low_6_400.png b/Decom_net_train/low_6_400.png
new file mode 100644
index 0000000..9c3bbee
Binary files /dev/null and b/Decom_net_train/low_6_400.png differ
diff --git a/Decom_net_train/low_7_200.png b/Decom_net_train/low_7_200.png
new file mode 100644
index 0000000..4e7308d
Binary files /dev/null and b/Decom_net_train/low_7_200.png differ
diff --git a/Decom_net_train/low_7_400.png b/Decom_net_train/low_7_400.png
new file mode 100644
index 0000000..0202d35
Binary files /dev/null and b/Decom_net_train/low_7_400.png differ
diff --git a/Decom_net_train/low_8_200.png b/Decom_net_train/low_8_200.png
new file mode 100644
index 0000000..f70885b
Binary files /dev/null and b/Decom_net_train/low_8_200.png differ
diff --git a/Decom_net_train/low_8_400.png b/Decom_net_train/low_8_400.png
new file mode 100644
index 0000000..04d90cd
Binary files /dev/null and b/Decom_net_train/low_8_400.png differ
diff --git a/Decom_net_train/low_9_200.png b/Decom_net_train/low_9_200.png
new file mode 100644
index 0000000..21125f2
Binary files /dev/null and b/Decom_net_train/low_9_200.png differ
diff --git a/Decom_net_train/low_9_400.png b/Decom_net_train/low_9_400.png
new file mode 100644
index 0000000..ecb1731
Binary files /dev/null and b/Decom_net_train/low_9_400.png differ
diff --git a/LOLdataset.zip b/LOLdataset.zip
new file mode 100644
index 0000000..9cd7823
Binary files /dev/null and b/LOLdataset.zip differ
diff --git a/README.md b/README.md
index 0d6ef0d..9947386 100644
--- a/README.md
+++ b/README.md
@@ -1,19 +1,58 @@
-#### 从命令行创建一个新的仓库
+# KinD
+This is a Tensorflow implementation of KinD
-```bash
-touch README.md
-git init
-git add README.md
-git commit -m "first commit"
-git remote add origin https://bdgit.educoder.net/ZhengHui/KinD.git
-git push -u origin master
+### The [KinD++](https://github.com/zhangyhuaee/KinD_plus) is an improved version. ###
-```
+Kindling the Darkness: a Practical Low-light Image Enhancer. In ACMMM2019
+Yonghua Zhang, Jiawan Zhang, Xiaojie Guo
+
+### [Paper](http://doi.acm.org/10.1145/3343031.3350926)
+
+
+
+
+### Requirements ###
+1. Python
+2. Tensorflow >= 1.10.0
+3. numpy, PIL
-#### 从命令行推送已经创建的仓库
+### Test ###
+First download the pre-trained checkpoints from [BaiduNetdisk](https://pan.baidu.com/s/1c4ZLYEIoR-8skNMiAVbl_A) or [google drive](https://drive.google.com/open?id=1-ljWntl7FExf6BSQtl5Mz3rMGWgnXDz4), then just run
+```shell
+python evaluate.py
+```
+Our pre-trained model has changed. Thus, the results have some difference with the report in our paper. However, you can adjust the illumination ratio to get better results.
-```bash
-git remote add origin https://bdgit.educoder.net/ZhengHui/KinD.git
-git push -u origin master
+### Train ###
+The original LOLdataset can be downloaded from [here](https://daooshee.github.io/BMVC2018website/). We rearrange the original LOLdataset and add several all-zero images to improve the decomposition results and restoration results. The new dataset can be download from [BaiduNetdisk](https://pan.baidu.com/s/1sn3vWJ2I5U2dlVUD7eqIBQ) or [google drive](https://drive.google.com/open?id=1-MaOVG7ylOkmGv1K4HWWcrai01i_FeDK). Save training pairs of LOL dataset under './LOLdataset/our485/' and save evaluating pairs under './LOLdataset/eval15/'. For training, just run
+```shell
+python decomposition_net_train.py
+python adjustment_net_train.py
+python reflectance_restoration_net_train.py
+```
+You can also evaluate on the LOLdataset, just run
+```shell
+python evaluate_LOLdataset.py
+```
+Our code partly refers to the [code](https://github.com/weichen582/RetinexNet).
+### Citation ###
+```
+@inproceedings{zhang2019kindling,
+ author = {Zhang, Yonghua and Zhang, Jiawan and Guo, Xiaojie},
+ title = {Kindling the Darkness: A Practical Low-light Image Enhancer},
+ booktitle = {Proceedings of the 27th ACM International Conference on Multimedia},
+ series = {MM '19},
+ year = {2019},
+ isbn = {978-1-4503-6889-6},
+ location = {Nice, France},
+ pages = {1632--1640},
+ numpages = {9},
+ url = {http://doi.acm.org/10.1145/3343031.3350926},
+ doi = {10.1145/3343031.3350926},
+ acmid = {3350926},
+ publisher = {ACM},
+ address = {New York, NY, USA},
+ keywords = {image decomposition, image restoration, low light enhancement},
+}
```
diff --git a/adjustment_net_train.py b/adjustment_net_train.py
new file mode 100644
index 0000000..59cf95e
--- /dev/null
+++ b/adjustment_net_train.py
@@ -0,0 +1,223 @@
+# coding: utf-8
+from __future__ import print_function
+import os
+import time
+import random
+#from skimage import color
+from PIL import Image
+import tensorflow as tf
+import numpy as np
+from utils import *
+from model import *
+from glob import glob
+
+batch_size = 10
+patch_size = 48
+
+sess = tf.Session()
+#the input of decomposition net
+input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom')
+#the input of illumination adjustment net
+input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i')
+input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio')
+input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i')
+
+[R_decom, I_decom] = DecomNet_simple(input_decom)
+#the output of decomposition network
+decom_output_R = R_decom
+decom_output_I = I_decom
+#the output of illumination adjustment net
+output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio)
+
+#define loss
+
+def grad_loss(input_i_low, input_i_high):
+ x_loss = tf.square(gradient(input_i_low, 'x') - gradient(input_i_high, 'x'))
+ y_loss = tf.square(gradient(input_i_low, 'y') - gradient(input_i_high, 'y'))
+ grad_loss_all = tf.reduce_mean(x_loss + y_loss)
+ return grad_loss_all
+
+loss_grad = grad_loss(output_i, input_high_i)
+loss_square = tf.reduce_mean(tf.square(output_i - input_high_i))# * ( 1 - input_low_r ))#* (1- input_low_i)))
+
+loss_adjust = loss_square + loss_grad
+
+lr = tf.placeholder(tf.float32, name='learning_rate')
+
+optimizer = tf.train.AdamOptimizer(learning_rate=lr, name='AdamOptimizer')
+
+var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
+var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name]
+
+saver_adjust = tf.train.Saver(var_list=var_adjust)
+saver_Decom = tf.train.Saver(var_list = var_Decom)
+train_op_adjust = optimizer.minimize(loss_adjust, var_list = var_adjust)
+sess.run(tf.global_variables_initializer())
+print("[*] Initialize model successfully...")
+
+### load data
+### Based on the decomposition net, we first get the decomposed reflectance maps
+### and illumination maps, then train the adjust net.
+###train_data
+train_low_data = []
+train_high_data = []
+train_low_data_names = glob('./LOLdataset/our485/low/*.png')
+train_low_data_names.sort()
+train_high_data_names = glob('./LOLdataset/our485/high/*.png')
+train_high_data_names.sort()
+assert len(train_low_data_names) == len(train_high_data_names)
+print('[*] Number of training data: %d' % len(train_low_data_names))
+for idx in range(len(train_low_data_names)):
+ low_im = load_images(train_low_data_names[idx])
+ train_low_data.append(low_im)
+ high_im = load_images(train_high_data_names[idx])
+ train_high_data.append(high_im)
+
+pre_decom_checkpoint_dir = './checkpoint/decom_net_train/'
+ckpt_pre=tf.train.get_checkpoint_state(pre_decom_checkpoint_dir)
+if ckpt_pre:
+ print('loaded '+ckpt_pre.model_checkpoint_path)
+ saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path)
+else:
+ print('No pre_decom_net checkpoint!')
+
+#decomposed_low_r_data_480 = []
+decomposed_low_i_data_480 = []
+#decomposed_high_r_data_480 = []
+decomposed_high_i_data_480 = []
+for idx in range(len(train_low_data)):
+ input_low = np.expand_dims(train_low_data[idx], axis=0)
+ RR, II = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low})
+ RR0 = np.squeeze(RR)
+ II0 = np.squeeze(II)
+ print(RR0.shape, II0.shape)
+ #decomposed_high_r_data_480.append(result_1_sq)
+ decomposed_low_i_data_480.append(II0)
+for idx in range(len(train_high_data)):
+ input_high = np.expand_dims(train_high_data[idx], axis=0)
+ RR2, II2 = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_high})
+ RR02 = np.squeeze(RR2)
+ II02 = np.squeeze(II2)
+ print(RR02.shape, II02.shape)
+ #decomposed_high_r_data_480.append(result_1_sq)
+ decomposed_high_i_data_480.append(II02)
+
+eval_adjust_low_i_data = decomposed_low_i_data_480[451:480]
+eval_adjust_high_i_data = decomposed_high_i_data_480[451:480]
+
+train_adjust_low_i_data = decomposed_low_i_data_480[0:450]
+train_adjust_high_i_data = decomposed_high_i_data_480[0:450]
+
+print('[*] Number of training data: %d' % len(train_adjust_high_i_data))
+
+learning_rate = 0.0001
+epoch = 2000
+eval_every_epoch = 200
+train_phase = 'adjustment'
+numBatch = len(train_adjust_low_i_data) // int(batch_size)
+train_op = train_op_adjust
+train_loss = loss_adjust
+saver = saver_adjust
+
+checkpoint_dir = './checkpoint/illumination_adjust_net_train/'
+if not os.path.isdir(checkpoint_dir):
+ os.makedirs(checkpoint_dir)
+ckpt=tf.train.get_checkpoint_state(checkpoint_dir)
+if ckpt:
+ print('loaded '+ckpt.model_checkpoint_path)
+ saver.restore(sess,ckpt.model_checkpoint_path)
+else:
+ print("No adjustment net pre model!")
+
+start_step = 0
+start_epoch = 0
+iter_num = 0
+print("[*] Start training for phase %s, with start epoch %d start iter %d : " % (train_phase, start_epoch, iter_num))
+
+sample_dir = './illumination_adjust_net_train/'
+if not os.path.isdir(sample_dir):
+ os.makedirs(sample_dir)
+
+start_time = time.time()
+image_id = 0
+
+for epoch in range(start_epoch, epoch):
+ for batch_id in range(start_step, numBatch):
+ batch_input_low_i_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
+ batch_input_high_i_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
+ batch_input_low_i = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
+ batch_input_high_i = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
+ input_low_i_rand = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
+ input_high_i_rand = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
+ input_low_i_rand_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
+ input_high_i_rand_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
+
+ for patch_id in range(batch_size):
+ i_low_data = train_adjust_low_i_data[image_id]
+ i_low_expand = np.expand_dims(i_low_data, axis = 2)
+ i_high_data = train_adjust_high_i_data[image_id]
+ i_high_expand = np.expand_dims(i_high_data, axis = 2)
+
+ h, w = train_adjust_low_i_data[image_id].shape
+ x = random.randint(0, h - patch_size)
+ y = random.randint(0, w - patch_size)
+ i_low_data_crop = i_low_expand[x : x+patch_size, y : y+patch_size, :]
+ i_high_data_crop = i_high_expand[x : x+patch_size, y : y+patch_size, :]
+
+ rand_mode = np.random.randint(0, 7)
+ batch_input_low_i[patch_id, :, :, :] = data_augmentation(i_low_data_crop , rand_mode)
+ batch_input_high_i[patch_id, :, :, :] = data_augmentation(i_high_data_crop, rand_mode)
+
+ ratio = np.mean(i_low_data_crop/(i_high_data_crop+0.0001))
+ #print(ratio)
+ i_low_data_ratio = np.ones([patch_size,patch_size])*(1/ratio+0.0001)
+ i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2)
+ i_high_data_ratio = np.ones([patch_size,patch_size])*(ratio)
+ i_high_ratio_expand = np.expand_dims(i_high_data_ratio , axis =2)
+ batch_input_low_i_ratio[patch_id, :, :, :] = i_low_ratio_expand
+ batch_input_high_i_ratio[patch_id, :, :, :] = i_high_ratio_expand
+
+ rand_mode = np.random.randint(0, 2)
+ if rand_mode == 1:
+ input_low_i_rand[patch_id, :, :, :] = batch_input_low_i[patch_id, :, :, :]
+ input_high_i_rand[patch_id, :, :, :] = batch_input_high_i[patch_id, :, :, :]
+ input_low_i_rand_ratio[patch_id, :, :, :] = batch_input_low_i_ratio[patch_id, :, :, :]
+ input_high_i_rand_ratio[patch_id, :, :, :] = batch_input_high_i_ratio[patch_id, :, :, :]
+ else:
+ input_low_i_rand[patch_id, :, :, :] = batch_input_high_i[patch_id, :, :, :]
+ input_high_i_rand[patch_id, :, :, :] = batch_input_low_i[patch_id, :, :, :]
+ input_low_i_rand_ratio[patch_id, :, :, :] = batch_input_high_i_ratio[patch_id, :, :, :]
+ input_high_i_rand_ratio[patch_id, :, :, :] = batch_input_low_i_ratio[patch_id, :, :, :]
+
+ image_id = (image_id + 1) % len(train_adjust_low_i_data)
+
+ _, loss = sess.run([train_op, train_loss], feed_dict={input_low_i: input_low_i_rand,input_low_i_ratio: input_low_i_rand_ratio,\
+ input_high_i: input_high_i_rand, \
+ lr: learning_rate})
+ print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
+ % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
+ iter_num += 1
+ if (epoch + 1) % eval_every_epoch == 0:
+ print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch + 1))
+
+ for idx in range(10):
+ rand_idx = idx#np.random.randint(26)
+ input_uu_i = eval_adjust_low_i_data[rand_idx]
+ input_low_eval_i = np.expand_dims(input_uu_i, axis=0)
+ input_low_eval_ii = np.expand_dims(input_low_eval_i, axis=3)
+ h, w = eval_adjust_low_i_data[idx].shape
+ rand_ratio = np.random.random(1)*2
+ input_uu_i_ratio = np.ones([h,w]) * rand_ratio
+ input_low_eval_i_ratio = np.expand_dims(input_uu_i_ratio, axis=0)
+ input_low_eval_ii_ratio = np.expand_dims(input_low_eval_i_ratio, axis=3)
+
+ result_1 = sess.run(output_i, feed_dict={input_low_i: input_low_eval_ii, input_low_i_ratio: input_low_eval_ii_ratio})
+ save_images(os.path.join(sample_dir, 'h_eval_%d_%d_%5f.png' % ( epoch + 1 , rand_idx + 1, rand_ratio)), input_uu_i, result_1)
+
+
+ saver.save(sess, checkpoint_dir + 'model.ckpt')
+
+print("[*] Finish training for phase %s." % train_phase)
+
+
+
diff --git a/checkpoint/Restoration_net_train/checkpoint b/checkpoint/Restoration_net_train/checkpoint
new file mode 100644
index 0000000..093fdf1
--- /dev/null
+++ b/checkpoint/Restoration_net_train/checkpoint
@@ -0,0 +1 @@
+model_checkpoint_path: "model.ckpt"
diff --git a/checkpoint/Restoration_net_train/model.ckpt.data-00000-of-00001 b/checkpoint/Restoration_net_train/model.ckpt.data-00000-of-00001
new file mode 100644
index 0000000..46b2789
Binary files /dev/null and b/checkpoint/Restoration_net_train/model.ckpt.data-00000-of-00001 differ
diff --git a/checkpoint/Restoration_net_train/model.ckpt.index b/checkpoint/Restoration_net_train/model.ckpt.index
new file mode 100644
index 0000000..996bad3
Binary files /dev/null and b/checkpoint/Restoration_net_train/model.ckpt.index differ
diff --git a/checkpoint/Restoration_net_train/model.ckpt.meta b/checkpoint/Restoration_net_train/model.ckpt.meta
new file mode 100644
index 0000000..bb96652
Binary files /dev/null and b/checkpoint/Restoration_net_train/model.ckpt.meta differ
diff --git a/checkpoint/decom_net_train/checkpoint b/checkpoint/decom_net_train/checkpoint
new file mode 100644
index 0000000..febd7d5
--- /dev/null
+++ b/checkpoint/decom_net_train/checkpoint
@@ -0,0 +1,2 @@
+model_checkpoint_path: "model.ckpt"
+all_model_checkpoint_paths: "model.ckpt"
diff --git a/checkpoint/decom_net_train/model.ckpt.data-00000-of-00001 b/checkpoint/decom_net_train/model.ckpt.data-00000-of-00001
new file mode 100644
index 0000000..f500187
Binary files /dev/null and b/checkpoint/decom_net_train/model.ckpt.data-00000-of-00001 differ
diff --git a/checkpoint/decom_net_train/model.ckpt.index b/checkpoint/decom_net_train/model.ckpt.index
new file mode 100644
index 0000000..6da6a7e
Binary files /dev/null and b/checkpoint/decom_net_train/model.ckpt.index differ
diff --git a/checkpoint/decom_net_train/model.ckpt.meta b/checkpoint/decom_net_train/model.ckpt.meta
new file mode 100644
index 0000000..025ea8c
Binary files /dev/null and b/checkpoint/decom_net_train/model.ckpt.meta differ
diff --git a/checkpoint/illumination_adjust_net_train/checkpoint b/checkpoint/illumination_adjust_net_train/checkpoint
new file mode 100644
index 0000000..febd7d5
--- /dev/null
+++ b/checkpoint/illumination_adjust_net_train/checkpoint
@@ -0,0 +1,2 @@
+model_checkpoint_path: "model.ckpt"
+all_model_checkpoint_paths: "model.ckpt"
diff --git a/checkpoint/illumination_adjust_net_train/model.ckpt.data-00000-of-00001 b/checkpoint/illumination_adjust_net_train/model.ckpt.data-00000-of-00001
new file mode 100644
index 0000000..88e2dd2
Binary files /dev/null and b/checkpoint/illumination_adjust_net_train/model.ckpt.data-00000-of-00001 differ
diff --git a/checkpoint/illumination_adjust_net_train/model.ckpt.index b/checkpoint/illumination_adjust_net_train/model.ckpt.index
new file mode 100644
index 0000000..826bfe4
Binary files /dev/null and b/checkpoint/illumination_adjust_net_train/model.ckpt.index differ
diff --git a/checkpoint/illumination_adjust_net_train/model.ckpt.meta b/checkpoint/illumination_adjust_net_train/model.ckpt.meta
new file mode 100644
index 0000000..9615198
Binary files /dev/null and b/checkpoint/illumination_adjust_net_train/model.ckpt.meta differ
diff --git a/decomposition_net_train.py b/decomposition_net_train.py
new file mode 100644
index 0000000..2098469
--- /dev/null
+++ b/decomposition_net_train.py
@@ -0,0 +1,174 @@
+# coding: utf-8
+from __future__ import print_function
+import os, time, random
+import tensorflow as tf
+from PIL import Image
+import numpy as np
+from utils import *
+from model import *
+from glob import glob
+
+batch_size = 10
+patch_size = 48
+
+sess = tf.Session()
+
+input_low = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low')
+input_high = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high')
+
+[R_low, I_low] = DecomNet_simple(input_low)
+[R_high, I_high] = DecomNet_simple(input_high)
+
+I_low_3 = tf.concat([I_low, I_low, I_low], axis=3)
+I_high_3 = tf.concat([I_high, I_high, I_high], axis=3)
+
+#network output
+output_R_low = R_low
+output_R_high = R_high
+output_I_low = I_low_3
+output_I_high = I_high_3
+
+# define loss
+
+def mutual_i_loss(input_I_low, input_I_high):
+ low_gradient_x = gradient(input_I_low, "x")
+ high_gradient_x = gradient(input_I_high, "x")
+ x_loss = (low_gradient_x + high_gradient_x)* tf.exp(-10*(low_gradient_x+high_gradient_x))
+ low_gradient_y = gradient(input_I_low, "y")
+ high_gradient_y = gradient(input_I_high, "y")
+ y_loss = (low_gradient_y + high_gradient_y) * tf.exp(-10*(low_gradient_y+high_gradient_y))
+ mutual_loss = tf.reduce_mean( x_loss + y_loss)
+ return mutual_loss
+
+def mutual_i_input_loss(input_I_low, input_im):
+ input_gray = tf.image.rgb_to_grayscale(input_im)
+ low_gradient_x = gradient(input_I_low, "x")
+ input_gradient_x = gradient(input_gray, "x")
+ x_loss = tf.abs(tf.div(low_gradient_x, tf.maximum(input_gradient_x, 0.01)))
+ low_gradient_y = gradient(input_I_low, "y")
+ input_gradient_y = gradient(input_gray, "y")
+ y_loss = tf.abs(tf.div(low_gradient_y, tf.maximum(input_gradient_y, 0.01)))
+ mut_loss = tf.reduce_mean(x_loss + y_loss)
+ return mut_loss
+
+recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - input_low))
+recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - input_high))
+
+equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high))
+
+i_mutual_loss = mutual_i_loss(I_low, I_high)
+
+i_input_mutual_loss_high = mutual_i_input_loss(I_high, input_high)
+i_input_mutual_loss_low = mutual_i_input_loss(I_low, input_low)
+
+loss_Decom = 1*recon_loss_high + 1*recon_loss_low \
+ + 0.01 * equal_R_loss + 0.2*i_mutual_loss \
+ + 0.15* i_input_mutual_loss_high + 0.15* i_input_mutual_loss_low
+
+###
+lr = tf.placeholder(tf.float32, name='learning_rate')
+
+optimizer = tf.train.AdamOptimizer(learning_rate=lr, name='AdamOptimizer')
+var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
+
+train_op_Decom = optimizer.minimize(loss_Decom, var_list = var_Decom)
+sess.run(tf.global_variables_initializer())
+
+saver_Decom = tf.train.Saver(var_list = var_Decom)
+print("[*] Initialize model successfully...")
+
+#load data
+###train_data
+train_low_data = []
+train_high_data = []
+train_low_data_names = glob('./LOLdataset/our485/low/*.png')
+train_low_data_names.sort()
+train_high_data_names = glob('./LOLdataset/our485/high/*.png')
+train_high_data_names.sort()
+assert len(train_low_data_names) == len(train_high_data_names)
+print('[*] Number of training data: %d' % len(train_low_data_names))
+for idx in range(len(train_low_data_names)):
+ low_im = load_images(train_low_data_names[idx])
+ train_low_data.append(low_im)
+ high_im = load_images(train_high_data_names[idx])
+ train_high_data.append(high_im)
+###eval_data
+eval_low_data = []
+eval_high_data = []
+eval_low_data_name = glob('./LOLdataset/eval15/low/*.png')
+eval_low_data_name.sort()
+eval_high_data_name = glob('./LOLdataset/eval15/high/*.png*')
+eval_high_data_name.sort()
+for idx in range(len(eval_low_data_name)):
+ eval_low_im = load_images(eval_low_data_name[idx])
+ eval_low_data.append(eval_low_im)
+ eval_high_im = load_images(eval_high_data_name[idx])
+ eval_high_data.append(eval_high_im)
+
+
+epoch = 2000
+learning_rate = 0.0001
+
+sample_dir = './Decom_net_train/'
+if not os.path.isdir(sample_dir):
+ os.makedirs(sample_dir)
+
+eval_every_epoch = 200
+train_phase = 'decomposition'
+numBatch = len(train_low_data) // int(batch_size)
+train_op = train_op_Decom
+train_loss = loss_Decom
+saver = saver_Decom
+
+checkpoint_dir = './checkpoint/decom_net_train/'
+if not os.path.isdir(checkpoint_dir):
+ os.makedirs(checkpoint_dir)
+ckpt=tf.train.get_checkpoint_state(checkpoint_dir)
+if ckpt:
+ print('loaded '+ckpt.model_checkpoint_path)
+ saver.restore(sess,ckpt.model_checkpoint_path)
+
+start_step = 0
+start_epoch = 0
+iter_num = 0
+print("[*] Start training for phase %s, with start epoch %d start iter %d : " % (train_phase, start_epoch, iter_num))
+
+start_time = time.time()
+image_id = 0
+for epoch in range(start_epoch, epoch):
+ for batch_id in range(start_step, numBatch):
+ batch_input_low = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
+ batch_input_high = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
+ for patch_id in range(batch_size):
+ h, w, _ = train_low_data[image_id].shape
+ x = random.randint(0, h - patch_size)
+ y = random.randint(0, w - patch_size)
+ rand_mode = random.randint(0, 7)
+ batch_input_low[patch_id, :, :, :] = data_augmentation(train_low_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)
+ batch_input_high[patch_id, :, :, :] = data_augmentation(train_high_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)
+ image_id = (image_id + 1) % len(train_low_data)
+ if image_id == 0:
+ tmp = list(zip(train_low_data, train_high_data))
+ random.shuffle(tmp)
+ train_low_data, train_high_data = zip(*tmp)
+
+ _, loss = sess.run([train_op, train_loss], feed_dict={input_low: batch_input_low, \
+ input_high: batch_input_high, \
+ lr: learning_rate})
+ print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
+ % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
+ iter_num += 1
+ if (epoch + 1) % eval_every_epoch == 0:
+ print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch + 1))
+ for idx in range(len(eval_low_data)):
+ input_low_eval = np.expand_dims(eval_low_data[idx], axis=0)
+ result_1, result_2 = sess.run([output_R_low, output_I_low], feed_dict={input_low: input_low_eval})
+ save_images(os.path.join(sample_dir, 'low_%d_%d.png' % ( idx + 1, epoch + 1)), result_1, result_2)
+ for idx in range(len(eval_high_data)):
+ input_high_eval = np.expand_dims(eval_high_data[idx], axis=0)
+ result_11, result_22 = sess.run([output_R_high, output_I_high], feed_dict={input_high: input_high_eval})
+ save_images(os.path.join(sample_dir, 'high_%d_%d.png' % ( idx + 1, epoch + 1)), result_11, result_22)
+
+ saver.save(sess, checkpoint_dir + 'model.ckpt')
+
+print("[*] Finish training for phase %s." % train_phase)
diff --git a/evaluate.py b/evaluate.py
new file mode 100644
index 0000000..5d8a633
--- /dev/null
+++ b/evaluate.py
@@ -0,0 +1,114 @@
+# coding: utf-8
+from __future__ import print_function
+import os
+import time
+import random
+from PIL import Image
+import tensorflow as tf
+import numpy as np
+from utils import *
+from model import *
+from glob import glob
+from skimage import color,filters
+
+sess = tf.Session()
+
+input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom')
+input_low_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low_r')
+input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i')
+input_high_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high_r')
+input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i')
+input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio')
+
+[R_decom, I_decom] = DecomNet_simple(input_decom)
+decom_output_R = R_decom
+decom_output_I = I_decom
+output_r = Restoration_net(input_low_r, input_low_i)
+output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio)
+
+var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
+var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name]
+var_restoration = [var for var in tf.trainable_variables() if 'Restoration_net' in var.name]
+
+saver_Decom = tf.train.Saver(var_list = var_Decom)
+saver_adjust = tf.train.Saver(var_list=var_adjust)
+saver_restoration = tf.train.Saver(var_list=var_restoration)
+
+decom_checkpoint_dir ='./checkpoint/decom_net_train/'
+ckpt_pre=tf.train.get_checkpoint_state(decom_checkpoint_dir)
+if ckpt_pre:
+ print('loaded '+ckpt_pre.model_checkpoint_path)
+ saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path)
+else:
+ print('No decomnet checkpoint!')
+
+checkpoint_dir_adjust = './checkpoint/illumination_adjust_net_train/'
+ckpt_adjust=tf.train.get_checkpoint_state(checkpoint_dir_adjust)
+if ckpt_adjust:
+ print('loaded '+ckpt_adjust.model_checkpoint_path)
+ saver_adjust.restore(sess,ckpt_adjust.model_checkpoint_path)
+else:
+ print("No adjust pre model!")
+
+checkpoint_dir_restoration = './checkpoint/Restoration_net_train/'
+ckpt=tf.train.get_checkpoint_state(checkpoint_dir_restoration)
+if ckpt:
+ print('loaded '+ckpt.model_checkpoint_path)
+ saver_restoration.restore(sess,ckpt.model_checkpoint_path)
+else:
+ print("No restoration pre model!")
+
+###load eval data
+eval_low_data = []
+eval_img_name =[]
+eval_low_data_name = glob('./test/*')
+eval_low_data_name.sort()
+for idx in range(len(eval_low_data_name)):
+ [_, name] = os.path.split(eval_low_data_name[idx])
+ suffix = name[name.find('.') + 1:]
+ name = name[:name.find('.')]
+ eval_img_name.append(name)
+ eval_low_im = load_images(eval_low_data_name[idx])
+ eval_low_data.append(eval_low_im)
+ print(eval_low_im.shape)
+
+sample_dir = './results/test/'
+if not os.path.isdir(sample_dir):
+ os.makedirs(sample_dir)
+
+print("Start evalating!")
+start_time = time.time()
+for idx in range(len(eval_low_data)):
+ print(idx)
+ name = eval_img_name[idx]
+ input_low = eval_low_data[idx]
+ input_low_eval = np.expand_dims(input_low, axis=0)
+ h, w, _ = input_low.shape
+
+ decom_r_low, decom_i_low = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low_eval})
+
+ restoration_r = sess.run(output_r, feed_dict={input_low_r: decom_r_low, input_low_i: decom_i_low})
+### change the ratio to get different exposure level, the value can be 0-5.0
+ ratio = 5.0
+ i_low_data_ratio = np.ones([h, w])*(ratio)
+ i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2)
+ i_low_ratio_expand2 = np.expand_dims(i_low_ratio_expand, axis=0)
+ adjust_i = sess.run(output_i, feed_dict={input_low_i: decom_i_low, input_low_i_ratio: i_low_ratio_expand2})
+
+ #The restoration result can find more details from very dark regions, however, it will restore the very dark regions
+#with gray colors, we use the following operator to alleviate this weakness.
+ decom_r_sq = np.squeeze(decom_r_low)
+ r_gray = color.rgb2gray(decom_r_sq)
+ r_gray_gaussion = filters.gaussian(r_gray, 3)
+ low_i = np.minimum((r_gray_gaussion*2)**0.5,1)
+ low_i_expand_0 = np.expand_dims(low_i, axis = 0)
+ low_i_expand_3 = np.expand_dims(low_i_expand_0, axis = 3)
+ result_denoise = restoration_r*low_i_expand_3
+ fusion4 = result_denoise*adjust_i
+
+ #fusion = restoration_r*adjust_i
+# fuse with the original input to avoid over-exposure
+ fusion2 = decom_i_low*input_low_eval + (1-decom_i_low)*fusion4
+ #print(fusion2.shape)
+ save_images(os.path.join(sample_dir, '%s_kindle.png' % (name)), fusion2)
+
diff --git a/evaluate_LOLdataset.py b/evaluate_LOLdataset.py
new file mode 100644
index 0000000..39d37d2
--- /dev/null
+++ b/evaluate_LOLdataset.py
@@ -0,0 +1,110 @@
+# coding: utf-8
+from __future__ import print_function
+import os
+import time
+import random
+from PIL import Image
+import tensorflow as tf
+import numpy as np
+from utils import *
+from model import *
+from glob import glob
+
+sess = tf.Session()
+
+input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom')
+input_low_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low_r')
+input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i')
+input_high_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high_r')
+input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i')
+input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio')
+
+[R_decom, I_decom] = DecomNet_simple(input_decom)
+decom_output_R = R_decom
+decom_output_I = I_decom
+output_r = Restoration_net(input_low_r, input_low_i)
+output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio)
+
+var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
+var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name]
+var_restoration = [var for var in tf.trainable_variables() if 'Restoration_net' in var.name]
+
+saver_Decom = tf.train.Saver(var_list = var_Decom)
+saver_adjust = tf.train.Saver(var_list=var_adjust)
+saver_restoration = tf.train.Saver(var_list=var_restoration)
+
+decom_checkpoint_dir ='./checkpoint/decom_net_train/'
+ckpt_pre=tf.train.get_checkpoint_state(decom_checkpoint_dir)
+if ckpt_pre:
+ print('loaded '+ckpt_pre.model_checkpoint_path)
+ saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path)
+else:
+ print('No decomnet checkpoint!')
+
+checkpoint_dir_adjust = './checkpoint/illumination_adjust_net_train/'
+ckpt_adjust=tf.train.get_checkpoint_state(checkpoint_dir_adjust)
+if ckpt_adjust:
+ print('loaded '+ckpt_adjust.model_checkpoint_path)
+ saver_adjust.restore(sess,ckpt_adjust.model_checkpoint_path)
+else:
+ print("No adjust pre model!")
+
+checkpoint_dir_restoration = './checkpoint/Restoration_net_train/'
+ckpt=tf.train.get_checkpoint_state(checkpoint_dir_restoration)
+if ckpt:
+ print('loaded '+ckpt.model_checkpoint_path)
+ saver_restoration.restore(sess,ckpt.model_checkpoint_path)
+else:
+ print("No restoration pre model!")
+
+###load eval data
+eval_low_data = []
+eval_img_name =[]
+eval_low_data_name = glob('./LOLdataset/eval15/low/*.png')
+eval_low_data_name.sort()
+for idx in range(len(eval_low_data_name)):
+ [_, name] = os.path.split(eval_low_data_name[idx])
+ suffix = name[name.find('.') + 1:]
+ name = name[:name.find('.')]
+ eval_img_name.append(name)
+ eval_low_im = load_images(eval_low_data_name[idx])
+ eval_low_data.append(eval_low_im)
+ print(eval_low_im.shape)
+# To get better results, the illumination adjustment ratio is computed based on the decom_i_high, so we also need the high data.
+eval_high_data = []
+eval_high_data_name = glob('./LOLdataset/eval15/high/*.png')
+eval_high_data_name.sort()
+for idx in range(len(eval_high_data_name)):
+ eval_high_im = load_images(eval_high_data_name[idx])
+ eval_high_data.append(eval_high_im)
+
+sample_dir = './results/LOLdataset_eval15/'
+if not os.path.isdir(sample_dir):
+ os.makedirs(sample_dir)
+
+print("Start evalating!")
+start_time = time.time()
+for idx in range(len(eval_low_data)):
+ print(idx)
+ name = eval_img_name[idx]
+ input_low = eval_low_data[idx]
+ input_low_eval = np.expand_dims(input_low, axis=0)
+ input_high = eval_high_data[idx]
+ input_high_eval = np.expand_dims(input_high, axis=0)
+ h, w, _ = input_low.shape
+
+ decom_r_low, decom_i_low = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low_eval})
+ decom_r_high, decom_i_high = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_high_eval})
+
+ restoration_r = sess.run(output_r, feed_dict={input_low_r: decom_r_low, input_low_i: decom_i_low})
+
+ ratio = np.mean(((decom_i_high))/(decom_i_low+0.0001))
+
+ i_low_data_ratio = np.ones([h, w])*(ratio)
+ i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2)
+ i_low_ratio_expand2 = np.expand_dims(i_low_ratio_expand, axis=0)
+
+ adjust_i = sess.run(output_i, feed_dict={input_low_i: decom_i_low, input_low_i_ratio: i_low_ratio_expand2})
+ fusion = restoration_r*adjust_i
+ save_images(os.path.join(sample_dir, '%s_kindle.png' % (name)), fusion)
+
diff --git a/figures/network.jpg b/figures/network.jpg
new file mode 100644
index 0000000..62001c2
Binary files /dev/null and b/figures/network.jpg differ
diff --git a/figures/result.jpg b/figures/result.jpg
new file mode 100644
index 0000000..3d89461
Binary files /dev/null and b/figures/result.jpg differ
diff --git a/illumination_adjust_net_train/h_eval_200_10_0.243455.png b/illumination_adjust_net_train/h_eval_200_10_0.243455.png
new file mode 100644
index 0000000..3fc33a8
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_10_0.243455.png differ
diff --git a/illumination_adjust_net_train/h_eval_200_1_0.283560.png b/illumination_adjust_net_train/h_eval_200_1_0.283560.png
new file mode 100644
index 0000000..6876f23
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_1_0.283560.png differ
diff --git a/illumination_adjust_net_train/h_eval_200_2_1.097869.png b/illumination_adjust_net_train/h_eval_200_2_1.097869.png
new file mode 100644
index 0000000..a62ce03
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_2_1.097869.png differ
diff --git a/illumination_adjust_net_train/h_eval_200_3_0.790207.png b/illumination_adjust_net_train/h_eval_200_3_0.790207.png
new file mode 100644
index 0000000..94091d6
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_3_0.790207.png differ
diff --git a/illumination_adjust_net_train/h_eval_200_4_1.368601.png b/illumination_adjust_net_train/h_eval_200_4_1.368601.png
new file mode 100644
index 0000000..bcf93ea
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_4_1.368601.png differ
diff --git a/illumination_adjust_net_train/h_eval_200_5_0.979668.png b/illumination_adjust_net_train/h_eval_200_5_0.979668.png
new file mode 100644
index 0000000..e65bd36
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_5_0.979668.png differ
diff --git a/illumination_adjust_net_train/h_eval_200_6_1.648318.png b/illumination_adjust_net_train/h_eval_200_6_1.648318.png
new file mode 100644
index 0000000..494f340
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_6_1.648318.png differ
diff --git a/illumination_adjust_net_train/h_eval_200_7_0.342475.png b/illumination_adjust_net_train/h_eval_200_7_0.342475.png
new file mode 100644
index 0000000..755ec5f
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_7_0.342475.png differ
diff --git a/illumination_adjust_net_train/h_eval_200_8_0.597354.png b/illumination_adjust_net_train/h_eval_200_8_0.597354.png
new file mode 100644
index 0000000..c2a3280
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_8_0.597354.png differ
diff --git a/illumination_adjust_net_train/h_eval_200_9_0.969436.png b/illumination_adjust_net_train/h_eval_200_9_0.969436.png
new file mode 100644
index 0000000..fe970a8
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_200_9_0.969436.png differ
diff --git a/illumination_adjust_net_train/h_eval_400_10_1.080474.png b/illumination_adjust_net_train/h_eval_400_10_1.080474.png
new file mode 100644
index 0000000..14b27bb
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_10_1.080474.png differ
diff --git a/illumination_adjust_net_train/h_eval_400_1_1.829798.png b/illumination_adjust_net_train/h_eval_400_1_1.829798.png
new file mode 100644
index 0000000..91aeeb9
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_1_1.829798.png differ
diff --git a/illumination_adjust_net_train/h_eval_400_2_0.833631.png b/illumination_adjust_net_train/h_eval_400_2_0.833631.png
new file mode 100644
index 0000000..55eea1a
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_2_0.833631.png differ
diff --git a/illumination_adjust_net_train/h_eval_400_3_0.639403.png b/illumination_adjust_net_train/h_eval_400_3_0.639403.png
new file mode 100644
index 0000000..6efd2ec
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_3_0.639403.png differ
diff --git a/illumination_adjust_net_train/h_eval_400_4_0.080941.png b/illumination_adjust_net_train/h_eval_400_4_0.080941.png
new file mode 100644
index 0000000..c8797c9
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_4_0.080941.png differ
diff --git a/illumination_adjust_net_train/h_eval_400_5_0.118340.png b/illumination_adjust_net_train/h_eval_400_5_0.118340.png
new file mode 100644
index 0000000..b8c4f00
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_5_0.118340.png differ
diff --git a/illumination_adjust_net_train/h_eval_400_6_1.416026.png b/illumination_adjust_net_train/h_eval_400_6_1.416026.png
new file mode 100644
index 0000000..c37bba8
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_6_1.416026.png differ
diff --git a/illumination_adjust_net_train/h_eval_400_7_0.024722.png b/illumination_adjust_net_train/h_eval_400_7_0.024722.png
new file mode 100644
index 0000000..6b61d59
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_7_0.024722.png differ
diff --git a/illumination_adjust_net_train/h_eval_400_8_1.792914.png b/illumination_adjust_net_train/h_eval_400_8_1.792914.png
new file mode 100644
index 0000000..29c7696
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_8_1.792914.png differ
diff --git a/illumination_adjust_net_train/h_eval_400_9_1.454717.png b/illumination_adjust_net_train/h_eval_400_9_1.454717.png
new file mode 100644
index 0000000..b425dda
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_400_9_1.454717.png differ
diff --git a/illumination_adjust_net_train/h_eval_600_10_1.570546.png b/illumination_adjust_net_train/h_eval_600_10_1.570546.png
new file mode 100644
index 0000000..48bf314
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_10_1.570546.png differ
diff --git a/illumination_adjust_net_train/h_eval_600_1_1.194817.png b/illumination_adjust_net_train/h_eval_600_1_1.194817.png
new file mode 100644
index 0000000..567112b
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_1_1.194817.png differ
diff --git a/illumination_adjust_net_train/h_eval_600_2_0.413976.png b/illumination_adjust_net_train/h_eval_600_2_0.413976.png
new file mode 100644
index 0000000..ed1c71d
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_2_0.413976.png differ
diff --git a/illumination_adjust_net_train/h_eval_600_3_0.765548.png b/illumination_adjust_net_train/h_eval_600_3_0.765548.png
new file mode 100644
index 0000000..db2f816
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_3_0.765548.png differ
diff --git a/illumination_adjust_net_train/h_eval_600_4_0.745525.png b/illumination_adjust_net_train/h_eval_600_4_0.745525.png
new file mode 100644
index 0000000..6673647
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_4_0.745525.png differ
diff --git a/illumination_adjust_net_train/h_eval_600_5_1.079725.png b/illumination_adjust_net_train/h_eval_600_5_1.079725.png
new file mode 100644
index 0000000..1f6dc30
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_5_1.079725.png differ
diff --git a/illumination_adjust_net_train/h_eval_600_6_0.365893.png b/illumination_adjust_net_train/h_eval_600_6_0.365893.png
new file mode 100644
index 0000000..7eed807
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_6_0.365893.png differ
diff --git a/illumination_adjust_net_train/h_eval_600_7_0.716873.png b/illumination_adjust_net_train/h_eval_600_7_0.716873.png
new file mode 100644
index 0000000..08369a3
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_7_0.716873.png differ
diff --git a/illumination_adjust_net_train/h_eval_600_8_0.492159.png b/illumination_adjust_net_train/h_eval_600_8_0.492159.png
new file mode 100644
index 0000000..85df0d4
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_8_0.492159.png differ
diff --git a/illumination_adjust_net_train/h_eval_600_9_0.445431.png b/illumination_adjust_net_train/h_eval_600_9_0.445431.png
new file mode 100644
index 0000000..b5a8597
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_600_9_0.445431.png differ
diff --git a/illumination_adjust_net_train/h_eval_800_10_0.235056.png b/illumination_adjust_net_train/h_eval_800_10_0.235056.png
new file mode 100644
index 0000000..b2217a7
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_10_0.235056.png differ
diff --git a/illumination_adjust_net_train/h_eval_800_1_1.459935.png b/illumination_adjust_net_train/h_eval_800_1_1.459935.png
new file mode 100644
index 0000000..e9d3700
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_1_1.459935.png differ
diff --git a/illumination_adjust_net_train/h_eval_800_2_0.965174.png b/illumination_adjust_net_train/h_eval_800_2_0.965174.png
new file mode 100644
index 0000000..afcbda4
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_2_0.965174.png differ
diff --git a/illumination_adjust_net_train/h_eval_800_3_0.243200.png b/illumination_adjust_net_train/h_eval_800_3_0.243200.png
new file mode 100644
index 0000000..6ee9fa2
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_3_0.243200.png differ
diff --git a/illumination_adjust_net_train/h_eval_800_4_1.968109.png b/illumination_adjust_net_train/h_eval_800_4_1.968109.png
new file mode 100644
index 0000000..895f40c
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_4_1.968109.png differ
diff --git a/illumination_adjust_net_train/h_eval_800_5_1.170218.png b/illumination_adjust_net_train/h_eval_800_5_1.170218.png
new file mode 100644
index 0000000..a9994de
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_5_1.170218.png differ
diff --git a/illumination_adjust_net_train/h_eval_800_6_0.804922.png b/illumination_adjust_net_train/h_eval_800_6_0.804922.png
new file mode 100644
index 0000000..29d807d
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_6_0.804922.png differ
diff --git a/illumination_adjust_net_train/h_eval_800_7_1.683641.png b/illumination_adjust_net_train/h_eval_800_7_1.683641.png
new file mode 100644
index 0000000..82ffbee
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_7_1.683641.png differ
diff --git a/illumination_adjust_net_train/h_eval_800_8_1.448064.png b/illumination_adjust_net_train/h_eval_800_8_1.448064.png
new file mode 100644
index 0000000..c924e5c
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_8_1.448064.png differ
diff --git a/illumination_adjust_net_train/h_eval_800_9_1.708184.png b/illumination_adjust_net_train/h_eval_800_9_1.708184.png
new file mode 100644
index 0000000..a3309b9
Binary files /dev/null and b/illumination_adjust_net_train/h_eval_800_9_1.708184.png differ
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..eae6e85
--- /dev/null
+++ b/model.py
@@ -0,0 +1,97 @@
+import tensorflow as tf
+import tensorflow.contrib.slim as slim
+from tensorflow.contrib.layers.python.layers import initializers
+
+def lrelu(x, trainbable=None):
+ return tf.maximum(x*0.2,x)
+
+def upsample_and_concat(x1, x2, output_channels, in_channels, scope_name, trainable=True):
+ with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
+ pool_size = 2
+ deconv_filter = tf.get_variable('weights', [pool_size, pool_size, output_channels, in_channels], trainable= True)
+ deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2) , strides=[1, pool_size, pool_size, 1], name=scope_name)
+
+ deconv_output = tf.concat([deconv, x2],3)
+ deconv_output.set_shape([None, None, None, output_channels*2])
+
+ return deconv_output
+
+def DecomNet_simple(input):
+ with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE):
+ conv1=slim.conv2d(input,32,[3,3], rate=1, activation_fn=lrelu,scope='g_conv1_1')
+ pool1=slim.max_pool2d(conv1, [2, 2], stride = 2, padding='SAME' )
+ conv2=slim.conv2d(pool1,64,[3,3], rate=1, activation_fn=lrelu,scope='g_conv2_1')
+ pool2=slim.max_pool2d(conv2, [2, 2], stride = 2, padding='SAME' )
+ conv3=slim.conv2d(pool2,128,[3,3], rate=1, activation_fn=lrelu,scope='g_conv3_1')
+ up8 = upsample_and_concat( conv3, conv2, 64, 128 , 'g_up_1')
+ conv8=slim.conv2d(up8, 64,[3,3], rate=1, activation_fn=lrelu,scope='g_conv8_1')
+ up9 = upsample_and_concat( conv8, conv1, 32, 64 , 'g_up_2')
+ conv9=slim.conv2d(up9, 32,[3,3], rate=1, activation_fn=lrelu,scope='g_conv9_1')
+ # Here, we use 1*1 kernel to replace the 3*3 ones in the paper to get better results.
+ conv10=slim.conv2d(conv9,3,[1,1], rate=1, activation_fn=None, scope='g_conv10')
+ R_out = tf.sigmoid(conv10)
+
+ l_conv2=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='l_conv1_2')
+ l_conv3=tf.concat([l_conv2, conv9],3)
+ # Here, we use 1*1 kernel to replace the 3*3 ones in the paper to get better results.
+ l_conv4=slim.conv2d(l_conv3,1,[1,1], rate=1, activation_fn=None,scope='l_conv1_4')
+ L_out = tf.sigmoid(l_conv4)
+
+ return R_out, L_out
+
+def Restoration_net(input_r, input_i):
+ with tf.variable_scope('Restoration_net', reuse=tf.AUTO_REUSE):
+ input_all = tf.concat([input_r,input_i], 3)
+
+ conv1=slim.conv2d(input_all,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv1_1')
+ conv1=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv1_2')
+ pool1=slim.max_pool2d(conv1, [2, 2], padding='SAME' )
+
+ conv2=slim.conv2d(pool1,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv2_1')
+ conv2=slim.conv2d(conv2,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv2_2')
+ pool2=slim.max_pool2d(conv2, [2, 2], padding='SAME' )
+
+ conv3=slim.conv2d(pool2,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv3_1')
+ conv3=slim.conv2d(conv3,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv3_2')
+ pool3=slim.max_pool2d(conv3, [2, 2], padding='SAME' )
+
+ conv4=slim.conv2d(pool3,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv4_1')
+ conv4=slim.conv2d(conv4,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv4_2')
+ pool4=slim.max_pool2d(conv4, [2, 2], padding='SAME' )
+
+ conv5=slim.conv2d(pool4,512,[3,3], rate=1, activation_fn=lrelu,scope='de_conv5_1')
+ conv5=slim.conv2d(conv5,512,[3,3], rate=1, activation_fn=lrelu,scope='de_conv5_2')
+
+ up6 = upsample_and_concat( conv5, conv4, 256, 512, 'up_6')
+
+ conv6=slim.conv2d(up6, 256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv6_1')
+ conv6=slim.conv2d(conv6,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv6_2')
+
+ up7 = upsample_and_concat( conv6, conv3, 128, 256, 'up_7' )
+ conv7=slim.conv2d(up7, 128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv7_1')
+ conv7=slim.conv2d(conv7,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv7_2')
+
+ up8 = upsample_and_concat( conv7, conv2, 64, 128, 'up_8' )
+ conv8=slim.conv2d(up8, 64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv8_1')
+ conv8=slim.conv2d(conv8,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv8_2')
+
+ up9 = upsample_and_concat( conv8, conv1, 32, 64, 'up_9' )
+ conv9=slim.conv2d(up9, 32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv9_1')
+ conv9=slim.conv2d(conv9,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv9_2')
+
+ conv10=slim.conv2d(conv9,3,[3,3], rate=1, activation_fn=None, scope='de_conv10')
+
+ out = tf.sigmoid(conv10)
+ return out
+
+def Illumination_adjust_net(input_i, input_ratio):
+ with tf.variable_scope('Illumination_adjust_net', reuse=tf.AUTO_REUSE):
+ input_all = tf.concat([input_i, input_ratio], 3)
+
+ conv1=slim.conv2d(input_all,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_1')
+ conv2=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_2')
+ conv3=slim.conv2d(conv2,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_3')
+ conv4=slim.conv2d(conv3,1,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_4')
+
+ L_enhance = tf.sigmoid(conv4)
+ return L_enhance
\ No newline at end of file
diff --git a/reflectance_restoration_net_train.py b/reflectance_restoration_net_train.py
new file mode 100644
index 0000000..e746420
--- /dev/null
+++ b/reflectance_restoration_net_train.py
@@ -0,0 +1,244 @@
+# coding: utf-8
+from __future__ import print_function
+import os
+import time
+import random
+from PIL import Image
+import tensorflow as tf
+import numpy as np
+from utils import *
+from model import *
+from glob import glob
+
+batch_size = 4
+patch_size = 384
+
+config = tf.ConfigProto()
+config.gpu_options.allow_growth = True
+sess=tf.Session(config=config)
+#the input of decomposition net
+input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom')
+#restoration input
+input_low_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low_r')
+input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i')
+input_high_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high_r')
+
+[R_decom, I_decom] = DecomNet_simple(input_decom)
+#the output of decomposition network
+decom_output_R = R_decom
+decom_output_I = I_decom
+
+output_r = Restoration_net(input_low_r, input_low_i)
+
+#define loss
+def grad_loss(input_r_low, input_r_high):
+ input_r_low_gray = tf.image.rgb_to_grayscale(input_r_low)
+ input_r_high_gray = tf.image.rgb_to_grayscale(input_r_high)
+ x_loss = tf.square(gradient(input_r_low_gray, 'x') - gradient(input_r_high_gray, 'x'))
+ y_loss = tf.square(gradient(input_r_low_gray, 'y') - gradient(input_r_high_gray, 'y'))
+ grad_loss_all = tf.reduce_mean(x_loss + y_loss)
+ return grad_loss_all
+
+def ssim_loss(output_r, input_high_r):
+ output_r_1 = output_r[:,:,:,0:1]
+ input_high_r_1 = input_high_r[:,:,:,0:1]
+ ssim_r_1 = tf_ssim(output_r_1, input_high_r_1)
+ output_r_2 = output_r[:,:,:,1:2]
+ input_high_r_2 = input_high_r[:,:,:,1:2]
+ ssim_r_2 = tf_ssim(output_r_2, input_high_r_2)
+ output_r_3 = output_r[:,:,:,2:3]
+ input_high_r_3 = input_high_r[:,:,:,2:3]
+ ssim_r_3 = tf_ssim(output_r_3, input_high_r_3)
+ ssim_r = (ssim_r_1 + ssim_r_2 + ssim_r_3)/3.0
+ loss_ssim1 = 1-ssim_r
+ return loss_ssim1
+
+loss_square = tf.reduce_mean(tf.square(output_r - input_high_r))
+loss_ssim = ssim_loss(output_r, input_high_r)
+loss_grad = grad_loss(output_r, input_high_r)
+
+loss_restoration = loss_square + loss_grad + loss_ssim
+
+### initialize
+lr = tf.placeholder(tf.float32, name='learning_rate')
+global_step = tf.get_variable('global_step', [], dtype=tf.int32, initializer=tf.constant_initializer(0), trainable=False)
+update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+optimizer = tf.train.AdamOptimizer(learning_rate=lr, name='AdamOptimizer')
+with tf.control_dependencies(update_ops):
+ grads = optimizer.compute_gradients(loss_restoration)
+ train_op_restoration = optimizer.apply_gradients(grads, global_step=global_step)
+
+var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
+var_restoration = [var for var in tf.trainable_variables() if 'Restoration_net' in var.name]
+
+saver_restoration = tf.train.Saver(var_list=var_restoration)
+saver_Decom = tf.train.Saver(var_list = var_Decom)
+sess.run(tf.global_variables_initializer())
+print("[*] Initialize model successfully...")
+
+### load data
+### Based on the decomposition net, we first get the decomposed reflectance maps
+### and illumination maps, then train the restoration net.
+###train_data
+train_low_data = []
+train_high_data = []
+train_low_data_names = glob('./LOLdataset/our485/low/*.png')
+train_low_data_names.sort()
+train_high_data_names = glob('./LOLdataset/our485/high/*.png')
+train_high_data_names.sort()
+assert len(train_low_data_names) == len(train_high_data_names)
+print('[*] Number of training data: %d' % len(train_low_data_names))
+for idx in range(len(train_low_data_names)):
+ low_im = load_images(train_low_data_names[idx])
+ train_low_data.append(low_im)
+ high_im = load_images(train_high_data_names[idx])
+ train_high_data.append(high_im)
+
+eval_low_data = []
+eval_low_data_names = glob('./LOLdataset/eval15/low/*.png')
+eval_low_data_names.sort()
+for idx in range(len(eval_low_data_names)):
+ eval_low_im = load_images(eval_low_data_names[idx])
+ eval_low_data.append(eval_low_im)
+
+pre_decom_checkpoint_dir = './checkpoint/decom_net_train/'
+ckpt_pre=tf.train.get_checkpoint_state(pre_decom_checkpoint_dir)
+if ckpt_pre:
+ print('loaded '+ckpt_pre.model_checkpoint_path)
+ saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path)
+else:
+ print('No pre_decom_net checkpoint!')
+
+decomposed_low_r_data_480 = []
+decomposed_low_i_data_480 = []
+decomposed_high_r_data_480 = []
+for idx in range(len(train_low_data)):
+ input_low = np.expand_dims(train_low_data[idx], axis=0)
+ RR, II = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low})
+ RR0 = np.squeeze(RR)
+ II0 = np.squeeze(II)
+ print(idx, RR0.shape, II0.shape)
+ decomposed_low_r_data_480.append(RR0)
+ decomposed_low_i_data_480.append(II0)
+for idx in range(len(train_high_data)):
+ input_high = np.expand_dims(train_high_data[idx], axis=0)
+ RR2, II2 = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_high})
+ ### To improve the constrast, we slightly change the decom_r_high by using decom_r_high**1.2
+ RR02 = np.squeeze(RR2**1.2)
+ print(idx, RR02.shape)
+ decomposed_high_r_data_480.append(RR02)
+
+decomposed_eval_low_r_data = []
+decomposed_eval_low_i_data = []
+for idx in range(len(eval_low_data)):
+ input_eval = np.expand_dims(eval_low_data[idx], axis=0)
+ RR3, II3 = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_eval})
+ RR03 = np.squeeze(RR3)
+ II03 = np.squeeze(II3)
+ print(idx, RR03.shape, II03.shape)
+ decomposed_eval_low_r_data.append(RR03)
+ decomposed_eval_low_i_data.append(II03)
+
+
+eval_restoration_low_r_data = decomposed_low_r_data_480[467:480] + decomposed_eval_low_r_data[0:15]
+eval_restoration_low_i_data = decomposed_low_i_data_480[467:480] + decomposed_eval_low_i_data[0:15]
+
+train_restoration_low_r_data = decomposed_low_r_data_480[0:466]
+train_restoration_low_i_data = decomposed_low_i_data_480[0:466]
+train_restoration_high_r_data = decomposed_high_r_data_480[0:466]
+#train_restoration_high_i_data = train_restoration_high_i_data_480[0:466]
+print(len(train_restoration_high_r_data), len(train_restoration_low_r_data),len(train_restoration_low_i_data))
+print(len(eval_restoration_low_r_data),len(eval_restoration_low_i_data))
+assert len(train_restoration_high_r_data) == len(train_restoration_low_r_data)
+assert len(train_restoration_low_i_data) == len(train_restoration_low_r_data)
+print('[*] Number of training data: %d' % len(train_restoration_high_r_data))
+
+learning_rate = 0.0001
+def lr_schedule(epoch):
+ initial_lr = learning_rate
+ if epoch<=800:
+ lr = initial_lr
+ elif epoch<=1250:
+ lr = initial_lr/2
+ elif epoch<=1500:
+ lr = initial_lr/4
+ else:
+ lr = initial_lr/10
+ return lr
+
+epoch = 1000
+
+sample_dir = './Restoration_net_train/'
+if not os.path.isdir(sample_dir):
+ os.makedirs(sample_dir)
+
+eval_every_epoch = 50
+train_phase = 'Restoration'
+numBatch = len(train_restoration_low_r_data) // int(batch_size)
+train_op = train_op_restoration
+train_loss = loss_restoration
+saver = saver_restoration
+
+checkpoint_dir = './checkpoint/Restoration_net_train/'
+if not os.path.isdir(checkpoint_dir):
+ os.makedirs(checkpoint_dir)
+ckpt=tf.train.get_checkpoint_state(checkpoint_dir)
+if ckpt:
+ print('loaded '+ckpt.model_checkpoint_path)
+ saver_restoration.restore(sess,ckpt.model_checkpoint_path)
+else:
+ print('No pre_restoration_net checkpoint!')
+
+start_step = 0
+start_epoch = 0
+iter_num = 0
+print("[*] Start training for phase %s, with start epoch %d start iter %d : " % (train_phase, start_epoch, iter_num))
+start_time = time.time()
+image_id = 0
+
+for epoch in range(start_epoch, epoch):
+ for batch_id in range(start_step, numBatch):
+ batch_input_low_r = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
+ batch_input_low_i = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
+
+ batch_input_high_r = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
+
+ for patch_id in range(batch_size):
+ h, w, _ = train_restoration_low_r_data[image_id].shape
+ x = random.randint(0, h - patch_size)
+ y = random.randint(0, w - patch_size)
+ i_low_expand = np.expand_dims(train_restoration_low_i_data[image_id], axis = 2)
+ rand_mode = random.randint(0, 7)
+ batch_input_low_r[patch_id, :, :, :] = data_augmentation(train_restoration_low_r_data[image_id][x : x+patch_size, y : y+patch_size, :] , rand_mode)#+ np.random.normal(0, 0.1, (patch_size,patch_size,3)) , rand_mode)
+ batch_input_low_i[patch_id, :, :, :] = data_augmentation(i_low_expand[x : x+patch_size, y : y+patch_size, :] , rand_mode)#+ np.random.normal(0, 0.1, (patch_size,patch_size,3)) , rand_mode)
+
+ batch_input_high_r[patch_id, :, :, :] = data_augmentation(train_restoration_high_r_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)
+
+ image_id = (image_id + 1) % len(train_restoration_low_r_data)
+ if image_id == 0:
+ tmp = list(zip(train_restoration_low_r_data, train_restoration_low_i_data, train_restoration_high_r_data))
+ random.shuffle(tmp)
+ train_restoration_low_r_data, train_restoration_low_i_data, train_restoration_high_r_data = zip(*tmp)
+
+ _, loss = sess.run([train_op, train_loss], feed_dict={input_low_r: batch_input_low_r,input_low_i: batch_input_low_i,\
+ input_high_r: batch_input_high_r, lr: lr_schedule(epoch)})
+ print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
+ % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
+ iter_num += 1
+ if (epoch + 1) % eval_every_epoch == 0:
+ print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch + 1))
+ for idx in range(len(eval_restoration_low_r_data)):
+ input_uu_r = eval_restoration_low_r_data[idx]
+ input_low_eval_r = np.expand_dims(input_uu_r, axis=0)
+ input_uu_i = eval_restoration_low_i_data[idx]
+ input_low_eval_i = np.expand_dims(input_uu_i, axis=0)
+ input_low_eval_ii = np.expand_dims(input_low_eval_i, axis=3)
+ result_1 = sess.run(output_r, feed_dict={input_low_r: input_low_eval_r, input_low_i: input_low_eval_ii})
+
+ save_images(os.path.join(sample_dir, 'eval_%d_%d.png' % ( idx + 1, epoch + 1)), input_uu_r, result_1)
+ saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=epoch)
+
+print("[*] Finish training for phase %s." % train_phase)
+
+
+
diff --git a/results/test/04_kindle.png b/results/test/04_kindle.png
new file mode 100644
index 0000000..7d20726
Binary files /dev/null and b/results/test/04_kindle.png differ
diff --git a/results/test/1_kindle.png b/results/test/1_kindle.png
new file mode 100644
index 0000000..393127b
Binary files /dev/null and b/results/test/1_kindle.png differ
diff --git a/results/test/4_kindle.png b/results/test/4_kindle.png
new file mode 100644
index 0000000..9589f2c
Binary files /dev/null and b/results/test/4_kindle.png differ
diff --git a/test/04.JPG b/test/04.JPG
new file mode 100644
index 0000000..14cedda
Binary files /dev/null and b/test/04.JPG differ
diff --git a/test/1.bmp b/test/1.bmp
new file mode 100644
index 0000000..4ba6173
Binary files /dev/null and b/test/1.bmp differ
diff --git a/test/4.bmp b/test/4.bmp
new file mode 100644
index 0000000..59199af
Binary files /dev/null and b/test/4.bmp differ
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..ca114d2
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,140 @@
+import numpy as np
+from PIL import Image
+import tensorflow as tf
+import scipy.stats as st
+from skimage import io,data,color
+from functools import reduce
+
+def gauss_kernel(kernlen=21, nsig=3, channels=1):
+ interval = (2*nsig+1.)/(kernlen)
+ x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
+ kern1d = np.diff(st.norm.cdf(x))
+ kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
+ kernel = kernel_raw/kernel_raw.sum()
+ out_filter = np.array(kernel, dtype = np.float32)
+ out_filter = out_filter.reshape((kernlen, kernlen, 1, 1))
+ out_filter = np.repeat(out_filter, channels, axis = 2)
+ return out_filter
+
+def tensor_size(tensor):
+ from operator import mul
+ return reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)
+
+def blur(x):
+ kernel_var = gauss_kernel(21, 3, 3)
+ return tf.nn.depthwise_conv2d(x, kernel_var, [1, 1, 1, 1], padding='SAME')
+
+def tensor_size(tensor):
+ from operator import mul
+ return reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)
+
+def data_augmentation(image, mode):
+ if mode == 0:
+ # original
+ return image
+ elif mode == 1:
+ # flip up and down
+ return np.flipud(image)
+ elif mode == 2:
+ # rotate counterwise 90 degree
+ return np.rot90(image)
+ elif mode == 3:
+ # rotate 90 degree and flip up and down
+ image = np.rot90(image)
+ return np.flipud(image)
+ elif mode == 4:
+ # rotate 180 degree
+ return np.rot90(image, k=2)
+ elif mode == 5:
+ # rotate 180 degree and flip
+ image = np.rot90(image, k=2)
+ return np.flipud(image)
+ elif mode == 6:
+ # rotate 270 degree
+ return np.rot90(image, k=3)
+ elif mode == 7:
+ # rotate 270 degree and flip
+ image = np.rot90(image, k=3)
+ return np.flipud(image)
+
+def load_images(file):
+ im = Image.open(file)
+ img = np.array(im, dtype="float32") / 255.0
+ img_max = np.max(img)
+ img_min = np.min(img)
+ img_norm = np.float32((img - img_min) / np.maximum((img_max - img_min), 0.001))
+ return img_norm
+
+def gradient(input_tensor, direction):
+ smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1])
+ smooth_kernel_y = tf.transpose(smooth_kernel_x, [1, 0, 2, 3])
+ if direction == "x":
+ kernel = smooth_kernel_x
+ elif direction == "y":
+ kernel = smooth_kernel_y
+ gradient_orig = tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME'))
+ grad_min = tf.reduce_min(gradient_orig)
+ grad_max = tf.reduce_max(gradient_orig)
+ grad_norm = tf.div((gradient_orig - grad_min), (grad_max - grad_min + 0.0001))
+ return grad_norm
+
+def _tf_fspecial_gauss(size, sigma):
+ """Function to mimic the 'fspecial' gaussian MATLAB function
+ """
+ x_data, y_data = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
+
+ x_data = np.expand_dims(x_data, axis=-1)
+ x_data = np.expand_dims(x_data, axis=-1)
+
+ y_data = np.expand_dims(y_data, axis=-1)
+ y_data = np.expand_dims(y_data, axis=-1)
+
+ x = tf.constant(x_data, dtype=tf.float32)
+ y = tf.constant(y_data, dtype=tf.float32)
+
+ g = tf.exp(-((x**2 + y**2)/(2.0*sigma**2)))
+ return g / tf.reduce_sum(g)
+
+def tf_ssim(img1, img2, cs_map=False, mean_metric=True, size=11, sigma=1.5):
+ window = _tf_fspecial_gauss(size, sigma) # window shape [size, size]
+ K1 = 0.01
+ K2 = 0.03
+ L = 1 # depth of image (255 in case the image has a differnt scale)
+ C1 = (K1*L)**2
+ C2 = (K2*L)**2
+ mu1 = tf.nn.conv2d(img1, window, strides=[1,1,1,1], padding='VALID')
+ mu2 = tf.nn.conv2d(img2, window, strides=[1,1,1,1],padding='VALID')
+ mu1_sq = mu1*mu1
+ mu2_sq = mu2*mu2
+ mu1_mu2 = mu1*mu2
+ sigma1_sq = tf.nn.conv2d(img1*img1, window, strides=[1,1,1,1],padding='VALID') - mu1_sq
+ sigma2_sq = tf.nn.conv2d(img2*img2, window, strides=[1,1,1,1],padding='VALID') - mu2_sq
+ sigma12 = tf.nn.conv2d(img1*img2, window, strides=[1,1,1,1],padding='VALID') - mu1_mu2
+ if cs_map:
+ value = (((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*
+ (sigma1_sq + sigma2_sq + C2)),
+ (2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2))
+ else:
+ value = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*
+ (sigma1_sq + sigma2_sq + C2))
+
+ if mean_metric:
+ value = tf.reduce_mean(value)
+ return value
+
+def save_images(filepath, result_1, result_2 = None, result_3 = None):
+ result_1 = np.squeeze(result_1)
+ result_2 = np.squeeze(result_2)
+ result_3 = np.squeeze(result_3)
+
+ if not result_2.any():
+ cat_image = result_1
+ else:
+ cat_image = np.concatenate([result_1, result_2], axis = 1)
+ if not result_3.any():
+ cat_image = cat_image
+ else:
+ cat_image = np.concatenate([cat_image, result_3], axis = 1)
+
+ im = Image.fromarray(np.clip(cat_image * 255.0, 0, 255.0).astype('uint8'))
+ im.save(filepath, 'png')