mirror of
https://github.com/igv/FSRCNN-TensorFlow.git
synced 2026-02-05 22:52:08 +08:00
Adversarial training
This commit is contained in:
31
discriminator.py
Normal file
31
discriminator.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import tensorflow as tf
|
||||
|
||||
def discriminator(inputs, channels=64, is_training=True):
|
||||
|
||||
def conv2(batch_input, num_outputs=64, kernel_size=[3, 3], stride=1, norm=True, scope=None):
|
||||
return tf.contrib.layers.conv2d(batch_input, num_outputs, kernel_size, stride, 'SAME', 'NHWC', scope=scope,
|
||||
activation_fn=tf.nn.leaky_relu,
|
||||
normalizer_fn=tf.contrib.layers.instance_norm if norm else None,
|
||||
normalizer_params={'center': False, 'scale': False, 'data_format': 'NHWC'},
|
||||
weights_initializer=tf.variance_scaling_initializer(scale=2.0))
|
||||
|
||||
with tf.device('/gpu:0'):
|
||||
net = conv2(inputs, channels, kernel_size=[5, 5], stride=2, norm=False, scope='input_stage')
|
||||
|
||||
net = conv2(net, channels*2, stride=2, scope='disblock_1')
|
||||
|
||||
net = conv2(net, channels*4, stride=2, scope='disblock_2')
|
||||
|
||||
net = conv2(net, channels*8, stride=2, scope='disblock_3')
|
||||
|
||||
net = tf.layers.flatten(net)
|
||||
|
||||
with tf.variable_scope('dense_layer_1'):
|
||||
net = tf.layers.dense(net, channels*16, activation=tf.nn.leaky_relu, use_bias=False,
|
||||
kernel_initializer=tf.variance_scaling_initializer(scale=2.0))
|
||||
|
||||
with tf.variable_scope('dense_layer_2'):
|
||||
net = tf.layers.dense(net, 1, use_bias=False)
|
||||
|
||||
return net
|
||||
|
||||
1
main.py
1
main.py
@@ -18,6 +18,7 @@ flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint director
|
||||
flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]")
|
||||
flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]")
|
||||
flags.DEFINE_boolean("train", True, "True for training, false for testing [True]")
|
||||
flags.DEFINE_boolean("adversarial", False, "Adversarial training [False]")
|
||||
flags.DEFINE_boolean("distort", False, "Distort some images with JPEG compression artifacts after downscaling [False]")
|
||||
flags.DEFINE_boolean("params", False, "Save weight and bias parameters [False]")
|
||||
|
||||
|
||||
53
model.py
53
model.py
@@ -25,6 +25,7 @@ class Model(object):
|
||||
self.arch = config.arch
|
||||
self.fast = config.fast
|
||||
self.train = config.train
|
||||
self.adversarial = config.adversarial
|
||||
self.epoch = config.epoch
|
||||
self.scale = config.scale
|
||||
self.radius = config.radius
|
||||
@@ -56,24 +57,54 @@ class Model(object):
|
||||
# Batch size differs in training vs testing
|
||||
self.batch = tf.placeholder(tf.int32, shape=[], name='batch')
|
||||
|
||||
model = importlib.import_module(self.arch)
|
||||
self.model = model.Model(self)
|
||||
with tf.variable_scope('generator'):
|
||||
model = importlib.import_module(self.arch)
|
||||
self.model = model.Model(self)
|
||||
|
||||
self.pred = self.model.model()
|
||||
self.pred = self.model.model()
|
||||
|
||||
model_dir = "%s_%s_%s_%s" % (self.model.name.lower(), self.label_size, '-'.join(str(i) for i in self.model.model_params), "r"+str(self.radius))
|
||||
self.model_dir = os.path.join(self.checkpoint_dir, model_dir)
|
||||
|
||||
self.loss = self.model.loss(self.labels, self.pred)
|
||||
|
||||
self.saver = tf.train.Saver()
|
||||
if self.adversarial and self.train and not self.params:
|
||||
from discriminator import discriminator
|
||||
with tf.variable_scope('discriminator', reuse=False):
|
||||
discrim_fake_output = discriminator(self.pred)
|
||||
with tf.variable_scope('discriminator', reuse=True):
|
||||
discrim_real_output = discriminator(self.labels)
|
||||
|
||||
avg_fake = tf.reduce_mean(discrim_fake_output)
|
||||
avg_real = tf.reduce_mean(discrim_real_output)
|
||||
real_tilde = tf.nn.sigmoid(discrim_real_output - avg_fake)
|
||||
fake_tilde = tf.nn.sigmoid(discrim_fake_output - avg_real)
|
||||
self.discrim_loss = - tf.reduce_mean(tf.log(real_tilde + 1e-12)) - tf.reduce_mean(tf.log(1 - fake_tilde + 1e-12))
|
||||
self.adversarial_loss = - tf.reduce_mean(tf.log(fake_tilde + 1e-12)) - tf.reduce_mean(tf.log(1 - real_tilde + 1e-12))
|
||||
|
||||
self.dis_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator'))
|
||||
|
||||
self.saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator'))
|
||||
|
||||
def run(self):
|
||||
global_step = tf.Variable(0, trainable=False)
|
||||
optimizer = tf.train.AdamOptimizer(self.learning_rate)
|
||||
optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)
|
||||
deconv_mult = lambda grads: list(map(lambda x: (x[0] * 1.0, x[1]) if 'deconv' in x[1].name else x, grads))
|
||||
grads = deconv_mult(optimizer.compute_gradients(self.loss))
|
||||
self.train_op = optimizer.apply_gradients(grads, global_step=global_step)
|
||||
if self.adversarial and self.train and not self.params:
|
||||
with tf.variable_scope('dicriminator_train'):
|
||||
discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
|
||||
discrim_grads = optimizer.compute_gradients(self.discrim_loss, discrim_tvars)
|
||||
discrim_train = optimizer.apply_gradients(discrim_grads)
|
||||
|
||||
with tf.variable_scope('generator_train'):
|
||||
# Need to wait discriminator to perform train step
|
||||
with tf.control_dependencies([discrim_train] + tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
|
||||
gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
|
||||
gen_grads = deconv_mult(optimizer.compute_gradients(self.loss + 5e-3 * self.adversarial_loss, gen_tvars))
|
||||
self.train_op = optimizer.apply_gradients(gen_grads, global_step=global_step)
|
||||
else:
|
||||
grads = deconv_mult(optimizer.compute_gradients(self.loss))
|
||||
self.train_op = optimizer.apply_gradients(grads, global_step=global_step)
|
||||
|
||||
tf.global_variables_initializer().run()
|
||||
|
||||
@@ -179,6 +210,10 @@ class Model(object):
|
||||
self.saver.save(self.sess,
|
||||
os.path.join(self.model_dir, model_name),
|
||||
global_step=step)
|
||||
if self.adversarial:
|
||||
self.dis_saver.save(self.sess,
|
||||
os.path.join(os.path.join(self.model_dir, "discriminator"), "discriminator.model"),
|
||||
global_step=step)
|
||||
|
||||
def load(self):
|
||||
print(" [*] Reading checkpoints...")
|
||||
@@ -187,6 +222,10 @@ class Model(object):
|
||||
if ckpt and ckpt.model_checkpoint_path:
|
||||
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
|
||||
self.saver.restore(self.sess, os.path.join(self.model_dir, ckpt_name))
|
||||
ckpt = tf.train.get_checkpoint_state(os.path.join(self.model_dir, "discriminator"))
|
||||
if self.adversarial and self.train and not self.params and ckpt and ckpt.model_checkpoint_path:
|
||||
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
|
||||
self.dis_saver.restore(self.sess, os.path.join(os.path.join(self.model_dir, "discriminator"), ckpt_name))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user