You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
5.2 KiB
97 lines
5.2 KiB
import tensorflow as tf
|
|
import tensorflow.contrib.slim as slim
|
|
from tensorflow.contrib.layers.python.layers import initializers
|
|
|
|
def lrelu(x, trainbable=None):
|
|
return tf.maximum(x*0.2,x)
|
|
|
|
def upsample_and_concat(x1, x2, output_channels, in_channels, scope_name, trainable=True):
|
|
with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
|
|
pool_size = 2
|
|
deconv_filter = tf.get_variable('weights', [pool_size, pool_size, output_channels, in_channels], trainable= True)
|
|
deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2) , strides=[1, pool_size, pool_size, 1], name=scope_name)
|
|
|
|
deconv_output = tf.concat([deconv, x2],3)
|
|
deconv_output.set_shape([None, None, None, output_channels*2])
|
|
|
|
return deconv_output
|
|
|
|
def DecomNet_simple(input):
|
|
with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE):
|
|
conv1=slim.conv2d(input,32,[3,3], rate=1, activation_fn=lrelu,scope='g_conv1_1')
|
|
pool1=slim.max_pool2d(conv1, [2, 2], stride = 2, padding='SAME' )
|
|
conv2=slim.conv2d(pool1,64,[3,3], rate=1, activation_fn=lrelu,scope='g_conv2_1')
|
|
pool2=slim.max_pool2d(conv2, [2, 2], stride = 2, padding='SAME' )
|
|
conv3=slim.conv2d(pool2,128,[3,3], rate=1, activation_fn=lrelu,scope='g_conv3_1')
|
|
up8 = upsample_and_concat( conv3, conv2, 64, 128 , 'g_up_1')
|
|
conv8=slim.conv2d(up8, 64,[3,3], rate=1, activation_fn=lrelu,scope='g_conv8_1')
|
|
up9 = upsample_and_concat( conv8, conv1, 32, 64 , 'g_up_2')
|
|
conv9=slim.conv2d(up9, 32,[3,3], rate=1, activation_fn=lrelu,scope='g_conv9_1')
|
|
# Here, we use 1*1 kernel to replace the 3*3 ones in the paper to get better results.
|
|
conv10=slim.conv2d(conv9,3,[1,1], rate=1, activation_fn=None, scope='g_conv10')
|
|
R_out = tf.sigmoid(conv10)
|
|
|
|
l_conv2=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='l_conv1_2')
|
|
l_conv3=tf.concat([l_conv2, conv9],3)
|
|
# Here, we use 1*1 kernel to replace the 3*3 ones in the paper to get better results.
|
|
l_conv4=slim.conv2d(l_conv3,1,[1,1], rate=1, activation_fn=None,scope='l_conv1_4')
|
|
L_out = tf.sigmoid(l_conv4)
|
|
|
|
return R_out, L_out
|
|
|
|
def Restoration_net(input_r, input_i):
|
|
with tf.variable_scope('Restoration_net', reuse=tf.AUTO_REUSE):
|
|
input_all = tf.concat([input_r,input_i], 3)
|
|
|
|
conv1=slim.conv2d(input_all,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv1_1')
|
|
conv1=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv1_2')
|
|
pool1=slim.max_pool2d(conv1, [2, 2], padding='SAME' )
|
|
|
|
conv2=slim.conv2d(pool1,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv2_1')
|
|
conv2=slim.conv2d(conv2,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv2_2')
|
|
pool2=slim.max_pool2d(conv2, [2, 2], padding='SAME' )
|
|
|
|
conv3=slim.conv2d(pool2,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv3_1')
|
|
conv3=slim.conv2d(conv3,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv3_2')
|
|
pool3=slim.max_pool2d(conv3, [2, 2], padding='SAME' )
|
|
|
|
conv4=slim.conv2d(pool3,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv4_1')
|
|
conv4=slim.conv2d(conv4,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv4_2')
|
|
pool4=slim.max_pool2d(conv4, [2, 2], padding='SAME' )
|
|
|
|
conv5=slim.conv2d(pool4,512,[3,3], rate=1, activation_fn=lrelu,scope='de_conv5_1')
|
|
conv5=slim.conv2d(conv5,512,[3,3], rate=1, activation_fn=lrelu,scope='de_conv5_2')
|
|
|
|
up6 = upsample_and_concat( conv5, conv4, 256, 512, 'up_6')
|
|
|
|
conv6=slim.conv2d(up6, 256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv6_1')
|
|
conv6=slim.conv2d(conv6,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv6_2')
|
|
|
|
up7 = upsample_and_concat( conv6, conv3, 128, 256, 'up_7' )
|
|
conv7=slim.conv2d(up7, 128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv7_1')
|
|
conv7=slim.conv2d(conv7,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv7_2')
|
|
|
|
up8 = upsample_and_concat( conv7, conv2, 64, 128, 'up_8' )
|
|
conv8=slim.conv2d(up8, 64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv8_1')
|
|
conv8=slim.conv2d(conv8,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv8_2')
|
|
|
|
up9 = upsample_and_concat( conv8, conv1, 32, 64, 'up_9' )
|
|
conv9=slim.conv2d(up9, 32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv9_1')
|
|
conv9=slim.conv2d(conv9,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv9_2')
|
|
|
|
conv10=slim.conv2d(conv9,3,[3,3], rate=1, activation_fn=None, scope='de_conv10')
|
|
|
|
out = tf.sigmoid(conv10)
|
|
return out
|
|
|
|
def Illumination_adjust_net(input_i, input_ratio):
|
|
with tf.variable_scope('Illumination_adjust_net', reuse=tf.AUTO_REUSE):
|
|
input_all = tf.concat([input_i, input_ratio], 3)
|
|
|
|
conv1=slim.conv2d(input_all,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_1')
|
|
conv2=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_2')
|
|
conv3=slim.conv2d(conv2,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_3')
|
|
conv4=slim.conv2d(conv3,1,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_4')
|
|
|
|
L_enhance = tf.sigmoid(conv4)
|
|
return L_enhance |