mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-02 14:45:25 +08:00
Use jax.tree_util.tree_map in place of deprecated tree_multimap.
The latter is a simple alias of the former, so this change is a no-op. PiperOrigin-RevId: 461229165
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
956c4b5d9c
commit
6fcb84268e
@@ -323,7 +323,7 @@ class ByolExperiment:
|
||||
|
||||
# cross-device grad and logs reductions
|
||||
grads = jax.tree_map(lambda v: jax.lax.pmean(v, axis_name='i'), grads)
|
||||
logs = jax.tree_multimap(lambda x: jax.lax.pmean(x, axis_name='i'), logs)
|
||||
logs = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='i'), logs)
|
||||
|
||||
learning_rate = schedules.learning_schedule(
|
||||
global_step,
|
||||
@@ -339,7 +339,7 @@ class ByolExperiment:
|
||||
global_step,
|
||||
base_ema=self._base_target_ema,
|
||||
max_steps=self._max_steps)
|
||||
target_params = jax.tree_multimap(lambda x, y: x + (1 - tau) * (y - x),
|
||||
target_params = jax.tree_map(lambda x, y: x + (1 - tau) * (y - x),
|
||||
target_params, online_params)
|
||||
logs['tau'] = tau
|
||||
logs['learning_rate'] = learning_rate
|
||||
@@ -518,7 +518,7 @@ class ByolExperiment:
|
||||
if summed_scalars is None:
|
||||
summed_scalars = scalars
|
||||
else:
|
||||
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
|
||||
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
|
||||
|
||||
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
|
||||
return mean_scalars
|
||||
|
||||
@@ -471,7 +471,7 @@ class EvalExperiment:
|
||||
if summed_scalars is None:
|
||||
summed_scalars = scalars
|
||||
else:
|
||||
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
|
||||
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
|
||||
|
||||
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
|
||||
return mean_scalars
|
||||
|
||||
@@ -51,7 +51,7 @@ def _partial_update(updates: optax.Updates,
|
||||
m = m.astype(g.dtype)
|
||||
return g * (1. - m) + t * m
|
||||
|
||||
return jax.tree_multimap(_update_fn, updates, new_updates, params_to_filter)
|
||||
return jax.tree_map(_update_fn, updates, new_updates, params_to_filter)
|
||||
|
||||
|
||||
class ScaleByLarsState(NamedTuple):
|
||||
@@ -78,7 +78,7 @@ def scale_by_lars(
|
||||
"""
|
||||
|
||||
def init_fn(params: optax.Params) -> ScaleByLarsState:
|
||||
mu = jax.tree_multimap(jnp.zeros_like, params) # momentum
|
||||
mu = jax.tree_map(jnp.zeros_like, params) # momentum
|
||||
return ScaleByLarsState(mu=mu)
|
||||
|
||||
def update_fn(updates: optax.Updates, state: ScaleByLarsState,
|
||||
@@ -95,10 +95,10 @@ def scale_by_lars(
|
||||
jnp.where(update_norm > 0,
|
||||
(eta * param_norm / update_norm), 1.0), 1.0)
|
||||
|
||||
adapted_updates = jax.tree_multimap(lars_adaptation, updates, params)
|
||||
adapted_updates = jax.tree_map(lars_adaptation, updates, params)
|
||||
adapted_updates = _partial_update(updates, adapted_updates, params,
|
||||
filter_fn)
|
||||
mu = jax.tree_multimap(lambda g, t: momentum * g + t,
|
||||
mu = jax.tree_map(lambda g, t: momentum * g + t,
|
||||
state.mu, adapted_updates)
|
||||
return mu, ScaleByLarsState(mu=mu)
|
||||
|
||||
@@ -130,7 +130,7 @@ def add_weight_decay(
|
||||
state: AddWeightDecayState,
|
||||
params: optax.Params,
|
||||
) -> Tuple[optax.Updates, AddWeightDecayState]:
|
||||
new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
|
||||
new_updates = jax.tree_map(lambda g, p: g + weight_decay * p, updates,
|
||||
params)
|
||||
new_updates = _partial_update(updates, new_updates, params, filter_fn)
|
||||
return new_updates, state
|
||||
|
||||
Reference in New Issue
Block a user