From 0e5237df2a8fa7b6ed6a0f4ff2a7c46829ddfa6b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 18 Sep 2020 21:21:27 +0100 Subject: [PATCH] Use jax.api.device_put_sharded() in place of private JAX APIs. PiperOrigin-RevId: 332514384 --- byol/utils/helpers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/byol/utils/helpers.py b/byol/utils/helpers.py index 81a9224..3243bf6 100644 --- a/byol/utils/helpers.py +++ b/byol/utils/helpers.py @@ -119,9 +119,7 @@ def bcast_local_devices(value): def _replicate(x): """Replicate an object on each device.""" x = jnp.array(x) - aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype) - buffers = [jax.interpreters.xla.device_put(x, d) for d in devices] - return jax.pxla.ShardedDeviceArray(aval, buffers) + return jax.api.device_put_sharded(len(devices) * [x], devices) return jax.tree_util.tree_map(_replicate, value)