Adversarial training

This commit is contained in:
igv
2017-11-29 02:00:23 +02:00
parent aef9256d75
commit db7f6c4499
3 changed files with 78 additions and 7 deletions

31
discriminator.py Normal file
View File

@@ -0,0 +1,31 @@
import tensorflow as tf
def discriminator(inputs, channels=64, is_training=True):
def conv2(batch_input, num_outputs=64, kernel_size=[3, 3], stride=1, norm=True, scope=None):
return tf.contrib.layers.conv2d(batch_input, num_outputs, kernel_size, stride, 'SAME', 'NHWC', scope=scope,
activation_fn=tf.nn.leaky_relu,
normalizer_fn=tf.contrib.layers.instance_norm if norm else None,
normalizer_params={'center': False, 'scale': False, 'data_format': 'NHWC'},
weights_initializer=tf.variance_scaling_initializer(scale=2.0))
with tf.device('/gpu:0'):
net = conv2(inputs, channels, kernel_size=[5, 5], stride=2, norm=False, scope='input_stage')
net = conv2(net, channels*2, stride=2, scope='disblock_1')
net = conv2(net, channels*4, stride=2, scope='disblock_2')
net = conv2(net, channels*8, stride=2, scope='disblock_3')
net = tf.layers.flatten(net)
with tf.variable_scope('dense_layer_1'):
net = tf.layers.dense(net, channels*16, activation=tf.nn.leaky_relu, use_bias=False,
kernel_initializer=tf.variance_scaling_initializer(scale=2.0))
with tf.variable_scope('dense_layer_2'):
net = tf.layers.dense(net, 1, use_bias=False)
return net

View File

@@ -18,6 +18,7 @@ flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint director
flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]")
flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]")
flags.DEFINE_boolean("train", True, "True for training, false for testing [True]")
flags.DEFINE_boolean("adversarial", False, "Adversarial training [False]")
flags.DEFINE_boolean("distort", False, "Distort some images with JPEG compression artifacts after downscaling [False]")
flags.DEFINE_boolean("params", False, "Save weight and bias parameters [False]")

View File

@@ -25,6 +25,7 @@ class Model(object):
self.arch = config.arch
self.fast = config.fast
self.train = config.train
self.adversarial = config.adversarial
self.epoch = config.epoch
self.scale = config.scale
self.radius = config.radius
@@ -56,24 +57,54 @@ class Model(object):
# Batch size differs in training vs testing
self.batch = tf.placeholder(tf.int32, shape=[], name='batch')
model = importlib.import_module(self.arch)
self.model = model.Model(self)
with tf.variable_scope('generator'):
model = importlib.import_module(self.arch)
self.model = model.Model(self)
self.pred = self.model.model()
self.pred = self.model.model()
model_dir = "%s_%s_%s_%s" % (self.model.name.lower(), self.label_size, '-'.join(str(i) for i in self.model.model_params), "r"+str(self.radius))
self.model_dir = os.path.join(self.checkpoint_dir, model_dir)
self.loss = self.model.loss(self.labels, self.pred)
self.saver = tf.train.Saver()
if self.adversarial and self.train and not self.params:
from discriminator import discriminator
with tf.variable_scope('discriminator', reuse=False):
discrim_fake_output = discriminator(self.pred)
with tf.variable_scope('discriminator', reuse=True):
discrim_real_output = discriminator(self.labels)
avg_fake = tf.reduce_mean(discrim_fake_output)
avg_real = tf.reduce_mean(discrim_real_output)
real_tilde = tf.nn.sigmoid(discrim_real_output - avg_fake)
fake_tilde = tf.nn.sigmoid(discrim_fake_output - avg_real)
self.discrim_loss = - tf.reduce_mean(tf.log(real_tilde + 1e-12)) - tf.reduce_mean(tf.log(1 - fake_tilde + 1e-12))
self.adversarial_loss = - tf.reduce_mean(tf.log(fake_tilde + 1e-12)) - tf.reduce_mean(tf.log(1 - real_tilde + 1e-12))
self.dis_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator'))
self.saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator'))
def run(self):
global_step = tf.Variable(0, trainable=False)
optimizer = tf.train.AdamOptimizer(self.learning_rate)
optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)
deconv_mult = lambda grads: list(map(lambda x: (x[0] * 1.0, x[1]) if 'deconv' in x[1].name else x, grads))
grads = deconv_mult(optimizer.compute_gradients(self.loss))
self.train_op = optimizer.apply_gradients(grads, global_step=global_step)
if self.adversarial and self.train and not self.params:
with tf.variable_scope('dicriminator_train'):
discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
discrim_grads = optimizer.compute_gradients(self.discrim_loss, discrim_tvars)
discrim_train = optimizer.apply_gradients(discrim_grads)
with tf.variable_scope('generator_train'):
# Need to wait discriminator to perform train step
with tf.control_dependencies([discrim_train] + tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
gen_grads = deconv_mult(optimizer.compute_gradients(self.loss + 5e-3 * self.adversarial_loss, gen_tvars))
self.train_op = optimizer.apply_gradients(gen_grads, global_step=global_step)
else:
grads = deconv_mult(optimizer.compute_gradients(self.loss))
self.train_op = optimizer.apply_gradients(grads, global_step=global_step)
tf.global_variables_initializer().run()
@@ -179,6 +210,10 @@ class Model(object):
self.saver.save(self.sess,
os.path.join(self.model_dir, model_name),
global_step=step)
if self.adversarial:
self.dis_saver.save(self.sess,
os.path.join(os.path.join(self.model_dir, "discriminator"), "discriminator.model"),
global_step=step)
def load(self):
print(" [*] Reading checkpoints...")
@@ -187,6 +222,10 @@ class Model(object):
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(self.model_dir, ckpt_name))
ckpt = tf.train.get_checkpoint_state(os.path.join(self.model_dir, "discriminator"))
if self.adversarial and self.train and not self.params and ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.dis_saver.restore(self.sess, os.path.join(os.path.join(self.model_dir, "discriminator"), ckpt_name))
return True
else:
return False