Don't divide test image into patches, don't remove colors and calculate PSNR and SSIM.

This commit is contained in:
igv
2018-08-31 16:21:39 +03:00
parent c096b67069
commit 45c629e5c7
7 changed files with 45 additions and 53 deletions

View File

@@ -12,7 +12,6 @@ class Model(object):
self.images = config.images
self.batch = config.batch
self.label_size = config.label_size
self.c_dim = config.c_dim
def model(self):
d = self.model_params

View File

@@ -15,7 +15,6 @@ class Model(object):
self.batch = config.batch
self.image_size = config.image_size - self.padding
self.label_size = config.label_size
self.c_dim = config.c_dim
def model(self):

View File

@@ -15,7 +15,6 @@ class Model(object):
self.batch = config.batch
self.image_size = config.image_size - self.padding
self.label_size = config.label_size
self.c_dim = config.c_dim
def model(self):
d, m, r = self.model_params

View File

@@ -3,7 +3,7 @@ TensorFlow implementation of the Fast Super-Resolution Convolutional Neural Netw
## Prerequisites
* Python 3
* TensorFlow-gpu >= 1.3
* TensorFlow-gpu >= 1.8
* CUDA & cuDNN >= 6.0
* Pillow
* ImageMagick (optional)

View File

@@ -12,7 +12,6 @@ flags.DEFINE_boolean("fast", False, "Use the fast model (FSRCNN-s) [False]")
flags.DEFINE_integer("epoch", 10, "Number of epochs [10]")
flags.DEFINE_integer("batch_size", 32, "The size of batch images [32]")
flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of the adam optimizer [1e-4]")
flags.DEFINE_integer("c_dim", 1, "Dimension of image color [1]")
flags.DEFINE_integer("scale", 2, "The size of scale factor for preprocessing input image [2]")
flags.DEFINE_integer("radius", 1, "Max radius of the deconvolution input tensor [1]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")

View File

@@ -26,8 +26,6 @@ class Model(object):
self.arch = config.arch
self.fast = config.fast
self.train = config.train
self.c_dim = config.c_dim
self.is_grayscale = (self.c_dim == 1)
self.epoch = config.epoch
self.scale = config.scale
self.radius = config.radius
@@ -51,8 +49,12 @@ class Model(object):
def init_model(self):
self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images')
self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')
if self.train:
self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, 1], name='images')
self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, 1], name='labels')
else:
self.images = tf.placeholder(tf.float32, [None, None, None, 1], name='images')
self.labels = tf.placeholder(tf.float32, [None, None, None, 1], name='labels')
# Batch size differs in training vs testing
self.batch = tf.placeholder(tf.int32, shape=[], name='batch')
@@ -154,20 +156,24 @@ class Model(object):
def run_test(self):
test_data, test_label, nx, ny = test_input_setup(self)
test_data, test_label = test_input_setup(self)
print("Testing...")
start_time = time.time()
result = self.pred.eval({self.images: test_data, self.labels: test_label, self.batch: nx * ny})
print("Took %.3f seconds" % (time.time() - start_time))
result = np.clip(self.pred.eval({self.images: test_data, self.labels: test_label, self.batch: 1}), 0, 1)
passed = time.time() - start_time
img1 = tf.convert_to_tensor(test_label, dtype=tf.float32)
img2 = tf.convert_to_tensor(result, dtype=tf.float32)
psnr = self.sess.run(tf.image.psnr(img1, img2, 1))
ssim = self.sess.run(tf.image.ssim(img1, img2, 1))
print("Took %.3f seconds, PSNR: %.6f, SSIM: %.6f" % (passed, psnr, ssim))
result = merge(result, [nx, ny, self.c_dim])
result = result.squeeze()
result = merge(self, result)
image_path = os.path.join(os.getcwd(), self.output_dir)
image_path = os.path.join(image_path, "test_image.png")
array_image_save(result * 255, image_path)
array_image_save(result, image_path)
def save(self, step):
model_name = self.model.name + ".model"

View File

@@ -240,45 +240,48 @@ def train_input_setup(config):
return (arrdata, arrlabel)
def test_input_setup(config):
"""
Read image files, make their sub-images, and save them as a h5 file format.
"""
sess = config.sess
image_size, label_size, stride, scale, padding = config.image_size, config.label_size, config.stride, config.scale, config.padding // 2
# Load data path
data = prepare_data(sess, dataset="Test")
sub_input_sequence, sub_label_sequence = [], []
pic_index = 2 # Index of image based on lexicographic order in data folder
input_, label_ = preprocess(data[pic_index], config.scale)
input_, label_ = preprocess(data[2], config.scale)
if len(input_.shape) == 3:
h, w, _ = input_.shape
else:
h, w = input_.shape
nx, ny = 0, 0
for x in range(0, h - image_size + 1, stride):
nx += 1
ny = 0
for y in range(0, w - image_size + 1, stride):
ny += 1
sub_input = input_[x : x + image_size, y : y + image_size]
x_loc, y_loc = x + padding, y + padding
sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]
arrdata = np.pad(input_.reshape([1, h, w, 1]), ((0,0),(2,2),(2,2),(0,0)), 'reflect')
sub_input = sub_input.reshape([image_size, image_size, 1])
sub_label = sub_label.reshape([label_size, label_size, 1])
if len(label_.shape) == 3:
h, w, _ = label_.shape
else:
h, w = label_.shape
sub_input_sequence.append(sub_input)
sub_label_sequence.append(sub_label)
arrlabel = label_.reshape([1, h, w, 1])
arrdata = np.asarray(sub_input_sequence)
arrlabel = np.asarray(sub_label_sequence)
return (arrdata, arrlabel)
return (arrdata, arrlabel, nx, ny)
def merge(config, Y):
"""
Merges super-resolved image with chroma components
"""
h, w = Y.shape[1], Y.shape[2]
Y = Y.reshape(h, w, 1) * 255
Y = Y.round().astype(np.uint8)
data = prepare_data(config.sess, dataset="Test")
src = Image.open(data[2]).convert('YCbCr')
(width, height) = src.size
if downsample is False:
src = src.resize((width * config.scale, height * config.scale), Image.BICUBIC)
(width, height) = src.size
CbCr = np.frombuffer(src.tobytes(), dtype=np.uint8).reshape(height, width, 3)[:,:,1:]
img = np.concatenate((Y, CbCr), axis=-1)
return img
def save_params(sess, params):
param_dir = "params/"
@@ -313,24 +316,11 @@ def save_params(sess, params):
h.close()
def merge(images, size):
"""
Merges sub-images back into original image size
"""
h, w = images.shape[1], images.shape[2]
img = np.zeros((h * size[0], w * size[1], size[2]))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j*h:j*h+h, i*w:i*w+w, :] = image
return img
def array_image_save(array, image_path):
"""
Converts np array to image and saves it
"""
image = Image.fromarray(array)
image = Image.fromarray(array, 'YCbCr')
if image.mode != 'RGB':
image = image.convert('RGB')
image.save(image_path)