You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

141 lines
4.9 KiB

4 years ago
import numpy as np
from PIL import Image
import tensorflow as tf
import scipy.stats as st
from skimage import io,data,color
from functools import reduce
def gauss_kernel(kernlen=21, nsig=3, channels=1):
interval = (2*nsig+1.)/(kernlen)
x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
kern1d = np.diff(st.norm.cdf(x))
kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
kernel = kernel_raw/kernel_raw.sum()
out_filter = np.array(kernel, dtype = np.float32)
out_filter = out_filter.reshape((kernlen, kernlen, 1, 1))
out_filter = np.repeat(out_filter, channels, axis = 2)
return out_filter
def tensor_size(tensor):
from operator import mul
return reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)
def blur(x):
kernel_var = gauss_kernel(21, 3, 3)
return tf.nn.depthwise_conv2d(x, kernel_var, [1, 1, 1, 1], padding='SAME')
def tensor_size(tensor):
from operator import mul
return reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)
def data_augmentation(image, mode):
if mode == 0:
# original
return image
elif mode == 1:
# flip up and down
return np.flipud(image)
elif mode == 2:
# rotate counterwise 90 degree
return np.rot90(image)
elif mode == 3:
# rotate 90 degree and flip up and down
image = np.rot90(image)
return np.flipud(image)
elif mode == 4:
# rotate 180 degree
return np.rot90(image, k=2)
elif mode == 5:
# rotate 180 degree and flip
image = np.rot90(image, k=2)
return np.flipud(image)
elif mode == 6:
# rotate 270 degree
return np.rot90(image, k=3)
elif mode == 7:
# rotate 270 degree and flip
image = np.rot90(image, k=3)
return np.flipud(image)
def load_images(file):
im = Image.open(file)
img = np.array(im, dtype="float32") / 255.0
img_max = np.max(img)
img_min = np.min(img)
img_norm = np.float32((img - img_min) / np.maximum((img_max - img_min), 0.001))
return img_norm
def gradient(input_tensor, direction):
smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1])
smooth_kernel_y = tf.transpose(smooth_kernel_x, [1, 0, 2, 3])
if direction == "x":
kernel = smooth_kernel_x
elif direction == "y":
kernel = smooth_kernel_y
gradient_orig = tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME'))
grad_min = tf.reduce_min(gradient_orig)
grad_max = tf.reduce_max(gradient_orig)
grad_norm = tf.div((gradient_orig - grad_min), (grad_max - grad_min + 0.0001))
return grad_norm
def _tf_fspecial_gauss(size, sigma):
"""Function to mimic the 'fspecial' gaussian MATLAB function
"""
x_data, y_data = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
x_data = np.expand_dims(x_data, axis=-1)
x_data = np.expand_dims(x_data, axis=-1)
y_data = np.expand_dims(y_data, axis=-1)
y_data = np.expand_dims(y_data, axis=-1)
x = tf.constant(x_data, dtype=tf.float32)
y = tf.constant(y_data, dtype=tf.float32)
g = tf.exp(-((x**2 + y**2)/(2.0*sigma**2)))
return g / tf.reduce_sum(g)
def tf_ssim(img1, img2, cs_map=False, mean_metric=True, size=11, sigma=1.5):
window = _tf_fspecial_gauss(size, sigma) # window shape [size, size]
K1 = 0.01
K2 = 0.03
L = 1 # depth of image (255 in case the image has a differnt scale)
C1 = (K1*L)**2
C2 = (K2*L)**2
mu1 = tf.nn.conv2d(img1, window, strides=[1,1,1,1], padding='VALID')
mu2 = tf.nn.conv2d(img2, window, strides=[1,1,1,1],padding='VALID')
mu1_sq = mu1*mu1
mu2_sq = mu2*mu2
mu1_mu2 = mu1*mu2
sigma1_sq = tf.nn.conv2d(img1*img1, window, strides=[1,1,1,1],padding='VALID') - mu1_sq
sigma2_sq = tf.nn.conv2d(img2*img2, window, strides=[1,1,1,1],padding='VALID') - mu2_sq
sigma12 = tf.nn.conv2d(img1*img2, window, strides=[1,1,1,1],padding='VALID') - mu1_mu2
if cs_map:
value = (((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*
(sigma1_sq + sigma2_sq + C2)),
(2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2))
else:
value = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*
(sigma1_sq + sigma2_sq + C2))
if mean_metric:
value = tf.reduce_mean(value)
return value
def save_images(filepath, result_1, result_2 = None, result_3 = None):
result_1 = np.squeeze(result_1)
result_2 = np.squeeze(result_2)
result_3 = np.squeeze(result_3)
if not result_2.any():
cat_image = result_1
else:
cat_image = np.concatenate([result_1, result_2], axis = 1)
if not result_3.any():
cat_image = cat_image
else:
cat_image = np.concatenate([cat_image, result_3], axis = 1)
im = Image.fromarray(np.clip(cat_image * 255.0, 0, 255.0).astype('uint8'))
im.save(filepath, 'png')