Use jax.api.device_put_sharded() in place of private JAX APIs.

PiperOrigin-RevId: 332514384
This commit is contained in:
Jake VanderPlas
2020-09-18 21:21:27 +01:00
committed by Saran Tunyasuvunakool
parent 1d763e0beb
commit 0e5237df2a
+1 -3
View File
@@ -119,9 +119,7 @@ def bcast_local_devices(value):
def _replicate(x):
"""Replicate an object on each device."""
x = jnp.array(x)
aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype)
buffers = [jax.interpreters.xla.device_put(x, d) for d in devices]
return jax.pxla.ShardedDeviceArray(aval, buffers)
return jax.api.device_put_sharded(len(devices) * [x], devices)
return jax.tree_util.tree_map(_replicate, value)