Use jax.tree_util.tree_map in place of deprecated tree_multimap.

The latter is a simple alias of the former, so this change is a no-op.

PiperOrigin-RevId: 461229165
This commit is contained in:
Jake VanderPlas
2022-07-15 21:14:58 +01:00
committed by Saran Tunyasuvunakool
parent 956c4b5d9c
commit 6fcb84268e
19 changed files with 42 additions and 42 deletions
+1 -1
View File
@@ -373,7 +373,7 @@ class Experiment(experiment.AbstractExperiment):
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars
+1 -1
View File
@@ -124,7 +124,7 @@ def ema_update(step: chex.Array,
def _weighted_average(p1, p2):
d = decay.astype(p1.dtype)
return (1 - d) * p1 + d * p2
return jax.tree_multimap(_weighted_average, new_params, avg_params)
return jax.tree_map(_weighted_average, new_params, avg_params)
def cutmix(rng: chex.PRNGKey,