Use jax.tree_util.tree_map in place of deprecated tree_multimap.

The latter is a simple alias of the former, so this change is a no-op.

PiperOrigin-RevId: 461229165
This commit is contained in:
Jake VanderPlas
2022-07-15 21:14:58 +01:00
committed by Saran Tunyasuvunakool
parent 956c4b5d9c
commit 6fcb84268e
19 changed files with 42 additions and 42 deletions

View File

@@ -373,7 +373,7 @@ class Experiment(experiment.AbstractExperiment):
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars

View File

@@ -124,7 +124,7 @@ def ema_update(step: chex.Array,
def _weighted_average(p1, p2):
d = decay.astype(p1.dtype)
return (1 - d) * p1 + d * p2
return jax.tree_multimap(_weighted_average, new_params, avg_params)
return jax.tree_map(_weighted_average, new_params, avg_params)
def cutmix(rng: chex.PRNGKey,

View File

@@ -323,7 +323,7 @@ class ByolExperiment:
# cross-device grad and logs reductions
grads = jax.tree_map(lambda v: jax.lax.pmean(v, axis_name='i'), grads)
logs = jax.tree_multimap(lambda x: jax.lax.pmean(x, axis_name='i'), logs)
logs = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='i'), logs)
learning_rate = schedules.learning_schedule(
global_step,
@@ -339,7 +339,7 @@ class ByolExperiment:
global_step,
base_ema=self._base_target_ema,
max_steps=self._max_steps)
target_params = jax.tree_multimap(lambda x, y: x + (1 - tau) * (y - x),
target_params = jax.tree_map(lambda x, y: x + (1 - tau) * (y - x),
target_params, online_params)
logs['tau'] = tau
logs['learning_rate'] = learning_rate
@@ -518,7 +518,7 @@ class ByolExperiment:
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars

View File

@@ -471,7 +471,7 @@ class EvalExperiment:
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars

View File

@@ -51,7 +51,7 @@ def _partial_update(updates: optax.Updates,
m = m.astype(g.dtype)
return g * (1. - m) + t * m
return jax.tree_multimap(_update_fn, updates, new_updates, params_to_filter)
return jax.tree_map(_update_fn, updates, new_updates, params_to_filter)
class ScaleByLarsState(NamedTuple):
@@ -78,7 +78,7 @@ def scale_by_lars(
"""
def init_fn(params: optax.Params) -> ScaleByLarsState:
mu = jax.tree_multimap(jnp.zeros_like, params) # momentum
mu = jax.tree_map(jnp.zeros_like, params) # momentum
return ScaleByLarsState(mu=mu)
def update_fn(updates: optax.Updates, state: ScaleByLarsState,
@@ -95,10 +95,10 @@ def scale_by_lars(
jnp.where(update_norm > 0,
(eta * param_norm / update_norm), 1.0), 1.0)
adapted_updates = jax.tree_multimap(lars_adaptation, updates, params)
adapted_updates = jax.tree_map(lars_adaptation, updates, params)
adapted_updates = _partial_update(updates, adapted_updates, params,
filter_fn)
mu = jax.tree_multimap(lambda g, t: momentum * g + t,
mu = jax.tree_map(lambda g, t: momentum * g + t,
state.mu, adapted_updates)
return mu, ScaleByLarsState(mu=mu)
@@ -130,7 +130,7 @@ def add_weight_decay(
state: AddWeightDecayState,
params: optax.Params,
) -> Tuple[optax.Updates, AddWeightDecayState]:
new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
new_updates = jax.tree_map(lambda g, p: g + weight_decay * p, updates,
params)
new_updates = _partial_update(updates, new_updates, params, filter_fn)
return new_updates, state

View File

@@ -72,7 +72,7 @@ def adaptive_grad_clip(clip, eps=1e-3) -> optax.GradientTransformation:
# Maximum allowable norm
max_norm = jax.tree_map(lambda x: clip * jnp.maximum(x, eps), p_norm)
# If grad norm > clipping * param_norm, rescale
updates = jax.tree_multimap(my_clip, g_norm, max_norm, updates)
updates = jax.tree_map(my_clip, g_norm, max_norm, updates)
return updates, state
return optax.GradientTransformation(init_fn, update_fn)

View File

@@ -270,8 +270,8 @@ class Experiment(experiment.AbstractExperiment):
if ema_params is not None:
ema_fn = getattr(utils, self.config.get('which_ema', 'tf1_ema'))
ema = lambda x, y: ema_fn(x, y, self.config.ema_decay, global_step)
ema_params = jax.tree_multimap(ema, ema_params, params)
ema_states = jax.tree_multimap(ema, ema_states, states)
ema_params = jax.tree_map(ema, ema_params, params)
ema_states = jax.tree_map(ema, ema_states, states)
return {
'params': params,
'states': states,
@@ -354,7 +354,7 @@ class Experiment(experiment.AbstractExperiment):
if summed_metrics is None:
summed_metrics = metrics
else:
summed_metrics = jax.tree_multimap(jnp.add, summed_metrics, metrics)
summed_metrics = jax.tree_map(jnp.add, summed_metrics, metrics)
mean_metrics = jax.tree_map(lambda x: x / num_samples, summed_metrics)
return jax.device_get(mean_metrics)

View File

@@ -107,7 +107,7 @@ def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple:
def _map_concat(nests):
concat = lambda *args: np.concatenate(args)
return tree.tree_multimap(concat, *nests)
return tree.tree_map(concat, *nests)
return jraph.GraphsTuple(
n_node=np.concatenate([g.n_node for g in graphs]),

View File

@@ -487,8 +487,8 @@ class Experiment(experiment.AbstractExperiment):
))
ema_fn = (lambda x, y: # pylint:disable=g-long-lambda
schedules.apply_ema_decay(x, y, ema_rate))
ema_params = jax.tree_multimap(ema_fn, ema_params, params)
ema_network_state = jax.tree_multimap(
ema_params = jax.tree_map(ema_fn, ema_params, params)
ema_network_state = jax.tree_map(
ema_fn,
ema_network_state,
network_state,

View File

@@ -107,7 +107,7 @@ def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple:
def _map_concat(nests):
concat = lambda *args: np.concatenate(args)
return tree.tree_multimap(concat, *nests)
return tree.tree_map(concat, *nests)
return jraph.GraphsTuple(
n_node=np.concatenate([g.n_node for g in graphs]),

View File

@@ -229,8 +229,8 @@ class Experiment(experiment.AbstractExperiment):
params = optax.apply_updates(params, updates)
if ema_params is not None:
ema = lambda x, y: tf1_ema(x, y, self.config.ema_decay, global_step)
ema_params = jax.tree_multimap(ema, ema_params, params)
ema_network_state = jax.tree_multimap(ema, ema_network_state,
ema_params = jax.tree_map(ema, ema_params, params)
ema_network_state = jax.tree_map(ema, ema_network_state,
network_state)
return params, ema_params, network_state, ema_network_state, opt_state, scalars

View File

@@ -527,7 +527,7 @@ class Experiment(experiment.AbstractExperiment):
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars

View File

@@ -192,7 +192,7 @@ def add_weight_decay(
u_ex, u_in = hk.data_structures.partition(exclude, updates)
_, p_in = hk.data_structures.partition(exclude, params)
u_in = jax.tree_multimap(lambda g, p: g + weight_decay * p, u_in, p_in)
u_in = jax.tree_map(lambda g, p: g + weight_decay * p, u_in, p_in)
updates = hk.data_structures.merge(u_ex, u_in)
return updates, state

View File

@@ -204,7 +204,7 @@ def solve_ivp_dt(
for t_and_dt_i in zip(t, dt):
y.append(loop_body(y[-1], t_and_dt_i)[0])
# Note that we do not return the initial point
return t_eval, jax.tree_multimap(lambda *args: jnp.stack(args, axis=0),
return t_eval, jax.tree_map(lambda *args: jnp.stack(args, axis=0),
*y[1:])
@@ -252,7 +252,7 @@ def solve_ivp_dt_two_directions(
)[1]
yt.append(yt_fwd)
if len(yt) > 1:
return jax.tree_multimap(lambda *a: jnp.concatenate(a, axis=0), *yt)
return jax.tree_map(lambda *a: jnp.concatenate(a, axis=0), *yt)
else:
return yt[0]

View File

@@ -192,7 +192,7 @@ class HGNExperiment(experiment.AbstractExperiment):
new_state = utils.pmean_if_pmap(new_state, axis_name="i")
new_state = hk.data_structures.to_mutable_dict(new_state)
new_state = hk.data_structures.to_immutable_dict(new_state)
return jax.tree_multimap(jnp.add, new_state, state)
return jax.tree_map(jnp.add, new_state, state)
# _
# _____ ____ _| |

View File

@@ -829,7 +829,7 @@ class DiscreteDynamicsNetwork(hk.Module):
if len(yt) == 1:
yt = yt[0][:, None]
else:
yt = jax.tree_multimap(lambda args: jnp.stack(args, 1), yt)
yt = jax.tree_map(lambda args: jnp.stack(args, 1), yt)
if return_stats:
return yt, dict()
else:

View File

@@ -220,11 +220,11 @@ class MultiBatchAccumulator(object):
self._obj = jax.tree_map(lambda y: y * num_samples, averaged_values)
self._num_samples = num_samples
else:
self._obj_max = jax.tree_multimap(jnp.maximum, self._obj_max,
self._obj_max = jax.tree_map(jnp.maximum, self._obj_max,
averaged_values)
self._obj_min = jax.tree_multimap(jnp.minimum, self._obj_min,
self._obj_min = jax.tree_map(jnp.minimum, self._obj_min,
averaged_values)
self._obj = jax.tree_multimap(lambda x, y: x + y * num_samples, self._obj,
self._obj = jax.tree_map(lambda x, y: x + y * num_samples, self._obj,
averaged_values)
self._num_samples += num_samples
@@ -249,7 +249,7 @@ register_pytree_node(
def inner_product(x: Any, y: Any) -> jnp.ndarray:
products = jax.tree_multimap(lambda x_, y_: jnp.sum(x_ * y_), x, y)
products = jax.tree_map(lambda x_, y_: jnp.sum(x_ * y_), x, y)
return sum(jax.tree_leaves(products))

View File

@@ -277,7 +277,7 @@ class UnbiasedExponentialWeightedAverageAgentTracker:
# 0 * NaN == NaN just replace self._statistics on the first step.
self._statistics = dict(agent.statistics)
else:
self._statistics = jax.tree_multimap(
self._statistics = jax.tree_map(
lambda s, x: (1 - final_step_size) * s + final_step_size * x,
self._statistics, agent.statistics)

View File

@@ -37,7 +37,7 @@ from absl import logging
import haiku as hk
import jax
from jax.tree_util import tree_multimap
from jax.tree_util import tree_map
import numpy as np
import optax
@@ -185,7 +185,7 @@ class Updater:
'state', 'params', 'rng', 'replicated_step', 'opt_state']))
state.update(extra_state)
state['step'] += 1
return state, tree_multimap(lambda x: x[0], out)
return state, tree_map(lambda x: x[0], out)
def _eval(self, state, data, max_graph_size=None):
"""Evaluates the current state on the given data."""
@@ -209,7 +209,7 @@ class Updater:
self._eval_fn, state, [data], keys=set([
'state', 'params', 'rng', 'replicated_step']))
state.update(extra_state)
return state, tree_multimap(lambda x: x[0], out)
return state, tree_map(lambda x: x[0], out)
def eval(self, state, data):
"""Returns metrics without updating the model."""
@@ -230,20 +230,20 @@ class Updater:
prefix = (self._num_devices, x.shape[0] // self._num_devices)
return np.reshape(x, prefix + x.shape[1:])
multi_inputs = tree_multimap(add_core_dimension, multi_inputs)
multi_inputs = tree_map(add_core_dimension, multi_inputs)
return multi_inputs
def params(self, state):
"""Returns model parameters."""
return tree_multimap(lambda x: x[0], state['params'])
return tree_map(lambda x: x[0], state['params'])
def opt_state(self, state):
"""Returns the state of the optimiser."""
return tree_multimap(lambda x: x[0], state['opt_state'])
return tree_map(lambda x: x[0], state['opt_state'])
def network_state(self, state):
"""Returns the model's state."""
return tree_multimap(lambda x: x[0], state['state'])
return tree_map(lambda x: x[0], state['state'])
def to_checkpoint_state(self, state):
"""Transforms the updater state into a checkpointable state."""
@@ -251,14 +251,14 @@ class Updater:
# Wrapper around checkpoint_state['step'] so we can get [0].
checkpoint_state['step'] = checkpoint_state['step'][np.newaxis]
# Unstack the replicated contents.
checkpoint_state = tree_multimap(lambda x: x[0], checkpoint_state)
checkpoint_state = tree_map(lambda x: x[0], checkpoint_state)
return checkpoint_state
def from_checkpoint_state(self, checkpoint_state):
"""Initializes the updater state from the checkpointed state."""
# Expand the checkpoint so we have a copy for each device.
state = tree_multimap(lambda x: np.stack(jax.local_device_count() * [x]),
checkpoint_state)
state = tree_map(lambda x: np.stack(jax.local_device_count() * [x]),
checkpoint_state)
state['step'] = state['step'][0] # Undo stacking for step.
return state