diff --git a/byol/utils/helpers.py b/byol/utils/helpers.py index 81a9224..3243bf6 100644 --- a/byol/utils/helpers.py +++ b/byol/utils/helpers.py @@ -119,9 +119,7 @@ def bcast_local_devices(value): def _replicate(x): """Replicate an object on each device.""" x = jnp.array(x) - aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype) - buffers = [jax.interpreters.xla.device_put(x, d) for d in devices] - return jax.pxla.ShardedDeviceArray(aval, buffers) + return jax.api.device_put_sharded(len(devices) * [x], devices) return jax.tree_util.tree_map(_replicate, value)