mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-31 04:45:31 +08:00
Use jax.tree_util.tree_map in place of deprecated tree_multimap.
The latter is a simple alias of the former, so this change is a no-op. PiperOrigin-RevId: 461229165
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
956c4b5d9c
commit
6fcb84268e
@@ -527,7 +527,7 @@ class Experiment(experiment.AbstractExperiment):
|
||||
if summed_scalars is None:
|
||||
summed_scalars = scalars
|
||||
else:
|
||||
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
|
||||
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
|
||||
|
||||
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
|
||||
return mean_scalars
|
||||
|
||||
@@ -192,7 +192,7 @@ def add_weight_decay(
|
||||
|
||||
u_ex, u_in = hk.data_structures.partition(exclude, updates)
|
||||
_, p_in = hk.data_structures.partition(exclude, params)
|
||||
u_in = jax.tree_multimap(lambda g, p: g + weight_decay * p, u_in, p_in)
|
||||
u_in = jax.tree_map(lambda g, p: g + weight_decay * p, u_in, p_in)
|
||||
updates = hk.data_structures.merge(u_ex, u_in)
|
||||
return updates, state
|
||||
|
||||
|
||||
Reference in New Issue
Block a user