Files
Peter Hawkins f6da52cb38 [JAX] Replace uses of deprecated jax.ops.index_update(x, idx, y) APIs with their up-to-date, more succinct equivalent x.at[idx].set(y).
The JAX operators:
jax.ops.index_update(x, jax.ops.index[idx], y)
jax.ops.index_add(x, jax.ops.index[idx], y)
...

have long been deprecated in lieu of their more succinct counterparts:
x.at[idx].set(y)
x.at[idx].add(y)
...

This change updates users of the deprecated APIs to use the current APIs, in preparation for removing the deprecated forms from JAX.

The main subtlety is that if `x` is not a JAX array, we must cast it to one using `jnp.asarray(x)` before using the new form, since `.at[...]` is only defined on JAX arrays.

PiperOrigin-RevId: 401233414
2021-10-26 18:22:43 +01:00
..