Remove users of jax.api.* symbols, in preparation for removing the deprecated jax.api name.

In most cases, use the public jax.* name instead.

PiperOrigin-RevId: 395812511
This commit is contained in:
Peter Hawkins
2021-09-10 00:32:39 +01:00
committed by Saran Tunyasuvunakool
parent 96a13847c6
commit a7d75013c9
2 changed files with 2 additions and 4 deletions

View File

@@ -119,7 +119,7 @@ def bcast_local_devices(value):
def _replicate(x):
"""Replicate an object on each device."""
x = jnp.array(x)
return jax.api.device_put_sharded(len(devices) * [x], devices)
return jax.device_put_sharded(len(devices) * [x], devices)
return jax.tree_util.tree_map(_replicate, value)

View File

@@ -91,7 +91,7 @@ def _replicate(x, devices=None):
x = jax.numpy.array(x)
if devices is None:
devices = jax.local_devices()
return jax.api.device_put_sharded(len(devices) * [x], devices)
return jax.device_put_sharded(len(devices) * [x], devices)
def broadcast(obj):
@@ -124,5 +124,3 @@ def flatten_haiku_tree(haiku_dict):
out_key = f'{out_module}.{key}'
out[out_key] = haiku_dict[module][key]
return out