diff --git a/ESPCN.py b/ESPCN.py index f0273fe..520272b 100644 --- a/ESPCN.py +++ b/ESPCN.py @@ -12,7 +12,6 @@ class Model(object): self.images = config.images self.batch = config.batch self.label_size = config.label_size - self.c_dim = config.c_dim def model(self): d = self.model_params diff --git a/FSRCNN.py b/FSRCNN.py index 7406ca2..7aa7994 100644 --- a/FSRCNN.py +++ b/FSRCNN.py @@ -15,7 +15,6 @@ class Model(object): self.batch = config.batch self.image_size = config.image_size - self.padding self.label_size = config.label_size - self.c_dim = config.c_dim def model(self): diff --git a/LapSRN.py b/LapSRN.py index 16a959a..11be706 100644 --- a/LapSRN.py +++ b/LapSRN.py @@ -15,7 +15,6 @@ class Model(object): self.batch = config.batch self.image_size = config.image_size - self.padding self.label_size = config.label_size - self.c_dim = config.c_dim def model(self): d, m, r = self.model_params diff --git a/README.md b/README.md index 42476d7..67525b5 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ TensorFlow implementation of the Fast Super-Resolution Convolutional Neural Netw ## Prerequisites * Python 3 - * TensorFlow-gpu >= 1.3 + * TensorFlow-gpu >= 1.8 * CUDA & cuDNN >= 6.0 * Pillow * ImageMagick (optional) diff --git a/main.py b/main.py index 5a710ce..97a03d3 100644 --- a/main.py +++ b/main.py @@ -12,7 +12,6 @@ flags.DEFINE_boolean("fast", False, "Use the fast model (FSRCNN-s) [False]") flags.DEFINE_integer("epoch", 10, "Number of epochs [10]") flags.DEFINE_integer("batch_size", 32, "The size of batch images [32]") flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of the adam optimizer [1e-4]") -flags.DEFINE_integer("c_dim", 1, "Dimension of image color [1]") flags.DEFINE_integer("scale", 2, "The size of scale factor for preprocessing input image [2]") flags.DEFINE_integer("radius", 1, "Max radius of the deconvolution input tensor [1]") flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]") diff --git a/model.py b/model.py index 5873636..5db5488 100644 --- a/model.py +++ b/model.py @@ -26,8 +26,6 @@ class Model(object): self.arch = config.arch self.fast = config.fast self.train = config.train - self.c_dim = config.c_dim - self.is_grayscale = (self.c_dim == 1) self.epoch = config.epoch self.scale = config.scale self.radius = config.radius @@ -51,8 +49,12 @@ class Model(object): def init_model(self): - self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images') - self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels') + if self.train: + self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, 1], name='images') + self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, 1], name='labels') + else: + self.images = tf.placeholder(tf.float32, [None, None, None, 1], name='images') + self.labels = tf.placeholder(tf.float32, [None, None, None, 1], name='labels') # Batch size differs in training vs testing self.batch = tf.placeholder(tf.int32, shape=[], name='batch') @@ -154,20 +156,24 @@ class Model(object): def run_test(self): - test_data, test_label, nx, ny = test_input_setup(self) + test_data, test_label = test_input_setup(self) print("Testing...") start_time = time.time() - result = self.pred.eval({self.images: test_data, self.labels: test_label, self.batch: nx * ny}) - print("Took %.3f seconds" % (time.time() - start_time)) + result = np.clip(self.pred.eval({self.images: test_data, self.labels: test_label, self.batch: 1}), 0, 1) + passed = time.time() - start_time + img1 = tf.convert_to_tensor(test_label, dtype=tf.float32) + img2 = tf.convert_to_tensor(result, dtype=tf.float32) + psnr = self.sess.run(tf.image.psnr(img1, img2, 1)) + ssim = self.sess.run(tf.image.ssim(img1, img2, 1)) + print("Took %.3f seconds, PSNR: %.6f, SSIM: %.6f" % (passed, psnr, ssim)) - result = merge(result, [nx, ny, self.c_dim]) - result = result.squeeze() + result = merge(self, result) image_path = os.path.join(os.getcwd(), self.output_dir) image_path = os.path.join(image_path, "test_image.png") - array_image_save(result * 255, image_path) + array_image_save(result, image_path) def save(self, step): model_name = self.model.name + ".model" diff --git a/utils.py b/utils.py index 609fffa..94382a7 100644 --- a/utils.py +++ b/utils.py @@ -240,45 +240,48 @@ def train_input_setup(config): return (arrdata, arrlabel) def test_input_setup(config): - """ - Read image files, make their sub-images, and save them as a h5 file format. - """ sess = config.sess - image_size, label_size, stride, scale, padding = config.image_size, config.label_size, config.stride, config.scale, config.padding // 2 # Load data path data = prepare_data(sess, dataset="Test") - sub_input_sequence, sub_label_sequence = [], [] - - pic_index = 2 # Index of image based on lexicographic order in data folder - input_, label_ = preprocess(data[pic_index], config.scale) + input_, label_ = preprocess(data[2], config.scale) if len(input_.shape) == 3: h, w, _ = input_.shape else: h, w = input_.shape - nx, ny = 0, 0 - for x in range(0, h - image_size + 1, stride): - nx += 1 - ny = 0 - for y in range(0, w - image_size + 1, stride): - ny += 1 - sub_input = input_[x : x + image_size, y : y + image_size] - x_loc, y_loc = x + padding, y + padding - sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size] + arrdata = np.pad(input_.reshape([1, h, w, 1]), ((0,0),(2,2),(2,2),(0,0)), 'reflect') - sub_input = sub_input.reshape([image_size, image_size, 1]) - sub_label = sub_label.reshape([label_size, label_size, 1]) + if len(label_.shape) == 3: + h, w, _ = label_.shape + else: + h, w = label_.shape - sub_input_sequence.append(sub_input) - sub_label_sequence.append(sub_label) + arrlabel = label_.reshape([1, h, w, 1]) - arrdata = np.asarray(sub_input_sequence) - arrlabel = np.asarray(sub_label_sequence) + return (arrdata, arrlabel) - return (arrdata, arrlabel, nx, ny) +def merge(config, Y): + """ + Merges super-resolved image with chroma components + """ + h, w = Y.shape[1], Y.shape[2] + Y = Y.reshape(h, w, 1) * 255 + Y = Y.round().astype(np.uint8) + + data = prepare_data(config.sess, dataset="Test") + src = Image.open(data[2]).convert('YCbCr') + (width, height) = src.size + if downsample is False: + src = src.resize((width * config.scale, height * config.scale), Image.BICUBIC) + (width, height) = src.size + CbCr = np.frombuffer(src.tobytes(), dtype=np.uint8).reshape(height, width, 3)[:,:,1:] + + img = np.concatenate((Y, CbCr), axis=-1) + + return img def save_params(sess, params): param_dir = "params/" @@ -313,24 +316,11 @@ def save_params(sess, params): h.close() -def merge(images, size): - """ - Merges sub-images back into original image size - """ - h, w = images.shape[1], images.shape[2] - img = np.zeros((h * size[0], w * size[1], size[2])) - for idx, image in enumerate(images): - i = idx % size[1] - j = idx // size[1] - img[j*h:j*h+h, i*w:i*w+w, :] = image - - return img - def array_image_save(array, image_path): """ Converts np array to image and saves it """ - image = Image.fromarray(array) + image = Image.fromarray(array, 'YCbCr') if image.mode != 'RGB': image = image.convert('RGB') image.save(image_path)