Fix omniglot case to reflect tfds interface.

PiperOrigin-RevId: 297126345
This commit is contained in:
Dushyant Rao
2020-02-25 16:50:29 +00:00
committed by Diego de Las Casas
parent d0efbec03a
commit 16a11d4c78
+5 -5
View File
@@ -193,14 +193,14 @@ def get_data_sources(dataset, dataset_kwargs, batch_size, test_batch_size,
name=dataset, split=tfds.Split.VALIDATION, **dataset_kwargs)
num_valid_examples = ds_info.splits[tfds.Split.VALIDATION].num_examples
assert (num_valid_examples %
test_batch_size == 0), ('test_batch_size must be a multiple of %d' %
test_batch_size == 0), ('test_batch_size must be a divisor of %d' %
num_valid_examples)
valid_dataset = valid_dataset.repeat(1).batch(
test_batch_size, drop_remainder=True)
valid_dataset = valid_dataset.map(preprocess_data)
valid_iter = valid_dataset.make_initializable_iterator()
valid_data = valid_iter.get_next()
except KeyError:
except (KeyError, ValueError):
logging.warning('No validation set!!')
valid_iter = None
valid_data = None
@@ -210,7 +210,7 @@ def get_data_sources(dataset, dataset_kwargs, batch_size, test_batch_size,
name=dataset, split=tfds.Split.TEST, **dataset_kwargs)
num_test_examples = ds_info.splits['test'].num_examples
assert (num_test_examples %
test_batch_size == 0), ('test_batch_size must be a multiple of %d' %
test_batch_size == 0), ('test_batch_size must be a divisor of %d' %
num_test_examples)
test_dataset = test_dataset.repeat(1).batch(
test_batch_size, drop_remainder=True)
@@ -542,8 +542,8 @@ def run_training(
label_key = 'label'
elif dataset == 'omniglot':
batch_size = 15
test_batch_size = 8115
dataset_kwargs = {'split': 'instance', 'label': 'alphabet'}
test_batch_size = 1318
dataset_kwargs = {}
image_key = 'image'
label_key = 'alphabet'
else: