Fix requirements for cs_gan. Update to new TFGAN code.

PiperOrigin-RevId: 277907210
This commit is contained in:
Mihaela Rosca
2019-11-01 11:58:40 +00:00
committed by Diego de Las Casas
parent 82063f13bc
commit 9cf75990be
3 changed files with 27 additions and 129 deletions
+23 -125
View File
@@ -16,135 +16,33 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow_gan as tfgan
def compute_inception_score(
images, max_classifier_batch_size=16, assert_data_ranges=True):
"""Computes the classifier score, using the given model.
Note that this metric is highly sensitive to the number of images used
to compute it: more is better. A standard number is 50K images.
For more details see: https://arxiv.org/abs/1606.03498.
Args:
images: A 4D tensor.
max_classifier_batch_size: An integer. The maximum batch size size to pass
through the classifier used to compute the metric.
assert_data_ranges: Whether or not to assert the input is in a valid
data range.
Returns:
A scalar `tf.Tensor` with the Inception score.
"""
num_images = images.shape.as_list()[0]
def _choose_batch_size(num_images, max_batch_size):
for size in range(max_batch_size, 0, -1):
if num_images % size == 0:
return size
batch_size = _choose_batch_size(num_images, max_classifier_batch_size)
tf.logging.debug('Using batch_size=%s for score classifier', batch_size)
num_batches = num_images // batch_size
if assert_data_ranges:
images_batch = images[0:batch_size]
min_data_range_check = tf.assert_greater_equal(images_batch, 0.0)
max_data_range_check = tf.assert_less_equal(images_batch, 1.0)
control_deps = [min_data_range_check, max_data_range_check]
else:
control_deps = []
# Inception module does resizing, if necessary.
with tf.control_dependencies(control_deps):
return tfgan.eval.run_inception(images, num_batches=num_batches)
def compute_fid(
real_images, other, max_classifier_batch_size=16,
assert_data_ranges=True):
"""Computes the classifier score, using the given model.
Note that this metric is highly sensitive to the number of images used
to compute it: more is better. A standard number is 50K images.
For more details see: https://arxiv.org/abs/1606.03498.
Args:
real_images: A 4D tensor. Samples from the true data distribution.
other: A 4D tensor. Data for which to compute the FID.
max_classifier_batch_size: An integer. The maximum batch size size to pass
through the classifier used to compute the metric.
assert_data_ranges: Whether or not to assert the input is in a valid
data range.
Returns:
A scalar `tf.Tensor` with the Inception score.
"""
num_images = real_images.shape.as_list()[0]
def _choose_batch_size(num_images, max_batch_size):
for size in range(max_batch_size, 0, -1):
if num_images % size == 0:
return size
batch_size = _choose_batch_size(num_images, max_classifier_batch_size)
tf.logging.debug('Using batch_size=%s for score classifier', batch_size)
num_batches = num_images // batch_size
if assert_data_ranges:
images_batch = tf.concat(
[real_images[0:batch_size], other[0:batch_size]], axis=0)
min_data_range_check = tf.assert_greater_equal(images_batch, 0.0)
max_data_range_check = tf.assert_less_equal(images_batch, 1.0)
control_deps = [min_data_range_check, max_data_range_check]
else:
control_deps = []
# Inception module does resizing, if necessary.
with tf.control_dependencies(control_deps):
return tfgan.eval.frechet_inception_distance(
real_images, other, num_batches=num_batches)
def generate_big_batch(generator, generator_inputs, max_num_samples=100):
"""Generate samples when the number of samples is too big to do one pass."""
num_samples = generator_inputs.shape.as_list()[0]
max_num_samples = min(max_num_samples, num_samples)
batched_shape = [num_samples // max_num_samples, max_num_samples]
batched_shape += generator_inputs.shape.as_list()[1:]
batched_generator_inputs = tf.reshape(generator_inputs, batched_shape)
# Create the samples by sequentially doing forward passes through the
# generator. This ensures we do not run out of memory.
output_samples = tf.map_fn(
fn=generator,
elems=batched_generator_inputs,
parallel_iterations=1,
back_prop=False,
swap_memory=True)
samples = tf.reshape(output_samples,
[num_samples] + output_samples.shape.as_list()[2:])
return samples
def get_image_metrics(real_images, other):
return {
'inception_score': compute_inception_score(other),
'fid': compute_fid(real_images, other)}
def get_image_metrics_for_samples(
real_images, generator, prior, data_processor, num_eval_samples):
generator_inputs = prior.sample(num_eval_samples)
samples = generate_big_batch(generator, generator_inputs)
# Ensure the samples are in [0, 1].
samples = data_processor.postprocess(samples)
"""Compute inception score and FID."""
max_classifier_batch = 10
num_batches = num_eval_samples // max_classifier_batch
def sample_fn(arg):
del arg
samples = generator(prior.sample(max_classifier_batch))
# Ensure data is in [0, 1], as expected by TFGAN.
# Resizing to appropriate size is done by TFGAN.
return data_processor.postprocess(samples)
fake_outputs = tfgan.eval.sample_and_run_inception(
sample_fn,
sample_inputs=[1.0] * num_batches) # Dummy inputs.
fake_logits = fake_outputs['logits']
inception_score = tfgan.eval.classifier_score_from_logits(fake_logits)
real_outputs = tfgan.eval.run_inception(real_images, num_batches=num_batches)
fid = tfgan.eval.frechet_classifier_distance_from_activations(
real_outputs['pool_3'], fake_outputs['pool_3'])
return {
'inception_score': compute_inception_score(samples),
'fid': compute_fid(real_images, samples)}
'inception_score': inception_score,
'fid': fid}
+3 -4
View File
@@ -1,8 +1,7 @@
absl-py==0.7.1
dm-sonnet==1.34
numpy==1.16.4
pillow==4.3.0
tensorflow==1.14
Pillow
tensorflow==1.15rc2
tensorflow-probability==0.7.0
tensorflow-gan==1.0.0.dev0
tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow.
tensorflow-gan==2.0.0
+1
View File
@@ -15,6 +15,7 @@
python3 -m venv cs_gan_venv
source cs_gan_venv/bin/activate
pip install --upgrade pip
pip install -r cs_gan/requirements.txt
python -m cs_gan.main_cs