diff --git a/byol/utils/dataset.py b/byol/utils/dataset.py index 7386f88..9a42df4 100644 --- a/byol/utils/dataset.py +++ b/byol/utils/dataset.py @@ -87,7 +87,7 @@ def load(split: Split, split=tfds_split, decoders={'image': tfds.decode.SkipDecoding()}) - options = ds.options() + options = tf.data.Options() options.experimental_threading.private_threadpool_size = 48 options.experimental_threading.max_intra_op_parallelism = 1 @@ -103,6 +103,8 @@ def load(split: Split, if split.num_examples % total_batch_size != 0: raise ValueError(f'Test/valid must be divisible by {total_batch_size}') + ds = ds.with_options(options) + def preprocess_pretrain(example): view1 = _preprocess_image(example['image'], mode=preprocess_mode) view2 = _preprocess_image(example['image'], mode=preprocess_mode)