diff --git a/cs_gan/image_metrics.py b/cs_gan/image_metrics.py index 34ba3e1..b601a26 100644 --- a/cs_gan/image_metrics.py +++ b/cs_gan/image_metrics.py @@ -16,135 +16,33 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import tensorflow as tf import tensorflow_gan as tfgan -def compute_inception_score( - images, max_classifier_batch_size=16, assert_data_ranges=True): - """Computes the classifier score, using the given model. - - Note that this metric is highly sensitive to the number of images used - to compute it: more is better. A standard number is 50K images. - - For more details see: https://arxiv.org/abs/1606.03498. - - Args: - images: A 4D tensor. - max_classifier_batch_size: An integer. The maximum batch size size to pass - through the classifier used to compute the metric. - assert_data_ranges: Whether or not to assert the input is in a valid - data range. - - Returns: - A scalar `tf.Tensor` with the Inception score. - """ - num_images = images.shape.as_list()[0] - - def _choose_batch_size(num_images, max_batch_size): - for size in range(max_batch_size, 0, -1): - if num_images % size == 0: - return size - - batch_size = _choose_batch_size(num_images, max_classifier_batch_size) - tf.logging.debug('Using batch_size=%s for score classifier', batch_size) - num_batches = num_images // batch_size - - if assert_data_ranges: - images_batch = images[0:batch_size] - min_data_range_check = tf.assert_greater_equal(images_batch, 0.0) - max_data_range_check = tf.assert_less_equal(images_batch, 1.0) - control_deps = [min_data_range_check, max_data_range_check] - else: - control_deps = [] - - # Inception module does resizing, if necessary. - with tf.control_dependencies(control_deps): - return tfgan.eval.run_inception(images, num_batches=num_batches) - - -def compute_fid( - real_images, other, max_classifier_batch_size=16, - assert_data_ranges=True): - """Computes the classifier score, using the given model. - - Note that this metric is highly sensitive to the number of images used - to compute it: more is better. A standard number is 50K images. - - For more details see: https://arxiv.org/abs/1606.03498. - - Args: - real_images: A 4D tensor. Samples from the true data distribution. - other: A 4D tensor. Data for which to compute the FID. - max_classifier_batch_size: An integer. The maximum batch size size to pass - through the classifier used to compute the metric. - assert_data_ranges: Whether or not to assert the input is in a valid - data range. - - Returns: - A scalar `tf.Tensor` with the Inception score. - """ - num_images = real_images.shape.as_list()[0] - - def _choose_batch_size(num_images, max_batch_size): - for size in range(max_batch_size, 0, -1): - if num_images % size == 0: - return size - - batch_size = _choose_batch_size(num_images, max_classifier_batch_size) - tf.logging.debug('Using batch_size=%s for score classifier', batch_size) - num_batches = num_images // batch_size - - if assert_data_ranges: - images_batch = tf.concat( - [real_images[0:batch_size], other[0:batch_size]], axis=0) - min_data_range_check = tf.assert_greater_equal(images_batch, 0.0) - max_data_range_check = tf.assert_less_equal(images_batch, 1.0) - control_deps = [min_data_range_check, max_data_range_check] - else: - control_deps = [] - - # Inception module does resizing, if necessary. - with tf.control_dependencies(control_deps): - return tfgan.eval.frechet_inception_distance( - real_images, other, num_batches=num_batches) - - -def generate_big_batch(generator, generator_inputs, max_num_samples=100): - """Generate samples when the number of samples is too big to do one pass.""" - num_samples = generator_inputs.shape.as_list()[0] - max_num_samples = min(max_num_samples, num_samples) - batched_shape = [num_samples // max_num_samples, max_num_samples] - batched_shape += generator_inputs.shape.as_list()[1:] - batched_generator_inputs = tf.reshape(generator_inputs, batched_shape) - - # Create the samples by sequentially doing forward passes through the - # generator. This ensures we do not run out of memory. - output_samples = tf.map_fn( - fn=generator, - elems=batched_generator_inputs, - parallel_iterations=1, - back_prop=False, - swap_memory=True) - - samples = tf.reshape(output_samples, - [num_samples] + output_samples.shape.as_list()[2:]) - return samples - - -def get_image_metrics(real_images, other): - return { - 'inception_score': compute_inception_score(other), - 'fid': compute_fid(real_images, other)} - - def get_image_metrics_for_samples( real_images, generator, prior, data_processor, num_eval_samples): - generator_inputs = prior.sample(num_eval_samples) - samples = generate_big_batch(generator, generator_inputs) - # Ensure the samples are in [0, 1]. - samples = data_processor.postprocess(samples) + """Compute inception score and FID.""" + max_classifier_batch = 10 + num_batches = num_eval_samples // max_classifier_batch + + def sample_fn(arg): + del arg + samples = generator(prior.sample(max_classifier_batch)) + # Ensure data is in [0, 1], as expected by TFGAN. + # Resizing to appropriate size is done by TFGAN. + return data_processor.postprocess(samples) + + fake_outputs = tfgan.eval.sample_and_run_inception( + sample_fn, + sample_inputs=[1.0] * num_batches) # Dummy inputs. + + fake_logits = fake_outputs['logits'] + inception_score = tfgan.eval.classifier_score_from_logits(fake_logits) + + real_outputs = tfgan.eval.run_inception(real_images, num_batches=num_batches) + fid = tfgan.eval.frechet_classifier_distance_from_activations( + real_outputs['pool_3'], fake_outputs['pool_3']) return { - 'inception_score': compute_inception_score(samples), - 'fid': compute_fid(real_images, samples)} + 'inception_score': inception_score, + 'fid': fid} diff --git a/cs_gan/requirements.txt b/cs_gan/requirements.txt index ec1d1a0..d571542 100644 --- a/cs_gan/requirements.txt +++ b/cs_gan/requirements.txt @@ -1,8 +1,7 @@ absl-py==0.7.1 dm-sonnet==1.34 numpy==1.16.4 -pillow==4.3.0 -tensorflow==1.14 +Pillow +tensorflow==1.15rc2 tensorflow-probability==0.7.0 -tensorflow-gan==1.0.0.dev0 -tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. +tensorflow-gan==2.0.0 diff --git a/cs_gan/run.sh b/cs_gan/run.sh index fd2997b..2a2afef 100755 --- a/cs_gan/run.sh +++ b/cs_gan/run.sh @@ -15,6 +15,7 @@ python3 -m venv cs_gan_venv source cs_gan_venv/bin/activate +pip install --upgrade pip pip install -r cs_gan/requirements.txt python -m cs_gan.main_cs