mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-23 15:55:20 +08:00
Public release of Deep Compressed Sensing project.
PiperOrigin-RevId: 272403580
This commit is contained in:
committed by
Diego de Las Casas
parent
f119aac804
commit
fa33baacca
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
|
||||
|
||||
## Projects
|
||||
|
||||
* [Deep Compressed Sensing](cs_gan), ICML 2019
|
||||
* [Side Effects Penalties](side_effects_penalties)
|
||||
* [PrediNet Architecture and Relations Game Datasets](PrediNet)
|
||||
* [Unsupervised Adversarial Training](unsupervised_adversarial_training)
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# Deep Compressed Sensing
|
||||
|
||||
This is the example code for the following ICML 2019 paper. If you use the code
|
||||
here please cite this paper.
|
||||
|
||||
> Yan Wu, Mihaela Rosca, Timothy Lillicrap
|
||||
*Deep Compressed Sensing*. ICML 2019. [\[arXiv\]](https://arxiv.org/abs/1905.06723).
|
||||
|
||||
|
||||
## Running the code
|
||||
|
||||
The code contains:
|
||||
|
||||
* the implementation of the compressed sensing algorithm (`cs.py`).
|
||||
* the implementation of the GAN algorithm (`gan.py`).
|
||||
* a main file (`main_cs.py`) to reproduce the Compressed Sensing results in
|
||||
the paper.
|
||||
* a main file (`main_gan.py`) to reproduce the GAN results in the paper
|
||||
(the improvement over the SN-GAN baseline via latent optimization).
|
||||
+154
@@ -0,0 +1,154 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""GAN modules."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
|
||||
from cs_gan import utils
|
||||
|
||||
|
||||
class CS(object):
|
||||
"""Compressed Sensing Module."""
|
||||
|
||||
def __init__(self, metric_net, generator,
|
||||
num_z_iters, z_step_size, z_project_method):
|
||||
"""Constructs the module.
|
||||
|
||||
Args:
|
||||
metric_net: the measurement network.
|
||||
generator: The generator network. A sonnet module. For examples, see
|
||||
`nets.py`.
|
||||
num_z_iters: an integer, the number of latent optimisation steps.
|
||||
z_step_size: an integer, latent optimisation step size.
|
||||
z_project_method: the method for projecting latent after optimisation,
|
||||
a string from {'norm', 'clip'}.
|
||||
"""
|
||||
|
||||
self._measure = metric_net
|
||||
self.generator = generator
|
||||
self.num_z_iters = num_z_iters
|
||||
self.z_project_method = z_project_method
|
||||
self._log_step_size_module = snt.TrainableVariable(
|
||||
[],
|
||||
initializers={'w': tf.constant_initializer(math.log(z_step_size))})
|
||||
self.z_step_size = tf.exp(self._log_step_size_module())
|
||||
|
||||
def connect(self, data, generator_inputs):
|
||||
"""Connects the components and returns the losses, outputs and debug ops.
|
||||
|
||||
Args:
|
||||
data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the
|
||||
rank
|
||||
of this tensor, but it has to be compatible with the shapes expected
|
||||
by the discriminator.
|
||||
generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not
|
||||
have to have the same batch size as the `data` tensor. There are not
|
||||
constraints on the rank of this tensor, but it has to be compatible
|
||||
with the shapes the generator network supports as inputs.
|
||||
|
||||
Returns:
|
||||
An `ModelOutputs` instance.
|
||||
"""
|
||||
|
||||
samples, optimised_z = utils.optimise_and_sample(
|
||||
generator_inputs, self, data, is_training=True)
|
||||
optimisation_cost = utils.get_optimisation_cost(generator_inputs,
|
||||
optimised_z)
|
||||
debug_ops = {}
|
||||
|
||||
initial_samples = self.generator(generator_inputs, is_training=True)
|
||||
generator_loss = tf.reduce_mean(self.gen_loss_fn(data, samples))
|
||||
# compute the RIP loss
|
||||
# (\sqrt{F(x_1 - x_2)^2} - \sqrt{(x_1 - x_2)^2})^2
|
||||
# as a triplet loss for 3 pairs of images.
|
||||
|
||||
r1 = self._get_rip_loss(samples, initial_samples)
|
||||
r2 = self._get_rip_loss(samples, data)
|
||||
r3 = self._get_rip_loss(initial_samples, data)
|
||||
rip_loss = tf.reduce_mean((r1 + r2 + r3) / 3.0)
|
||||
total_loss = generator_loss + rip_loss
|
||||
optimization_components = self._build_optimization_components(
|
||||
generator_loss=total_loss)
|
||||
debug_ops['rip_loss'] = rip_loss
|
||||
debug_ops['recons_loss'] = tf.reduce_mean(
|
||||
tf.norm(snt.BatchFlatten()(samples)
|
||||
- snt.BatchFlatten()(data), axis=-1))
|
||||
|
||||
debug_ops['z_step_size'] = self.z_step_size
|
||||
debug_ops['opt_cost'] = optimisation_cost
|
||||
debug_ops['gen_loss'] = generator_loss
|
||||
|
||||
return utils.ModelOutputs(
|
||||
optimization_components, debug_ops)
|
||||
|
||||
def _get_rip_loss(self, img1, img2):
|
||||
r"""Compute the RIP loss from two images.
|
||||
|
||||
The RIP loss: (\sqrt{F(x_1 - x_2)^2} - \sqrt{(x_1 - x_2)^2})^2
|
||||
|
||||
Args:
|
||||
img1: an image (x_1), 4D tensor of shape [batch_size, W, H, C].
|
||||
img2: an other image (x_2), 4D tensor of shape [batch_size, W, H, C].
|
||||
"""
|
||||
|
||||
m1 = self._measure(img1)
|
||||
m2 = self._measure(img2)
|
||||
|
||||
img_diff_norm = tf.norm(snt.BatchFlatten()(img1)
|
||||
- snt.BatchFlatten()(img2), axis=-1)
|
||||
m_diff_norm = tf.norm(m1 - m2, axis=-1)
|
||||
|
||||
return tf.square(img_diff_norm - m_diff_norm)
|
||||
|
||||
def _get_measurement_error(self, target_img, sample_img):
|
||||
"""Compute the measurement error of sample images given the targets."""
|
||||
|
||||
m_targets = self._measure(target_img)
|
||||
m_samples = self._measure(sample_img)
|
||||
|
||||
return tf.reduce_sum(tf.square(m_targets - m_samples), -1)
|
||||
|
||||
def gen_loss_fn(self, data, samples):
|
||||
"""Generator loss as latent optimisation's error function."""
|
||||
return self._get_measurement_error(data, samples)
|
||||
|
||||
def _build_optimization_components(
|
||||
self, generator_loss=None, discriminator_loss=None):
|
||||
"""Create the optimization components for this module."""
|
||||
|
||||
metric_vars = _get_and_check_variables(self._measure)
|
||||
generator_vars = _get_and_check_variables(self.generator)
|
||||
step_vars = _get_and_check_variables(self._log_step_size_module)
|
||||
|
||||
assert discriminator_loss is None
|
||||
optimization_components = utils.OptimizationComponent(
|
||||
generator_loss, generator_vars + metric_vars + step_vars)
|
||||
return optimization_components
|
||||
|
||||
|
||||
def _get_and_check_variables(module):
|
||||
module_variables = module.get_all_variables()
|
||||
if not module_variables:
|
||||
raise ValueError(
|
||||
'Module {} has no variables! Variables needed for training.'.format(
|
||||
module.module_name))
|
||||
|
||||
# TensorFlow optimizers require lists to be passed in.
|
||||
return list(module_variables)
|
||||
@@ -0,0 +1,82 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""File utilities."""
|
||||
|
||||
import math
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class FileExporter(object):
|
||||
"""File exporter utilities."""
|
||||
|
||||
def __init__(self, path, grid_height=None, zoom=1):
|
||||
"""Constructor.
|
||||
|
||||
Arguments:
|
||||
path: The directory to save data to.
|
||||
grid_height: How many data elements tall to make the grid, if appropriate.
|
||||
The width will be chosen based on height. If None, automatically
|
||||
determined.
|
||||
zoom: How much to zoom in each data element by, if appropriate.
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
self._path = path
|
||||
self._zoom = zoom
|
||||
self._grid_height = grid_height
|
||||
|
||||
def _reshape(self, data):
|
||||
"""Reshape given data into image format."""
|
||||
batch_size, height, width, n_channels = data.shape
|
||||
if self._grid_height:
|
||||
grid_height = self._grid_height
|
||||
else:
|
||||
grid_height = int(math.floor(math.sqrt(batch_size)))
|
||||
|
||||
grid_width = int(math.ceil(batch_size/grid_height))
|
||||
|
||||
if n_channels == 1:
|
||||
data = np.tile(data, (1, 1, 1, 3))
|
||||
n_channels = 3
|
||||
|
||||
if n_channels != 3:
|
||||
raise ValueError('Image batch must have either 1 or 3 channels, but '
|
||||
'was {}'.format(n_channels))
|
||||
|
||||
shape = (height * grid_height, width * grid_width, n_channels)
|
||||
buf = np.full(shape, 255, dtype=np.uint8)
|
||||
multiplier = 1 if data.dtype in (np.int32, np.int64) else 255
|
||||
|
||||
for k in range(batch_size):
|
||||
i = k // grid_width
|
||||
j = k % grid_width
|
||||
arr = data[k]
|
||||
x, y = i * height, j * width
|
||||
buf[x:x + height, y:y + width, :] = np.clip(
|
||||
multiplier * arr, 0, 255).astype(np.uint8)
|
||||
|
||||
if self._zoom > 1:
|
||||
buf = buf.repeat(self._zoom, axis=0).repeat(self._zoom, axis=1)
|
||||
return buf
|
||||
|
||||
def save(self, data, name):
|
||||
data = self._reshape(data)
|
||||
relative_name = '{}_last.png'.format(name)
|
||||
target_file = os.path.join(self._path, relative_name)
|
||||
|
||||
img = Image.fromarray(data)
|
||||
img.save(target_file, format='PNG')
|
||||
+158
@@ -0,0 +1,158 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""GAN modules."""
|
||||
import collections
|
||||
import math
|
||||
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
|
||||
from cs_gan import utils
|
||||
|
||||
|
||||
class GAN(object):
|
||||
"""Standard generative adversarial network setup.
|
||||
|
||||
The aim of the generator is to generate samples which fool a discriminator.
|
||||
Does not make any assumptions about the discriminator and generator loss
|
||||
functions.
|
||||
|
||||
Trained module components:
|
||||
|
||||
* discriminator
|
||||
* generator
|
||||
|
||||
|
||||
For the standard GAN algorithm, generator_inputs is a vector of noise (either
|
||||
Gaussian or uniform).
|
||||
"""
|
||||
|
||||
def __init__(self, discriminator, generator, num_z_iters, z_step_size,
|
||||
z_project_method, optimisation_cost_weight):
|
||||
"""Constructs the module.
|
||||
|
||||
Args:
|
||||
discriminator: The discriminator network. A sonnet module. See `nets.py`.
|
||||
generator: The generator network. A sonnet module. For examples, see
|
||||
`nets.py`.
|
||||
num_z_iters: an integer, the number of latent optimisation steps.
|
||||
z_step_size: an integer, latent optimisation step size.
|
||||
z_project_method: the method for projecting latent after optimisation,
|
||||
a string from {'norm', 'clip'}.
|
||||
optimisation_cost_weight: a float, how much to penalise the distance of z
|
||||
moved by latent optimisation.
|
||||
"""
|
||||
self._discriminator = discriminator
|
||||
self.generator = generator
|
||||
self.num_z_iters = num_z_iters
|
||||
self.z_project_method = z_project_method
|
||||
self._log_step_size_module = snt.TrainableVariable(
|
||||
[],
|
||||
initializers={'w': tf.constant_initializer(math.log(z_step_size))})
|
||||
self.z_step_size = tf.exp(self._log_step_size_module())
|
||||
self._optimisation_cost_weight = optimisation_cost_weight
|
||||
|
||||
def connect(self, data, generator_inputs):
|
||||
"""Connects the components and returns the losses, outputs and debug ops.
|
||||
|
||||
Args:
|
||||
data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the
|
||||
rank
|
||||
of this tensor, but it has to be compatible with the shapes expected
|
||||
by the discriminator.
|
||||
generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not
|
||||
have to have the same batch size as the `data` tensor. There are not
|
||||
constraints on the rank of this tensor, but it has to be compatible
|
||||
with the shapes the generator network supports as inputs.
|
||||
|
||||
Returns:
|
||||
An `ModelOutputs` instance.
|
||||
"""
|
||||
samples, optimised_z = utils.optimise_and_sample(
|
||||
generator_inputs, self, data, is_training=True)
|
||||
optimisation_cost = utils.get_optimisation_cost(generator_inputs,
|
||||
optimised_z)
|
||||
|
||||
# Pass in the labels to the discriminator in case we are using a
|
||||
# discriminator which makes use of labels. The labels can be None.
|
||||
disc_data_logits = self._discriminator(data)
|
||||
disc_sample_logits = self._discriminator(samples)
|
||||
|
||||
disc_data_loss = utils.cross_entropy_loss(
|
||||
disc_data_logits,
|
||||
tf.ones(tf.shape(disc_data_logits[:, 0]), dtype=tf.int32))
|
||||
|
||||
disc_sample_loss = utils.cross_entropy_loss(
|
||||
disc_sample_logits,
|
||||
tf.zeros(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32))
|
||||
|
||||
disc_loss = disc_data_loss + disc_sample_loss
|
||||
|
||||
generator_loss = utils.cross_entropy_loss(
|
||||
disc_sample_logits,
|
||||
tf.ones(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32))
|
||||
|
||||
optimization_components = self._build_optimization_components(
|
||||
discriminator_loss=disc_loss, generator_loss=generator_loss,
|
||||
optimisation_cost=optimisation_cost)
|
||||
|
||||
debug_ops = {}
|
||||
debug_ops['z_step_size'] = self.z_step_size
|
||||
debug_ops['disc_data_loss'] = disc_data_loss
|
||||
debug_ops['disc_sample_loss'] = disc_sample_loss
|
||||
debug_ops['disc_loss'] = disc_loss
|
||||
debug_ops['gen_loss'] = generator_loss
|
||||
debug_ops['opt_cost'] = optimisation_cost
|
||||
|
||||
return utils.ModelOutputs(
|
||||
optimization_components, debug_ops)
|
||||
|
||||
def gen_loss_fn(self, data, samples):
|
||||
"""Generator loss as latent optimisation's error function."""
|
||||
del data
|
||||
disc_sample_logits = self._discriminator(samples)
|
||||
generator_loss = utils.cross_entropy_loss(
|
||||
disc_sample_logits,
|
||||
tf.ones(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32))
|
||||
return generator_loss
|
||||
|
||||
def _build_optimization_components(
|
||||
self, generator_loss=None, discriminator_loss=None,
|
||||
optimisation_cost=None):
|
||||
"""Create the optimization components for this module."""
|
||||
|
||||
discriminator_vars = _get_and_check_variables(self._discriminator)
|
||||
generator_vars = _get_and_check_variables(self.generator)
|
||||
step_vars = _get_and_check_variables(self._log_step_size_module)
|
||||
|
||||
optimization_components = collections.OrderedDict()
|
||||
optimization_components['disc'] = utils.OptimizationComponent(
|
||||
discriminator_loss, discriminator_vars)
|
||||
optimization_components['gen'] = utils.OptimizationComponent(
|
||||
generator_loss
|
||||
+ self._optimisation_cost_weight * optimisation_cost,
|
||||
generator_vars + step_vars)
|
||||
return optimization_components
|
||||
|
||||
|
||||
def _get_and_check_variables(module):
|
||||
module_variables = module.get_all_variables()
|
||||
if not module_variables:
|
||||
raise ValueError(
|
||||
'Module {} has no variables! Variables needed for training.'.format(
|
||||
module.module_name))
|
||||
|
||||
# TensorFlow optimizers require lists to be passed in.
|
||||
return list(module_variables)
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Compute image metrics: IS, FID."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_gan as tfgan
|
||||
|
||||
|
||||
def inception_preprocess_fn(images):
|
||||
return tfgan.eval.preprocess_image(images * 255)
|
||||
|
||||
|
||||
def compute_inception_score(
|
||||
images, max_classifier_batch_size=16, assert_data_ranges=True):
|
||||
"""Computes the classifier score, using the given model.
|
||||
|
||||
Note that this metric is highly sensitive to the number of images used
|
||||
to compute it: more is better. A standard number is 50K images.
|
||||
|
||||
For more details see: https://arxiv.org/abs/1606.03498.
|
||||
|
||||
Args:
|
||||
images: A 4D tensor.
|
||||
max_classifier_batch_size: An integer. The maximum batch size size to pass
|
||||
through the classifier used to compute the metric.
|
||||
assert_data_ranges: Whether or not to assert the input is in a valid
|
||||
data range.
|
||||
|
||||
Returns:
|
||||
A scalar `tf.Tensor` with the Inception score.
|
||||
"""
|
||||
num_images = images.shape.as_list()[0]
|
||||
|
||||
def _choose_batch_size(num_images, max_batch_size):
|
||||
for size in range(max_batch_size, 0, -1):
|
||||
if num_images % size == 0:
|
||||
return size
|
||||
|
||||
batch_size = _choose_batch_size(num_images, max_classifier_batch_size)
|
||||
tf.logging.debug('Using batch_size=%s for score classifier', batch_size)
|
||||
num_batches = num_images // batch_size
|
||||
|
||||
if assert_data_ranges:
|
||||
images_batch = images[0:batch_size]
|
||||
min_data_range_check = tf.assert_greater_equal(images_batch, 0.0)
|
||||
max_data_range_check = tf.assert_less_equal(images_batch, 1.0)
|
||||
control_deps = [min_data_range_check, max_data_range_check]
|
||||
else:
|
||||
control_deps = []
|
||||
|
||||
# Do the preprocessing in the fn function to avoid having to keep all the
|
||||
# resized data in memory.
|
||||
def classifier_fn(images):
|
||||
return tfgan.eval.run_inception(inception_preprocess_fn(images))
|
||||
|
||||
with tf.control_dependencies(control_deps):
|
||||
return tfgan.eval.classifier_score(
|
||||
images, classifier_fn=classifier_fn, num_batches=num_batches)
|
||||
|
||||
|
||||
def compute_fid(
|
||||
real_images, other, max_classifier_batch_size=16,
|
||||
assert_data_ranges=True):
|
||||
"""Computes the classifier score, using the given model.
|
||||
|
||||
Note that this metric is highly sensitive to the number of images used
|
||||
to compute it: more is better. A standard number is 50K images.
|
||||
|
||||
For more details see: https://arxiv.org/abs/1606.03498.
|
||||
|
||||
Args:
|
||||
real_images: A 4D tensor. Samples from the true data distribution.
|
||||
other: A 4D tensor. Data for which to compute the FID.
|
||||
max_classifier_batch_size: An integer. The maximum batch size size to pass
|
||||
through the classifier used to compute the metric.
|
||||
assert_data_ranges: Whether or not to assert the input is in a valid
|
||||
data range.
|
||||
|
||||
Returns:
|
||||
A scalar `tf.Tensor` with the Inception score.
|
||||
"""
|
||||
num_images = real_images.shape.as_list()[0]
|
||||
|
||||
def _choose_batch_size(num_images, max_batch_size):
|
||||
for size in range(max_batch_size, 0, -1):
|
||||
if num_images % size == 0:
|
||||
return size
|
||||
|
||||
batch_size = _choose_batch_size(num_images, max_classifier_batch_size)
|
||||
tf.logging.debug('Using batch_size=%s for score classifier', batch_size)
|
||||
num_batches = num_images // batch_size
|
||||
|
||||
if assert_data_ranges:
|
||||
images_batch = tf.concat(
|
||||
[real_images[0:batch_size], other[0:batch_size]], axis=0)
|
||||
min_data_range_check = tf.assert_greater_equal(images_batch, 0.0)
|
||||
max_data_range_check = tf.assert_less_equal(images_batch, 1.0)
|
||||
control_deps = [min_data_range_check, max_data_range_check]
|
||||
else:
|
||||
control_deps = []
|
||||
|
||||
# Do the preprocessing in the fn function to avoid having to keep all the
|
||||
# resized data in memory.
|
||||
def classifier_fn(images):
|
||||
return tfgan.eval.run_inception(
|
||||
inception_preprocess_fn(images),
|
||||
output_tensor=tfgan.eval.INCEPTION_FINAL_POOL)
|
||||
|
||||
with tf.control_dependencies(control_deps):
|
||||
return tfgan.eval.frechet_classifier_distance(
|
||||
real_images, other,
|
||||
classifier_fn=classifier_fn, num_batches=num_batches)
|
||||
|
||||
|
||||
def generate_big_batch(generator, generator_inputs, max_num_samples=100):
|
||||
"""Generate samples when the number of samples is too big to do one pass."""
|
||||
num_samples = generator_inputs.shape.as_list()[0]
|
||||
max_num_samples = min(max_num_samples, num_samples)
|
||||
batched_shape = [num_samples // max_num_samples, max_num_samples]
|
||||
batched_shape += generator_inputs.shape.as_list()[1:]
|
||||
batched_generator_inputs = tf.reshape(generator_inputs, batched_shape)
|
||||
|
||||
# Create the samples by sequentially doing forward passes through the
|
||||
# generator. This ensures we do not run out of memory.
|
||||
output_samples = tf.map_fn(
|
||||
fn=generator,
|
||||
elems=batched_generator_inputs,
|
||||
parallel_iterations=1,
|
||||
back_prop=False,
|
||||
swap_memory=True)
|
||||
|
||||
samples = tf.reshape(output_samples,
|
||||
[num_samples] + output_samples.shape.as_list()[2:])
|
||||
return samples
|
||||
|
||||
|
||||
def get_image_metrics(real_images, other):
|
||||
return {
|
||||
'inception_score': compute_inception_score(other),
|
||||
'fid': compute_fid(real_images, other)}
|
||||
|
||||
|
||||
def get_image_metrics_for_samples(
|
||||
real_images, generator, prior, data_processor, num_eval_samples):
|
||||
generator_inputs = prior.sample(num_eval_samples)
|
||||
samples = generate_big_batch(generator, generator_inputs)
|
||||
# Ensure the samples are in [0, 1].
|
||||
samples = data_processor.postprocess(samples)
|
||||
|
||||
return {
|
||||
'inception_score': compute_inception_score(samples),
|
||||
'fid': compute_fid(real_images, samples)}
|
||||
+195
@@ -0,0 +1,195 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Training script."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from cs_gan import file_utils
|
||||
from cs_gan import gan
|
||||
from cs_gan import image_metrics
|
||||
from cs_gan import utils
|
||||
|
||||
flags.DEFINE_integer(
|
||||
'num_training_iterations', 200000,
|
||||
'Number of training iterations.')
|
||||
flags.DEFINE_integer(
|
||||
'batch_size', 64, 'Training batch size.')
|
||||
flags.DEFINE_integer(
|
||||
'num_latents', 128, 'The number of latents')
|
||||
flags.DEFINE_integer(
|
||||
'summary_every_step', 1000,
|
||||
'The interval at which to log debug ops.')
|
||||
flags.DEFINE_integer(
|
||||
'image_metrics_every_step', 2000,
|
||||
'The interval at which to log (expensive) image metrics.')
|
||||
flags.DEFINE_integer(
|
||||
'export_every', 10,
|
||||
'The interval at which to export samples.')
|
||||
flags.DEFINE_integer(
|
||||
'num_eval_samples', 10000,
|
||||
'The number of samples used to evaluate FID/IS')
|
||||
flags.DEFINE_string(
|
||||
'dataset', 'cifar', 'The dataset used for learning (cifar|mnist.')
|
||||
flags.DEFINE_float(
|
||||
'optimisation_cost_weight', 3., 'weight for latent optimisation cost.')
|
||||
flags.DEFINE_integer(
|
||||
'num_z_iters', 3, 'The number of latent optimisation steps.'
|
||||
'It falls back to vanilla GAN when num_z_iters is set to 0.')
|
||||
flags.DEFINE_float(
|
||||
'z_step_size', 0.01, 'Step size for latent optimisation.')
|
||||
flags.DEFINE_string(
|
||||
'z_project_method', 'norm', 'The method to project z.')
|
||||
flags.DEFINE_string(
|
||||
'output_dir', '/tmp/cs_gan/gan', 'Location where to save output files.')
|
||||
flags.DEFINE_float('disc_lr', 2e-4, 'Discriminator Learning rate.')
|
||||
flags.DEFINE_float('gen_lr', 2e-4, 'Generator Learning rate.')
|
||||
flags.DEFINE_bool(
|
||||
'run_real_data_metrics', False,
|
||||
'Whether or not to run image metrics on real data.')
|
||||
flags.DEFINE_bool(
|
||||
'run_sample_metrics', True,
|
||||
'Whether or not to run image metrics on samples.')
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# Log info level (for Hooks).
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
utils.make_output_dir(FLAGS.output_dir)
|
||||
data_processor = utils.DataProcessor()
|
||||
images = utils.get_train_dataset(data_processor, FLAGS.dataset,
|
||||
FLAGS.batch_size)
|
||||
|
||||
logging.info('Generator learning rate: %d', FLAGS.gen_lr)
|
||||
logging.info('Discriminator learning rate: %d', FLAGS.disc_lr)
|
||||
|
||||
# Construct optimizers.
|
||||
disc_optimizer = tf.train.AdamOptimizer(FLAGS.disc_lr, beta1=0.5, beta2=0.999)
|
||||
gen_optimizer = tf.train.AdamOptimizer(FLAGS.gen_lr, beta1=0.5, beta2=0.999)
|
||||
|
||||
# Create the networks and models.
|
||||
generator = utils.get_generator(FLAGS.dataset)
|
||||
metric_net = utils.get_metric_net(FLAGS.dataset)
|
||||
model = gan.GAN(metric_net, generator,
|
||||
FLAGS.num_z_iters, FLAGS.z_step_size,
|
||||
FLAGS.z_project_method, FLAGS.optimisation_cost_weight)
|
||||
prior = utils.make_prior(FLAGS.num_latents)
|
||||
generator_inputs = prior.sample(FLAGS.batch_size)
|
||||
|
||||
model_output = model.connect(images, generator_inputs)
|
||||
optimization_components = model_output.optimization_components
|
||||
debug_ops = model_output.debug_ops
|
||||
samples = generator(generator_inputs, is_training=False)
|
||||
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
# We pass the global step both to the disc and generator update ops.
|
||||
# This means that the global step will not be the same as the number of
|
||||
# iterations, but ensures that hooks which rely on global step work correctly.
|
||||
disc_update_op = disc_optimizer.minimize(
|
||||
optimization_components['disc'].loss,
|
||||
var_list=optimization_components['disc'].vars,
|
||||
global_step=global_step)
|
||||
|
||||
gen_update_op = gen_optimizer.minimize(
|
||||
optimization_components['gen'].loss,
|
||||
var_list=optimization_components['gen'].vars,
|
||||
global_step=global_step)
|
||||
|
||||
# Get data needed to compute FID. We also compute metrics on
|
||||
# real data as a sanity check and as a reference point.
|
||||
eval_real_data = utils.get_real_data_for_eval(FLAGS.num_eval_samples,
|
||||
FLAGS.dataset,
|
||||
split='train')
|
||||
|
||||
def sample_fn(x):
|
||||
return utils.optimise_and_sample(x, module=model,
|
||||
data=None, is_training=False)[0]
|
||||
|
||||
if FLAGS.run_sample_metrics:
|
||||
sample_metrics = image_metrics.get_image_metrics_for_samples(
|
||||
eval_real_data, sample_fn,
|
||||
prior, data_processor,
|
||||
num_eval_samples=FLAGS.num_eval_samples)
|
||||
else:
|
||||
sample_metrics = {}
|
||||
|
||||
if FLAGS.run_real_data_metrics:
|
||||
data_metrics = image_metrics.get_image_metrics(
|
||||
eval_real_data, eval_real_data)
|
||||
else:
|
||||
data_metrics = {}
|
||||
|
||||
sample_exporter = file_utils.FileExporter(
|
||||
os.path.join(FLAGS.output_dir, 'samples'))
|
||||
|
||||
# Hooks.
|
||||
debug_ops['it'] = global_step
|
||||
# Abort training on Nans.
|
||||
nan_disc_hook = tf.train.NanTensorHook(optimization_components['disc'].loss)
|
||||
nan_gen_hook = tf.train.NanTensorHook(optimization_components['gen'].loss)
|
||||
# Step counter.
|
||||
step_conter_hook = tf.train.StepCounterHook()
|
||||
|
||||
checkpoint_saver_hook = tf.train.CheckpointSaverHook(
|
||||
checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60)
|
||||
|
||||
loss_summary_saver_hook = tf.train.SummarySaverHook(
|
||||
save_steps=FLAGS.summary_every_step,
|
||||
output_dir=os.path.join(FLAGS.output_dir, 'summaries'),
|
||||
summary_op=utils.get_summaries(debug_ops))
|
||||
|
||||
metrics_summary_saver_hook = tf.train.SummarySaverHook(
|
||||
save_steps=FLAGS.image_metrics_every_step,
|
||||
output_dir=os.path.join(FLAGS.output_dir, 'summaries'),
|
||||
summary_op=utils.get_summaries(sample_metrics))
|
||||
|
||||
hooks = [checkpoint_saver_hook, metrics_summary_saver_hook,
|
||||
nan_disc_hook, nan_gen_hook, step_conter_hook,
|
||||
loss_summary_saver_hook]
|
||||
|
||||
# Start training.
|
||||
with tf.train.MonitoredSession(hooks=hooks) as sess:
|
||||
logging.info('starting training')
|
||||
|
||||
for key, value in sess.run(data_metrics).items():
|
||||
logging.info('%s: %d', key, value)
|
||||
|
||||
for i in range(FLAGS.num_training_iterations):
|
||||
sess.run(disc_update_op)
|
||||
sess.run(gen_update_op)
|
||||
|
||||
if i % FLAGS.export_every == 0:
|
||||
samples_np, data_np = sess.run([samples, images])
|
||||
# Create an object which gets data and does the processing.
|
||||
data_np = data_processor.postprocess(data_np)
|
||||
samples_np = data_processor.postprocess(samples_np)
|
||||
sample_exporter.save(samples_np, 'samples')
|
||||
sample_exporter.save(data_np, 'data')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,141 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Training script."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
from cs_gan import cs
|
||||
from cs_gan import file_utils
|
||||
from cs_gan import utils
|
||||
|
||||
tfd = tfp.distributions
|
||||
|
||||
flags.DEFINE_string(
|
||||
'mode', 'recons', 'Model mode.')
|
||||
flags.DEFINE_integer(
|
||||
'num_training_iterations', 10000000,
|
||||
'Number of training iterations.')
|
||||
flags.DEFINE_integer(
|
||||
'batch_size', 64, 'Training batch size.')
|
||||
flags.DEFINE_integer(
|
||||
'num_measurements', 25, 'The number of measurements')
|
||||
flags.DEFINE_integer(
|
||||
'num_latents', 100, 'The number of latents')
|
||||
flags.DEFINE_integer(
|
||||
'num_z_iters', 3, 'The number of latent optimisation steps.')
|
||||
flags.DEFINE_float(
|
||||
'z_step_size', 0.01, 'Step size for latent optimisation.')
|
||||
flags.DEFINE_string(
|
||||
'z_project_method', 'norm', 'The method to project z.')
|
||||
flags.DEFINE_integer(
|
||||
'summary_every_step', 1000,
|
||||
'The interval at which to log debug ops.')
|
||||
flags.DEFINE_integer(
|
||||
'export_every', 10,
|
||||
'The interval at which to export samples.')
|
||||
flags.DEFINE_string(
|
||||
'dataset', 'mnist', 'The dataset used for learning (cifar|mnist.')
|
||||
flags.DEFINE_float('learning_rate', 1e-4, 'Learning rate.')
|
||||
flags.DEFINE_string(
|
||||
'output_dir', '/tmp/cs_gan/cs', 'Location where to save output files.')
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# Log info level (for Hooks).
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
utils.make_output_dir(FLAGS.output_dir)
|
||||
data_processor = utils.DataProcessor()
|
||||
images = utils.get_train_dataset(data_processor, FLAGS.dataset,
|
||||
FLAGS.batch_size)
|
||||
|
||||
logging.info('Learning rate: %d', FLAGS.learning_rate)
|
||||
|
||||
# Construct optimizers.
|
||||
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
|
||||
|
||||
# Create the networks and models.
|
||||
generator = utils.get_generator(FLAGS.dataset)
|
||||
metric_net = utils.get_metric_net(FLAGS.dataset, FLAGS.num_measurements)
|
||||
model = cs.CS(metric_net, generator,
|
||||
FLAGS.num_z_iters, FLAGS.z_step_size, FLAGS.z_project_method)
|
||||
prior = utils.make_prior(FLAGS.num_latents)
|
||||
generator_inputs = prior.sample(FLAGS.batch_size)
|
||||
|
||||
model_output = model.connect(images, generator_inputs)
|
||||
optimization_components = model_output.optimization_components
|
||||
debug_ops = model_output.debug_ops
|
||||
reconstructions, _ = utils.optimise_and_sample(
|
||||
generator_inputs, model, images, is_training=False)
|
||||
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
update_op = optimizer.minimize(
|
||||
optimization_components.loss,
|
||||
var_list=optimization_components.vars,
|
||||
global_step=global_step)
|
||||
|
||||
sample_exporter = file_utils.FileExporter(
|
||||
os.path.join(FLAGS.output_dir, 'reconstructions'))
|
||||
|
||||
# Hooks.
|
||||
debug_ops['it'] = global_step
|
||||
# Abort training on Nans.
|
||||
nan_hook = tf.train.NanTensorHook(optimization_components.loss)
|
||||
# Step counter.
|
||||
step_conter_hook = tf.train.StepCounterHook()
|
||||
|
||||
checkpoint_saver_hook = tf.train.CheckpointSaverHook(
|
||||
checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60)
|
||||
|
||||
loss_summary_saver_hook = tf.train.SummarySaverHook(
|
||||
save_steps=FLAGS.summary_every_step,
|
||||
output_dir=os.path.join(FLAGS.output_dir, 'summaries'),
|
||||
summary_op=utils.get_summaries(debug_ops))
|
||||
|
||||
hooks = [checkpoint_saver_hook, nan_hook, step_conter_hook,
|
||||
loss_summary_saver_hook]
|
||||
|
||||
# Start training.
|
||||
with tf.train.MonitoredSession(hooks=hooks) as sess:
|
||||
logging.info('starting training')
|
||||
|
||||
for i in range(FLAGS.num_training_iterations):
|
||||
sess.run(update_op)
|
||||
|
||||
if i % FLAGS.export_every == 0:
|
||||
reconstructions_np, data_np = sess.run([reconstructions, images])
|
||||
# Create an object which gets data and does the processing.
|
||||
data_np = data_processor.postprocess(data_np)
|
||||
reconstructions_np = data_processor.postprocess(reconstructions_np)
|
||||
sample_exporter.save(reconstructions_np, 'reconstructions')
|
||||
sample_exporter.save(data_np, 'data')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
+106
@@ -0,0 +1,106 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Network utilities."""
|
||||
|
||||
import functools
|
||||
import re
|
||||
import numpy as np
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
import tensorflow_gan as tfgan
|
||||
|
||||
|
||||
def _sn_custom_getter():
|
||||
def name_filter(name):
|
||||
match = re.match(r'.*w(_.*)?$', name)
|
||||
return match is not None
|
||||
|
||||
return tfgan.features.spectral_normalization_custom_getter(
|
||||
name_filter=name_filter)
|
||||
|
||||
|
||||
class SNGenNet(snt.AbstractModule):
|
||||
"""As in the SN paper."""
|
||||
|
||||
def __init__(self, name='conv_gen'):
|
||||
super(SNGenNet, self).__init__(name=name)
|
||||
|
||||
def _build(self, inputs, is_training):
|
||||
batch_size = inputs.get_shape().as_list()[0]
|
||||
first_shape = [4, 4, 512]
|
||||
norm_ctor = snt.BatchNormV2
|
||||
norm_ctor_config = {'scale': True}
|
||||
up_tensor = snt.Linear(np.prod(first_shape))(inputs)
|
||||
first_tensor = tf.reshape(up_tensor, shape=[batch_size] + first_shape)
|
||||
|
||||
net = snt.nets.ConvNet2DTranspose(
|
||||
output_channels=[256, 128, 64, 3],
|
||||
output_shapes=[(8, 8), (16, 16), (32, 32), (32, 32)],
|
||||
kernel_shapes=[(4, 4), (4, 4), (4, 4), (3, 3)],
|
||||
strides=[2, 2, 2, 1],
|
||||
normalization_ctor=norm_ctor,
|
||||
normalization_kwargs=norm_ctor_config,
|
||||
normalize_final=False,
|
||||
paddings=[snt.SAME], activate_final=False, activation=tf.nn.relu)
|
||||
output = net(first_tensor, is_training=is_training)
|
||||
return tf.nn.tanh(output)
|
||||
|
||||
|
||||
class SNMetricNet(snt.AbstractModule):
|
||||
"""Spectral normalization discriminator (metric) architecture."""
|
||||
|
||||
def __init__(self, num_outputs=2, name='sn_metric'):
|
||||
super(SNMetricNet, self).__init__(name=name)
|
||||
self._num_outputs = num_outputs
|
||||
|
||||
def _build(self, inputs):
|
||||
with tf.variable_scope('', custom_getter=_sn_custom_getter()):
|
||||
net = snt.nets.ConvNet2D(
|
||||
output_channels=[64, 64, 128, 128, 256, 256, 512],
|
||||
kernel_shapes=[
|
||||
(3, 3), (4, 4), (3, 3), (4, 4), (3, 3), (4, 4), (3, 3)],
|
||||
strides=[1, 2, 1, 2, 1, 2, 1],
|
||||
paddings=[snt.SAME], activate_final=True,
|
||||
activation=functools.partial(tf.nn.leaky_relu, alpha=0.1))
|
||||
linear = snt.Linear(self._num_outputs)
|
||||
output = linear(snt.BatchFlatten()(net(inputs)))
|
||||
return output
|
||||
|
||||
|
||||
class MLPGeneratorNet(snt.AbstractModule):
|
||||
"""MNIST generator net."""
|
||||
|
||||
def __init__(self, name='mlp_generator'):
|
||||
super(MLPGeneratorNet, self).__init__(name=name)
|
||||
|
||||
def _build(self, inputs, is_training=True):
|
||||
del is_training
|
||||
net = snt.nets.MLP([500, 500, 784], activation=tf.nn.leaky_relu)
|
||||
out = net(inputs)
|
||||
out = tf.nn.tanh(out)
|
||||
return snt.BatchReshape([28, 28, 1])(out)
|
||||
|
||||
|
||||
class MLPMetricNet(snt.AbstractModule):
|
||||
"""Same as in Grover and Ermon, ICLR workshop 2017."""
|
||||
|
||||
def __init__(self, num_outputs=2, name='mlp_metric'):
|
||||
super(MLPMetricNet, self).__init__(name=name)
|
||||
self._layer_size = [500, 500, num_outputs]
|
||||
|
||||
def _build(self, inputs):
|
||||
net = snt.nets.MLP(self._layer_size,
|
||||
activation=tf.nn.leaky_relu)
|
||||
output = net(snt.BatchFlatten()(inputs))
|
||||
return output
|
||||
@@ -0,0 +1,8 @@
|
||||
absl-py==0.7.1
|
||||
dm-sonnet==1.34
|
||||
numpy==1.16.4
|
||||
pillow==4.3.0
|
||||
tensorflow==1.14
|
||||
tensorflow-probability==0.7.0
|
||||
tensorflow-gan==1.0.0.dev0
|
||||
tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow.
|
||||
Executable
+22
@@ -0,0 +1,22 @@
|
||||
#!/bin/sh
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
python3 -m venv cs_gan_venv
|
||||
source cs_gan_venv/bin/activate
|
||||
pip install -r cs_gan/requirements.txt
|
||||
|
||||
python -m cs_gan.main_cs
|
||||
|
||||
python -m cs_gan.main
|
||||
@@ -0,0 +1,57 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
|
||||
from cs_gan import gan
|
||||
|
||||
|
||||
class DummyGenerator(snt.AbstractModule):
|
||||
|
||||
def __init__(self):
|
||||
super(DummyGenerator, self).__init__(name='dummy_generator')
|
||||
|
||||
def _build(self, inputs, is_training):
|
||||
return snt.Linear(10)(inputs)
|
||||
|
||||
|
||||
class GanTest(tf.test.TestCase):
|
||||
|
||||
def testConnect(self):
|
||||
discriminator = snt.Linear(2)
|
||||
generator = DummyGenerator()
|
||||
model = gan.GAN(
|
||||
discriminator, generator,
|
||||
num_z_iters=0, z_step_size=0.1,
|
||||
z_project_method='none', optimisation_cost_weight=0.0)
|
||||
|
||||
generator_inputs = tf.ones((16, 3), dtype=tf.float32)
|
||||
data = tf.ones((16, 10))
|
||||
opt_compoments, _ = model.connect(data, generator_inputs)
|
||||
|
||||
self.assertIn('disc', opt_compoments)
|
||||
self.assertIn('gen', opt_compoments)
|
||||
|
||||
self.assertCountEqual(
|
||||
opt_compoments['disc'].vars,
|
||||
discriminator.get_variables())
|
||||
self.assertCountEqual(
|
||||
opt_compoments['gen'].vars,
|
||||
generator.get_variables() + model._log_step_size_module.get_variables())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+234
@@ -0,0 +1,234 @@
|
||||
# python3
|
||||
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tools for latent optimisation."""
|
||||
import collections
|
||||
import os
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
from cs_gan import nets
|
||||
|
||||
tfd = tfp.distributions
|
||||
|
||||
|
||||
class ModelOutputs(
|
||||
collections.namedtuple('AdversarialModelOutputs',
|
||||
['optimization_components', 'debug_ops'])):
|
||||
"""All the information produced by the adversarial module.
|
||||
|
||||
Fields:
|
||||
|
||||
* `optimization_components`: A dictionary. Each entry in this dictionary
|
||||
corresponds to a module to train using their own optimizer. The keys are
|
||||
names of the components, and the values are `common.OptimizationComponent`
|
||||
instances. The keys of this dict can be made keys of the configuration
|
||||
used by the main train loop, to define the configuration of the
|
||||
optimization details for each module.
|
||||
* `debug_ops`: A dictionary, from string to a scalar `tf.Tensor`. Quantities
|
||||
used for tracking training.
|
||||
"""
|
||||
|
||||
|
||||
class OptimizationComponent(
|
||||
collections.namedtuple('OptimizationComponent', ['loss', 'vars'])):
|
||||
"""Information needed by the optimizer to train modules.
|
||||
|
||||
Usage:
|
||||
`optimizer.minimize(
|
||||
opt_compoment.loss, var_list=opt_component.vars)`
|
||||
|
||||
Fields:
|
||||
|
||||
* `loss`: A `tf.Tensor` the loss of the module.
|
||||
* `vars`: A list of variables, the ones which will be used to minimize the
|
||||
loss.
|
||||
"""
|
||||
|
||||
|
||||
def cross_entropy_loss(logits, expected):
|
||||
"""The cross entropy classification loss between logits and expected values.
|
||||
|
||||
The loss proposed by the original GAN paper: https://arxiv.org/abs/1406.2661.
|
||||
|
||||
Args:
|
||||
logits: a `tf.Tensor`, the model produced logits.
|
||||
expected: a `tf.Tensor`, the expected output.
|
||||
|
||||
Returns:
|
||||
A scalar `tf.Tensor`, the average loss obtained on the given inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: if the logits do not have shape [batch_size, 2].
|
||||
"""
|
||||
|
||||
num_logits = logits.get_shape()[1]
|
||||
if num_logits != 2:
|
||||
raise ValueError(('Invalid number of logits for cross_entropy_loss! '
|
||||
'cross_entropy_loss supports only 2 output logits!'))
|
||||
return tf.reduce_mean(
|
||||
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=logits, labels=expected))
|
||||
|
||||
|
||||
def optimise_and_sample(init_z, module, data, is_training):
|
||||
"""Optimising generator latent variables and sample."""
|
||||
|
||||
if module.num_z_iters == 0:
|
||||
z_final = init_z
|
||||
else:
|
||||
init_loop_vars = (0, _project_z(init_z, module.z_project_method))
|
||||
loop_cond = lambda i, _: i < module.num_z_iters
|
||||
def loop_body(i, z):
|
||||
loop_samples = module.generator(z, is_training)
|
||||
gen_loss = module.gen_loss_fn(data, loop_samples)
|
||||
z_grad = tf.gradients(gen_loss, z)[0]
|
||||
z -= module.z_step_size * z_grad
|
||||
z = _project_z(z, module.z_project_method)
|
||||
return i + 1, z
|
||||
|
||||
# Use the following static loop for debugging
|
||||
# z = init_z
|
||||
# for _ in xrange(num_z_iters):
|
||||
# _, z = loop_body(0, z)
|
||||
# z_final = z
|
||||
|
||||
_, z_final = tf.while_loop(loop_cond,
|
||||
loop_body,
|
||||
init_loop_vars)
|
||||
|
||||
return module.generator(z_final, is_training), z_final
|
||||
|
||||
|
||||
def get_optimisation_cost(initial_z, optimised_z):
|
||||
optimisation_cost = tf.reduce_mean(
|
||||
tf.reduce_sum((optimised_z - initial_z)**2, -1))
|
||||
return optimisation_cost
|
||||
|
||||
|
||||
def _project_z(z, project_method='clip'):
|
||||
"""To be used for projected gradient descent over z."""
|
||||
if project_method == 'norm':
|
||||
z_p = tf.nn.l2_normalize(z, axis=-1)
|
||||
elif project_method == 'clip':
|
||||
z_p = tf.clip_by_value(z, -1, 1)
|
||||
else:
|
||||
raise ValueError('Unknown project_method: {}'.format(project_method))
|
||||
return z_p
|
||||
|
||||
|
||||
class DataProcessor(object):
|
||||
|
||||
def preprocess(self, x):
|
||||
return x * 2 - 1
|
||||
|
||||
def postprocess(self, x):
|
||||
return (x + 1) / 2.
|
||||
|
||||
|
||||
def _get_np_data(data_processor, dataset, split='train'):
|
||||
"""Get the dataset as numpy arrays."""
|
||||
index = 0 if split == 'train' else 1
|
||||
if dataset == 'mnist':
|
||||
# Construct the dataset.
|
||||
x, _ = tf.keras.datasets.mnist.load_data()[index]
|
||||
# Note: tf dataset is binary so we convert it to float.
|
||||
x = x.astype(np.float32)
|
||||
x = x / 255.
|
||||
x = x.reshape((-1, 28, 28, 1))
|
||||
|
||||
if dataset == 'cifar':
|
||||
x, _ = tf.keras.datasets.cifar10.load_data()[index]
|
||||
x = x.astype(np.float32)
|
||||
x = x / 255.
|
||||
|
||||
if data_processor:
|
||||
# Normalize data if a processor is given.
|
||||
x = data_processor.preprocess(x)
|
||||
return x
|
||||
|
||||
|
||||
def make_output_dir(output_dir):
|
||||
logging.info('Creating output dir %s', output_dir)
|
||||
if not tf.gfile.IsDirectory(output_dir):
|
||||
tf.gfile.MakeDirs(output_dir)
|
||||
|
||||
|
||||
def get_ckpt_dir(output_dir):
|
||||
ckpt_dir = os.path.join(output_dir, 'ckpt')
|
||||
if not tf.gfile.IsDirectory(ckpt_dir):
|
||||
tf.gfile.MakeDirs(ckpt_dir)
|
||||
return ckpt_dir
|
||||
|
||||
|
||||
def get_real_data_for_eval(num_eval_samples, dataset, split='valid'):
|
||||
data = _get_np_data(data_processor=None, dataset=dataset, split=split)
|
||||
data = data[:num_eval_samples]
|
||||
return tf.constant(data)
|
||||
|
||||
|
||||
def get_summaries(ops):
|
||||
summaries = []
|
||||
for name, op in ops.items():
|
||||
# Ensure to log the value ops before writing them in the summary.
|
||||
# We do this instead of a hook to ensure IS/FID are never computed twice.
|
||||
print_op = tf.print(name, [op], output_stream=tf.logging.info)
|
||||
with tf.control_dependencies([print_op]):
|
||||
summary = tf.summary.scalar(name, op)
|
||||
summaries.append(summary)
|
||||
return summaries
|
||||
|
||||
|
||||
def get_train_dataset(data_processor, dataset, batch_size):
|
||||
"""Creates the training data tensors."""
|
||||
x_train = _get_np_data(data_processor, dataset, split='train')
|
||||
# Create the TF dataset.
|
||||
dataset = tf.data.Dataset.from_tensor_slices(x_train)
|
||||
|
||||
# Shuffle and repeat the dataset for training.
|
||||
# This is required because we want to do multiple passes through the entire
|
||||
# dataset when training.
|
||||
dataset = dataset.shuffle(100000).repeat()
|
||||
|
||||
# Batch the data and return the data batch.
|
||||
one_shot_iterator = dataset.batch(batch_size).make_one_shot_iterator()
|
||||
data_batch = one_shot_iterator.get_next()
|
||||
return data_batch
|
||||
|
||||
|
||||
def get_generator(dataset):
|
||||
if dataset == 'mnist':
|
||||
return nets.MLPGeneratorNet()
|
||||
if dataset == 'cifar':
|
||||
return nets.SNGenNet()
|
||||
|
||||
|
||||
def get_metric_net(dataset, num_outputs=2):
|
||||
if dataset == 'mnist':
|
||||
return nets.MLPMetricNet(num_outputs)
|
||||
if dataset == 'cifar':
|
||||
return nets.SNMetricNet(num_outputs)
|
||||
|
||||
|
||||
def make_prior(num_latents):
|
||||
# Zero mean, unit variance prior.
|
||||
prior_mean = tf.zeros(shape=(num_latents), dtype=tf.float32)
|
||||
prior_scale = tf.ones(shape=(num_latents), dtype=tf.float32)
|
||||
|
||||
return tfd.Normal(loc=prior_mean, scale=prior_scale)
|
||||
|
||||
Reference in New Issue
Block a user