Use jax.tree_util.tree_map in place of deprecated tree_multimap.

The latter is a simple alias of the former, so this change is a no-op.

PiperOrigin-RevId: 461229165
This commit is contained in:
Jake VanderPlas
2022-07-15 21:14:58 +01:00
committed by Saran Tunyasuvunakool
parent 956c4b5d9c
commit 6fcb84268e
19 changed files with 42 additions and 42 deletions
+1 -1
View File
@@ -527,7 +527,7 @@ class Experiment(experiment.AbstractExperiment):
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars
+1 -1
View File
@@ -192,7 +192,7 @@ def add_weight_decay(
u_ex, u_in = hk.data_structures.partition(exclude, updates)
_, p_in = hk.data_structures.partition(exclude, params)
u_in = jax.tree_multimap(lambda g, p: g + weight_decay * p, u_in, p_in)
u_in = jax.tree_map(lambda g, p: g + weight_decay * p, u_in, p_in)
updates = hk.data_structures.merge(u_ex, u_in)
return updates, state