Fixes
3
.gitignore
vendored
@@ -1,3 +1,6 @@
|
||||
*checkpoint/
|
||||
params/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
23
README.md
@@ -2,11 +2,11 @@
|
||||
TensorFlow implementation of the Fast Super-Resolution Convolutional Neural Network (FSRCNN). This implements two models: FSRCNN which is more accurate but slower and FSRCNN-s which is faster but less accurate. Based on this [project](http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html).
|
||||
|
||||
## Prerequisites
|
||||
* Python 2.7
|
||||
* TensorFlow
|
||||
* Scipy version > 0.18
|
||||
* Python 3
|
||||
* TensorFlow-gpu >= 1.3
|
||||
* CUDA & cuDNN >= 6.0
|
||||
* h5py
|
||||
* PIL
|
||||
* Pillow
|
||||
|
||||
## Usage
|
||||
For training: `python main.py`
|
||||
@@ -17,7 +17,7 @@ To use FSCRNN-s instead of FSCRNN: `python main.py --fast True`
|
||||
|
||||
Can specify epochs, learning rate, data directory, etc:
|
||||
<br>
|
||||
`python main.py --epochs 10 --learning_rate 0.0001 --data_dir Train`
|
||||
`python main.py --epoch 100 --learning_rate 0.0002 --data_dir Train`
|
||||
<br>
|
||||
Check `main.py` for all the possible flags
|
||||
|
||||
@@ -27,23 +27,26 @@ Also includes script `expand_data.py` which scales and rotates all the images in
|
||||
|
||||
Original butterfly image:
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
Bicubic interpolated image:
|
||||
Ewa_lanczos interpolated image:
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
Super-resolved image:
|
||||
|
||||

|
||||

