mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-31 13:05:40 +08:00
f6da52cb38
The JAX operators: jax.ops.index_update(x, jax.ops.index[idx], y) jax.ops.index_add(x, jax.ops.index[idx], y) ... have long been deprecated in lieu of their more succinct counterparts: x.at[idx].set(y) x.at[idx].add(y) ... This change updates users of the deprecated APIs to use the current APIs, in preparation for removing the deprecated forms from JAX. The main subtlety is that if `x` is not a JAX array, we must cast it to one using `jnp.asarray(x)` before using the new form, since `.at[...]` is only defined on JAX arrays. PiperOrigin-RevId: 401233414