updating cs_gan for bug fixing and ODE-GAN opensource.

PiperOrigin-RevId: 352764474
This commit is contained in:
Yan Wu
2021-01-20 12:10:46 +00:00
committed by Diego de Las Casas
parent 1fcac77dc8
commit dd89722a76
6 changed files with 458 additions and 27 deletions
+21 -11
View File
@@ -38,8 +38,9 @@ class GAN(object):
Gaussian or uniform).
"""
def __init__(self, discriminator, generator, num_z_iters, z_step_size,
z_project_method, optimisation_cost_weight):
def __init__(self, discriminator, generator,
num_z_iters=None, z_step_size=None,
z_project_method=None, optimisation_cost_weight=None):
"""Constructs the module.
Args:
@@ -57,10 +58,11 @@ class GAN(object):
self.generator = generator
self.num_z_iters = num_z_iters
self.z_project_method = z_project_method
self._log_step_size_module = snt.TrainableVariable(
[],
initializers={'w': tf.constant_initializer(math.log(z_step_size))})
self.z_step_size = tf.exp(self._log_step_size_module())
if z_step_size:
self._log_step_size_module = snt.TrainableVariable(
[],
initializers={'w': tf.constant_initializer(math.log(z_step_size))})
self.z_step_size = tf.exp(self._log_step_size_module())
self._optimisation_cost_weight = optimisation_cost_weight
def connect(self, data, generator_inputs):
@@ -108,12 +110,13 @@ class GAN(object):
optimisation_cost=optimisation_cost)
debug_ops = {}
debug_ops['z_step_size'] = self.z_step_size
debug_ops['disc_data_loss'] = disc_data_loss
debug_ops['disc_sample_loss'] = disc_sample_loss
debug_ops['disc_loss'] = disc_loss
debug_ops['gen_loss'] = generator_loss
debug_ops['opt_cost'] = optimisation_cost
if hasattr(self, 'z_step_size'):
debug_ops['z_step_size'] = self.z_step_size
return utils.ModelOutputs(
optimization_components, debug_ops)
@@ -134,17 +137,24 @@ class GAN(object):
discriminator_vars = _get_and_check_variables(self._discriminator)
generator_vars = _get_and_check_variables(self.generator)
step_vars = _get_and_check_variables(self._log_step_size_module)
if hasattr(self, '_log_step_size_module'):
step_vars = _get_and_check_variables(self._log_step_size_module)
generator_vars += step_vars
optimization_components = collections.OrderedDict()
optimization_components['disc'] = utils.OptimizationComponent(
discriminator_loss, discriminator_vars)
if self._optimisation_cost_weight:
generator_loss += self._optimisation_cost_weight * optimisation_cost
optimization_components['gen'] = utils.OptimizationComponent(
generator_loss
+ self._optimisation_cost_weight * optimisation_cost,
generator_vars + step_vars)
generator_loss, generator_vars)
return optimization_components
def get_variables(self):
disc_vars = _get_and_check_variables(self._discriminator)
gen_vars = _get_and_check_variables(self.generator)
return disc_vars, gen_vars
def _get_and_check_variables(module):
module_variables = module.get_all_variables()