mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 12:37:43 +08:00
updating cs_gan for bug fixing and ODE-GAN opensource.
PiperOrigin-RevId: 352764474
This commit is contained in:
committed by
Diego de Las Casas
parent
1fcac77dc8
commit
dd89722a76
+21
-11
@@ -38,8 +38,9 @@ class GAN(object):
|
|||||||
Gaussian or uniform).
|
Gaussian or uniform).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, discriminator, generator, num_z_iters, z_step_size,
|
def __init__(self, discriminator, generator,
|
||||||
z_project_method, optimisation_cost_weight):
|
num_z_iters=None, z_step_size=None,
|
||||||
|
z_project_method=None, optimisation_cost_weight=None):
|
||||||
"""Constructs the module.
|
"""Constructs the module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -57,10 +58,11 @@ class GAN(object):
|
|||||||
self.generator = generator
|
self.generator = generator
|
||||||
self.num_z_iters = num_z_iters
|
self.num_z_iters = num_z_iters
|
||||||
self.z_project_method = z_project_method
|
self.z_project_method = z_project_method
|
||||||
self._log_step_size_module = snt.TrainableVariable(
|
if z_step_size:
|
||||||
[],
|
self._log_step_size_module = snt.TrainableVariable(
|
||||||
initializers={'w': tf.constant_initializer(math.log(z_step_size))})
|
[],
|
||||||
self.z_step_size = tf.exp(self._log_step_size_module())
|
initializers={'w': tf.constant_initializer(math.log(z_step_size))})
|
||||||
|
self.z_step_size = tf.exp(self._log_step_size_module())
|
||||||
self._optimisation_cost_weight = optimisation_cost_weight
|
self._optimisation_cost_weight = optimisation_cost_weight
|
||||||
|
|
||||||
def connect(self, data, generator_inputs):
|
def connect(self, data, generator_inputs):
|
||||||
@@ -108,12 +110,13 @@ class GAN(object):
|
|||||||
optimisation_cost=optimisation_cost)
|
optimisation_cost=optimisation_cost)
|
||||||
|
|
||||||
debug_ops = {}
|
debug_ops = {}
|
||||||
debug_ops['z_step_size'] = self.z_step_size
|
|
||||||
debug_ops['disc_data_loss'] = disc_data_loss
|
debug_ops['disc_data_loss'] = disc_data_loss
|
||||||
debug_ops['disc_sample_loss'] = disc_sample_loss
|
debug_ops['disc_sample_loss'] = disc_sample_loss
|
||||||
debug_ops['disc_loss'] = disc_loss
|
debug_ops['disc_loss'] = disc_loss
|
||||||
debug_ops['gen_loss'] = generator_loss
|
debug_ops['gen_loss'] = generator_loss
|
||||||
debug_ops['opt_cost'] = optimisation_cost
|
debug_ops['opt_cost'] = optimisation_cost
|
||||||
|
if hasattr(self, 'z_step_size'):
|
||||||
|
debug_ops['z_step_size'] = self.z_step_size
|
||||||
|
|
||||||
return utils.ModelOutputs(
|
return utils.ModelOutputs(
|
||||||
optimization_components, debug_ops)
|
optimization_components, debug_ops)
|
||||||
@@ -134,17 +137,24 @@ class GAN(object):
|
|||||||
|
|
||||||
discriminator_vars = _get_and_check_variables(self._discriminator)
|
discriminator_vars = _get_and_check_variables(self._discriminator)
|
||||||
generator_vars = _get_and_check_variables(self.generator)
|
generator_vars = _get_and_check_variables(self.generator)
|
||||||
step_vars = _get_and_check_variables(self._log_step_size_module)
|
if hasattr(self, '_log_step_size_module'):
|
||||||
|
step_vars = _get_and_check_variables(self._log_step_size_module)
|
||||||
|
generator_vars += step_vars
|
||||||
|
|
||||||
optimization_components = collections.OrderedDict()
|
optimization_components = collections.OrderedDict()
|
||||||
optimization_components['disc'] = utils.OptimizationComponent(
|
optimization_components['disc'] = utils.OptimizationComponent(
|
||||||
discriminator_loss, discriminator_vars)
|
discriminator_loss, discriminator_vars)
|
||||||
|
if self._optimisation_cost_weight:
|
||||||
|
generator_loss += self._optimisation_cost_weight * optimisation_cost
|
||||||
optimization_components['gen'] = utils.OptimizationComponent(
|
optimization_components['gen'] = utils.OptimizationComponent(
|
||||||
generator_loss
|
generator_loss, generator_vars)
|
||||||
+ self._optimisation_cost_weight * optimisation_cost,
|
|
||||||
generator_vars + step_vars)
|
|
||||||
return optimization_components
|
return optimization_components
|
||||||
|
|
||||||
|
def get_variables(self):
|
||||||
|
disc_vars = _get_and_check_variables(self._discriminator)
|
||||||
|
gen_vars = _get_and_check_variables(self.generator)
|
||||||
|
return disc_vars, gen_vars
|
||||||
|
|
||||||
|
|
||||||
def _get_and_check_variables(module):
|
def _get_and_check_variables(module):
|
||||||
module_variables = module.get_all_variables()
|
module_variables = module.get_all_variables()
|
||||||
|
|||||||
@@ -0,0 +1,367 @@
|
|||||||
|
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Training script."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from cs_gan import file_utils
|
||||||
|
from cs_gan import gan
|
||||||
|
from cs_gan import image_metrics
|
||||||
|
from cs_gan import utils
|
||||||
|
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
'num_training_iterations', 1200000,
|
||||||
|
'Number of training iterations.')
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'ode_mode', 'rk4', 'Integration method.')
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
'batch_size', 64, 'Training batch size.')
|
||||||
|
flags.DEFINE_float(
|
||||||
|
'grad_reg_weight', 0.02, 'Step size for latent optimisation.')
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'opt_name', 'gd', 'Name of the optimiser (gd|adam).')
|
||||||
|
flags.DEFINE_bool(
|
||||||
|
'schedule_lr', True, 'The method to project z.')
|
||||||
|
flags.DEFINE_bool(
|
||||||
|
'reg_first_grad_only', True, 'Whether only to regularise the first grad.')
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
'num_latents', 128, 'The number of latents')
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
'summary_every_step', 1000,
|
||||||
|
'The interval at which to log debug ops.')
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
'image_metrics_every_step', 1000,
|
||||||
|
'The interval at which to log (expensive) image metrics.')
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
'export_every', 10,
|
||||||
|
'The interval at which to export samples.')
|
||||||
|
# Use 50k to reproduce scores from the paper. Default to 10k here to avoid the
|
||||||
|
# runtime error caused by too large graph with 50k samples on some machines.
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
'num_eval_samples', 10000,
|
||||||
|
'The number of samples used to evaluate FID/IS.')
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'dataset', 'cifar', 'The dataset used for learning (cifar|mnist).')
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'output_dir', '/tmp/ode_gan/gan', 'Location where to save output files.')
|
||||||
|
flags.DEFINE_float('disc_lr', 4e-2, 'Discriminator Learning rate.')
|
||||||
|
flags.DEFINE_float('gen_lr', 4e-2, 'Generator Learning rate.')
|
||||||
|
flags.DEFINE_bool(
|
||||||
|
'run_real_data_metrics', False,
|
||||||
|
'Whether or not to run image metrics on real data.')
|
||||||
|
flags.DEFINE_bool(
|
||||||
|
'run_sample_metrics', True,
|
||||||
|
'Whether or not to run image metrics on samples.')
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
# Log info level (for Hooks).
|
||||||
|
tf.logging.set_verbosity(tf.logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_vars(v_list):
|
||||||
|
"""Copy variables in v_list."""
|
||||||
|
t_list = []
|
||||||
|
for v in v_list:
|
||||||
|
t_list.append(tf.identity(v))
|
||||||
|
return t_list
|
||||||
|
|
||||||
|
|
||||||
|
def _restore_vars(v_list, t_list):
|
||||||
|
"""Restore variables in v_list from t_list."""
|
||||||
|
ops = []
|
||||||
|
for v, t in zip(v_list, t_list):
|
||||||
|
ops.append(v.assign(t))
|
||||||
|
return ops
|
||||||
|
|
||||||
|
|
||||||
|
def _scale_vars(s, v_list):
|
||||||
|
"""Scale all variables in v_list by s."""
|
||||||
|
return [s * v for v in v_list]
|
||||||
|
|
||||||
|
|
||||||
|
def _acc_grads(g_sum, g_w, g):
|
||||||
|
"""Accumulate gradients in g, weighted by g_w."""
|
||||||
|
return [g_sum_i + g_w * g_i for g_sum_i, g_i in zip(g_sum, g)]
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_reg_grads(gen_grads, disc_vars):
|
||||||
|
"""Compute gradients norm (this is an upper-bpund of the full-batch norm)."""
|
||||||
|
gen_norm = tf.accumulate_n([tf.reduce_sum(u * u) for u in gen_grads])
|
||||||
|
disc_reg_grads = tf.gradients(gen_norm, disc_vars)
|
||||||
|
return disc_reg_grads
|
||||||
|
|
||||||
|
|
||||||
|
def run_model(prior, images, model, disc_reg_weight):
|
||||||
|
"""Run the model with new data and samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prior: the noise source as the generator input.
|
||||||
|
images: images sampled from dataset.
|
||||||
|
model: a GAN model defined in gan.py.
|
||||||
|
disc_reg_weight: regularisation weight for discrmininator gradients.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
debug_ops: statistics from the model, see gan.py for more detials.
|
||||||
|
disc_grads: discriminator gradients.
|
||||||
|
gen_grads: generator gradients.
|
||||||
|
"""
|
||||||
|
generator_inputs = prior.sample(FLAGS.batch_size)
|
||||||
|
model_output = model.connect(images, generator_inputs)
|
||||||
|
optimization_components = model_output.optimization_components
|
||||||
|
|
||||||
|
disc_grads = tf.gradients(
|
||||||
|
optimization_components['disc'].loss,
|
||||||
|
optimization_components['disc'].vars)
|
||||||
|
|
||||||
|
gen_grads = tf.gradients(
|
||||||
|
optimization_components['gen'].loss,
|
||||||
|
optimization_components['gen'].vars)
|
||||||
|
|
||||||
|
if disc_reg_weight > 0.0:
|
||||||
|
reg_grads = _compute_reg_grads(gen_grads,
|
||||||
|
optimization_components['disc'].vars)
|
||||||
|
disc_grads = _acc_grads(disc_grads, disc_reg_weight, reg_grads)
|
||||||
|
|
||||||
|
debug_ops = model_output.debug_ops
|
||||||
|
|
||||||
|
return debug_ops, disc_grads, gen_grads
|
||||||
|
|
||||||
|
|
||||||
|
def update_model(model, disc_grads, gen_grads, disc_opt, gen_opt,
|
||||||
|
global_step, update_scale):
|
||||||
|
"""Update model with gradients."""
|
||||||
|
|
||||||
|
disc_vars, gen_vars = model.get_variables()
|
||||||
|
|
||||||
|
with tf.control_dependencies(gen_grads + disc_grads):
|
||||||
|
disc_update_op = disc_opt.apply_gradients(
|
||||||
|
zip(_scale_vars(update_scale, disc_grads),
|
||||||
|
disc_vars))
|
||||||
|
|
||||||
|
gen_update_op = gen_opt.apply_gradients(
|
||||||
|
zip(_scale_vars(update_scale, gen_grads),
|
||||||
|
gen_vars),
|
||||||
|
global_step=global_step)
|
||||||
|
|
||||||
|
update_op = tf.group([disc_update_op, gen_update_op])
|
||||||
|
|
||||||
|
return update_op
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
utils.make_output_dir(FLAGS.output_dir)
|
||||||
|
data_processor = utils.DataProcessor()
|
||||||
|
# Compute the batch-size multiplier
|
||||||
|
if FLAGS.ode_mode == 'rk2':
|
||||||
|
batch_mul = 2
|
||||||
|
elif FLAGS.ode_mode == 'rk4':
|
||||||
|
batch_mul = 4
|
||||||
|
else:
|
||||||
|
batch_mul = 1
|
||||||
|
images = utils.get_train_dataset(data_processor, FLAGS.dataset,
|
||||||
|
int(FLAGS.batch_size * batch_mul))
|
||||||
|
image_splits = tf.split(images, batch_mul)
|
||||||
|
|
||||||
|
logging.info('Generator learning rate: %d', FLAGS.gen_lr)
|
||||||
|
logging.info('Discriminator learning rate: %d', FLAGS.disc_lr)
|
||||||
|
|
||||||
|
global_step = tf.train.get_or_create_global_step()
|
||||||
|
# Construct optimizers.
|
||||||
|
if FLAGS.opt_name == 'adam':
|
||||||
|
disc_opt = tf.train.AdamOptimizer(FLAGS.disc_lr, beta1=0.5, beta2=0.999)
|
||||||
|
gen_opt = tf.train.AdamOptimizer(FLAGS.gen_lr, beta1=0.5, beta2=0.999)
|
||||||
|
elif FLAGS.opt_name == 'gd':
|
||||||
|
if FLAGS.schedule_lr:
|
||||||
|
gd_disc_lr = tf.train.piecewise_constant(
|
||||||
|
global_step,
|
||||||
|
values=[FLAGS.disc_lr / 4., FLAGS.disc_lr, FLAGS.disc_lr / 2.],
|
||||||
|
boundaries=[500, 400000])
|
||||||
|
gd_gen_lr = tf.train.piecewise_constant(
|
||||||
|
global_step,
|
||||||
|
values=[FLAGS.gen_lr / 4., FLAGS.gen_lr, FLAGS.gen_lr / 2.],
|
||||||
|
boundaries=[500, 400000])
|
||||||
|
else:
|
||||||
|
gd_disc_lr = FLAGS.disc_lr
|
||||||
|
gd_gen_lr = FLAGS.gen_lr
|
||||||
|
disc_opt = tf.train.GradientDescentOptimizer(gd_disc_lr)
|
||||||
|
gen_opt = tf.train.GradientDescentOptimizer(gd_gen_lr)
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown ODE mode!')
|
||||||
|
|
||||||
|
# Create the networks and models.
|
||||||
|
generator = utils.get_generator(FLAGS.dataset)
|
||||||
|
metric_net = utils.get_metric_net(FLAGS.dataset, use_sn=False)
|
||||||
|
model = gan.GAN(metric_net, generator)
|
||||||
|
prior = utils.make_prior(FLAGS.num_latents)
|
||||||
|
|
||||||
|
# Setup ODE parameters.
|
||||||
|
if FLAGS.ode_mode == 'rk2':
|
||||||
|
ode_grad_weights = [0.5, 0.5]
|
||||||
|
step_scale = [1.0]
|
||||||
|
elif FLAGS.ode_mode == 'rk4':
|
||||||
|
ode_grad_weights = [1. / 6., 1. / 3., 1. / 3., 1. / 6.]
|
||||||
|
step_scale = [0.5, 0.5, 1.]
|
||||||
|
elif FLAGS.ode_mode == 'euler':
|
||||||
|
# Euler update
|
||||||
|
ode_grad_weights = [1.0]
|
||||||
|
step_scale = []
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown ODE mode!')
|
||||||
|
|
||||||
|
# Extra steps for RK updates.
|
||||||
|
num_extra_steps = len(step_scale)
|
||||||
|
|
||||||
|
if FLAGS.reg_first_grad_only:
|
||||||
|
first_reg_weight = FLAGS.grad_reg_weight / ode_grad_weights[0]
|
||||||
|
other_reg_weight = 0.0
|
||||||
|
else:
|
||||||
|
first_reg_weight = FLAGS.grad_reg_weight
|
||||||
|
other_reg_weight = FLAGS.grad_reg_weight
|
||||||
|
|
||||||
|
debug_ops, disc_grads, gen_grads = run_model(prior, image_splits[0],
|
||||||
|
model, first_reg_weight)
|
||||||
|
|
||||||
|
disc_vars, gen_vars = model.get_variables()
|
||||||
|
|
||||||
|
final_disc_grads = _scale_vars(ode_grad_weights[0], disc_grads)
|
||||||
|
final_gen_grads = _scale_vars(ode_grad_weights[0], gen_grads)
|
||||||
|
|
||||||
|
restore_ops = []
|
||||||
|
# Preparing for further RK steps.
|
||||||
|
if num_extra_steps > 0:
|
||||||
|
# copy the variables before they are changed by update_op
|
||||||
|
saved_disc_vars = _copy_vars(disc_vars)
|
||||||
|
saved_gen_vars = _copy_vars(gen_vars)
|
||||||
|
|
||||||
|
# Enter RK loop.
|
||||||
|
with tf.control_dependencies(saved_disc_vars + saved_gen_vars):
|
||||||
|
step_deps = []
|
||||||
|
for i_step in range(num_extra_steps):
|
||||||
|
with tf.control_dependencies(step_deps):
|
||||||
|
# Compute gradient steps for intermediate updates.
|
||||||
|
update_op = update_model(
|
||||||
|
model, disc_grads, gen_grads, disc_opt, gen_opt,
|
||||||
|
None, step_scale[i_step])
|
||||||
|
with tf.control_dependencies([update_op]):
|
||||||
|
_, disc_grads, gen_grads = run_model(
|
||||||
|
prior, image_splits[i_step + 1], model, other_reg_weight)
|
||||||
|
|
||||||
|
# Accumlate gradients for final update.
|
||||||
|
final_disc_grads = _acc_grads(final_disc_grads,
|
||||||
|
ode_grad_weights[i_step + 1],
|
||||||
|
disc_grads)
|
||||||
|
final_gen_grads = _acc_grads(final_gen_grads,
|
||||||
|
ode_grad_weights[i_step + 1],
|
||||||
|
gen_grads)
|
||||||
|
|
||||||
|
# Make new restore_op for each step.
|
||||||
|
restore_ops = []
|
||||||
|
restore_ops += _restore_vars(disc_vars, saved_disc_vars)
|
||||||
|
restore_ops += _restore_vars(gen_vars, saved_gen_vars)
|
||||||
|
|
||||||
|
step_deps = restore_ops
|
||||||
|
|
||||||
|
with tf.control_dependencies(restore_ops):
|
||||||
|
update_op = update_model(
|
||||||
|
model, final_disc_grads, final_gen_grads, disc_opt, gen_opt,
|
||||||
|
global_step, 1.0)
|
||||||
|
|
||||||
|
samples = generator(prior.sample(FLAGS.batch_size), is_training=False)
|
||||||
|
|
||||||
|
# Get data needed to compute FID. We also compute metrics on
|
||||||
|
# real data as a sanity check and as a reference point.
|
||||||
|
eval_real_data = utils.get_real_data_for_eval(FLAGS.num_eval_samples,
|
||||||
|
FLAGS.dataset,
|
||||||
|
split='train')
|
||||||
|
|
||||||
|
def sample_fn(x):
|
||||||
|
return utils.optimise_and_sample(x, module=model,
|
||||||
|
data=None, is_training=False)[0]
|
||||||
|
|
||||||
|
if FLAGS.run_sample_metrics:
|
||||||
|
sample_metrics = image_metrics.get_image_metrics_for_samples(
|
||||||
|
eval_real_data, sample_fn,
|
||||||
|
prior, data_processor,
|
||||||
|
num_eval_samples=FLAGS.num_eval_samples)
|
||||||
|
else:
|
||||||
|
sample_metrics = {}
|
||||||
|
|
||||||
|
if FLAGS.run_real_data_metrics:
|
||||||
|
data_metrics = image_metrics.get_image_metrics(
|
||||||
|
eval_real_data, eval_real_data)
|
||||||
|
else:
|
||||||
|
data_metrics = {}
|
||||||
|
|
||||||
|
sample_exporter = file_utils.FileExporter(
|
||||||
|
os.path.join(FLAGS.output_dir, 'samples'))
|
||||||
|
|
||||||
|
# Hooks.
|
||||||
|
debug_ops['it'] = global_step
|
||||||
|
# Abort training on Nans.
|
||||||
|
nan_disc_hook = tf.train.NanTensorHook(debug_ops['disc_loss'])
|
||||||
|
nan_gen_hook = tf.train.NanTensorHook(debug_ops['gen_loss'])
|
||||||
|
# Step counter.
|
||||||
|
step_conter_hook = tf.train.StepCounterHook()
|
||||||
|
|
||||||
|
checkpoint_saver_hook = tf.train.CheckpointSaverHook(
|
||||||
|
checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60)
|
||||||
|
|
||||||
|
loss_summary_saver_hook = tf.train.SummarySaverHook(
|
||||||
|
save_steps=FLAGS.summary_every_step,
|
||||||
|
output_dir=os.path.join(FLAGS.output_dir, 'summaries'),
|
||||||
|
summary_op=utils.get_summaries(debug_ops))
|
||||||
|
|
||||||
|
metrics_summary_saver_hook = tf.train.SummarySaverHook(
|
||||||
|
save_steps=FLAGS.image_metrics_every_step,
|
||||||
|
output_dir=os.path.join(FLAGS.output_dir, 'summaries'),
|
||||||
|
summary_op=utils.get_summaries(sample_metrics))
|
||||||
|
|
||||||
|
hooks = [checkpoint_saver_hook, metrics_summary_saver_hook,
|
||||||
|
nan_disc_hook, nan_gen_hook, step_conter_hook,
|
||||||
|
loss_summary_saver_hook]
|
||||||
|
|
||||||
|
# Start training.
|
||||||
|
with tf.train.MonitoredSession(hooks=hooks) as sess:
|
||||||
|
logging.info('starting training')
|
||||||
|
|
||||||
|
for key, value in sess.run(data_metrics).items():
|
||||||
|
logging.info('%s: %d', key, value)
|
||||||
|
|
||||||
|
for i in range(FLAGS.num_training_iterations):
|
||||||
|
sess.run(update_op)
|
||||||
|
|
||||||
|
if i % FLAGS.export_every == 0:
|
||||||
|
samples_np, data_np = sess.run([samples, image_splits[0]])
|
||||||
|
# Create an object which gets data and does the processing.
|
||||||
|
data_np = data_processor.postprocess(data_np)
|
||||||
|
samples_np = data_processor.postprocess(samples_np)
|
||||||
|
sample_exporter.save(samples_np, 'samples')
|
||||||
|
sample_exporter.save(data_np, 'data')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.enable_resource_variables()
|
||||||
|
app.run(main)
|
||||||
+17
-8
@@ -30,11 +30,11 @@ def _sn_custom_getter():
|
|||||||
name_filter=name_filter)
|
name_filter=name_filter)
|
||||||
|
|
||||||
|
|
||||||
class SNGenNet(snt.AbstractModule):
|
class ConvGenNet(snt.AbstractModule):
|
||||||
"""As in the SN paper."""
|
"""As in the SN paper."""
|
||||||
|
|
||||||
def __init__(self, name='conv_gen'):
|
def __init__(self, name='conv_gen'):
|
||||||
super(SNGenNet, self).__init__(name=name)
|
super(ConvGenNet, self).__init__(name=name)
|
||||||
|
|
||||||
def _build(self, inputs, is_training):
|
def _build(self, inputs, is_training):
|
||||||
batch_size = inputs.get_shape().as_list()[0]
|
batch_size = inputs.get_shape().as_list()[0]
|
||||||
@@ -57,15 +57,17 @@ class SNGenNet(snt.AbstractModule):
|
|||||||
return tf.nn.tanh(output)
|
return tf.nn.tanh(output)
|
||||||
|
|
||||||
|
|
||||||
class SNMetricNet(snt.AbstractModule):
|
class ConvMetricNet(snt.AbstractModule):
|
||||||
"""Spectral normalization discriminator (metric) architecture."""
|
"""Convolutional discriminator (metric) architecture."""
|
||||||
|
|
||||||
def __init__(self, num_outputs=2, name='sn_metric'):
|
def __init__(self, num_outputs=2, use_sn=True, name='sn_metric'):
|
||||||
super(SNMetricNet, self).__init__(name=name)
|
super(ConvMetricNet, self).__init__(name=name)
|
||||||
self._num_outputs = num_outputs
|
self._num_outputs = num_outputs
|
||||||
|
self._use_sn = use_sn
|
||||||
|
|
||||||
def _build(self, inputs):
|
def _build(self, inputs):
|
||||||
with tf.variable_scope('', custom_getter=_sn_custom_getter()):
|
|
||||||
|
def build_net():
|
||||||
net = snt.nets.ConvNet2D(
|
net = snt.nets.ConvNet2D(
|
||||||
output_channels=[64, 64, 128, 128, 256, 256, 512],
|
output_channels=[64, 64, 128, 128, 256, 256, 512],
|
||||||
kernel_shapes=[
|
kernel_shapes=[
|
||||||
@@ -74,7 +76,14 @@ class SNMetricNet(snt.AbstractModule):
|
|||||||
paddings=[snt.SAME], activate_final=True,
|
paddings=[snt.SAME], activate_final=True,
|
||||||
activation=functools.partial(tf.nn.leaky_relu, alpha=0.1))
|
activation=functools.partial(tf.nn.leaky_relu, alpha=0.1))
|
||||||
linear = snt.Linear(self._num_outputs)
|
linear = snt.Linear(self._num_outputs)
|
||||||
output = linear(snt.BatchFlatten()(net(inputs)))
|
output = linear(snt.BatchFlatten()(net(inputs)))
|
||||||
|
return output
|
||||||
|
if self._use_sn:
|
||||||
|
with tf.variable_scope('', custom_getter=_sn_custom_getter()):
|
||||||
|
output = build_net()
|
||||||
|
else:
|
||||||
|
output = build_net()
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Executable
+41
@@ -0,0 +1,41 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Install python3.5
|
||||||
|
which python3.5
|
||||||
|
if [ $? -eq 1 ]; then
|
||||||
|
echo 'Installing python3.5'
|
||||||
|
(cd /usr/src/
|
||||||
|
sudo wget https://www.python.org/ftp/python/3.5.6/Python-3.5.6.tgz
|
||||||
|
tar -xvzf Python-3.5.6.tgz
|
||||||
|
sudo tar -xvzf Python-3.5.6.tgz
|
||||||
|
cd Python-3.5.6
|
||||||
|
./configure --enable-loadable-sqlite-extensions --enable-optimizations
|
||||||
|
sudo make altinstall)
|
||||||
|
fi
|
||||||
|
# Fail on any error.
|
||||||
|
set -e
|
||||||
|
python3.5 -m venv cs_gan_venv
|
||||||
|
echo 'Created venv'
|
||||||
|
source cs_gan_venv/bin/activate
|
||||||
|
echo 'Installing pip'
|
||||||
|
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||||
|
python3.5 get-pip.py pip==20.2.3
|
||||||
|
|
||||||
|
echo 'Getting requirements.'
|
||||||
|
pip install -r cs_gan/requirements.txt
|
||||||
|
|
||||||
|
|
||||||
|
echo 'Starting training...'
|
||||||
|
python3.5 -m cs_gan.main_ode
|
||||||
+4
-4
@@ -89,7 +89,7 @@ def cross_entropy_loss(logits, expected):
|
|||||||
def optimise_and_sample(init_z, module, data, is_training):
|
def optimise_and_sample(init_z, module, data, is_training):
|
||||||
"""Optimising generator latent variables and sample."""
|
"""Optimising generator latent variables and sample."""
|
||||||
|
|
||||||
if module.num_z_iters == 0:
|
if module.num_z_iters is None or module.num_z_iters == 0:
|
||||||
z_final = init_z
|
z_final = init_z
|
||||||
else:
|
else:
|
||||||
init_loop_vars = (0, _project_z(init_z, module.z_project_method))
|
init_loop_vars = (0, _project_z(init_z, module.z_project_method))
|
||||||
@@ -215,14 +215,14 @@ def get_generator(dataset):
|
|||||||
if dataset == 'mnist':
|
if dataset == 'mnist':
|
||||||
return nets.MLPGeneratorNet()
|
return nets.MLPGeneratorNet()
|
||||||
if dataset == 'cifar':
|
if dataset == 'cifar':
|
||||||
return nets.SNGenNet()
|
return nets.ConvGenNet()
|
||||||
|
|
||||||
|
|
||||||
def get_metric_net(dataset, num_outputs=2):
|
def get_metric_net(dataset, num_outputs=2, use_sn=True):
|
||||||
if dataset == 'mnist':
|
if dataset == 'mnist':
|
||||||
return nets.MLPMetricNet(num_outputs)
|
return nets.MLPMetricNet(num_outputs)
|
||||||
if dataset == 'cifar':
|
if dataset == 'cifar':
|
||||||
return nets.SNMetricNet(num_outputs)
|
return nets.ConvMetricNet(num_outputs, use_sn)
|
||||||
|
|
||||||
|
|
||||||
def make_prior(num_latents):
|
def make_prior(num_latents):
|
||||||
|
|||||||
+8
-4
@@ -1,12 +1,16 @@
|
|||||||
# ODE-GAN: Training GANs by Solving Ordinary Differential Equations
|
# ODE-GAN: Training GANs by Solving Ordinary Differential Equations
|
||||||
|
|
||||||
|
Mixture of Gaussian Example (Colab):
|
||||||
|
|
||||||
[](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/ode_gan/odegan_mog16.ipynb)
|
[](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/ode_gan/odegan_mog16.ipynb)
|
||||||
|
|
||||||
This package contains a [Colaboratory notebook](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/ode_gan/odegan_mog16.ipynb)
|
Cifar 10 Example (Tensorflow):
|
||||||
that demos the algorithm [ODE-GAN](https://arxiv.org/abs/2010.15040) for
|
Launch Training from
|
||||||
mixture of Gaussians with 16 modes.
|
https://github/deepmind/deepmind_research/tree/master/cs_gan/run_ode.sh
|
||||||
|
|
||||||
If you make use of this Colab in your work, please cite:
|
This package demos the algorithm [ODE-GAN](https://arxiv.org/abs/2010.15040).
|
||||||
|
|
||||||
|
If you make use of any code in your work, please cite:
|
||||||
|
|
||||||
```
|
```
|
||||||
@article{qin2020training,
|
@article{qin2020training,
|
||||||
|
|||||||
Reference in New Issue
Block a user