Internal change

PiperOrigin-RevId: 419592537
This commit is contained in:
Tom Hennigan
2022-01-04 15:45:11 +00:00
committed by Diego de Las Casas
parent 41256fd5e9
commit ca3f9e9599

View File

@@ -179,7 +179,7 @@ def load(
def cast_fn(batch):
batch = dict(**batch)
batch['images'] = tf.cast(batch['images'], _to_tf_dtype(dtype))
batch['images'] = tf.cast(batch['images'], tf.dtypes.as_dtype(dtype))
return batch
for i, batch_size in enumerate(reversed(batch_dims)):
@@ -220,15 +220,7 @@ def load(
ds = ds.prefetch(AUTOTUNE)
ds = tfds.as_numpy(ds)
if dtype == jnp.bfloat16:
# JAX and TF disagree on the NumPy bfloat16 type so we need to reinterpret
# tf_bfloat16->jnp.bfloat16.
for batch in ds:
batch['images'] = batch['images'].view(jnp.bfloat16)
yield batch
else:
yield from ds
yield from ds
def cutmix_padding(h, w):
@@ -329,13 +321,6 @@ def my_mixup_cutmix(batch):
'ratio': tf.concat([mixup_ratio[..., 0, 0, 0], cutmix_ratio], axis=0)}
def _to_tf_dtype(jax_dtype: jnp.dtype) -> tf.DType:
if jax_dtype == jnp.bfloat16:
return tf.bfloat16
else:
return tf.dtypes.as_dtype(jax_dtype)
def _to_tfds_split(split: Split) -> tfds.Split:
"""Returns the TFDS split appropriately sharded."""
if split in (Split.TRAIN, Split.TRAIN_AND_VALID, Split.VALID):