Internal change

PiperOrigin-RevId: 419592537
This commit is contained in:
Tom Hennigan
2022-01-04 15:45:11 +00:00
committed by Diego de Las Casas
parent 41256fd5e9
commit ca3f9e9599
+2 -17
View File
@@ -179,7 +179,7 @@ def load(
def cast_fn(batch): def cast_fn(batch):
batch = dict(**batch) batch = dict(**batch)
batch['images'] = tf.cast(batch['images'], _to_tf_dtype(dtype)) batch['images'] = tf.cast(batch['images'], tf.dtypes.as_dtype(dtype))
return batch return batch
for i, batch_size in enumerate(reversed(batch_dims)): for i, batch_size in enumerate(reversed(batch_dims)):
@@ -220,15 +220,7 @@ def load(
ds = ds.prefetch(AUTOTUNE) ds = ds.prefetch(AUTOTUNE)
ds = tfds.as_numpy(ds) ds = tfds.as_numpy(ds)
yield from ds
if dtype == jnp.bfloat16:
# JAX and TF disagree on the NumPy bfloat16 type so we need to reinterpret
# tf_bfloat16->jnp.bfloat16.
for batch in ds:
batch['images'] = batch['images'].view(jnp.bfloat16)
yield batch
else:
yield from ds
def cutmix_padding(h, w): def cutmix_padding(h, w):
@@ -329,13 +321,6 @@ def my_mixup_cutmix(batch):
'ratio': tf.concat([mixup_ratio[..., 0, 0, 0], cutmix_ratio], axis=0)} 'ratio': tf.concat([mixup_ratio[..., 0, 0, 0], cutmix_ratio], axis=0)}
def _to_tf_dtype(jax_dtype: jnp.dtype) -> tf.DType:
if jax_dtype == jnp.bfloat16:
return tf.bfloat16
else:
return tf.dtypes.as_dtype(jax_dtype)
def _to_tfds_split(split: Split) -> tfds.Split: def _to_tfds_split(split: Split) -> tfds.Split:
"""Returns the TFDS split appropriately sharded.""" """Returns the TFDS split appropriately sharded."""
if split in (Split.TRAIN, Split.TRAIN_AND_VALID, Split.VALID): if split in (Split.TRAIN, Split.TRAIN_AND_VALID, Split.VALID):