Use jax.tree_util.tree_map in place of deprecated tree_multimap.

The latter is a simple alias of the former, so this change is a no-op.

PiperOrigin-RevId: 461229165
This commit is contained in:
Jake VanderPlas
2022-07-15 21:14:58 +01:00
committed by Saran Tunyasuvunakool
parent 956c4b5d9c
commit 6fcb84268e
19 changed files with 42 additions and 42 deletions
+3 -3
View File
@@ -323,7 +323,7 @@ class ByolExperiment:
# cross-device grad and logs reductions
grads = jax.tree_map(lambda v: jax.lax.pmean(v, axis_name='i'), grads)
logs = jax.tree_multimap(lambda x: jax.lax.pmean(x, axis_name='i'), logs)
logs = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='i'), logs)
learning_rate = schedules.learning_schedule(
global_step,
@@ -339,7 +339,7 @@ class ByolExperiment:
global_step,
base_ema=self._base_target_ema,
max_steps=self._max_steps)
target_params = jax.tree_multimap(lambda x, y: x + (1 - tau) * (y - x),
target_params = jax.tree_map(lambda x, y: x + (1 - tau) * (y - x),
target_params, online_params)
logs['tau'] = tau
logs['learning_rate'] = learning_rate
@@ -518,7 +518,7 @@ class ByolExperiment:
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars
+1 -1
View File
@@ -471,7 +471,7 @@ class EvalExperiment:
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars
+5 -5
View File
@@ -51,7 +51,7 @@ def _partial_update(updates: optax.Updates,
m = m.astype(g.dtype)
return g * (1. - m) + t * m
return jax.tree_multimap(_update_fn, updates, new_updates, params_to_filter)
return jax.tree_map(_update_fn, updates, new_updates, params_to_filter)
class ScaleByLarsState(NamedTuple):
@@ -78,7 +78,7 @@ def scale_by_lars(
"""
def init_fn(params: optax.Params) -> ScaleByLarsState:
mu = jax.tree_multimap(jnp.zeros_like, params) # momentum
mu = jax.tree_map(jnp.zeros_like, params) # momentum
return ScaleByLarsState(mu=mu)
def update_fn(updates: optax.Updates, state: ScaleByLarsState,
@@ -95,10 +95,10 @@ def scale_by_lars(
jnp.where(update_norm > 0,
(eta * param_norm / update_norm), 1.0), 1.0)
adapted_updates = jax.tree_multimap(lars_adaptation, updates, params)
adapted_updates = jax.tree_map(lars_adaptation, updates, params)
adapted_updates = _partial_update(updates, adapted_updates, params,
filter_fn)
mu = jax.tree_multimap(lambda g, t: momentum * g + t,
mu = jax.tree_map(lambda g, t: momentum * g + t,
state.mu, adapted_updates)
return mu, ScaleByLarsState(mu=mu)
@@ -130,7 +130,7 @@ def add_weight_decay(
state: AddWeightDecayState,
params: optax.Params,
) -> Tuple[optax.Updates, AddWeightDecayState]:
new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
new_updates = jax.tree_map(lambda g, p: g + weight_decay * p, updates,
params)
new_updates = _partial_update(updates, new_updates, params, filter_fn)
return new_updates, state