mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-30 04:05:27 +08:00
updating cs_gan for bug fixing and ODE-GAN opensource.
PiperOrigin-RevId: 352764474
This commit is contained in:
committed by
Diego de Las Casas
parent
1fcac77dc8
commit
dd89722a76
+21
-11
@@ -38,8 +38,9 @@ class GAN(object):
|
||||
Gaussian or uniform).
|
||||
"""
|
||||
|
||||
def __init__(self, discriminator, generator, num_z_iters, z_step_size,
|
||||
z_project_method, optimisation_cost_weight):
|
||||
def __init__(self, discriminator, generator,
|
||||
num_z_iters=None, z_step_size=None,
|
||||
z_project_method=None, optimisation_cost_weight=None):
|
||||
"""Constructs the module.
|
||||
|
||||
Args:
|
||||
@@ -57,10 +58,11 @@ class GAN(object):
|
||||
self.generator = generator
|
||||
self.num_z_iters = num_z_iters
|
||||
self.z_project_method = z_project_method
|
||||
self._log_step_size_module = snt.TrainableVariable(
|
||||
[],
|
||||
initializers={'w': tf.constant_initializer(math.log(z_step_size))})
|
||||
self.z_step_size = tf.exp(self._log_step_size_module())
|
||||
if z_step_size:
|
||||
self._log_step_size_module = snt.TrainableVariable(
|
||||
[],
|
||||
initializers={'w': tf.constant_initializer(math.log(z_step_size))})
|
||||
self.z_step_size = tf.exp(self._log_step_size_module())
|
||||
self._optimisation_cost_weight = optimisation_cost_weight
|
||||
|
||||
def connect(self, data, generator_inputs):
|
||||
@@ -108,12 +110,13 @@ class GAN(object):
|
||||
optimisation_cost=optimisation_cost)
|
||||
|
||||
debug_ops = {}
|
||||
debug_ops['z_step_size'] = self.z_step_size
|
||||
debug_ops['disc_data_loss'] = disc_data_loss
|
||||
debug_ops['disc_sample_loss'] = disc_sample_loss
|
||||
debug_ops['disc_loss'] = disc_loss
|
||||
debug_ops['gen_loss'] = generator_loss
|
||||
debug_ops['opt_cost'] = optimisation_cost
|
||||
if hasattr(self, 'z_step_size'):
|
||||
debug_ops['z_step_size'] = self.z_step_size
|
||||
|
||||
return utils.ModelOutputs(
|
||||
optimization_components, debug_ops)
|
||||
@@ -134,17 +137,24 @@ class GAN(object):
|
||||
|
||||
discriminator_vars = _get_and_check_variables(self._discriminator)
|
||||
generator_vars = _get_and_check_variables(self.generator)
|
||||
step_vars = _get_and_check_variables(self._log_step_size_module)
|
||||
if hasattr(self, '_log_step_size_module'):
|
||||
step_vars = _get_and_check_variables(self._log_step_size_module)
|
||||
generator_vars += step_vars
|
||||
|
||||
optimization_components = collections.OrderedDict()
|
||||
optimization_components['disc'] = utils.OptimizationComponent(
|
||||
discriminator_loss, discriminator_vars)
|
||||
if self._optimisation_cost_weight:
|
||||
generator_loss += self._optimisation_cost_weight * optimisation_cost
|
||||
optimization_components['gen'] = utils.OptimizationComponent(
|
||||
generator_loss
|
||||
+ self._optimisation_cost_weight * optimisation_cost,
|
||||
generator_vars + step_vars)
|
||||
generator_loss, generator_vars)
|
||||
return optimization_components
|
||||
|
||||
def get_variables(self):
|
||||
disc_vars = _get_and_check_variables(self._discriminator)
|
||||
gen_vars = _get_and_check_variables(self.generator)
|
||||
return disc_vars, gen_vars
|
||||
|
||||
|
||||
def _get_and_check_variables(module):
|
||||
module_variables = module.get_all_variables()
|
||||
|
||||
Reference in New Issue
Block a user