mirror of
https://github.com/igv/FSRCNN-TensorFlow.git
synced 2026-02-06 15:11:56 +08:00
Don't divide test image into patches, don't remove colors and calculate PSNR and SSIM.
This commit is contained in:
1
ESPCN.py
1
ESPCN.py
@@ -12,7 +12,6 @@ class Model(object):
|
||||
self.images = config.images
|
||||
self.batch = config.batch
|
||||
self.label_size = config.label_size
|
||||
self.c_dim = config.c_dim
|
||||
|
||||
def model(self):
|
||||
d = self.model_params
|
||||
|
||||
@@ -15,7 +15,6 @@ class Model(object):
|
||||
self.batch = config.batch
|
||||
self.image_size = config.image_size - self.padding
|
||||
self.label_size = config.label_size
|
||||
self.c_dim = config.c_dim
|
||||
|
||||
def model(self):
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ class Model(object):
|
||||
self.batch = config.batch
|
||||
self.image_size = config.image_size - self.padding
|
||||
self.label_size = config.label_size
|
||||
self.c_dim = config.c_dim
|
||||
|
||||
def model(self):
|
||||
d, m, r = self.model_params
|
||||
|
||||
@@ -3,7 +3,7 @@ TensorFlow implementation of the Fast Super-Resolution Convolutional Neural Netw
|
||||
|
||||
## Prerequisites
|
||||
* Python 3
|
||||
* TensorFlow-gpu >= 1.3
|
||||
* TensorFlow-gpu >= 1.8
|
||||
* CUDA & cuDNN >= 6.0
|
||||
* Pillow
|
||||
* ImageMagick (optional)
|
||||
|
||||
1
main.py
1
main.py
@@ -12,7 +12,6 @@ flags.DEFINE_boolean("fast", False, "Use the fast model (FSRCNN-s) [False]")
|
||||
flags.DEFINE_integer("epoch", 10, "Number of epochs [10]")
|
||||
flags.DEFINE_integer("batch_size", 32, "The size of batch images [32]")
|
||||
flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of the adam optimizer [1e-4]")
|
||||
flags.DEFINE_integer("c_dim", 1, "Dimension of image color [1]")
|
||||
flags.DEFINE_integer("scale", 2, "The size of scale factor for preprocessing input image [2]")
|
||||
flags.DEFINE_integer("radius", 1, "Max radius of the deconvolution input tensor [1]")
|
||||
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
|
||||
|
||||
26
model.py
26
model.py
@@ -26,8 +26,6 @@ class Model(object):
|
||||
self.arch = config.arch
|
||||
self.fast = config.fast
|
||||
self.train = config.train
|
||||
self.c_dim = config.c_dim
|
||||
self.is_grayscale = (self.c_dim == 1)
|
||||
self.epoch = config.epoch
|
||||
self.scale = config.scale
|
||||
self.radius = config.radius
|
||||
@@ -51,8 +49,12 @@ class Model(object):
|
||||
|
||||
|
||||
def init_model(self):
|
||||
self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images')
|
||||
self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')
|
||||
if self.train:
|
||||
self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, 1], name='images')
|
||||
self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, 1], name='labels')
|
||||
else:
|
||||
self.images = tf.placeholder(tf.float32, [None, None, None, 1], name='images')
|
||||
self.labels = tf.placeholder(tf.float32, [None, None, None, 1], name='labels')
|
||||
# Batch size differs in training vs testing
|
||||
self.batch = tf.placeholder(tf.int32, shape=[], name='batch')
|
||||
|
||||
@@ -154,20 +156,24 @@ class Model(object):
|
||||
|
||||
|
||||
def run_test(self):
|
||||
test_data, test_label, nx, ny = test_input_setup(self)
|
||||
test_data, test_label = test_input_setup(self)
|
||||
|
||||
print("Testing...")
|
||||
|
||||
start_time = time.time()
|
||||
result = self.pred.eval({self.images: test_data, self.labels: test_label, self.batch: nx * ny})
|
||||
print("Took %.3f seconds" % (time.time() - start_time))
|
||||
result = np.clip(self.pred.eval({self.images: test_data, self.labels: test_label, self.batch: 1}), 0, 1)
|
||||
passed = time.time() - start_time
|
||||
img1 = tf.convert_to_tensor(test_label, dtype=tf.float32)
|
||||
img2 = tf.convert_to_tensor(result, dtype=tf.float32)
|
||||
psnr = self.sess.run(tf.image.psnr(img1, img2, 1))
|
||||
ssim = self.sess.run(tf.image.ssim(img1, img2, 1))
|
||||
print("Took %.3f seconds, PSNR: %.6f, SSIM: %.6f" % (passed, psnr, ssim))
|
||||
|
||||
result = merge(result, [nx, ny, self.c_dim])
|
||||
result = result.squeeze()
|
||||
result = merge(self, result)
|
||||
image_path = os.path.join(os.getcwd(), self.output_dir)
|
||||
image_path = os.path.join(image_path, "test_image.png")
|
||||
|
||||
array_image_save(result * 255, image_path)
|
||||
array_image_save(result, image_path)
|
||||
|
||||
def save(self, step):
|
||||
model_name = self.model.name + ".model"
|
||||
|
||||
66
utils.py
66
utils.py
@@ -240,45 +240,48 @@ def train_input_setup(config):
|
||||
return (arrdata, arrlabel)
|
||||
|
||||
def test_input_setup(config):
|
||||
"""
|
||||
Read image files, make their sub-images, and save them as a h5 file format.
|
||||
"""
|
||||
sess = config.sess
|
||||
image_size, label_size, stride, scale, padding = config.image_size, config.label_size, config.stride, config.scale, config.padding // 2
|
||||
|
||||
# Load data path
|
||||
data = prepare_data(sess, dataset="Test")
|
||||
|
||||
sub_input_sequence, sub_label_sequence = [], []
|
||||
|
||||
pic_index = 2 # Index of image based on lexicographic order in data folder
|
||||
input_, label_ = preprocess(data[pic_index], config.scale)
|
||||
input_, label_ = preprocess(data[2], config.scale)
|
||||
|
||||
if len(input_.shape) == 3:
|
||||
h, w, _ = input_.shape
|
||||
else:
|
||||
h, w = input_.shape
|
||||
|
||||
nx, ny = 0, 0
|
||||
for x in range(0, h - image_size + 1, stride):
|
||||
nx += 1
|
||||
ny = 0
|
||||
for y in range(0, w - image_size + 1, stride):
|
||||
ny += 1
|
||||
sub_input = input_[x : x + image_size, y : y + image_size]
|
||||
x_loc, y_loc = x + padding, y + padding
|
||||
sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]
|
||||
arrdata = np.pad(input_.reshape([1, h, w, 1]), ((0,0),(2,2),(2,2),(0,0)), 'reflect')
|
||||
|
||||
sub_input = sub_input.reshape([image_size, image_size, 1])
|
||||
sub_label = sub_label.reshape([label_size, label_size, 1])
|
||||
if len(label_.shape) == 3:
|
||||
h, w, _ = label_.shape
|
||||
else:
|
||||
h, w = label_.shape
|
||||
|
||||
sub_input_sequence.append(sub_input)
|
||||
sub_label_sequence.append(sub_label)
|
||||
arrlabel = label_.reshape([1, h, w, 1])
|
||||
|
||||
arrdata = np.asarray(sub_input_sequence)
|
||||
arrlabel = np.asarray(sub_label_sequence)
|
||||
return (arrdata, arrlabel)
|
||||
|
||||
return (arrdata, arrlabel, nx, ny)
|
||||
def merge(config, Y):
|
||||
"""
|
||||
Merges super-resolved image with chroma components
|
||||
"""
|
||||
h, w = Y.shape[1], Y.shape[2]
|
||||
Y = Y.reshape(h, w, 1) * 255
|
||||
Y = Y.round().astype(np.uint8)
|
||||
|
||||
data = prepare_data(config.sess, dataset="Test")
|
||||
src = Image.open(data[2]).convert('YCbCr')
|
||||
(width, height) = src.size
|
||||
if downsample is False:
|
||||
src = src.resize((width * config.scale, height * config.scale), Image.BICUBIC)
|
||||
(width, height) = src.size
|
||||
CbCr = np.frombuffer(src.tobytes(), dtype=np.uint8).reshape(height, width, 3)[:,:,1:]
|
||||
|
||||
img = np.concatenate((Y, CbCr), axis=-1)
|
||||
|
||||
return img
|
||||
|
||||
def save_params(sess, params):
|
||||
param_dir = "params/"
|
||||
@@ -313,24 +316,11 @@ def save_params(sess, params):
|
||||
|
||||
h.close()
|
||||
|
||||
def merge(images, size):
|
||||
"""
|
||||
Merges sub-images back into original image size
|
||||
"""
|
||||
h, w = images.shape[1], images.shape[2]
|
||||
img = np.zeros((h * size[0], w * size[1], size[2]))
|
||||
for idx, image in enumerate(images):
|
||||
i = idx % size[1]
|
||||
j = idx // size[1]
|
||||
img[j*h:j*h+h, i*w:i*w+w, :] = image
|
||||
|
||||
return img
|
||||
|
||||
def array_image_save(array, image_path):
|
||||
"""
|
||||
Converts np array to image and saves it
|
||||
"""
|
||||
image = Image.fromarray(array)
|
||||
image = Image.fromarray(array, 'YCbCr')
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
image.save(image_path)
|
||||
|
||||
Reference in New Issue
Block a user