mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-02 22:46:46 +08:00
Use jax.tree_util.tree_map in place of deprecated tree_multimap.
The latter is a simple alias of the former, so this change is a no-op. PiperOrigin-RevId: 461229165
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
956c4b5d9c
commit
6fcb84268e
@@ -373,7 +373,7 @@ class Experiment(experiment.AbstractExperiment):
|
||||
if summed_scalars is None:
|
||||
summed_scalars = scalars
|
||||
else:
|
||||
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
|
||||
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
|
||||
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
|
||||
return mean_scalars
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ def ema_update(step: chex.Array,
|
||||
def _weighted_average(p1, p2):
|
||||
d = decay.astype(p1.dtype)
|
||||
return (1 - d) * p1 + d * p2
|
||||
return jax.tree_multimap(_weighted_average, new_params, avg_params)
|
||||
return jax.tree_map(_weighted_average, new_params, avg_params)
|
||||
|
||||
|
||||
def cutmix(rng: chex.PRNGKey,
|
||||
|
||||
Reference in New Issue
Block a user