mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-27 18:25:49 +08:00
Remove users of jax.api.* symbols, in preparation for removing the deprecated jax.api name.
In most cases, use the public jax.* name instead. PiperOrigin-RevId: 395812511
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
96a13847c6
commit
a7d75013c9
@@ -119,7 +119,7 @@ def bcast_local_devices(value):
|
|||||||
def _replicate(x):
|
def _replicate(x):
|
||||||
"""Replicate an object on each device."""
|
"""Replicate an object on each device."""
|
||||||
x = jnp.array(x)
|
x = jnp.array(x)
|
||||||
return jax.api.device_put_sharded(len(devices) * [x], devices)
|
return jax.device_put_sharded(len(devices) * [x], devices)
|
||||||
|
|
||||||
return jax.tree_util.tree_map(_replicate, value)
|
return jax.tree_util.tree_map(_replicate, value)
|
||||||
|
|
||||||
|
|||||||
+1
-3
@@ -91,7 +91,7 @@ def _replicate(x, devices=None):
|
|||||||
x = jax.numpy.array(x)
|
x = jax.numpy.array(x)
|
||||||
if devices is None:
|
if devices is None:
|
||||||
devices = jax.local_devices()
|
devices = jax.local_devices()
|
||||||
return jax.api.device_put_sharded(len(devices) * [x], devices)
|
return jax.device_put_sharded(len(devices) * [x], devices)
|
||||||
|
|
||||||
|
|
||||||
def broadcast(obj):
|
def broadcast(obj):
|
||||||
@@ -124,5 +124,3 @@ def flatten_haiku_tree(haiku_dict):
|
|||||||
out_key = f'{out_module}.{key}'
|
out_key = f'{out_module}.{key}'
|
||||||
out[out_key] = haiku_dict[module][key]
|
out[out_key] = haiku_dict[module][key]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user