Use jax.tree_util.tree_map in place of deprecated tree_multimap.

The latter is a simple alias of the former, so this change is a no-op.

PiperOrigin-RevId: 461229165
This commit is contained in:
Jake VanderPlas
2022-07-15 21:14:58 +01:00
committed by Saran Tunyasuvunakool
parent 956c4b5d9c
commit 6fcb84268e
19 changed files with 42 additions and 42 deletions
+1 -1
View File
@@ -277,7 +277,7 @@ class UnbiasedExponentialWeightedAverageAgentTracker:
# 0 * NaN == NaN just replace self._statistics on the first step.
self._statistics = dict(agent.statistics)
else:
self._statistics = jax.tree_multimap(
self._statistics = jax.tree_map(
lambda s, x: (1 - final_step_size) * s + final_step_size * x,
self._statistics, agent.statistics)