diff --git a/discriminator.py b/discriminator.py new file mode 100644 index 0000000..86ab512 --- /dev/null +++ b/discriminator.py @@ -0,0 +1,31 @@ +import tensorflow as tf + +def discriminator(inputs, channels=64, is_training=True): + + def conv2(batch_input, num_outputs=64, kernel_size=[3, 3], stride=1, norm=True, scope=None): + return tf.contrib.layers.conv2d(batch_input, num_outputs, kernel_size, stride, 'SAME', 'NHWC', scope=scope, + activation_fn=tf.nn.leaky_relu, + normalizer_fn=tf.contrib.layers.instance_norm if norm else None, + normalizer_params={'center': False, 'scale': False, 'data_format': 'NHWC'}, + weights_initializer=tf.variance_scaling_initializer(scale=2.0)) + + with tf.device('/gpu:0'): + net = conv2(inputs, channels, kernel_size=[5, 5], stride=2, norm=False, scope='input_stage') + + net = conv2(net, channels*2, stride=2, scope='disblock_1') + + net = conv2(net, channels*4, stride=2, scope='disblock_2') + + net = conv2(net, channels*8, stride=2, scope='disblock_3') + + net = tf.layers.flatten(net) + + with tf.variable_scope('dense_layer_1'): + net = tf.layers.dense(net, channels*16, activation=tf.nn.leaky_relu, use_bias=False, + kernel_initializer=tf.variance_scaling_initializer(scale=2.0)) + + with tf.variable_scope('dense_layer_2'): + net = tf.layers.dense(net, 1, use_bias=False) + + return net + diff --git a/main.py b/main.py index be7ba29..e2378c6 100644 --- a/main.py +++ b/main.py @@ -18,6 +18,7 @@ flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint director flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]") flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]") flags.DEFINE_boolean("train", True, "True for training, false for testing [True]") +flags.DEFINE_boolean("adversarial", False, "Adversarial training [False]") flags.DEFINE_boolean("distort", False, "Distort some images with JPEG compression artifacts after downscaling [False]") flags.DEFINE_boolean("params", False, "Save weight and bias parameters [False]") diff --git a/model.py b/model.py index c4c3240..9c2f123 100644 --- a/model.py +++ b/model.py @@ -25,6 +25,7 @@ class Model(object): self.arch = config.arch self.fast = config.fast self.train = config.train + self.adversarial = config.adversarial self.epoch = config.epoch self.scale = config.scale self.radius = config.radius @@ -56,24 +57,54 @@ class Model(object): # Batch size differs in training vs testing self.batch = tf.placeholder(tf.int32, shape=[], name='batch') - model = importlib.import_module(self.arch) - self.model = model.Model(self) + with tf.variable_scope('generator'): + model = importlib.import_module(self.arch) + self.model = model.Model(self) - self.pred = self.model.model() + self.pred = self.model.model() model_dir = "%s_%s_%s_%s" % (self.model.name.lower(), self.label_size, '-'.join(str(i) for i in self.model.model_params), "r"+str(self.radius)) self.model_dir = os.path.join(self.checkpoint_dir, model_dir) self.loss = self.model.loss(self.labels, self.pred) - self.saver = tf.train.Saver() + if self.adversarial and self.train and not self.params: + from discriminator import discriminator + with tf.variable_scope('discriminator', reuse=False): + discrim_fake_output = discriminator(self.pred) + with tf.variable_scope('discriminator', reuse=True): + discrim_real_output = discriminator(self.labels) + + avg_fake = tf.reduce_mean(discrim_fake_output) + avg_real = tf.reduce_mean(discrim_real_output) + real_tilde = tf.nn.sigmoid(discrim_real_output - avg_fake) + fake_tilde = tf.nn.sigmoid(discrim_fake_output - avg_real) + self.discrim_loss = - tf.reduce_mean(tf.log(real_tilde + 1e-12)) - tf.reduce_mean(tf.log(1 - fake_tilde + 1e-12)) + self.adversarial_loss = - tf.reduce_mean(tf.log(fake_tilde + 1e-12)) - tf.reduce_mean(tf.log(1 - real_tilde + 1e-12)) + + self.dis_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator')) + + self.saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')) def run(self): global_step = tf.Variable(0, trainable=False) - optimizer = tf.train.AdamOptimizer(self.learning_rate) + optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) deconv_mult = lambda grads: list(map(lambda x: (x[0] * 1.0, x[1]) if 'deconv' in x[1].name else x, grads)) - grads = deconv_mult(optimizer.compute_gradients(self.loss)) - self.train_op = optimizer.apply_gradients(grads, global_step=global_step) + if self.adversarial and self.train and not self.params: + with tf.variable_scope('dicriminator_train'): + discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') + discrim_grads = optimizer.compute_gradients(self.discrim_loss, discrim_tvars) + discrim_train = optimizer.apply_gradients(discrim_grads) + + with tf.variable_scope('generator_train'): + # Need to wait discriminator to perform train step + with tf.control_dependencies([discrim_train] + tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') + gen_grads = deconv_mult(optimizer.compute_gradients(self.loss + 5e-3 * self.adversarial_loss, gen_tvars)) + self.train_op = optimizer.apply_gradients(gen_grads, global_step=global_step) + else: + grads = deconv_mult(optimizer.compute_gradients(self.loss)) + self.train_op = optimizer.apply_gradients(grads, global_step=global_step) tf.global_variables_initializer().run() @@ -179,6 +210,10 @@ class Model(object): self.saver.save(self.sess, os.path.join(self.model_dir, model_name), global_step=step) + if self.adversarial: + self.dis_saver.save(self.sess, + os.path.join(os.path.join(self.model_dir, "discriminator"), "discriminator.model"), + global_step=step) def load(self): print(" [*] Reading checkpoints...") @@ -187,6 +222,10 @@ class Model(object): if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.saver.restore(self.sess, os.path.join(self.model_dir, ckpt_name)) + ckpt = tf.train.get_checkpoint_state(os.path.join(self.model_dir, "discriminator")) + if self.adversarial and self.train and not self.params and ckpt and ckpt.model_checkpoint_path: + ckpt_name = os.path.basename(ckpt.model_checkpoint_path) + self.dis_saver.restore(self.sess, os.path.join(os.path.join(self.model_dir, "discriminator"), ckpt_name)) return True else: return False