After Width: | Height: | Size: 531 KiB |
After Width: | Height: | Size: 535 KiB |
After Width: | Height: | Size: 465 KiB |
After Width: | Height: | Size: 477 KiB |
After Width: | Height: | Size: 536 KiB |
After Width: | Height: | Size: 538 KiB |
After Width: | Height: | Size: 646 KiB |
After Width: | Height: | Size: 650 KiB |
After Width: | Height: | Size: 495 KiB |
After Width: | Height: | Size: 500 KiB |
After Width: | Height: | Size: 418 KiB |
After Width: | Height: | Size: 439 KiB |
After Width: | Height: | Size: 440 KiB |
After Width: | Height: | Size: 447 KiB |
After Width: | Height: | Size: 445 KiB |
After Width: | Height: | Size: 454 KiB |
After Width: | Height: | Size: 463 KiB |
After Width: | Height: | Size: 476 KiB |
After Width: | Height: | Size: 520 KiB |
After Width: | Height: | Size: 523 KiB |
After Width: | Height: | Size: 486 KiB |
After Width: | Height: | Size: 500 KiB |
After Width: | Height: | Size: 472 KiB |
After Width: | Height: | Size: 493 KiB |
After Width: | Height: | Size: 638 KiB |
After Width: | Height: | Size: 644 KiB |
After Width: | Height: | Size: 500 KiB |
After Width: | Height: | Size: 509 KiB |
After Width: | Height: | Size: 422 KiB |
After Width: | Height: | Size: 447 KiB |
After Width: | Height: | Size: 602 KiB |
After Width: | Height: | Size: 612 KiB |
After Width: | Height: | Size: 693 KiB |
After Width: | Height: | Size: 693 KiB |
After Width: | Height: | Size: 564 KiB |
After Width: | Height: | Size: 569 KiB |
After Width: | Height: | Size: 557 KiB |
After Width: | Height: | Size: 561 KiB |
After Width: | Height: | Size: 580 KiB |
After Width: | Height: | Size: 586 KiB |
After Width: | Height: | Size: 541 KiB |
After Width: | Height: | Size: 545 KiB |
After Width: | Height: | Size: 610 KiB |
After Width: | Height: | Size: 618 KiB |
After Width: | Height: | Size: 610 KiB |
After Width: | Height: | Size: 616 KiB |
After Width: | Height: | Size: 558 KiB |
After Width: | Height: | Size: 561 KiB |
After Width: | Height: | Size: 602 KiB |
After Width: | Height: | Size: 603 KiB |
After Width: | Height: | Size: 650 KiB |
After Width: | Height: | Size: 652 KiB |
After Width: | Height: | Size: 676 KiB |
After Width: | Height: | Size: 675 KiB |
After Width: | Height: | Size: 702 KiB |
After Width: | Height: | Size: 708 KiB |
After Width: | Height: | Size: 714 KiB |
After Width: | Height: | Size: 715 KiB |
After Width: | Height: | Size: 677 KiB |
After Width: | Height: | Size: 678 KiB |
@ -1,19 +1,58 @@
|
||||
#### 从命令行创建一个新的仓库
|
||||
# KinD
|
||||
This is a Tensorflow implementation of KinD
|
||||
|
||||
```bash
|
||||
touch README.md
|
||||
git init
|
||||
git add README.md
|
||||
git commit -m "first commit"
|
||||
git remote add origin https://bdgit.educoder.net/ZhengHui/KinD.git
|
||||
git push -u origin master
|
||||
### The [KinD++](https://github.com/zhangyhuaee/KinD_plus) is an improved version. ###
|
||||
|
||||
```
|
||||
Kindling the Darkness: a Practical Low-light Image Enhancer. In ACMMM2019<br>
|
||||
Yonghua Zhang, Jiawan Zhang, Xiaojie Guo
|
||||
|
||||
### [Paper](http://doi.acm.org/10.1145/3343031.3350926)
|
||||
<img src="figures/network.jpg" width="800px"/>
|
||||
|
||||
<img src="figures/result.jpg" width="800px"/>
|
||||
|
||||
### Requirements ###
|
||||
1. Python
|
||||
2. Tensorflow >= 1.10.0
|
||||
3. numpy, PIL
|
||||
|
||||
#### 从命令行推送已经创建的仓库
|
||||
### Test ###
|
||||
First download the pre-trained checkpoints from [BaiduNetdisk](https://pan.baidu.com/s/1c4ZLYEIoR-8skNMiAVbl_A) or [google drive](https://drive.google.com/open?id=1-ljWntl7FExf6BSQtl5Mz3rMGWgnXDz4), then just run
|
||||
```shell
|
||||
python evaluate.py
|
||||
```
|
||||
Our pre-trained model has changed. Thus, the results have some difference with the report in our paper. However, you can adjust the illumination ratio to get better results.
|
||||
|
||||
```bash
|
||||
git remote add origin https://bdgit.educoder.net/ZhengHui/KinD.git
|
||||
git push -u origin master
|
||||
### Train ###
|
||||
The original LOLdataset can be downloaded from [here](https://daooshee.github.io/BMVC2018website/). We rearrange the original LOLdataset and add several all-zero images to improve the decomposition results and restoration results. The new dataset can be download from [BaiduNetdisk](https://pan.baidu.com/s/1sn3vWJ2I5U2dlVUD7eqIBQ) or [google drive](https://drive.google.com/open?id=1-MaOVG7ylOkmGv1K4HWWcrai01i_FeDK). Save training pairs of LOL dataset under './LOLdataset/our485/' and save evaluating pairs under './LOLdataset/eval15/'. For training, just run
|
||||
```shell
|
||||
python decomposition_net_train.py
|
||||
python adjustment_net_train.py
|
||||
python reflectance_restoration_net_train.py
|
||||
```
|
||||
You can also evaluate on the LOLdataset, just run
|
||||
```shell
|
||||
python evaluate_LOLdataset.py
|
||||
```
|
||||
Our code partly refers to the [code](https://github.com/weichen582/RetinexNet).
|
||||
|
||||
### Citation ###
|
||||
```
|
||||
@inproceedings{zhang2019kindling,
|
||||
author = {Zhang, Yonghua and Zhang, Jiawan and Guo, Xiaojie},
|
||||
title = {Kindling the Darkness: A Practical Low-light Image Enhancer},
|
||||
booktitle = {Proceedings of the 27th ACM International Conference on Multimedia},
|
||||
series = {MM '19},
|
||||
year = {2019},
|
||||
isbn = {978-1-4503-6889-6},
|
||||
location = {Nice, France},
|
||||
pages = {1632--1640},
|
||||
numpages = {9},
|
||||
url = {http://doi.acm.org/10.1145/3343031.3350926},
|
||||
doi = {10.1145/3343031.3350926},
|
||||
acmid = {3350926},
|
||||
publisher = {ACM},
|
||||
address = {New York, NY, USA},
|
||||
keywords = {image decomposition, image restoration, low light enhancement},
|
||||
}
|
||||
```
|
||||
|
@ -0,0 +1,223 @@
|
||||
# coding: utf-8
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
#from skimage import color
|
||||
from PIL import Image
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from utils import *
|
||||
from model import *
|
||||
from glob import glob
|
||||
|
||||
batch_size = 10
|
||||
patch_size = 48
|
||||
|
||||
sess = tf.Session()
|
||||
#the input of decomposition net
|
||||
input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom')
|
||||
#the input of illumination adjustment net
|
||||
input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i')
|
||||
input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio')
|
||||
input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i')
|
||||
|
||||
[R_decom, I_decom] = DecomNet_simple(input_decom)
|
||||
#the output of decomposition network
|
||||
decom_output_R = R_decom
|
||||
decom_output_I = I_decom
|
||||
#the output of illumination adjustment net
|
||||
output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio)
|
||||
|
||||
#define loss
|
||||
|
||||
def grad_loss(input_i_low, input_i_high):
|
||||
x_loss = tf.square(gradient(input_i_low, 'x') - gradient(input_i_high, 'x'))
|
||||
y_loss = tf.square(gradient(input_i_low, 'y') - gradient(input_i_high, 'y'))
|
||||
grad_loss_all = tf.reduce_mean(x_loss + y_loss)
|
||||
return grad_loss_all
|
||||
|
||||
loss_grad = grad_loss(output_i, input_high_i)
|
||||
loss_square = tf.reduce_mean(tf.square(output_i - input_high_i))# * ( 1 - input_low_r ))#* (1- input_low_i)))
|
||||
|
||||
loss_adjust = loss_square + loss_grad
|
||||
|
||||
lr = tf.placeholder(tf.float32, name='learning_rate')
|
||||
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=lr, name='AdamOptimizer')
|
||||
|
||||
var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
|
||||
var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name]
|
||||
|
||||
saver_adjust = tf.train.Saver(var_list=var_adjust)
|
||||
saver_Decom = tf.train.Saver(var_list = var_Decom)
|
||||
train_op_adjust = optimizer.minimize(loss_adjust, var_list = var_adjust)
|
||||
sess.run(tf.global_variables_initializer())
|
||||
print("[*] Initialize model successfully...")
|
||||
|
||||
### load data
|
||||
### Based on the decomposition net, we first get the decomposed reflectance maps
|
||||
### and illumination maps, then train the adjust net.
|
||||
###train_data
|
||||
train_low_data = []
|
||||
train_high_data = []
|
||||
train_low_data_names = glob('./LOLdataset/our485/low/*.png')
|
||||
train_low_data_names.sort()
|
||||
train_high_data_names = glob('./LOLdataset/our485/high/*.png')
|
||||
train_high_data_names.sort()
|
||||
assert len(train_low_data_names) == len(train_high_data_names)
|
||||
print('[*] Number of training data: %d' % len(train_low_data_names))
|
||||
for idx in range(len(train_low_data_names)):
|
||||
low_im = load_images(train_low_data_names[idx])
|
||||
train_low_data.append(low_im)
|
||||
high_im = load_images(train_high_data_names[idx])
|
||||
train_high_data.append(high_im)
|
||||
|
||||
pre_decom_checkpoint_dir = './checkpoint/decom_net_train/'
|
||||
ckpt_pre=tf.train.get_checkpoint_state(pre_decom_checkpoint_dir)
|
||||
if ckpt_pre:
|
||||
print('loaded '+ckpt_pre.model_checkpoint_path)
|
||||
saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path)
|
||||
else:
|
||||
print('No pre_decom_net checkpoint!')
|
||||
|
||||
#decomposed_low_r_data_480 = []
|
||||
decomposed_low_i_data_480 = []
|
||||
#decomposed_high_r_data_480 = []
|
||||
decomposed_high_i_data_480 = []
|
||||
for idx in range(len(train_low_data)):
|
||||
input_low = np.expand_dims(train_low_data[idx], axis=0)
|
||||
RR, II = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low})
|
||||
RR0 = np.squeeze(RR)
|
||||
II0 = np.squeeze(II)
|
||||
print(RR0.shape, II0.shape)
|
||||
#decomposed_high_r_data_480.append(result_1_sq)
|
||||
decomposed_low_i_data_480.append(II0)
|
||||
for idx in range(len(train_high_data)):
|
||||
input_high = np.expand_dims(train_high_data[idx], axis=0)
|
||||
RR2, II2 = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_high})
|
||||
RR02 = np.squeeze(RR2)
|
||||
II02 = np.squeeze(II2)
|
||||
print(RR02.shape, II02.shape)
|
||||
#decomposed_high_r_data_480.append(result_1_sq)
|
||||
decomposed_high_i_data_480.append(II02)
|
||||
|
||||
eval_adjust_low_i_data = decomposed_low_i_data_480[451:480]
|
||||
eval_adjust_high_i_data = decomposed_high_i_data_480[451:480]
|
||||
|
||||
train_adjust_low_i_data = decomposed_low_i_data_480[0:450]
|
||||
train_adjust_high_i_data = decomposed_high_i_data_480[0:450]
|
||||
|
||||
print('[*] Number of training data: %d' % len(train_adjust_high_i_data))
|
||||
|
||||
learning_rate = 0.0001
|
||||
epoch = 2000
|
||||
eval_every_epoch = 200
|
||||
train_phase = 'adjustment'
|
||||
numBatch = len(train_adjust_low_i_data) // int(batch_size)
|
||||
train_op = train_op_adjust
|
||||
train_loss = loss_adjust
|
||||
saver = saver_adjust
|
||||
|
||||
checkpoint_dir = './checkpoint/illumination_adjust_net_train/'
|
||||
if not os.path.isdir(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir)
|
||||
ckpt=tf.train.get_checkpoint_state(checkpoint_dir)
|
||||
if ckpt:
|
||||
print('loaded '+ckpt.model_checkpoint_path)
|
||||
saver.restore(sess,ckpt.model_checkpoint_path)
|
||||
else:
|
||||
print("No adjustment net pre model!")
|
||||
|
||||
start_step = 0
|
||||
start_epoch = 0
|
||||
iter_num = 0
|
||||
print("[*] Start training for phase %s, with start epoch %d start iter %d : " % (train_phase, start_epoch, iter_num))
|
||||
|
||||
sample_dir = './illumination_adjust_net_train/'
|
||||
if not os.path.isdir(sample_dir):
|
||||
os.makedirs(sample_dir)
|
||||
|
||||
start_time = time.time()
|
||||
image_id = 0
|
||||
|
||||
for epoch in range(start_epoch, epoch):
|
||||
for batch_id in range(start_step, numBatch):
|
||||
batch_input_low_i_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
|
||||
batch_input_high_i_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
|
||||
batch_input_low_i = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
|
||||
batch_input_high_i = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
|
||||
input_low_i_rand = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
|
||||
input_high_i_rand = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
|
||||
input_low_i_rand_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
|
||||
input_high_i_rand_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
|
||||
|
||||
for patch_id in range(batch_size):
|
||||
i_low_data = train_adjust_low_i_data[image_id]
|
||||
i_low_expand = np.expand_dims(i_low_data, axis = 2)
|
||||
i_high_data = train_adjust_high_i_data[image_id]
|
||||
i_high_expand = np.expand_dims(i_high_data, axis = 2)
|
||||
|
||||
h, w = train_adjust_low_i_data[image_id].shape
|
||||
x = random.randint(0, h - patch_size)
|
||||
y = random.randint(0, w - patch_size)
|
||||
i_low_data_crop = i_low_expand[x : x+patch_size, y : y+patch_size, :]
|
||||
i_high_data_crop = i_high_expand[x : x+patch_size, y : y+patch_size, :]
|
||||
|
||||
rand_mode = np.random.randint(0, 7)
|
||||
batch_input_low_i[patch_id, :, :, :] = data_augmentation(i_low_data_crop , rand_mode)
|
||||
batch_input_high_i[patch_id, :, :, :] = data_augmentation(i_high_data_crop, rand_mode)
|
||||
|
||||
ratio = np.mean(i_low_data_crop/(i_high_data_crop+0.0001))
|
||||
#print(ratio)
|
||||
i_low_data_ratio = np.ones([patch_size,patch_size])*(1/ratio+0.0001)
|
||||
i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2)
|
||||
i_high_data_ratio = np.ones([patch_size,patch_size])*(ratio)
|
||||
i_high_ratio_expand = np.expand_dims(i_high_data_ratio , axis =2)
|
||||
batch_input_low_i_ratio[patch_id, :, :, :] = i_low_ratio_expand
|
||||
batch_input_high_i_ratio[patch_id, :, :, :] = i_high_ratio_expand
|
||||
|
||||
rand_mode = np.random.randint(0, 2)
|
||||
if rand_mode == 1:
|
||||
input_low_i_rand[patch_id, :, :, :] = batch_input_low_i[patch_id, :, :, :]
|
||||
input_high_i_rand[patch_id, :, :, :] = batch_input_high_i[patch_id, :, :, :]
|
||||
input_low_i_rand_ratio[patch_id, :, :, :] = batch_input_low_i_ratio[patch_id, :, :, :]
|
||||
input_high_i_rand_ratio[patch_id, :, :, :] = batch_input_high_i_ratio[patch_id, :, :, :]
|
||||
else:
|
||||
input_low_i_rand[patch_id, :, :, :] = batch_input_high_i[patch_id, :, :, :]
|
||||
input_high_i_rand[patch_id, :, :, :] = batch_input_low_i[patch_id, :, :, :]
|
||||
input_low_i_rand_ratio[patch_id, :, :, :] = batch_input_high_i_ratio[patch_id, :, :, :]
|
||||
input_high_i_rand_ratio[patch_id, :, :, :] = batch_input_low_i_ratio[patch_id, :, :, :]
|
||||
|
||||
image_id = (image_id + 1) % len(train_adjust_low_i_data)
|
||||
|
||||
_, loss = sess.run([train_op, train_loss], feed_dict={input_low_i: input_low_i_rand,input_low_i_ratio: input_low_i_rand_ratio,\
|
||||
input_high_i: input_high_i_rand, \
|
||||
lr: learning_rate})
|
||||
print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
|
||||
% (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
|
||||
iter_num += 1
|
||||
if (epoch + 1) % eval_every_epoch == 0:
|
||||
print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch + 1))
|
||||
|
||||
for idx in range(10):
|
||||
rand_idx = idx#np.random.randint(26)
|
||||
input_uu_i = eval_adjust_low_i_data[rand_idx]
|
||||
input_low_eval_i = np.expand_dims(input_uu_i, axis=0)
|
||||
input_low_eval_ii = np.expand_dims(input_low_eval_i, axis=3)
|
||||
h, w = eval_adjust_low_i_data[idx].shape
|
||||
rand_ratio = np.random.random(1)*2
|
||||
input_uu_i_ratio = np.ones([h,w]) * rand_ratio
|
||||
input_low_eval_i_ratio = np.expand_dims(input_uu_i_ratio, axis=0)
|
||||
input_low_eval_ii_ratio = np.expand_dims(input_low_eval_i_ratio, axis=3)
|
||||
|
||||
result_1 = sess.run(output_i, feed_dict={input_low_i: input_low_eval_ii, input_low_i_ratio: input_low_eval_ii_ratio})
|
||||
save_images(os.path.join(sample_dir, 'h_eval_%d_%d_%5f.png' % ( epoch + 1 , rand_idx + 1, rand_ratio)), input_uu_i, result_1)
|
||||
|
||||
|
||||
saver.save(sess, checkpoint_dir + 'model.ckpt')
|
||||
|
||||
print("[*] Finish training for phase %s." % train_phase)
|
||||
|
||||
|
||||
|
@ -0,0 +1 @@
|
||||
model_checkpoint_path: "model.ckpt"
|
@ -0,0 +1,2 @@
|
||||
model_checkpoint_path: "model.ckpt"
|
||||
all_model_checkpoint_paths: "model.ckpt"
|
@ -0,0 +1,2 @@
|
||||
model_checkpoint_path: "model.ckpt"
|
||||
all_model_checkpoint_paths: "model.ckpt"
|
@ -0,0 +1,174 @@
|
||||
# coding: utf-8
|
||||
from __future__ import print_function
|
||||
import os, time, random
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from utils import *
|
||||
from model import *
|
||||
from glob import glob
|
||||
|
||||
batch_size = 10
|
||||
patch_size = 48
|
||||
|
||||
sess = tf.Session()
|
||||
|
||||
input_low = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low')
|
||||
input_high = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high')
|
||||
|
||||
[R_low, I_low] = DecomNet_simple(input_low)
|
||||
[R_high, I_high] = DecomNet_simple(input_high)
|
||||
|
||||
I_low_3 = tf.concat([I_low, I_low, I_low], axis=3)
|
||||
I_high_3 = tf.concat([I_high, I_high, I_high], axis=3)
|
||||
|
||||
#network output
|
||||
output_R_low = R_low
|
||||
output_R_high = R_high
|
||||
output_I_low = I_low_3
|
||||
output_I_high = I_high_3
|
||||
|
||||
# define loss
|
||||
|
||||
def mutual_i_loss(input_I_low, input_I_high):
|
||||
low_gradient_x = gradient(input_I_low, "x")
|
||||
high_gradient_x = gradient(input_I_high, "x")
|
||||
x_loss = (low_gradient_x + high_gradient_x)* tf.exp(-10*(low_gradient_x+high_gradient_x))
|
||||
low_gradient_y = gradient(input_I_low, "y")
|
||||
high_gradient_y = gradient(input_I_high, "y")
|
||||
y_loss = (low_gradient_y + high_gradient_y) * tf.exp(-10*(low_gradient_y+high_gradient_y))
|
||||
mutual_loss = tf.reduce_mean( x_loss + y_loss)
|
||||
return mutual_loss
|
||||
|
||||
def mutual_i_input_loss(input_I_low, input_im):
|
||||
input_gray = tf.image.rgb_to_grayscale(input_im)
|
||||
low_gradient_x = gradient(input_I_low, "x")
|
||||
input_gradient_x = gradient(input_gray, "x")
|
||||
x_loss = tf.abs(tf.div(low_gradient_x, tf.maximum(input_gradient_x, 0.01)))
|
||||
low_gradient_y = gradient(input_I_low, "y")
|
||||
input_gradient_y = gradient(input_gray, "y")
|
||||
y_loss = tf.abs(tf.div(low_gradient_y, tf.maximum(input_gradient_y, 0.01)))
|
||||
mut_loss = tf.reduce_mean(x_loss + y_loss)
|
||||
return mut_loss
|
||||
|
||||
recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - input_low))
|
||||
recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - input_high))
|
||||
|
||||
equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high))
|
||||
|
||||
i_mutual_loss = mutual_i_loss(I_low, I_high)
|
||||
|
||||
i_input_mutual_loss_high = mutual_i_input_loss(I_high, input_high)
|
||||
i_input_mutual_loss_low = mutual_i_input_loss(I_low, input_low)
|
||||
|
||||
loss_Decom = 1*recon_loss_high + 1*recon_loss_low \
|
||||
+ 0.01 * equal_R_loss + 0.2*i_mutual_loss \
|
||||
+ 0.15* i_input_mutual_loss_high + 0.15* i_input_mutual_loss_low
|
||||
|
||||
###
|
||||
lr = tf.placeholder(tf.float32, name='learning_rate')
|
||||
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=lr, name='AdamOptimizer')
|
||||
var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
|
||||
|
||||
train_op_Decom = optimizer.minimize(loss_Decom, var_list = var_Decom)
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
saver_Decom = tf.train.Saver(var_list = var_Decom)
|
||||
print("[*] Initialize model successfully...")
|
||||
|
||||
#load data
|
||||
###train_data
|
||||
train_low_data = []
|
||||
train_high_data = []
|
||||
train_low_data_names = glob('./LOLdataset/our485/low/*.png')
|
||||
train_low_data_names.sort()
|
||||
train_high_data_names = glob('./LOLdataset/our485/high/*.png')
|
||||
train_high_data_names.sort()
|
||||
assert len(train_low_data_names) == len(train_high_data_names)
|
||||
print('[*] Number of training data: %d' % len(train_low_data_names))
|
||||
for idx in range(len(train_low_data_names)):
|
||||
low_im = load_images(train_low_data_names[idx])
|
||||
train_low_data.append(low_im)
|
||||
high_im = load_images(train_high_data_names[idx])
|
||||
train_high_data.append(high_im)
|
||||
###eval_data
|
||||
eval_low_data = []
|
||||
eval_high_data = []
|
||||
eval_low_data_name = glob('./LOLdataset/eval15/low/*.png')
|
||||
eval_low_data_name.sort()
|
||||
eval_high_data_name = glob('./LOLdataset/eval15/high/*.png*')
|
||||
eval_high_data_name.sort()
|
||||
for idx in range(len(eval_low_data_name)):
|
||||
eval_low_im = load_images(eval_low_data_name[idx])
|
||||
eval_low_data.append(eval_low_im)
|
||||
eval_high_im = load_images(eval_high_data_name[idx])
|
||||
eval_high_data.append(eval_high_im)
|
||||
|
||||
|
||||
epoch = 2000
|
||||
learning_rate = 0.0001
|
||||
|
||||
sample_dir = './Decom_net_train/'
|
||||
if not os.path.isdir(sample_dir):
|
||||
os.makedirs(sample_dir)
|
||||
|
||||
eval_every_epoch = 200
|
||||
train_phase = 'decomposition'
|
||||
numBatch = len(train_low_data) // int(batch_size)
|
||||
train_op = train_op_Decom
|
||||
train_loss = loss_Decom
|
||||
saver = saver_Decom
|
||||
|
||||
checkpoint_dir = './checkpoint/decom_net_train/'
|
||||
if not os.path.isdir(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir)
|
||||
ckpt=tf.train.get_checkpoint_state(checkpoint_dir)
|
||||
if ckpt:
|
||||
print('loaded '+ckpt.model_checkpoint_path)
|
||||
saver.restore(sess,ckpt.model_checkpoint_path)
|
||||
|
||||
start_step = 0
|
||||
start_epoch = 0
|
||||
iter_num = 0
|
||||
print("[*] Start training for phase %s, with start epoch %d start iter %d : " % (train_phase, start_epoch, iter_num))
|
||||
|
||||
start_time = time.time()
|
||||
image_id = 0
|
||||
for epoch in range(start_epoch, epoch):
|
||||
for batch_id in range(start_step, numBatch):
|
||||
batch_input_low = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
|
||||
batch_input_high = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
|
||||
for patch_id in range(batch_size):
|
||||
h, w, _ = train_low_data[image_id].shape
|
||||
x = random.randint(0, h - patch_size)
|
||||
y = random.randint(0, w - patch_size)
|
||||
rand_mode = random.randint(0, 7)
|
||||
batch_input_low[patch_id, :, :, :] = data_augmentation(train_low_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)
|
||||
batch_input_high[patch_id, :, :, :] = data_augmentation(train_high_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)
|
||||
image_id = (image_id + 1) % len(train_low_data)
|
||||
if image_id == 0:
|
||||
tmp = list(zip(train_low_data, train_high_data))
|
||||
random.shuffle(tmp)
|
||||
train_low_data, train_high_data = zip(*tmp)
|
||||
|
||||
_, loss = sess.run([train_op, train_loss], feed_dict={input_low: batch_input_low, \
|
||||
input_high: batch_input_high, \
|
||||
lr: learning_rate})
|
||||
print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
|
||||
% (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
|
||||
iter_num += 1
|
||||
if (epoch + 1) % eval_every_epoch == 0:
|
||||
print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch + 1))
|
||||
for idx in range(len(eval_low_data)):
|
||||
input_low_eval = np.expand_dims(eval_low_data[idx], axis=0)
|
||||
result_1, result_2 = sess.run([output_R_low, output_I_low], feed_dict={input_low: input_low_eval})
|
||||
save_images(os.path.join(sample_dir, 'low_%d_%d.png' % ( idx + 1, epoch + 1)), result_1, result_2)
|
||||
for idx in range(len(eval_high_data)):
|
||||
input_high_eval = np.expand_dims(eval_high_data[idx], axis=0)
|
||||
result_11, result_22 = sess.run([output_R_high, output_I_high], feed_dict={input_high: input_high_eval})
|
||||
save_images(os.path.join(sample_dir, 'high_%d_%d.png' % ( idx + 1, epoch + 1)), result_11, result_22)
|
||||
|
||||
saver.save(sess, checkpoint_dir + 'model.ckpt')
|
||||
|
||||
print("[*] Finish training for phase %s." % train_phase)
|
@ -0,0 +1,114 @@
|
||||
# coding: utf-8
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
from PIL import Image
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from utils import *
|
||||
from model import *
|
||||
from glob import glob
|
||||
from skimage import color,filters
|
||||
|
||||
sess = tf.Session()
|
||||
|
||||
input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom')
|
||||
input_low_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low_r')
|
||||
input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i')
|
||||
input_high_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high_r')
|
||||
input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i')
|
||||
input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio')
|
||||
|
||||
[R_decom, I_decom] = DecomNet_simple(input_decom)
|
||||
decom_output_R = R_decom
|
||||
decom_output_I = I_decom
|
||||
output_r = Restoration_net(input_low_r, input_low_i)
|
||||
output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio)
|
||||
|
||||
var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
|
||||
var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name]
|
||||
var_restoration = [var for var in tf.trainable_variables() if 'Restoration_net' in var.name]
|
||||
|
||||
saver_Decom = tf.train.Saver(var_list = var_Decom)
|
||||
saver_adjust = tf.train.Saver(var_list=var_adjust)
|
||||
saver_restoration = tf.train.Saver(var_list=var_restoration)
|
||||
|
||||
decom_checkpoint_dir ='./checkpoint/decom_net_train/'
|
||||
ckpt_pre=tf.train.get_checkpoint_state(decom_checkpoint_dir)
|
||||
if ckpt_pre:
|
||||
print('loaded '+ckpt_pre.model_checkpoint_path)
|
||||
saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path)
|
||||
else:
|
||||
print('No decomnet checkpoint!')
|
||||
|
||||
checkpoint_dir_adjust = './checkpoint/illumination_adjust_net_train/'
|
||||
ckpt_adjust=tf.train.get_checkpoint_state(checkpoint_dir_adjust)
|
||||
if ckpt_adjust:
|
||||
print('loaded '+ckpt_adjust.model_checkpoint_path)
|
||||
saver_adjust.restore(sess,ckpt_adjust.model_checkpoint_path)
|
||||
else:
|
||||
print("No adjust pre model!")
|
||||
|
||||
checkpoint_dir_restoration = './checkpoint/Restoration_net_train/'
|
||||
ckpt=tf.train.get_checkpoint_state(checkpoint_dir_restoration)
|
||||
if ckpt:
|
||||
print('loaded '+ckpt.model_checkpoint_path)
|
||||
saver_restoration.restore(sess,ckpt.model_checkpoint_path)
|
||||
else:
|
||||
print("No restoration pre model!")
|
||||
|
||||
###load eval data
|
||||
eval_low_data = []
|
||||
eval_img_name =[]
|
||||
eval_low_data_name = glob('./test/*')
|
||||
eval_low_data_name.sort()
|
||||
for idx in range(len(eval_low_data_name)):
|
||||
[_, name] = os.path.split(eval_low_data_name[idx])
|
||||
suffix = name[name.find('.') + 1:]
|
||||
name = name[:name.find('.')]
|
||||
eval_img_name.append(name)
|
||||
eval_low_im = load_images(eval_low_data_name[idx])
|
||||
eval_low_data.append(eval_low_im)
|
||||
print(eval_low_im.shape)
|
||||
|
||||
sample_dir = './results/test/'
|
||||
if not os.path.isdir(sample_dir):
|
||||
os.makedirs(sample_dir)
|
||||
|
||||
print("Start evalating!")
|
||||
start_time = time.time()
|
||||
for idx in range(len(eval_low_data)):
|
||||
print(idx)
|
||||
name = eval_img_name[idx]
|
||||
input_low = eval_low_data[idx]
|
||||
input_low_eval = np.expand_dims(input_low, axis=0)
|
||||
h, w, _ = input_low.shape
|
||||
|
||||
decom_r_low, decom_i_low = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low_eval})
|
||||
|
||||
restoration_r = sess.run(output_r, feed_dict={input_low_r: decom_r_low, input_low_i: decom_i_low})
|
||||
### change the ratio to get different exposure level, the value can be 0-5.0
|
||||
ratio = 5.0
|
||||
i_low_data_ratio = np.ones([h, w])*(ratio)
|
||||
i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2)
|
||||
i_low_ratio_expand2 = np.expand_dims(i_low_ratio_expand, axis=0)
|
||||
adjust_i = sess.run(output_i, feed_dict={input_low_i: decom_i_low, input_low_i_ratio: i_low_ratio_expand2})
|
||||
|
||||
#The restoration result can find more details from very dark regions, however, it will restore the very dark regions
|
||||
#with gray colors, we use the following operator to alleviate this weakness.
|
||||
decom_r_sq = np.squeeze(decom_r_low)
|
||||
r_gray = color.rgb2gray(decom_r_sq)
|
||||
r_gray_gaussion = filters.gaussian(r_gray, 3)
|
||||
low_i = np.minimum((r_gray_gaussion*2)**0.5,1)
|
||||
low_i_expand_0 = np.expand_dims(low_i, axis = 0)
|
||||
low_i_expand_3 = np.expand_dims(low_i_expand_0, axis = 3)
|
||||
result_denoise = restoration_r*low_i_expand_3
|
||||
fusion4 = result_denoise*adjust_i
|
||||
|
||||
#fusion = restoration_r*adjust_i
|
||||
# fuse with the original input to avoid over-exposure
|
||||
fusion2 = decom_i_low*input_low_eval + (1-decom_i_low)*fusion4
|
||||
#print(fusion2.shape)
|
||||
save_images(os.path.join(sample_dir, '%s_kindle.png' % (name)), fusion2)
|
||||
|
@ -0,0 +1,110 @@
|
||||
# coding: utf-8
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
from PIL import Image
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from utils import *
|
||||
from model import *
|
||||
from glob import glob
|
||||
|
||||
sess = tf.Session()
|
||||
|
||||
input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom')
|
||||
input_low_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low_r')
|
||||
input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i')
|
||||
input_high_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high_r')
|
||||
input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i')
|
||||
input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio')
|
||||
|
||||
[R_decom, I_decom] = DecomNet_simple(input_decom)
|
||||
decom_output_R = R_decom
|
||||
decom_output_I = I_decom
|
||||
output_r = Restoration_net(input_low_r, input_low_i)
|
||||
output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio)
|
||||
|
||||
var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
|
||||
var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name]
|
||||
var_restoration = [var for var in tf.trainable_variables() if 'Restoration_net' in var.name]
|
||||
|
||||
saver_Decom = tf.train.Saver(var_list = var_Decom)
|
||||
saver_adjust = tf.train.Saver(var_list=var_adjust)
|
||||
saver_restoration = tf.train.Saver(var_list=var_restoration)
|
||||
|
||||
decom_checkpoint_dir ='./checkpoint/decom_net_train/'
|
||||
ckpt_pre=tf.train.get_checkpoint_state(decom_checkpoint_dir)
|
||||
if ckpt_pre:
|
||||
print('loaded '+ckpt_pre.model_checkpoint_path)
|
||||
saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path)
|
||||
else:
|
||||
print('No decomnet checkpoint!')
|
||||
|
||||
checkpoint_dir_adjust = './checkpoint/illumination_adjust_net_train/'
|
||||
ckpt_adjust=tf.train.get_checkpoint_state(checkpoint_dir_adjust)
|
||||
if ckpt_adjust:
|
||||
print('loaded '+ckpt_adjust.model_checkpoint_path)
|
||||
saver_adjust.restore(sess,ckpt_adjust.model_checkpoint_path)
|
||||
else:
|
||||
print("No adjust pre model!")
|
||||
|
||||
checkpoint_dir_restoration = './checkpoint/Restoration_net_train/'
|
||||
ckpt=tf.train.get_checkpoint_state(checkpoint_dir_restoration)
|
||||
if ckpt:
|
||||
print('loaded '+ckpt.model_checkpoint_path)
|
||||
saver_restoration.restore(sess,ckpt.model_checkpoint_path)
|
||||
else:
|
||||
print("No restoration pre model!")
|
||||
|
||||
###load eval data
|
||||
eval_low_data = []
|
||||
eval_img_name =[]
|
||||
eval_low_data_name = glob('./LOLdataset/eval15/low/*.png')
|
||||
eval_low_data_name.sort()
|
||||
for idx in range(len(eval_low_data_name)):
|
||||
[_, name] = os.path.split(eval_low_data_name[idx])
|
||||
suffix = name[name.find('.') + 1:]
|
||||
name = name[:name.find('.')]
|
||||
eval_img_name.append(name)
|
||||
eval_low_im = load_images(eval_low_data_name[idx])
|
||||
eval_low_data.append(eval_low_im)
|
||||
print(eval_low_im.shape)
|
||||
# To get better results, the illumination adjustment ratio is computed based on the decom_i_high, so we also need the high data.
|
||||
eval_high_data = []
|
||||
eval_high_data_name = glob('./LOLdataset/eval15/high/*.png')
|
||||
eval_high_data_name.sort()
|
||||
for idx in range(len(eval_high_data_name)):
|
||||
eval_high_im = load_images(eval_high_data_name[idx])
|
||||
eval_high_data.append(eval_high_im)
|
||||
|
||||
sample_dir = './results/LOLdataset_eval15/'
|
||||
if not os.path.isdir(sample_dir):
|
||||
os.makedirs(sample_dir)
|
||||
|
||||
print("Start evalating!")
|
||||
start_time = time.time()
|
||||
for idx in range(len(eval_low_data)):
|
||||
print(idx)
|
||||
name = eval_img_name[idx]
|
||||
input_low = eval_low_data[idx]
|
||||
input_low_eval = np.expand_dims(input_low, axis=0)
|
||||
input_high = eval_high_data[idx]
|
||||
input_high_eval = np.expand_dims(input_high, axis=0)
|
||||
h, w, _ = input_low.shape
|
||||
|
||||
decom_r_low, decom_i_low = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low_eval})
|
||||
decom_r_high, decom_i_high = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_high_eval})
|
||||
|
||||
restoration_r = sess.run(output_r, feed_dict={input_low_r: decom_r_low, input_low_i: decom_i_low})
|
||||
|
||||
ratio = np.mean(((decom_i_high))/(decom_i_low+0.0001))
|
||||
|
||||
i_low_data_ratio = np.ones([h, w])*(ratio)
|
||||
i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2)
|
||||
i_low_ratio_expand2 = np.expand_dims(i_low_ratio_expand, axis=0)
|
||||
|
||||
adjust_i = sess.run(output_i, feed_dict={input_low_i: decom_i_low, input_low_i_ratio: i_low_ratio_expand2})
|
||||
fusion = restoration_r*adjust_i
|
||||
save_images(os.path.join(sample_dir, '%s_kindle.png' % (name)), fusion)
|
||||
|
After Width: | Height: | Size: 205 KiB |
After Width: | Height: | Size: 174 KiB |
After Width: | Height: | Size: 67 KiB |
After Width: | Height: | Size: 29 KiB |
After Width: | Height: | Size: 74 KiB |
After Width: | Height: | Size: 68 KiB |
After Width: | Height: | Size: 76 KiB |
After Width: | Height: | Size: 63 KiB |
After Width: | Height: | Size: 75 KiB |
After Width: | Height: | Size: 82 KiB |
After Width: | Height: | Size: 91 KiB |
After Width: | Height: | Size: 99 KiB |
After Width: | Height: | Size: 95 KiB |
After Width: | Height: | Size: 56 KiB |
After Width: | Height: | Size: 69 KiB |
After Width: | Height: | Size: 66 KiB |
After Width: | Height: | Size: 41 KiB |
After Width: | Height: | Size: 41 KiB |
After Width: | Height: | Size: 73 KiB |
After Width: | Height: | Size: 64 KiB |
After Width: | Height: | Size: 108 KiB |
After Width: | Height: | Size: 106 KiB |