mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-10 05:17:46 +08:00
Fix requirements for cs_gan. Update to new TFGAN code.
PiperOrigin-RevId: 277907210
This commit is contained in:
committed by
Diego de Las Casas
parent
82063f13bc
commit
9cf75990be
+23
-125
@@ -16,135 +16,33 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_gan as tfgan
|
||||
|
||||
|
||||
def compute_inception_score(
|
||||
images, max_classifier_batch_size=16, assert_data_ranges=True):
|
||||
"""Computes the classifier score, using the given model.
|
||||
|
||||
Note that this metric is highly sensitive to the number of images used
|
||||
to compute it: more is better. A standard number is 50K images.
|
||||
|
||||
For more details see: https://arxiv.org/abs/1606.03498.
|
||||
|
||||
Args:
|
||||
images: A 4D tensor.
|
||||
max_classifier_batch_size: An integer. The maximum batch size size to pass
|
||||
through the classifier used to compute the metric.
|
||||
assert_data_ranges: Whether or not to assert the input is in a valid
|
||||
data range.
|
||||
|
||||
Returns:
|
||||
A scalar `tf.Tensor` with the Inception score.
|
||||
"""
|
||||
num_images = images.shape.as_list()[0]
|
||||
|
||||
def _choose_batch_size(num_images, max_batch_size):
|
||||
for size in range(max_batch_size, 0, -1):
|
||||
if num_images % size == 0:
|
||||
return size
|
||||
|
||||
batch_size = _choose_batch_size(num_images, max_classifier_batch_size)
|
||||
tf.logging.debug('Using batch_size=%s for score classifier', batch_size)
|
||||
num_batches = num_images // batch_size
|
||||
|
||||
if assert_data_ranges:
|
||||
images_batch = images[0:batch_size]
|
||||
min_data_range_check = tf.assert_greater_equal(images_batch, 0.0)
|
||||
max_data_range_check = tf.assert_less_equal(images_batch, 1.0)
|
||||
control_deps = [min_data_range_check, max_data_range_check]
|
||||
else:
|
||||
control_deps = []
|
||||
|
||||
# Inception module does resizing, if necessary.
|
||||
with tf.control_dependencies(control_deps):
|
||||
return tfgan.eval.run_inception(images, num_batches=num_batches)
|
||||
|
||||
|
||||
def compute_fid(
|
||||
real_images, other, max_classifier_batch_size=16,
|
||||
assert_data_ranges=True):
|
||||
"""Computes the classifier score, using the given model.
|
||||
|
||||
Note that this metric is highly sensitive to the number of images used
|
||||
to compute it: more is better. A standard number is 50K images.
|
||||
|
||||
For more details see: https://arxiv.org/abs/1606.03498.
|
||||
|
||||
Args:
|
||||
real_images: A 4D tensor. Samples from the true data distribution.
|
||||
other: A 4D tensor. Data for which to compute the FID.
|
||||
max_classifier_batch_size: An integer. The maximum batch size size to pass
|
||||
through the classifier used to compute the metric.
|
||||
assert_data_ranges: Whether or not to assert the input is in a valid
|
||||
data range.
|
||||
|
||||
Returns:
|
||||
A scalar `tf.Tensor` with the Inception score.
|
||||
"""
|
||||
num_images = real_images.shape.as_list()[0]
|
||||
|
||||
def _choose_batch_size(num_images, max_batch_size):
|
||||
for size in range(max_batch_size, 0, -1):
|
||||
if num_images % size == 0:
|
||||
return size
|
||||
|
||||
batch_size = _choose_batch_size(num_images, max_classifier_batch_size)
|
||||
tf.logging.debug('Using batch_size=%s for score classifier', batch_size)
|
||||
num_batches = num_images // batch_size
|
||||
|
||||
if assert_data_ranges:
|
||||
images_batch = tf.concat(
|
||||
[real_images[0:batch_size], other[0:batch_size]], axis=0)
|
||||
min_data_range_check = tf.assert_greater_equal(images_batch, 0.0)
|
||||
max_data_range_check = tf.assert_less_equal(images_batch, 1.0)
|
||||
control_deps = [min_data_range_check, max_data_range_check]
|
||||
else:
|
||||
control_deps = []
|
||||
|
||||
# Inception module does resizing, if necessary.
|
||||
with tf.control_dependencies(control_deps):
|
||||
return tfgan.eval.frechet_inception_distance(
|
||||
real_images, other, num_batches=num_batches)
|
||||
|
||||
|
||||
def generate_big_batch(generator, generator_inputs, max_num_samples=100):
|
||||
"""Generate samples when the number of samples is too big to do one pass."""
|
||||
num_samples = generator_inputs.shape.as_list()[0]
|
||||
max_num_samples = min(max_num_samples, num_samples)
|
||||
batched_shape = [num_samples // max_num_samples, max_num_samples]
|
||||
batched_shape += generator_inputs.shape.as_list()[1:]
|
||||
batched_generator_inputs = tf.reshape(generator_inputs, batched_shape)
|
||||
|
||||
# Create the samples by sequentially doing forward passes through the
|
||||
# generator. This ensures we do not run out of memory.
|
||||
output_samples = tf.map_fn(
|
||||
fn=generator,
|
||||
elems=batched_generator_inputs,
|
||||
parallel_iterations=1,
|
||||
back_prop=False,
|
||||
swap_memory=True)
|
||||
|
||||
samples = tf.reshape(output_samples,
|
||||
[num_samples] + output_samples.shape.as_list()[2:])
|
||||
return samples
|
||||
|
||||
|
||||
def get_image_metrics(real_images, other):
|
||||
return {
|
||||
'inception_score': compute_inception_score(other),
|
||||
'fid': compute_fid(real_images, other)}
|
||||
|
||||
|
||||
def get_image_metrics_for_samples(
|
||||
real_images, generator, prior, data_processor, num_eval_samples):
|
||||
generator_inputs = prior.sample(num_eval_samples)
|
||||
samples = generate_big_batch(generator, generator_inputs)
|
||||
# Ensure the samples are in [0, 1].
|
||||
samples = data_processor.postprocess(samples)
|
||||
"""Compute inception score and FID."""
|
||||
max_classifier_batch = 10
|
||||
num_batches = num_eval_samples // max_classifier_batch
|
||||
|
||||
def sample_fn(arg):
|
||||
del arg
|
||||
samples = generator(prior.sample(max_classifier_batch))
|
||||
# Ensure data is in [0, 1], as expected by TFGAN.
|
||||
# Resizing to appropriate size is done by TFGAN.
|
||||
return data_processor.postprocess(samples)
|
||||
|
||||
fake_outputs = tfgan.eval.sample_and_run_inception(
|
||||
sample_fn,
|
||||
sample_inputs=[1.0] * num_batches) # Dummy inputs.
|
||||
|
||||
fake_logits = fake_outputs['logits']
|
||||
inception_score = tfgan.eval.classifier_score_from_logits(fake_logits)
|
||||
|
||||
real_outputs = tfgan.eval.run_inception(real_images, num_batches=num_batches)
|
||||
fid = tfgan.eval.frechet_classifier_distance_from_activations(
|
||||
real_outputs['pool_3'], fake_outputs['pool_3'])
|
||||
|
||||
return {
|
||||
'inception_score': compute_inception_score(samples),
|
||||
'fid': compute_fid(real_images, samples)}
|
||||
'inception_score': inception_score,
|
||||
'fid': fid}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
absl-py==0.7.1
|
||||
dm-sonnet==1.34
|
||||
numpy==1.16.4
|
||||
pillow==4.3.0
|
||||
tensorflow==1.14
|
||||
Pillow
|
||||
tensorflow==1.15rc2
|
||||
tensorflow-probability==0.7.0
|
||||
tensorflow-gan==1.0.0.dev0
|
||||
tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow.
|
||||
tensorflow-gan==2.0.0
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
python3 -m venv cs_gan_venv
|
||||
source cs_gan_venv/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r cs_gan/requirements.txt
|
||||
|
||||
python -m cs_gan.main_cs
|
||||
|
||||
Reference in New Issue
Block a user