diff --git a/cs_gan/image_metrics.py b/cs_gan/image_metrics.py index b601a26..1cf2f41 100644 --- a/cs_gan/image_metrics.py +++ b/cs_gan/image_metrics.py @@ -28,9 +28,9 @@ def get_image_metrics_for_samples( def sample_fn(arg): del arg samples = generator(prior.sample(max_classifier_batch)) - # Ensure data is in [0, 1], as expected by TFGAN. + # Samples must be in [-1, 1], as expected by TFGAN. # Resizing to appropriate size is done by TFGAN. - return data_processor.postprocess(samples) + return samples fake_outputs = tfgan.eval.sample_and_run_inception( sample_fn, @@ -39,7 +39,8 @@ def get_image_metrics_for_samples( fake_logits = fake_outputs['logits'] inception_score = tfgan.eval.classifier_score_from_logits(fake_logits) - real_outputs = tfgan.eval.run_inception(real_images, num_batches=num_batches) + real_outputs = tfgan.eval.run_inception( + data_processor.preprocess(real_images), num_batches=num_batches) fid = tfgan.eval.frechet_classifier_distance_from_activations( real_outputs['pool_3'], fake_outputs['pool_3'])