mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Internal change
PiperOrigin-RevId: 419592537
This commit is contained in:
committed by
Diego de Las Casas
parent
41256fd5e9
commit
ca3f9e9599
@@ -179,7 +179,7 @@ def load(
|
||||
|
||||
def cast_fn(batch):
|
||||
batch = dict(**batch)
|
||||
batch['images'] = tf.cast(batch['images'], _to_tf_dtype(dtype))
|
||||
batch['images'] = tf.cast(batch['images'], tf.dtypes.as_dtype(dtype))
|
||||
return batch
|
||||
|
||||
for i, batch_size in enumerate(reversed(batch_dims)):
|
||||
@@ -220,15 +220,7 @@ def load(
|
||||
|
||||
ds = ds.prefetch(AUTOTUNE)
|
||||
ds = tfds.as_numpy(ds)
|
||||
|
||||
if dtype == jnp.bfloat16:
|
||||
# JAX and TF disagree on the NumPy bfloat16 type so we need to reinterpret
|
||||
# tf_bfloat16->jnp.bfloat16.
|
||||
for batch in ds:
|
||||
batch['images'] = batch['images'].view(jnp.bfloat16)
|
||||
yield batch
|
||||
else:
|
||||
yield from ds
|
||||
yield from ds
|
||||
|
||||
|
||||
def cutmix_padding(h, w):
|
||||
@@ -329,13 +321,6 @@ def my_mixup_cutmix(batch):
|
||||
'ratio': tf.concat([mixup_ratio[..., 0, 0, 0], cutmix_ratio], axis=0)}
|
||||
|
||||
|
||||
def _to_tf_dtype(jax_dtype: jnp.dtype) -> tf.DType:
|
||||
if jax_dtype == jnp.bfloat16:
|
||||
return tf.bfloat16
|
||||
else:
|
||||
return tf.dtypes.as_dtype(jax_dtype)
|
||||
|
||||
|
||||
def _to_tfds_split(split: Split) -> tfds.Split:
|
||||
"""Returns the TFDS split appropriately sharded."""
|
||||
if split in (Split.TRAIN, Split.TRAIN_AND_VALID, Split.VALID):
|
||||
|
||||
Reference in New Issue
Block a user