|
||||
|
||||
## Additional datasets
|
||||
|
||||
* [General-100](https://drive.google.com/open?id=0B7tU5Pj1dfCMVVdJelZqV0prWnM)
|
||||
|
||||
## TODO
|
||||
|
||||
* Add RGB support (Increase each layer depth to 3)
|
||||
* Speed up pre-processing for large datasets
|
||||
* Set learning rate for deconvolutional layer to 1e-4 (vs 1e-3 for the rest)
|
||||
|
||||
## References
|
||||
|
||||
|
||||
BIN
Train/t51.bmp
Normal file → Executable file
|
Before Width: | Height: | Size: 246 KiB After Width: | Height: | Size: 120 KiB |
BIN
Train/t65.bmp
Normal file → Executable file
|
Before Width: | Height: | Size: 111 KiB After Width: | Height: | Size: 90 KiB |
@@ -13,7 +13,7 @@ def main():
|
||||
print("Missing argument: You must specify a folder with images to expand")
|
||||
return
|
||||
|
||||
for i in xrange(len(data)):
|
||||
for i in range(len(data)):
|
||||
scale(data[i])
|
||||
rotate(data[i])
|
||||
|
||||
@@ -45,4 +45,4 @@ def rotate(file):
|
||||
new_image.save(new_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
8
main.py
@@ -9,13 +9,13 @@ import os
|
||||
flags = tf.app.flags
|
||||
flags.DEFINE_boolean("fast", False, "Use the fast model (FSRCNN-s) [False]")
|
||||
flags.DEFINE_integer("epoch", 10, "Number of epochs [10]")
|
||||
flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
|
||||
flags.DEFINE_integer("batch_size", 32, "The size of batch images [32]")
|
||||
flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of the adam optimizer [1e-4]")
|
||||
flags.DEFINE_integer("c_dim", 1, "Dimension of image color [1]")
|
||||
flags.DEFINE_integer("scale", 3, "The size of scale factor for preprocessing input image [3]")
|
||||
flags.DEFINE_integer("stride", 4, "The size of stride to apply to input image [4]")
|
||||
flags.DEFINE_integer("scale", 2, "The size of scale factor for preprocessing input image [2]")
|
||||
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
|
||||
flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]")
|
||||
flags.DEFINE_string("data_dir", "FastTrain", "Name of data directory to train on [FastTrain]")
|
||||
flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]")
|
||||
flags.DEFINE_boolean("train", True, "True for training, false for testing [True]")
|
||||
flags.DEFINE_integer("threads", 1, "Number of processes to pre-process data with [1]")
|
||||
flags.DEFINE_boolean("params", False, "Save weight and bias parameters [False]")
|
||||
|
||||
48
model.py
@@ -30,20 +30,22 @@ class FSRCNN(object):
|
||||
self.is_grayscale = (self.c_dim == 1)
|
||||
self.epoch = config.epoch
|
||||
self.scale = config.scale
|
||||
self.stride = config.stride
|
||||
self.batch_size = config.batch_size
|
||||
self.learning_rate = config.learning_rate
|
||||
self.threads = config.threads
|
||||
self.params = config.params
|
||||
|
||||
# Different image/label sub-sizes for different scaling factors x2, x3, x4
|
||||
scale_factors = [[14, 20], [11, 21], [10, 24]]
|
||||
scale_factors = [[24, 40], [18, 42], [16, 48]]
|
||||
self.image_size, self.label_size = scale_factors[self.scale - 2]
|
||||
# Testing uses different strides to ensure sub-images line up correctly
|
||||
if not self.train:
|
||||
self.stride = [10, 7, 6][self.scale - 2]
|
||||
self.stride = [20, 14, 12][self.scale - 2]
|
||||
else:
|
||||
self.stride = [12, 8, 7][self.scale - 2]
|
||||
|
||||
# Different model layer counts and filter sizes for FSRCNN vs FSRCNN-s (fast), (s, d, m) in paper
|
||||
model_params = [[56, 12, 4], [32, 5, 1]]
|
||||
# Different model layer counts and filter sizes for FSRCNN vs FSRCNN-s (fast), (d, s, m) in paper
|
||||
model_params = [[56, 12, 4], [32, 8, 1]]
|
||||
self.model_params = model_params[self.fast]
|
||||
|
||||
self.checkpoint_dir = config.checkpoint_dir
|
||||
@@ -59,21 +61,21 @@ class FSRCNN(object):
|
||||
self.batch = tf.placeholder(tf.int32, shape=[], name='batch')
|
||||
|
||||
# FSCRNN-s (fast) has smaller filters and less layers but can achieve faster performance
|
||||
s, d, m = self.model_params
|
||||
d, s, m = self.model_params
|
||||
|
||||
expand_weight, deconv_weight = 'w{}'.format(m + 3), 'w{}'.format(m + 4)
|
||||
self.weights = {
|
||||
'w1': tf.Variable(tf.random_normal([5, 5, 1, s], stddev=0.0378, dtype=tf.float32), name='w1'),
|
||||
'w2': tf.Variable(tf.random_normal([1, 1, s, d], stddev=0.3536, dtype=tf.float32), name='w2'),
|
||||
expand_weight: tf.Variable(tf.random_normal([1, 1, d, s], stddev=0.189, dtype=tf.float32), name=expand_weight),
|
||||
deconv_weight: tf.Variable(tf.random_normal([9, 9, 1, s], stddev=0.0001, dtype=tf.float32), name=deconv_weight)
|
||||
'w1': tf.Variable(tf.random_normal([5, 5, 1, d], stddev=0.0378, dtype=tf.float32), name='w1'),
|
||||
'w2': tf.Variable(tf.random_normal([1, 1, d, s], stddev=0.3536, dtype=tf.float32), name='w2'),
|
||||
expand_weight: tf.Variable(tf.random_normal([1, 1, s, d], stddev=0.189, dtype=tf.float32), name=expand_weight),
|
||||
deconv_weight: tf.Variable(tf.random_normal([9, 9, 1, d], stddev=0.0001, dtype=tf.float32), name=deconv_weight)
|
||||
}
|
||||
|
||||
expand_bias, deconv_bias = 'b{}'.format(m + 3), 'b{}'.format(m + 4)
|
||||
self.biases = {
|
||||
'b1': tf.Variable(tf.zeros([s]), name='b1'),
|
||||
'b2': tf.Variable(tf.zeros([d]), name='b2'),
|
||||
expand_bias: tf.Variable(tf.zeros([s]), name=expand_bias),
|
||||
'b1': tf.Variable(tf.zeros([d]), name='b1'),
|
||||
'b2': tf.Variable(tf.zeros([s]), name='b2'),
|
||||
expand_bias: tf.Variable(tf.zeros([d]), name=expand_bias),
|
||||
deconv_bias: tf.Variable(tf.zeros([1]), name=deconv_bias)
|
||||
}
|
||||
|
||||
@@ -82,8 +84,8 @@ class FSRCNN(object):
|
||||
# Create the m mapping layers weights/biases
|
||||
for i in range(3, m + 3):
|
||||
weight_name, bias_name = 'w{}'.format(i), 'b{}'.format(i)
|
||||
self.weights[weight_name] = tf.Variable(tf.random_normal([3, 3, d, d], stddev=0.1179, dtype=tf.float32), name=weight_name)
|
||||
self.biases[bias_name] = tf.Variable(tf.zeros([d]), name=bias_name)
|
||||
self.weights[weight_name] = tf.Variable(tf.random_normal([3, 3, s, s], stddev=0.1179, dtype=tf.float32), name=weight_name)
|
||||
self.biases[bias_name] = tf.Variable(tf.zeros([s]), name=bias_name)
|
||||
|
||||
self.pred = self.model()
|
||||
|
||||
@@ -94,9 +96,9 @@ class FSRCNN(object):
|
||||
self.saver = tf.train.Saver()
|
||||
|
||||
def run(self):
|
||||
self.train_op = tf.train.AdamOptimizer().minimize(self.loss)
|
||||
self.train_op = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss)
|
||||
|
||||
tf.initialize_all_variables().run()
|
||||
tf.global_variables_initializer().run()
|
||||
|
||||
if self.load(self.checkpoint_dir):
|
||||
print(" [*] Load SUCCESS")
|
||||
@@ -104,8 +106,8 @@ class FSRCNN(object):
|
||||
print(" [!] Load failed...")
|
||||
|
||||
if self.params:
|
||||
s, d, m = self.model_params
|
||||
save_params(self.sess, self.weights, self.biases, self.alphas, s, d, m)
|
||||
d, s, m = self.model_params
|
||||
save_params(self.sess, d, s, m)
|
||||
elif self.train:
|
||||
self.run_train()
|
||||
else:
|
||||
@@ -128,11 +130,11 @@ class FSRCNN(object):
|
||||
start_time = time.time()
|
||||
start_average, end_average, counter = 0, 0, 0
|
||||
|
||||
for ep in xrange(self.epoch):
|
||||
for ep in range(self.epoch):
|
||||
# Run by batch images
|
||||
batch_idxs = len(train_data) // self.batch_size
|
||||
batch_average = 0
|
||||
for idx in xrange(0, batch_idxs):
|
||||
for idx in range(0, batch_idxs):
|
||||
batch_images = train_data[idx * self.batch_size : (idx + 1) * self.batch_size]
|
||||
batch_labels = train_label[idx * self.batch_size : (idx + 1) * self.batch_size]
|
||||
|
||||
@@ -178,7 +180,7 @@ class FSRCNN(object):
|
||||
result = self.pred.eval({self.images: test_data, self.labels: test_label, self.batch: nx * ny})
|
||||
print("Took %.3f seconds" % (time.time() - start_time))
|
||||
|
||||
result = merge(result, [nx, ny])
|
||||
result = merge(result, [nx, ny, self.c_dim])
|
||||
result = result.squeeze()
|
||||
image_path = os.path.join(os.getcwd(), self.output_dir)
|
||||
image_path = os.path.join(image_path, "test_image.png")
|
||||
@@ -244,4 +246,4 @@ class FSRCNN(object):
|
||||
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
Before Width: | Height: | Size: 35 KiB |
BIN
result/ewa_lanczos.png
Normal file
|
After Width: | Height: | Size: 112 KiB |
|
Before Width: | Height: | Size: 48 KiB After Width: | Height: | Size: 120 KiB |
|
Before Width: | Height: | Size: 42 KiB |
140
utils.py
@@ -5,13 +5,11 @@ Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread func
|
||||
import os
|
||||
import glob
|
||||
import h5py
|
||||
import random
|
||||
from math import floor
|
||||
from math import ceil
|
||||
import struct
|
||||
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
from scipy.misc import imread
|
||||
import numpy as np
|
||||
from multiprocessing import Pool, Lock, active_children
|
||||
|
||||
@@ -45,18 +43,18 @@ def preprocess(path, scale=3):
|
||||
|
||||
image = Image.open(path).convert('L')
|
||||
(width, height) = image.size
|
||||
label_ = np.array(list(image.getdata())).astype(np.float).reshape((height, width)) / 255
|
||||
label_ = np.fromstring(image.tobytes(), dtype=np.uint8).reshape((height, width)) / 255
|
||||
image.close()
|
||||
|
||||
cropped_image = Image.fromarray(modcrop(label_, scale))
|
||||
|
||||
|
||||
(width, height) = cropped_image.size
|
||||
new_width, new_height = int(width / scale), int(height / scale)
|
||||
scaled_image = cropped_image.resize((new_width, new_height), Image.ANTIALIAS)
|
||||
scaled_image = cropped_image.resize((new_width, new_height), Image.BICUBIC)
|
||||
cropped_image.close()
|
||||
|
||||
(width, height) = scaled_image.size
|
||||
input_ = np.array(list(scaled_image.getdata())).astype(np.float).reshape((height, width))
|
||||
input_ = np.array(scaled_image.getdata()).astype(np.float).reshape((height, width))
|
||||
|
||||
return input_, label_
|
||||
|
||||
@@ -90,16 +88,6 @@ def make_data(sess, checkpoint_dir, data, label):
|
||||
hf.create_dataset('data', data=data)
|
||||
hf.create_dataset('label', data=label)
|
||||
|
||||
def image_read(path, is_grayscale=True):
|
||||
"""
|
||||
Read image using its path.
|
||||
Default value is gray-scale, and image is read by YCbCr format as the paper said.
|
||||
"""
|
||||
if is_grayscale:
|
||||
return imread(path, flatten=True, mode='YCbCr').astype(np.float)
|
||||
else:
|
||||
return imread(path, mode='YCbCr').astype(np.float)
|
||||
|
||||
def modcrop(image, scale=3):
|
||||
"""
|
||||
To scale down and up the original image, first thing to do is to have no remainder while scaling operation.
|
||||
@@ -122,11 +110,11 @@ def modcrop(image, scale=3):
|
||||
|
||||
def train_input_worker(args):
|
||||
image_data, config = args
|
||||
image_size, label_size, stride, scale, save_image = config
|
||||
image_size, label_size, stride, scale = config
|
||||
|
||||
single_input_sequence, single_label_sequence = [], []
|
||||
padding = abs(image_size - label_size) / 2 # eg. for 3x: (21 - 11) / 2 = 5
|
||||
label_padding = label_size / scale # eg. for 3x: 21 / 3 = 7
|
||||
padding = abs(image_size - label_size) // 2 # eg. for 3x: (21 - 11) / 2 = 5
|
||||
label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7
|
||||
|
||||
input_, label_ = preprocess(image_data, scale)
|
||||
|
||||
@@ -167,8 +155,8 @@ def thread_train_setup(config):
|
||||
pool = Pool(config.threads)
|
||||
|
||||
# Distribute |images_per_thread| images across each worker process
|
||||
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.save_image]
|
||||
images_per_thread = len(data) / config.threads
|
||||
config_values = [config.image_size, config.label_size, config.stride, config.scale]
|
||||
images_per_thread = len(data) // config.threads
|
||||
workers = []
|
||||
for thread in range(config.threads):
|
||||
args_list = [(data[i], config_values) for i in range(thread * images_per_thread, (thread + 1) * images_per_thread)]
|
||||
@@ -210,10 +198,10 @@ def train_input_setup(config):
|
||||
data = prepare_data(sess, dataset=config.data_dir)
|
||||
|
||||
sub_input_sequence, sub_label_sequence = [], []
|
||||
padding = abs(image_size - label_size) / 2 # eg. for 3x: (21 - 11) / 2 = 5
|
||||
label_padding = label_size / scale # eg. for 3x: 21 / 3 = 7
|
||||
padding = abs(image_size - label_size) // 2 # eg. for 3x: (21 - 11) / 2 = 5
|
||||
label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7
|
||||
|
||||
for i in xrange(len(data)):
|
||||
for i in range(len(data)):
|
||||
input_, label_ = preprocess(data[i], scale)
|
||||
|
||||
if len(input_.shape) == 3:
|
||||
@@ -250,8 +238,8 @@ def test_input_setup(config):
|
||||
data = prepare_data(sess, dataset="Test")
|
||||
|
||||
sub_input_sequence, sub_label_sequence = [], []
|
||||
padding = abs(image_size - label_size) / 2 # eg. (21 - 11) / 2 = 5
|
||||
label_padding = label_size / scale # eg. 21 / 3 = 7
|
||||
padding = abs(image_size - label_size) // 2 # eg. (21 - 11) / 2 = 5
|
||||
label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7
|
||||
|
||||
pic_index = 2 # Index of image based on lexicographic order in data folder
|
||||
input_, label_ = preprocess(data[pic_index], config.scale)
|
||||
@@ -273,7 +261,7 @@ def test_input_setup(config):
|
||||
|
||||
sub_input = sub_input.reshape([image_size, image_size, 1])
|
||||
sub_label = sub_label.reshape([label_size, label_size, 1])
|
||||
|
||||
|
||||
sub_input_sequence.append(sub_input)
|
||||
sub_label_sequence.append(sub_label)
|
||||
|
||||
@@ -284,39 +272,36 @@ def test_input_setup(config):
|
||||
|
||||
return nx, ny
|
||||
|
||||
# You can ignore, I just wanted to see how much space all the parameters would take up
|
||||
def save_params(sess, weights, biases, alphas, s, d, m):
|
||||
def save_params(sess, d, s, m):
|
||||
param_dir = "params/"
|
||||
|
||||
if not os.path.exists(param_dir):
|
||||
os.makedirs(param_dir)
|
||||
|
||||
h = open(param_dir + "weights{}_{}_{}.txt".format(s, d, m), 'w')
|
||||
h = open(param_dir + "weights{}_{}_{}.txt".format(d, s, m), 'w')
|
||||
|
||||
for layer in weights:
|
||||
h.write("{} =\n [".format(layer))
|
||||
layer_weights = sess.run(weights[layer])
|
||||
sep = False
|
||||
variables = dict((var.name, sess.run(var)) for var in tf.trainable_variables())
|
||||
|
||||
for filter_x in range(len(layer_weights)):
|
||||
for filter_y in range(len(layer_weights[filter_x])):
|
||||
filter_weights = layer_weights[filter_x][filter_y]
|
||||
for input_channel in range(len(filter_weights)):
|
||||
for output_channel in range(len(filter_weights[input_channel])):
|
||||
val = filter_weights[input_channel][output_channel]
|
||||
if sep:
|
||||
h.write(', ')
|
||||
h.write("{}".format(val))
|
||||
sep = True
|
||||
h.write("\n ")
|
||||
for name, weights in variables.items():
|
||||
h.write("{} =\n".format(name[:name.index(':')]))
|
||||
|
||||
h.write("]\n\n")
|
||||
|
||||
for layer, tensor in biases.items() + alphas.items():
|
||||
h.write("{} = [".format(layer))
|
||||
vals = sess.run(tensor)
|
||||
h.write(",".join(map(str, vals)))
|
||||
h.write("]\n")
|
||||
if len(weights.shape) < 4:
|
||||
h.write("{}\n\n".format(weights.flatten().tolist()))
|
||||
else:
|
||||
h.write("[")
|
||||
sep = False
|
||||
for filter_x in range(len(weights)):
|
||||
for filter_y in range(len(weights[filter_x])):
|
||||
filter_weights = weights[filter_x][filter_y]
|
||||
for input_channel in range(len(filter_weights)):
|
||||
for output_channel in range(len(filter_weights[input_channel])):
|
||||
val = filter_weights[input_channel][output_channel]
|
||||
if sep:
|
||||
h.write(', ')
|
||||
h.write("{}".format(val))
|
||||
sep = True
|
||||
h.write("\n ")
|
||||
h.write("]\n\n")
|
||||
|
||||
h.close()
|
||||
|
||||
@@ -325,7 +310,7 @@ def merge(images, size):
|
||||
Merges sub-images back into original image size
|
||||
"""
|
||||
h, w = images.shape[1], images.shape[2]
|
||||
img = np.zeros((h * size[0], w * size[1], 1))
|
||||
img = np.zeros((h * size[0], w * size[1], size[2]))
|
||||
for idx, image in enumerate(images):
|
||||
i = idx % size[1]
|
||||
j = idx // size[1]
|
||||
@@ -361,25 +346,25 @@ def _tf_fspecial_gauss(size, sigma):
|
||||
return g / tf.reduce_sum(g)
|
||||
|
||||
|
||||
def tf_ssim(img1, img2, cs_map=False, mean_metric=True, size=11, sigma=1.5):
|
||||
window = _tf_fspecial_gauss(size, sigma) # window shape [size, size]
|
||||
def tf_ssim(img1, img2, cs_map=False, mean_metric=True, sigma=1.5):
|
||||
size = int(sigma * 3) * 2 + 1
|
||||
window = _tf_fspecial_gauss(size, sigma)
|
||||
K1 = 0.01
|
||||
K2 = 0.03
|
||||
L = 1 # depth of image (255 in case the image has a differnt scale)
|
||||
C1 = (K1*L)**2
|
||||
C2 = (K2*L)**2
|
||||
mu1 = tf.nn.conv2d(img1, window, strides=[1,1,1,1], padding='VALID')
|
||||
mu2 = tf.nn.conv2d(img2, window, strides=[1,1,1,1],padding='VALID')
|
||||
mu1 = tf.nn.conv2d(img1, window, strides=[1,1,1,1], padding='VALID', data_format='NHWC')
|
||||
mu2 = tf.nn.conv2d(img2, window, strides=[1,1,1,1], padding='VALID', data_format='NHWC')
|
||||
mu1_sq = mu1*mu1
|
||||
mu2_sq = mu2*mu2
|
||||
mu1_mu2 = mu1*mu2
|
||||
sigma1_sq = tf.nn.conv2d(img1*img1, window, strides=[1,1,1,1],padding='VALID') - mu1_sq
|
||||
sigma2_sq = tf.nn.conv2d(img2*img2, window, strides=[1,1,1,1],padding='VALID') - mu2_sq
|
||||
sigma12 = tf.nn.conv2d(img1*img2, window, strides=[1,1,1,1],padding='VALID') - mu1_mu2
|
||||
sigma1_sq = tf.abs(tf.nn.conv2d(img1*img1, window, strides=[1,1,1,1], padding='VALID', data_format='NHWC') - mu1_sq)
|
||||
sigma2_sq = tf.abs(tf.nn.conv2d(img2*img2, window, strides=[1,1,1,1], padding='VALID', data_format='NHWC') - mu2_sq)
|
||||
sigma12 = tf.nn.conv2d(img1*img2, window, strides=[1,1,1,1], padding='VALID', data_format='NHWC') - mu1_mu2
|
||||
|
||||
if cs_map:
|
||||
value = (((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*
|
||||
(sigma1_sq + sigma2_sq + C2)),
|
||||
(2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2))
|
||||
value = (2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2)
|
||||
else:
|
||||
value = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
@@ -389,26 +374,15 @@ def tf_ssim(img1, img2, cs_map=False, mean_metric=True, size=11, sigma=1.5):
|
||||
return value
|
||||
|
||||
|
||||
def tf_ms_ssim(img1, img2, mean_metric=True, level=5):
|
||||
weight = tf.constant([0.0448, 0.2856, 0.3001, 0.2363, 0.1333], dtype=tf.float32)
|
||||
def tf_ms_ssim(img1, img2, sigma=1.5, weights=[0.1, 0.9]):
|
||||
weights = weights / np.sum(weights)
|
||||
window = _tf_fspecial_gauss(5, 1)
|
||||
mssim = []
|
||||
mcs = []
|
||||
for l in range(level):
|
||||
ssim_map, cs_map = tf_ssim(img1, img2, cs_map=True, mean_metric=False)
|
||||
mssim.append(tf.reduce_mean(ssim_map))
|
||||
mcs.append(tf.reduce_mean(cs_map))
|
||||
filtered_im1 = tf.nn.avg_pool(img1, [1,2,2,1], [1,2,2,1], padding='SAME')
|
||||
filtered_im2 = tf.nn.avg_pool(img2, [1,2,2,1], [1,2,2,1], padding='SAME')
|
||||
img1 = filtered_im1
|
||||
img2 = filtered_im2
|
||||
for i in range(len(weights)):
|
||||
mssim.append(tf_ssim(img1, img2, sigma=sigma))
|
||||
img1 = tf.nn.conv2d(img1, window, [1,2,2,1], 'VALID')
|
||||
img2 = tf.nn.conv2d(img2, window, [1,2,2,1], 'VALID')
|
||||
|
||||
# list to tensor of dim D+1
|
||||
mssim = tf.stack(mssim, axis=0)
|
||||
mcs = tf.stack(mcs, axis=0)
|
||||
value = tf.reduce_sum(tf.multiply(tf.stack(mssim), weights))
|
||||
|
||||
value = (tf.reduce_prod(mcs[0:level-1]**weight[0:level-1])*
|
||||
(mssim[level-1]**weight[level-1]))
|
||||
|
||||
if mean_metric:
|
||||
value = tf.reduce_mean(value)
|
||||
return value
|
||||
|
||||