From 15bc7987019698dee4f4677d4f17a47ff9b6011f Mon Sep 17 00:00:00 2001 From: igv Date: Thu, 27 Jul 2017 14:46:28 +0300 Subject: [PATCH] Option to add distortions to random low-res images --- main.py | 1 + model.py | 1 + utils.py | 24 +++++++++++++++++------- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 7c43b07..36a2ea5 100644 --- a/main.py +++ b/main.py @@ -18,6 +18,7 @@ flags.DEFINE_string("output_dir", "result", "Name of test output directory [resu flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]") flags.DEFINE_boolean("train", True, "True for training, false for testing [True]") flags.DEFINE_integer("threads", 1, "Number of processes to pre-process data with [1]") +flags.DEFINE_boolean("distort", False, "Distort some images with JPEG compression artifacts after downscaling [False]") flags.DEFINE_boolean("params", False, "Save weight and bias parameters [False]") FLAGS = flags.FLAGS diff --git a/model.py b/model.py index b5038ea..e573307 100644 --- a/model.py +++ b/model.py @@ -33,6 +33,7 @@ class FSRCNN(object): self.batch_size = config.batch_size self.learning_rate = config.learning_rate self.threads = config.threads + self.distort = config.distort self.params = config.params # Different image/label sub-sizes for different scaling factors x2, x3, x4 diff --git a/utils.py b/utils.py index 6502589..d11ff59 100644 --- a/utils.py +++ b/utils.py @@ -7,6 +7,8 @@ import glob import h5py from math import ceil import struct +import io +from random import randrange import tensorflow as tf from PIL import Image @@ -33,12 +35,12 @@ def read_data(path): label = np.array(hf.get('label')) return data, label -def preprocess(path, scale=3): +def preprocess(path, scale=3, distort=False): """ Preprocess single image file - (1) Read original image as YCbCr format (and grayscale as default) + (1) Read original image as YCbCr format (2) Normalize - (3) Downsampled by scale factor (using anti-aliasing) + (3) Downsampled by scale factor """ image = Image.open(path).convert('L') @@ -56,6 +58,14 @@ def preprocess(path, scale=3): (width, height) = scaled_image.size input_ = np.array(scaled_image.getdata()).astype(np.float).reshape((height, width)) + if randrange(3) == 2 and distort==True: + buf = io.BytesIO() + i = Image.fromarray(input_ * 255) + i.convert('RGB').save(buf, "JPEG", quality=randrange(50, 99, 5)) + buf.seek(0) + scaled_image = Image.open(buf).convert('L') + input_ = np.fromstring(scaled_image.tobytes(), dtype=np.uint8).reshape((height, width)) / 255 + return input_, label_ def prepare_data(sess, dataset): @@ -110,13 +120,13 @@ def modcrop(image, scale=3): def train_input_worker(args): image_data, config = args - image_size, label_size, stride, scale = config + image_size, label_size, stride, scale, distort = config single_input_sequence, single_label_sequence = [], [] padding = abs(image_size - label_size) // 2 # eg. for 3x: (21 - 11) / 2 = 5 label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7 - input_, label_ = preprocess(image_data, scale) + input_, label_ = preprocess(image_data, scale, distort=distort) if len(input_.shape) == 3: h, w, _ = input_.shape @@ -155,7 +165,7 @@ def thread_train_setup(config): pool = Pool(config.threads) # Distribute |images_per_thread| images across each worker process - config_values = [config.image_size, config.label_size, config.stride, config.scale] + config_values = [config.image_size, config.label_size, config.stride, config.scale, config.distort] images_per_thread = len(data) // config.threads workers = [] for thread in range(config.threads): @@ -202,7 +212,7 @@ def train_input_setup(config): label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7 for i in range(len(data)): - input_, label_ = preprocess(data[i], scale) + input_, label_ = preprocess(data[i], scale, distort=config.distort) if len(input_.shape) == 3: h, w, _ = input_.shape