From a7d75013c9a52cbfed56c3916bef469229088096 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 10 Sep 2021 00:32:39 +0100 Subject: [PATCH] Remove users of jax.api.* symbols, in preparation for removing the deprecated jax.api name. In most cases, use the public jax.* name instead. PiperOrigin-RevId: 395812511 --- byol/utils/helpers.py | 2 +- nfnets/utils.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/byol/utils/helpers.py b/byol/utils/helpers.py index 3243bf6..1a66c3c 100644 --- a/byol/utils/helpers.py +++ b/byol/utils/helpers.py @@ -119,7 +119,7 @@ def bcast_local_devices(value): def _replicate(x): """Replicate an object on each device.""" x = jnp.array(x) - return jax.api.device_put_sharded(len(devices) * [x], devices) + return jax.device_put_sharded(len(devices) * [x], devices) return jax.tree_util.tree_map(_replicate, value) diff --git a/nfnets/utils.py b/nfnets/utils.py index f344a76..a2afd1c 100644 --- a/nfnets/utils.py +++ b/nfnets/utils.py @@ -91,7 +91,7 @@ def _replicate(x, devices=None): x = jax.numpy.array(x) if devices is None: devices = jax.local_devices() - return jax.api.device_put_sharded(len(devices) * [x], devices) + return jax.device_put_sharded(len(devices) * [x], devices) def broadcast(obj): @@ -124,5 +124,3 @@ def flatten_haiku_tree(haiku_dict): out_key = f'{out_module}.{key}' out[out_key] = haiku_dict[module][key] return out - -