mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-31 04:45:31 +08:00
Use jax.api.device_put_sharded() in place of private JAX APIs.
PiperOrigin-RevId: 332514384
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
1d763e0beb
commit
0e5237df2a
@@ -119,9 +119,7 @@ def bcast_local_devices(value):
|
|||||||
def _replicate(x):
|
def _replicate(x):
|
||||||
"""Replicate an object on each device."""
|
"""Replicate an object on each device."""
|
||||||
x = jnp.array(x)
|
x = jnp.array(x)
|
||||||
aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype)
|
return jax.api.device_put_sharded(len(devices) * [x], devices)
|
||||||
buffers = [jax.interpreters.xla.device_put(x, d) for d in devices]
|
|
||||||
return jax.pxla.ShardedDeviceArray(aval, buffers)
|
|
||||||
|
|
||||||
return jax.tree_util.tree_map(_replicate, value)
|
return jax.tree_util.tree_map(_replicate, value)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user