From fa33baaccaf40ed8e1d54d27577ddb92ab859057 Mon Sep 17 00:00:00 2001 From: Mihaela Rosca Date: Mon, 30 Sep 2019 11:52:41 +0100 Subject: [PATCH] Public release of Deep Compressed Sensing project. PiperOrigin-RevId: 272403580 --- README.md | 1 + cs_gan/README.md | 19 ++++ cs_gan/cs.py | 154 ++++++++++++++++++++++++++ cs_gan/file_utils.py | 82 ++++++++++++++ cs_gan/gan.py | 158 ++++++++++++++++++++++++++ cs_gan/image_metrics.py | 166 +++++++++++++++++++++++++++ cs_gan/main.py | 195 ++++++++++++++++++++++++++++++++ cs_gan/main_cs.py | 141 +++++++++++++++++++++++ cs_gan/nets.py | 106 ++++++++++++++++++ cs_gan/requirements.txt | 8 ++ cs_gan/run.sh | 22 ++++ cs_gan/tests/gan_test.py | 57 ++++++++++ cs_gan/utils.py | 234 +++++++++++++++++++++++++++++++++++++++ 13 files changed, 1343 insertions(+) create mode 100644 cs_gan/README.md create mode 100644 cs_gan/cs.py create mode 100644 cs_gan/file_utils.py create mode 100644 cs_gan/gan.py create mode 100644 cs_gan/image_metrics.py create mode 100644 cs_gan/main.py create mode 100644 cs_gan/main_cs.py create mode 100644 cs_gan/nets.py create mode 100644 cs_gan/requirements.txt create mode 100755 cs_gan/run.sh create mode 100644 cs_gan/tests/gan_test.py create mode 100644 cs_gan/utils.py diff --git a/README.md b/README.md index 673885e..fdd70fc 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ https://deepmind.com/research/publications/ ## Projects +* [Deep Compressed Sensing](cs_gan), ICML 2019 * [Side Effects Penalties](side_effects_penalties) * [PrediNet Architecture and Relations Game Datasets](PrediNet) * [Unsupervised Adversarial Training](unsupervised_adversarial_training) diff --git a/cs_gan/README.md b/cs_gan/README.md new file mode 100644 index 0000000..384c5d8 --- /dev/null +++ b/cs_gan/README.md @@ -0,0 +1,19 @@ +# Deep Compressed Sensing + +This is the example code for the following ICML 2019 paper. If you use the code +here please cite this paper. + +> Yan Wu, Mihaela Rosca, Timothy Lillicrap + *Deep Compressed Sensing*. ICML 2019. [\[arXiv\]](https://arxiv.org/abs/1905.06723). + + +## Running the code + +The code contains: + + * the implementation of the compressed sensing algorithm (`cs.py`). + * the implementation of the GAN algorithm (`gan.py`). + * a main file (`main_cs.py`) to reproduce the Compressed Sensing results in + the paper. + * a main file (`main_gan.py`) to reproduce the GAN results in the paper + (the improvement over the SN-GAN baseline via latent optimization). diff --git a/cs_gan/cs.py b/cs_gan/cs.py new file mode 100644 index 0000000..ebe171e --- /dev/null +++ b/cs_gan/cs.py @@ -0,0 +1,154 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GAN modules.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import sonnet as snt +import tensorflow as tf + +from cs_gan import utils + + +class CS(object): + """Compressed Sensing Module.""" + + def __init__(self, metric_net, generator, + num_z_iters, z_step_size, z_project_method): + """Constructs the module. + + Args: + metric_net: the measurement network. + generator: The generator network. A sonnet module. For examples, see + `nets.py`. + num_z_iters: an integer, the number of latent optimisation steps. + z_step_size: an integer, latent optimisation step size. + z_project_method: the method for projecting latent after optimisation, + a string from {'norm', 'clip'}. + """ + + self._measure = metric_net + self.generator = generator + self.num_z_iters = num_z_iters + self.z_project_method = z_project_method + self._log_step_size_module = snt.TrainableVariable( + [], + initializers={'w': tf.constant_initializer(math.log(z_step_size))}) + self.z_step_size = tf.exp(self._log_step_size_module()) + + def connect(self, data, generator_inputs): + """Connects the components and returns the losses, outputs and debug ops. + + Args: + data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the + rank + of this tensor, but it has to be compatible with the shapes expected + by the discriminator. + generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not + have to have the same batch size as the `data` tensor. There are not + constraints on the rank of this tensor, but it has to be compatible + with the shapes the generator network supports as inputs. + + Returns: + An `ModelOutputs` instance. + """ + + samples, optimised_z = utils.optimise_and_sample( + generator_inputs, self, data, is_training=True) + optimisation_cost = utils.get_optimisation_cost(generator_inputs, + optimised_z) + debug_ops = {} + + initial_samples = self.generator(generator_inputs, is_training=True) + generator_loss = tf.reduce_mean(self.gen_loss_fn(data, samples)) + # compute the RIP loss + # (\sqrt{F(x_1 - x_2)^2} - \sqrt{(x_1 - x_2)^2})^2 + # as a triplet loss for 3 pairs of images. + + r1 = self._get_rip_loss(samples, initial_samples) + r2 = self._get_rip_loss(samples, data) + r3 = self._get_rip_loss(initial_samples, data) + rip_loss = tf.reduce_mean((r1 + r2 + r3) / 3.0) + total_loss = generator_loss + rip_loss + optimization_components = self._build_optimization_components( + generator_loss=total_loss) + debug_ops['rip_loss'] = rip_loss + debug_ops['recons_loss'] = tf.reduce_mean( + tf.norm(snt.BatchFlatten()(samples) + - snt.BatchFlatten()(data), axis=-1)) + + debug_ops['z_step_size'] = self.z_step_size + debug_ops['opt_cost'] = optimisation_cost + debug_ops['gen_loss'] = generator_loss + + return utils.ModelOutputs( + optimization_components, debug_ops) + + def _get_rip_loss(self, img1, img2): + r"""Compute the RIP loss from two images. + + The RIP loss: (\sqrt{F(x_1 - x_2)^2} - \sqrt{(x_1 - x_2)^2})^2 + + Args: + img1: an image (x_1), 4D tensor of shape [batch_size, W, H, C]. + img2: an other image (x_2), 4D tensor of shape [batch_size, W, H, C]. + """ + + m1 = self._measure(img1) + m2 = self._measure(img2) + + img_diff_norm = tf.norm(snt.BatchFlatten()(img1) + - snt.BatchFlatten()(img2), axis=-1) + m_diff_norm = tf.norm(m1 - m2, axis=-1) + + return tf.square(img_diff_norm - m_diff_norm) + + def _get_measurement_error(self, target_img, sample_img): + """Compute the measurement error of sample images given the targets.""" + + m_targets = self._measure(target_img) + m_samples = self._measure(sample_img) + + return tf.reduce_sum(tf.square(m_targets - m_samples), -1) + + def gen_loss_fn(self, data, samples): + """Generator loss as latent optimisation's error function.""" + return self._get_measurement_error(data, samples) + + def _build_optimization_components( + self, generator_loss=None, discriminator_loss=None): + """Create the optimization components for this module.""" + + metric_vars = _get_and_check_variables(self._measure) + generator_vars = _get_and_check_variables(self.generator) + step_vars = _get_and_check_variables(self._log_step_size_module) + + assert discriminator_loss is None + optimization_components = utils.OptimizationComponent( + generator_loss, generator_vars + metric_vars + step_vars) + return optimization_components + + +def _get_and_check_variables(module): + module_variables = module.get_all_variables() + if not module_variables: + raise ValueError( + 'Module {} has no variables! Variables needed for training.'.format( + module.module_name)) + + # TensorFlow optimizers require lists to be passed in. + return list(module_variables) diff --git a/cs_gan/file_utils.py b/cs_gan/file_utils.py new file mode 100644 index 0000000..9890b53 --- /dev/null +++ b/cs_gan/file_utils.py @@ -0,0 +1,82 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""File utilities.""" + +import math +import os +import numpy as np +from PIL import Image + + +class FileExporter(object): + """File exporter utilities.""" + + def __init__(self, path, grid_height=None, zoom=1): + """Constructor. + + Arguments: + path: The directory to save data to. + grid_height: How many data elements tall to make the grid, if appropriate. + The width will be chosen based on height. If None, automatically + determined. + zoom: How much to zoom in each data element by, if appropriate. + """ + if not os.path.exists(path): + os.makedirs(path) + + self._path = path + self._zoom = zoom + self._grid_height = grid_height + + def _reshape(self, data): + """Reshape given data into image format.""" + batch_size, height, width, n_channels = data.shape + if self._grid_height: + grid_height = self._grid_height + else: + grid_height = int(math.floor(math.sqrt(batch_size))) + + grid_width = int(math.ceil(batch_size/grid_height)) + + if n_channels == 1: + data = np.tile(data, (1, 1, 1, 3)) + n_channels = 3 + + if n_channels != 3: + raise ValueError('Image batch must have either 1 or 3 channels, but ' + 'was {}'.format(n_channels)) + + shape = (height * grid_height, width * grid_width, n_channels) + buf = np.full(shape, 255, dtype=np.uint8) + multiplier = 1 if data.dtype in (np.int32, np.int64) else 255 + + for k in range(batch_size): + i = k // grid_width + j = k % grid_width + arr = data[k] + x, y = i * height, j * width + buf[x:x + height, y:y + width, :] = np.clip( + multiplier * arr, 0, 255).astype(np.uint8) + + if self._zoom > 1: + buf = buf.repeat(self._zoom, axis=0).repeat(self._zoom, axis=1) + return buf + + def save(self, data, name): + data = self._reshape(data) + relative_name = '{}_last.png'.format(name) + target_file = os.path.join(self._path, relative_name) + + img = Image.fromarray(data) + img.save(target_file, format='PNG') diff --git a/cs_gan/gan.py b/cs_gan/gan.py new file mode 100644 index 0000000..6f0ec97 --- /dev/null +++ b/cs_gan/gan.py @@ -0,0 +1,158 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GAN modules.""" +import collections +import math + +import sonnet as snt +import tensorflow as tf + +from cs_gan import utils + + +class GAN(object): + """Standard generative adversarial network setup. + + The aim of the generator is to generate samples which fool a discriminator. + Does not make any assumptions about the discriminator and generator loss + functions. + + Trained module components: + + * discriminator + * generator + + + For the standard GAN algorithm, generator_inputs is a vector of noise (either + Gaussian or uniform). + """ + + def __init__(self, discriminator, generator, num_z_iters, z_step_size, + z_project_method, optimisation_cost_weight): + """Constructs the module. + + Args: + discriminator: The discriminator network. A sonnet module. See `nets.py`. + generator: The generator network. A sonnet module. For examples, see + `nets.py`. + num_z_iters: an integer, the number of latent optimisation steps. + z_step_size: an integer, latent optimisation step size. + z_project_method: the method for projecting latent after optimisation, + a string from {'norm', 'clip'}. + optimisation_cost_weight: a float, how much to penalise the distance of z + moved by latent optimisation. + """ + self._discriminator = discriminator + self.generator = generator + self.num_z_iters = num_z_iters + self.z_project_method = z_project_method + self._log_step_size_module = snt.TrainableVariable( + [], + initializers={'w': tf.constant_initializer(math.log(z_step_size))}) + self.z_step_size = tf.exp(self._log_step_size_module()) + self._optimisation_cost_weight = optimisation_cost_weight + + def connect(self, data, generator_inputs): + """Connects the components and returns the losses, outputs and debug ops. + + Args: + data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the + rank + of this tensor, but it has to be compatible with the shapes expected + by the discriminator. + generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not + have to have the same batch size as the `data` tensor. There are not + constraints on the rank of this tensor, but it has to be compatible + with the shapes the generator network supports as inputs. + + Returns: + An `ModelOutputs` instance. + """ + samples, optimised_z = utils.optimise_and_sample( + generator_inputs, self, data, is_training=True) + optimisation_cost = utils.get_optimisation_cost(generator_inputs, + optimised_z) + + # Pass in the labels to the discriminator in case we are using a + # discriminator which makes use of labels. The labels can be None. + disc_data_logits = self._discriminator(data) + disc_sample_logits = self._discriminator(samples) + + disc_data_loss = utils.cross_entropy_loss( + disc_data_logits, + tf.ones(tf.shape(disc_data_logits[:, 0]), dtype=tf.int32)) + + disc_sample_loss = utils.cross_entropy_loss( + disc_sample_logits, + tf.zeros(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32)) + + disc_loss = disc_data_loss + disc_sample_loss + + generator_loss = utils.cross_entropy_loss( + disc_sample_logits, + tf.ones(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32)) + + optimization_components = self._build_optimization_components( + discriminator_loss=disc_loss, generator_loss=generator_loss, + optimisation_cost=optimisation_cost) + + debug_ops = {} + debug_ops['z_step_size'] = self.z_step_size + debug_ops['disc_data_loss'] = disc_data_loss + debug_ops['disc_sample_loss'] = disc_sample_loss + debug_ops['disc_loss'] = disc_loss + debug_ops['gen_loss'] = generator_loss + debug_ops['opt_cost'] = optimisation_cost + + return utils.ModelOutputs( + optimization_components, debug_ops) + + def gen_loss_fn(self, data, samples): + """Generator loss as latent optimisation's error function.""" + del data + disc_sample_logits = self._discriminator(samples) + generator_loss = utils.cross_entropy_loss( + disc_sample_logits, + tf.ones(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32)) + return generator_loss + + def _build_optimization_components( + self, generator_loss=None, discriminator_loss=None, + optimisation_cost=None): + """Create the optimization components for this module.""" + + discriminator_vars = _get_and_check_variables(self._discriminator) + generator_vars = _get_and_check_variables(self.generator) + step_vars = _get_and_check_variables(self._log_step_size_module) + + optimization_components = collections.OrderedDict() + optimization_components['disc'] = utils.OptimizationComponent( + discriminator_loss, discriminator_vars) + optimization_components['gen'] = utils.OptimizationComponent( + generator_loss + + self._optimisation_cost_weight * optimisation_cost, + generator_vars + step_vars) + return optimization_components + + +def _get_and_check_variables(module): + module_variables = module.get_all_variables() + if not module_variables: + raise ValueError( + 'Module {} has no variables! Variables needed for training.'.format( + module.module_name)) + + # TensorFlow optimizers require lists to be passed in. + return list(module_variables) + diff --git a/cs_gan/image_metrics.py b/cs_gan/image_metrics.py new file mode 100644 index 0000000..f174869 --- /dev/null +++ b/cs_gan/image_metrics.py @@ -0,0 +1,166 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compute image metrics: IS, FID.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import tensorflow_gan as tfgan + + +def inception_preprocess_fn(images): + return tfgan.eval.preprocess_image(images * 255) + + +def compute_inception_score( + images, max_classifier_batch_size=16, assert_data_ranges=True): + """Computes the classifier score, using the given model. + + Note that this metric is highly sensitive to the number of images used + to compute it: more is better. A standard number is 50K images. + + For more details see: https://arxiv.org/abs/1606.03498. + + Args: + images: A 4D tensor. + max_classifier_batch_size: An integer. The maximum batch size size to pass + through the classifier used to compute the metric. + assert_data_ranges: Whether or not to assert the input is in a valid + data range. + + Returns: + A scalar `tf.Tensor` with the Inception score. + """ + num_images = images.shape.as_list()[0] + + def _choose_batch_size(num_images, max_batch_size): + for size in range(max_batch_size, 0, -1): + if num_images % size == 0: + return size + + batch_size = _choose_batch_size(num_images, max_classifier_batch_size) + tf.logging.debug('Using batch_size=%s for score classifier', batch_size) + num_batches = num_images // batch_size + + if assert_data_ranges: + images_batch = images[0:batch_size] + min_data_range_check = tf.assert_greater_equal(images_batch, 0.0) + max_data_range_check = tf.assert_less_equal(images_batch, 1.0) + control_deps = [min_data_range_check, max_data_range_check] + else: + control_deps = [] + + # Do the preprocessing in the fn function to avoid having to keep all the + # resized data in memory. + def classifier_fn(images): + return tfgan.eval.run_inception(inception_preprocess_fn(images)) + + with tf.control_dependencies(control_deps): + return tfgan.eval.classifier_score( + images, classifier_fn=classifier_fn, num_batches=num_batches) + + +def compute_fid( + real_images, other, max_classifier_batch_size=16, + assert_data_ranges=True): + """Computes the classifier score, using the given model. + + Note that this metric is highly sensitive to the number of images used + to compute it: more is better. A standard number is 50K images. + + For more details see: https://arxiv.org/abs/1606.03498. + + Args: + real_images: A 4D tensor. Samples from the true data distribution. + other: A 4D tensor. Data for which to compute the FID. + max_classifier_batch_size: An integer. The maximum batch size size to pass + through the classifier used to compute the metric. + assert_data_ranges: Whether or not to assert the input is in a valid + data range. + + Returns: + A scalar `tf.Tensor` with the Inception score. + """ + num_images = real_images.shape.as_list()[0] + + def _choose_batch_size(num_images, max_batch_size): + for size in range(max_batch_size, 0, -1): + if num_images % size == 0: + return size + + batch_size = _choose_batch_size(num_images, max_classifier_batch_size) + tf.logging.debug('Using batch_size=%s for score classifier', batch_size) + num_batches = num_images // batch_size + + if assert_data_ranges: + images_batch = tf.concat( + [real_images[0:batch_size], other[0:batch_size]], axis=0) + min_data_range_check = tf.assert_greater_equal(images_batch, 0.0) + max_data_range_check = tf.assert_less_equal(images_batch, 1.0) + control_deps = [min_data_range_check, max_data_range_check] + else: + control_deps = [] + + # Do the preprocessing in the fn function to avoid having to keep all the + # resized data in memory. + def classifier_fn(images): + return tfgan.eval.run_inception( + inception_preprocess_fn(images), + output_tensor=tfgan.eval.INCEPTION_FINAL_POOL) + + with tf.control_dependencies(control_deps): + return tfgan.eval.frechet_classifier_distance( + real_images, other, + classifier_fn=classifier_fn, num_batches=num_batches) + + +def generate_big_batch(generator, generator_inputs, max_num_samples=100): + """Generate samples when the number of samples is too big to do one pass.""" + num_samples = generator_inputs.shape.as_list()[0] + max_num_samples = min(max_num_samples, num_samples) + batched_shape = [num_samples // max_num_samples, max_num_samples] + batched_shape += generator_inputs.shape.as_list()[1:] + batched_generator_inputs = tf.reshape(generator_inputs, batched_shape) + + # Create the samples by sequentially doing forward passes through the + # generator. This ensures we do not run out of memory. + output_samples = tf.map_fn( + fn=generator, + elems=batched_generator_inputs, + parallel_iterations=1, + back_prop=False, + swap_memory=True) + + samples = tf.reshape(output_samples, + [num_samples] + output_samples.shape.as_list()[2:]) + return samples + + +def get_image_metrics(real_images, other): + return { + 'inception_score': compute_inception_score(other), + 'fid': compute_fid(real_images, other)} + + +def get_image_metrics_for_samples( + real_images, generator, prior, data_processor, num_eval_samples): + generator_inputs = prior.sample(num_eval_samples) + samples = generate_big_batch(generator, generator_inputs) + # Ensure the samples are in [0, 1]. + samples = data_processor.postprocess(samples) + + return { + 'inception_score': compute_inception_score(samples), + 'fid': compute_fid(real_images, samples)} diff --git a/cs_gan/main.py b/cs_gan/main.py new file mode 100644 index 0000000..c4ba3be --- /dev/null +++ b/cs_gan/main.py @@ -0,0 +1,195 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Training script.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import app +from absl import flags +from absl import logging +import tensorflow as tf + +from cs_gan import file_utils +from cs_gan import gan +from cs_gan import image_metrics +from cs_gan import utils + +flags.DEFINE_integer( + 'num_training_iterations', 200000, + 'Number of training iterations.') +flags.DEFINE_integer( + 'batch_size', 64, 'Training batch size.') +flags.DEFINE_integer( + 'num_latents', 128, 'The number of latents') +flags.DEFINE_integer( + 'summary_every_step', 1000, + 'The interval at which to log debug ops.') +flags.DEFINE_integer( + 'image_metrics_every_step', 2000, + 'The interval at which to log (expensive) image metrics.') +flags.DEFINE_integer( + 'export_every', 10, + 'The interval at which to export samples.') +flags.DEFINE_integer( + 'num_eval_samples', 10000, + 'The number of samples used to evaluate FID/IS') +flags.DEFINE_string( + 'dataset', 'cifar', 'The dataset used for learning (cifar|mnist.') +flags.DEFINE_float( + 'optimisation_cost_weight', 3., 'weight for latent optimisation cost.') +flags.DEFINE_integer( + 'num_z_iters', 3, 'The number of latent optimisation steps.' + 'It falls back to vanilla GAN when num_z_iters is set to 0.') +flags.DEFINE_float( + 'z_step_size', 0.01, 'Step size for latent optimisation.') +flags.DEFINE_string( + 'z_project_method', 'norm', 'The method to project z.') +flags.DEFINE_string( + 'output_dir', '/tmp/cs_gan/gan', 'Location where to save output files.') +flags.DEFINE_float('disc_lr', 2e-4, 'Discriminator Learning rate.') +flags.DEFINE_float('gen_lr', 2e-4, 'Generator Learning rate.') +flags.DEFINE_bool( + 'run_real_data_metrics', False, + 'Whether or not to run image metrics on real data.') +flags.DEFINE_bool( + 'run_sample_metrics', True, + 'Whether or not to run image metrics on samples.') + + +FLAGS = flags.FLAGS + +# Log info level (for Hooks). +tf.logging.set_verbosity(tf.logging.INFO) + + +def main(argv): + del argv + + utils.make_output_dir(FLAGS.output_dir) + data_processor = utils.DataProcessor() + images = utils.get_train_dataset(data_processor, FLAGS.dataset, + FLAGS.batch_size) + + logging.info('Generator learning rate: %d', FLAGS.gen_lr) + logging.info('Discriminator learning rate: %d', FLAGS.disc_lr) + + # Construct optimizers. + disc_optimizer = tf.train.AdamOptimizer(FLAGS.disc_lr, beta1=0.5, beta2=0.999) + gen_optimizer = tf.train.AdamOptimizer(FLAGS.gen_lr, beta1=0.5, beta2=0.999) + + # Create the networks and models. + generator = utils.get_generator(FLAGS.dataset) + metric_net = utils.get_metric_net(FLAGS.dataset) + model = gan.GAN(metric_net, generator, + FLAGS.num_z_iters, FLAGS.z_step_size, + FLAGS.z_project_method, FLAGS.optimisation_cost_weight) + prior = utils.make_prior(FLAGS.num_latents) + generator_inputs = prior.sample(FLAGS.batch_size) + + model_output = model.connect(images, generator_inputs) + optimization_components = model_output.optimization_components + debug_ops = model_output.debug_ops + samples = generator(generator_inputs, is_training=False) + + global_step = tf.train.get_or_create_global_step() + # We pass the global step both to the disc and generator update ops. + # This means that the global step will not be the same as the number of + # iterations, but ensures that hooks which rely on global step work correctly. + disc_update_op = disc_optimizer.minimize( + optimization_components['disc'].loss, + var_list=optimization_components['disc'].vars, + global_step=global_step) + + gen_update_op = gen_optimizer.minimize( + optimization_components['gen'].loss, + var_list=optimization_components['gen'].vars, + global_step=global_step) + + # Get data needed to compute FID. We also compute metrics on + # real data as a sanity check and as a reference point. + eval_real_data = utils.get_real_data_for_eval(FLAGS.num_eval_samples, + FLAGS.dataset, + split='train') + + def sample_fn(x): + return utils.optimise_and_sample(x, module=model, + data=None, is_training=False)[0] + + if FLAGS.run_sample_metrics: + sample_metrics = image_metrics.get_image_metrics_for_samples( + eval_real_data, sample_fn, + prior, data_processor, + num_eval_samples=FLAGS.num_eval_samples) + else: + sample_metrics = {} + + if FLAGS.run_real_data_metrics: + data_metrics = image_metrics.get_image_metrics( + eval_real_data, eval_real_data) + else: + data_metrics = {} + + sample_exporter = file_utils.FileExporter( + os.path.join(FLAGS.output_dir, 'samples')) + + # Hooks. + debug_ops['it'] = global_step + # Abort training on Nans. + nan_disc_hook = tf.train.NanTensorHook(optimization_components['disc'].loss) + nan_gen_hook = tf.train.NanTensorHook(optimization_components['gen'].loss) + # Step counter. + step_conter_hook = tf.train.StepCounterHook() + + checkpoint_saver_hook = tf.train.CheckpointSaverHook( + checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60) + + loss_summary_saver_hook = tf.train.SummarySaverHook( + save_steps=FLAGS.summary_every_step, + output_dir=os.path.join(FLAGS.output_dir, 'summaries'), + summary_op=utils.get_summaries(debug_ops)) + + metrics_summary_saver_hook = tf.train.SummarySaverHook( + save_steps=FLAGS.image_metrics_every_step, + output_dir=os.path.join(FLAGS.output_dir, 'summaries'), + summary_op=utils.get_summaries(sample_metrics)) + + hooks = [checkpoint_saver_hook, metrics_summary_saver_hook, + nan_disc_hook, nan_gen_hook, step_conter_hook, + loss_summary_saver_hook] + + # Start training. + with tf.train.MonitoredSession(hooks=hooks) as sess: + logging.info('starting training') + + for key, value in sess.run(data_metrics).items(): + logging.info('%s: %d', key, value) + + for i in range(FLAGS.num_training_iterations): + sess.run(disc_update_op) + sess.run(gen_update_op) + + if i % FLAGS.export_every == 0: + samples_np, data_np = sess.run([samples, images]) + # Create an object which gets data and does the processing. + data_np = data_processor.postprocess(data_np) + samples_np = data_processor.postprocess(samples_np) + sample_exporter.save(samples_np, 'samples') + sample_exporter.save(data_np, 'data') + + +if __name__ == '__main__': + app.run(main) diff --git a/cs_gan/main_cs.py b/cs_gan/main_cs.py new file mode 100644 index 0000000..26b459d --- /dev/null +++ b/cs_gan/main_cs.py @@ -0,0 +1,141 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Training script.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import app +from absl import flags +from absl import logging + +import tensorflow as tf +import tensorflow_probability as tfp + +from cs_gan import cs +from cs_gan import file_utils +from cs_gan import utils + +tfd = tfp.distributions + +flags.DEFINE_string( + 'mode', 'recons', 'Model mode.') +flags.DEFINE_integer( + 'num_training_iterations', 10000000, + 'Number of training iterations.') +flags.DEFINE_integer( + 'batch_size', 64, 'Training batch size.') +flags.DEFINE_integer( + 'num_measurements', 25, 'The number of measurements') +flags.DEFINE_integer( + 'num_latents', 100, 'The number of latents') +flags.DEFINE_integer( + 'num_z_iters', 3, 'The number of latent optimisation steps.') +flags.DEFINE_float( + 'z_step_size', 0.01, 'Step size for latent optimisation.') +flags.DEFINE_string( + 'z_project_method', 'norm', 'The method to project z.') +flags.DEFINE_integer( + 'summary_every_step', 1000, + 'The interval at which to log debug ops.') +flags.DEFINE_integer( + 'export_every', 10, + 'The interval at which to export samples.') +flags.DEFINE_string( + 'dataset', 'mnist', 'The dataset used for learning (cifar|mnist.') +flags.DEFINE_float('learning_rate', 1e-4, 'Learning rate.') +flags.DEFINE_string( + 'output_dir', '/tmp/cs_gan/cs', 'Location where to save output files.') + + +FLAGS = flags.FLAGS + +# Log info level (for Hooks). +tf.logging.set_verbosity(tf.logging.INFO) + + +def main(argv): + del argv + + utils.make_output_dir(FLAGS.output_dir) + data_processor = utils.DataProcessor() + images = utils.get_train_dataset(data_processor, FLAGS.dataset, + FLAGS.batch_size) + + logging.info('Learning rate: %d', FLAGS.learning_rate) + + # Construct optimizers. + optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) + + # Create the networks and models. + generator = utils.get_generator(FLAGS.dataset) + metric_net = utils.get_metric_net(FLAGS.dataset, FLAGS.num_measurements) + model = cs.CS(metric_net, generator, + FLAGS.num_z_iters, FLAGS.z_step_size, FLAGS.z_project_method) + prior = utils.make_prior(FLAGS.num_latents) + generator_inputs = prior.sample(FLAGS.batch_size) + + model_output = model.connect(images, generator_inputs) + optimization_components = model_output.optimization_components + debug_ops = model_output.debug_ops + reconstructions, _ = utils.optimise_and_sample( + generator_inputs, model, images, is_training=False) + + global_step = tf.train.get_or_create_global_step() + update_op = optimizer.minimize( + optimization_components.loss, + var_list=optimization_components.vars, + global_step=global_step) + + sample_exporter = file_utils.FileExporter( + os.path.join(FLAGS.output_dir, 'reconstructions')) + + # Hooks. + debug_ops['it'] = global_step + # Abort training on Nans. + nan_hook = tf.train.NanTensorHook(optimization_components.loss) + # Step counter. + step_conter_hook = tf.train.StepCounterHook() + + checkpoint_saver_hook = tf.train.CheckpointSaverHook( + checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60) + + loss_summary_saver_hook = tf.train.SummarySaverHook( + save_steps=FLAGS.summary_every_step, + output_dir=os.path.join(FLAGS.output_dir, 'summaries'), + summary_op=utils.get_summaries(debug_ops)) + + hooks = [checkpoint_saver_hook, nan_hook, step_conter_hook, + loss_summary_saver_hook] + + # Start training. + with tf.train.MonitoredSession(hooks=hooks) as sess: + logging.info('starting training') + + for i in range(FLAGS.num_training_iterations): + sess.run(update_op) + + if i % FLAGS.export_every == 0: + reconstructions_np, data_np = sess.run([reconstructions, images]) + # Create an object which gets data and does the processing. + data_np = data_processor.postprocess(data_np) + reconstructions_np = data_processor.postprocess(reconstructions_np) + sample_exporter.save(reconstructions_np, 'reconstructions') + sample_exporter.save(data_np, 'data') + + +if __name__ == '__main__': + app.run(main) diff --git a/cs_gan/nets.py b/cs_gan/nets.py new file mode 100644 index 0000000..b6ccb41 --- /dev/null +++ b/cs_gan/nets.py @@ -0,0 +1,106 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Network utilities.""" + +import functools +import re +import numpy as np +import sonnet as snt +import tensorflow as tf +import tensorflow_gan as tfgan + + +def _sn_custom_getter(): + def name_filter(name): + match = re.match(r'.*w(_.*)?$', name) + return match is not None + + return tfgan.features.spectral_normalization_custom_getter( + name_filter=name_filter) + + +class SNGenNet(snt.AbstractModule): + """As in the SN paper.""" + + def __init__(self, name='conv_gen'): + super(SNGenNet, self).__init__(name=name) + + def _build(self, inputs, is_training): + batch_size = inputs.get_shape().as_list()[0] + first_shape = [4, 4, 512] + norm_ctor = snt.BatchNormV2 + norm_ctor_config = {'scale': True} + up_tensor = snt.Linear(np.prod(first_shape))(inputs) + first_tensor = tf.reshape(up_tensor, shape=[batch_size] + first_shape) + + net = snt.nets.ConvNet2DTranspose( + output_channels=[256, 128, 64, 3], + output_shapes=[(8, 8), (16, 16), (32, 32), (32, 32)], + kernel_shapes=[(4, 4), (4, 4), (4, 4), (3, 3)], + strides=[2, 2, 2, 1], + normalization_ctor=norm_ctor, + normalization_kwargs=norm_ctor_config, + normalize_final=False, + paddings=[snt.SAME], activate_final=False, activation=tf.nn.relu) + output = net(first_tensor, is_training=is_training) + return tf.nn.tanh(output) + + +class SNMetricNet(snt.AbstractModule): + """Spectral normalization discriminator (metric) architecture.""" + + def __init__(self, num_outputs=2, name='sn_metric'): + super(SNMetricNet, self).__init__(name=name) + self._num_outputs = num_outputs + + def _build(self, inputs): + with tf.variable_scope('', custom_getter=_sn_custom_getter()): + net = snt.nets.ConvNet2D( + output_channels=[64, 64, 128, 128, 256, 256, 512], + kernel_shapes=[ + (3, 3), (4, 4), (3, 3), (4, 4), (3, 3), (4, 4), (3, 3)], + strides=[1, 2, 1, 2, 1, 2, 1], + paddings=[snt.SAME], activate_final=True, + activation=functools.partial(tf.nn.leaky_relu, alpha=0.1)) + linear = snt.Linear(self._num_outputs) + output = linear(snt.BatchFlatten()(net(inputs))) + return output + + +class MLPGeneratorNet(snt.AbstractModule): + """MNIST generator net.""" + + def __init__(self, name='mlp_generator'): + super(MLPGeneratorNet, self).__init__(name=name) + + def _build(self, inputs, is_training=True): + del is_training + net = snt.nets.MLP([500, 500, 784], activation=tf.nn.leaky_relu) + out = net(inputs) + out = tf.nn.tanh(out) + return snt.BatchReshape([28, 28, 1])(out) + + +class MLPMetricNet(snt.AbstractModule): + """Same as in Grover and Ermon, ICLR workshop 2017.""" + + def __init__(self, num_outputs=2, name='mlp_metric'): + super(MLPMetricNet, self).__init__(name=name) + self._layer_size = [500, 500, num_outputs] + + def _build(self, inputs): + net = snt.nets.MLP(self._layer_size, + activation=tf.nn.leaky_relu) + output = net(snt.BatchFlatten()(inputs)) + return output diff --git a/cs_gan/requirements.txt b/cs_gan/requirements.txt new file mode 100644 index 0000000..ec1d1a0 --- /dev/null +++ b/cs_gan/requirements.txt @@ -0,0 +1,8 @@ +absl-py==0.7.1 +dm-sonnet==1.34 +numpy==1.16.4 +pillow==4.3.0 +tensorflow==1.14 +tensorflow-probability==0.7.0 +tensorflow-gan==1.0.0.dev0 +tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. diff --git a/cs_gan/run.sh b/cs_gan/run.sh new file mode 100755 index 0000000..fd2997b --- /dev/null +++ b/cs_gan/run.sh @@ -0,0 +1,22 @@ +#!/bin/sh +# Copyright 2019 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +python3 -m venv cs_gan_venv +source cs_gan_venv/bin/activate +pip install -r cs_gan/requirements.txt + +python -m cs_gan.main_cs + +python -m cs_gan.main diff --git a/cs_gan/tests/gan_test.py b/cs_gan/tests/gan_test.py new file mode 100644 index 0000000..11d30df --- /dev/null +++ b/cs_gan/tests/gan_test.py @@ -0,0 +1,57 @@ +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sonnet as snt +import tensorflow as tf + +from cs_gan import gan + + +class DummyGenerator(snt.AbstractModule): + + def __init__(self): + super(DummyGenerator, self).__init__(name='dummy_generator') + + def _build(self, inputs, is_training): + return snt.Linear(10)(inputs) + + +class GanTest(tf.test.TestCase): + + def testConnect(self): + discriminator = snt.Linear(2) + generator = DummyGenerator() + model = gan.GAN( + discriminator, generator, + num_z_iters=0, z_step_size=0.1, + z_project_method='none', optimisation_cost_weight=0.0) + + generator_inputs = tf.ones((16, 3), dtype=tf.float32) + data = tf.ones((16, 10)) + opt_compoments, _ = model.connect(data, generator_inputs) + + self.assertIn('disc', opt_compoments) + self.assertIn('gen', opt_compoments) + + self.assertCountEqual( + opt_compoments['disc'].vars, + discriminator.get_variables()) + self.assertCountEqual( + opt_compoments['gen'].vars, + generator.get_variables() + model._log_step_size_module.get_variables()) + + +if __name__ == '__main__': + tf.test.main() diff --git a/cs_gan/utils.py b/cs_gan/utils.py new file mode 100644 index 0000000..adb5ea4 --- /dev/null +++ b/cs_gan/utils.py @@ -0,0 +1,234 @@ +# python3 + +# Copyright 2019 DeepMind Technologies Limited and Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tools for latent optimisation.""" +import collections +import os + +from absl import logging +import numpy as np +import tensorflow as tf +import tensorflow_probability as tfp + +from cs_gan import nets + +tfd = tfp.distributions + + +class ModelOutputs( + collections.namedtuple('AdversarialModelOutputs', + ['optimization_components', 'debug_ops'])): + """All the information produced by the adversarial module. + + Fields: + + * `optimization_components`: A dictionary. Each entry in this dictionary + corresponds to a module to train using their own optimizer. The keys are + names of the components, and the values are `common.OptimizationComponent` + instances. The keys of this dict can be made keys of the configuration + used by the main train loop, to define the configuration of the + optimization details for each module. + * `debug_ops`: A dictionary, from string to a scalar `tf.Tensor`. Quantities + used for tracking training. + """ + + +class OptimizationComponent( + collections.namedtuple('OptimizationComponent', ['loss', 'vars'])): + """Information needed by the optimizer to train modules. + + Usage: + `optimizer.minimize( + opt_compoment.loss, var_list=opt_component.vars)` + + Fields: + + * `loss`: A `tf.Tensor` the loss of the module. + * `vars`: A list of variables, the ones which will be used to minimize the + loss. + """ + + +def cross_entropy_loss(logits, expected): + """The cross entropy classification loss between logits and expected values. + + The loss proposed by the original GAN paper: https://arxiv.org/abs/1406.2661. + + Args: + logits: a `tf.Tensor`, the model produced logits. + expected: a `tf.Tensor`, the expected output. + + Returns: + A scalar `tf.Tensor`, the average loss obtained on the given inputs. + + Raises: + ValueError: if the logits do not have shape [batch_size, 2]. + """ + + num_logits = logits.get_shape()[1] + if num_logits != 2: + raise ValueError(('Invalid number of logits for cross_entropy_loss! ' + 'cross_entropy_loss supports only 2 output logits!')) + return tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=expected)) + + +def optimise_and_sample(init_z, module, data, is_training): + """Optimising generator latent variables and sample.""" + + if module.num_z_iters == 0: + z_final = init_z + else: + init_loop_vars = (0, _project_z(init_z, module.z_project_method)) + loop_cond = lambda i, _: i < module.num_z_iters + def loop_body(i, z): + loop_samples = module.generator(z, is_training) + gen_loss = module.gen_loss_fn(data, loop_samples) + z_grad = tf.gradients(gen_loss, z)[0] + z -= module.z_step_size * z_grad + z = _project_z(z, module.z_project_method) + return i + 1, z + + # Use the following static loop for debugging + # z = init_z + # for _ in xrange(num_z_iters): + # _, z = loop_body(0, z) + # z_final = z + + _, z_final = tf.while_loop(loop_cond, + loop_body, + init_loop_vars) + + return module.generator(z_final, is_training), z_final + + +def get_optimisation_cost(initial_z, optimised_z): + optimisation_cost = tf.reduce_mean( + tf.reduce_sum((optimised_z - initial_z)**2, -1)) + return optimisation_cost + + +def _project_z(z, project_method='clip'): + """To be used for projected gradient descent over z.""" + if project_method == 'norm': + z_p = tf.nn.l2_normalize(z, axis=-1) + elif project_method == 'clip': + z_p = tf.clip_by_value(z, -1, 1) + else: + raise ValueError('Unknown project_method: {}'.format(project_method)) + return z_p + + +class DataProcessor(object): + + def preprocess(self, x): + return x * 2 - 1 + + def postprocess(self, x): + return (x + 1) / 2. + + +def _get_np_data(data_processor, dataset, split='train'): + """Get the dataset as numpy arrays.""" + index = 0 if split == 'train' else 1 + if dataset == 'mnist': + # Construct the dataset. + x, _ = tf.keras.datasets.mnist.load_data()[index] + # Note: tf dataset is binary so we convert it to float. + x = x.astype(np.float32) + x = x / 255. + x = x.reshape((-1, 28, 28, 1)) + + if dataset == 'cifar': + x, _ = tf.keras.datasets.cifar10.load_data()[index] + x = x.astype(np.float32) + x = x / 255. + + if data_processor: + # Normalize data if a processor is given. + x = data_processor.preprocess(x) + return x + + +def make_output_dir(output_dir): + logging.info('Creating output dir %s', output_dir) + if not tf.gfile.IsDirectory(output_dir): + tf.gfile.MakeDirs(output_dir) + + +def get_ckpt_dir(output_dir): + ckpt_dir = os.path.join(output_dir, 'ckpt') + if not tf.gfile.IsDirectory(ckpt_dir): + tf.gfile.MakeDirs(ckpt_dir) + return ckpt_dir + + +def get_real_data_for_eval(num_eval_samples, dataset, split='valid'): + data = _get_np_data(data_processor=None, dataset=dataset, split=split) + data = data[:num_eval_samples] + return tf.constant(data) + + +def get_summaries(ops): + summaries = [] + for name, op in ops.items(): + # Ensure to log the value ops before writing them in the summary. + # We do this instead of a hook to ensure IS/FID are never computed twice. + print_op = tf.print(name, [op], output_stream=tf.logging.info) + with tf.control_dependencies([print_op]): + summary = tf.summary.scalar(name, op) + summaries.append(summary) + return summaries + + +def get_train_dataset(data_processor, dataset, batch_size): + """Creates the training data tensors.""" + x_train = _get_np_data(data_processor, dataset, split='train') + # Create the TF dataset. + dataset = tf.data.Dataset.from_tensor_slices(x_train) + + # Shuffle and repeat the dataset for training. + # This is required because we want to do multiple passes through the entire + # dataset when training. + dataset = dataset.shuffle(100000).repeat() + + # Batch the data and return the data batch. + one_shot_iterator = dataset.batch(batch_size).make_one_shot_iterator() + data_batch = one_shot_iterator.get_next() + return data_batch + + +def get_generator(dataset): + if dataset == 'mnist': + return nets.MLPGeneratorNet() + if dataset == 'cifar': + return nets.SNGenNet() + + +def get_metric_net(dataset, num_outputs=2): + if dataset == 'mnist': + return nets.MLPMetricNet(num_outputs) + if dataset == 'cifar': + return nets.SNMetricNet(num_outputs) + + +def make_prior(num_latents): + # Zero mean, unit variance prior. + prior_mean = tf.zeros(shape=(num_latents), dtype=tf.float32) + prior_scale = tf.ones(shape=(num_latents), dtype=tf.float32) + + return tfd.Normal(loc=prior_mean, scale=prior_scale) +