Switch dataset to the new tf.data.Dataset API.

PiperOrigin-RevId: 369679305
This commit is contained in:
Florent Altché
2021-04-21 17:07:23 +00:00
committed by Louise Deason
parent 3dc0baece1
commit 2e866f1937
+3 -1
View File
@@ -87,7 +87,7 @@ def load(split: Split,
split=tfds_split,
decoders={'image': tfds.decode.SkipDecoding()})
options = ds.options()
options = tf.data.Options()
options.experimental_threading.private_threadpool_size = 48
options.experimental_threading.max_intra_op_parallelism = 1
@@ -103,6 +103,8 @@ def load(split: Split,
if split.num_examples % total_batch_size != 0:
raise ValueError(f'Test/valid must be divisible by {total_batch_size}')
ds = ds.with_options(options)
def preprocess_pretrain(example):
view1 = _preprocess_image(example['image'], mode=preprocess_mode)
view2 = _preprocess_image(example['image'], mode=preprocess_mode)