mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Use jax.tree_util.tree_map in place of deprecated tree_multimap.
The latter is a simple alias of the former, so this change is a no-op. PiperOrigin-RevId: 461229165
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
956c4b5d9c
commit
6fcb84268e
@@ -373,7 +373,7 @@ class Experiment(experiment.AbstractExperiment):
|
||||
if summed_scalars is None:
|
||||
summed_scalars = scalars
|
||||
else:
|
||||
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
|
||||
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
|
||||
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
|
||||
return mean_scalars
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ def ema_update(step: chex.Array,
|
||||
def _weighted_average(p1, p2):
|
||||
d = decay.astype(p1.dtype)
|
||||
return (1 - d) * p1 + d * p2
|
||||
return jax.tree_multimap(_weighted_average, new_params, avg_params)
|
||||
return jax.tree_map(_weighted_average, new_params, avg_params)
|
||||
|
||||
|
||||
def cutmix(rng: chex.PRNGKey,
|
||||
|
||||
@@ -323,7 +323,7 @@ class ByolExperiment:
|
||||
|
||||
# cross-device grad and logs reductions
|
||||
grads = jax.tree_map(lambda v: jax.lax.pmean(v, axis_name='i'), grads)
|
||||
logs = jax.tree_multimap(lambda x: jax.lax.pmean(x, axis_name='i'), logs)
|
||||
logs = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='i'), logs)
|
||||
|
||||
learning_rate = schedules.learning_schedule(
|
||||
global_step,
|
||||
@@ -339,7 +339,7 @@ class ByolExperiment:
|
||||
global_step,
|
||||
base_ema=self._base_target_ema,
|
||||
max_steps=self._max_steps)
|
||||
target_params = jax.tree_multimap(lambda x, y: x + (1 - tau) * (y - x),
|
||||
target_params = jax.tree_map(lambda x, y: x + (1 - tau) * (y - x),
|
||||
target_params, online_params)
|
||||
logs['tau'] = tau
|
||||
logs['learning_rate'] = learning_rate
|
||||
@@ -518,7 +518,7 @@ class ByolExperiment:
|
||||
if summed_scalars is None:
|
||||
summed_scalars = scalars
|
||||
else:
|
||||
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
|
||||
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
|
||||
|
||||
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
|
||||
return mean_scalars
|
||||
|
||||
@@ -471,7 +471,7 @@ class EvalExperiment:
|
||||
if summed_scalars is None:
|
||||
summed_scalars = scalars
|
||||
else:
|
||||
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
|
||||
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
|
||||
|
||||
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
|
||||
return mean_scalars
|
||||
|
||||
@@ -51,7 +51,7 @@ def _partial_update(updates: optax.Updates,
|
||||
m = m.astype(g.dtype)
|
||||
return g * (1. - m) + t * m
|
||||
|
||||
return jax.tree_multimap(_update_fn, updates, new_updates, params_to_filter)
|
||||
return jax.tree_map(_update_fn, updates, new_updates, params_to_filter)
|
||||
|
||||
|
||||
class ScaleByLarsState(NamedTuple):
|
||||
@@ -78,7 +78,7 @@ def scale_by_lars(
|
||||
"""
|
||||
|
||||
def init_fn(params: optax.Params) -> ScaleByLarsState:
|
||||
mu = jax.tree_multimap(jnp.zeros_like, params) # momentum
|
||||
mu = jax.tree_map(jnp.zeros_like, params) # momentum
|
||||
return ScaleByLarsState(mu=mu)
|
||||
|
||||
def update_fn(updates: optax.Updates, state: ScaleByLarsState,
|
||||
@@ -95,10 +95,10 @@ def scale_by_lars(
|
||||
jnp.where(update_norm > 0,
|
||||
(eta * param_norm / update_norm), 1.0), 1.0)
|
||||
|
||||
adapted_updates = jax.tree_multimap(lars_adaptation, updates, params)
|
||||
adapted_updates = jax.tree_map(lars_adaptation, updates, params)
|
||||
adapted_updates = _partial_update(updates, adapted_updates, params,
|
||||
filter_fn)
|
||||
mu = jax.tree_multimap(lambda g, t: momentum * g + t,
|
||||
mu = jax.tree_map(lambda g, t: momentum * g + t,
|
||||
state.mu, adapted_updates)
|
||||
return mu, ScaleByLarsState(mu=mu)
|
||||
|
||||
@@ -130,7 +130,7 @@ def add_weight_decay(
|
||||
state: AddWeightDecayState,
|
||||
params: optax.Params,
|
||||
) -> Tuple[optax.Updates, AddWeightDecayState]:
|
||||
new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
|
||||
new_updates = jax.tree_map(lambda g, p: g + weight_decay * p, updates,
|
||||
params)
|
||||
new_updates = _partial_update(updates, new_updates, params, filter_fn)
|
||||
return new_updates, state
|
||||
|
||||
@@ -72,7 +72,7 @@ def adaptive_grad_clip(clip, eps=1e-3) -> optax.GradientTransformation:
|
||||
# Maximum allowable norm
|
||||
max_norm = jax.tree_map(lambda x: clip * jnp.maximum(x, eps), p_norm)
|
||||
# If grad norm > clipping * param_norm, rescale
|
||||
updates = jax.tree_multimap(my_clip, g_norm, max_norm, updates)
|
||||
updates = jax.tree_map(my_clip, g_norm, max_norm, updates)
|
||||
return updates, state
|
||||
|
||||
return optax.GradientTransformation(init_fn, update_fn)
|
||||
|
||||
@@ -270,8 +270,8 @@ class Experiment(experiment.AbstractExperiment):
|
||||
if ema_params is not None:
|
||||
ema_fn = getattr(utils, self.config.get('which_ema', 'tf1_ema'))
|
||||
ema = lambda x, y: ema_fn(x, y, self.config.ema_decay, global_step)
|
||||
ema_params = jax.tree_multimap(ema, ema_params, params)
|
||||
ema_states = jax.tree_multimap(ema, ema_states, states)
|
||||
ema_params = jax.tree_map(ema, ema_params, params)
|
||||
ema_states = jax.tree_map(ema, ema_states, states)
|
||||
return {
|
||||
'params': params,
|
||||
'states': states,
|
||||
@@ -354,7 +354,7 @@ class Experiment(experiment.AbstractExperiment):
|
||||
if summed_metrics is None:
|
||||
summed_metrics = metrics
|
||||
else:
|
||||
summed_metrics = jax.tree_multimap(jnp.add, summed_metrics, metrics)
|
||||
summed_metrics = jax.tree_map(jnp.add, summed_metrics, metrics)
|
||||
mean_metrics = jax.tree_map(lambda x: x / num_samples, summed_metrics)
|
||||
return jax.device_get(mean_metrics)
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple:
|
||||
|
||||
def _map_concat(nests):
|
||||
concat = lambda *args: np.concatenate(args)
|
||||
return tree.tree_multimap(concat, *nests)
|
||||
return tree.tree_map(concat, *nests)
|
||||
|
||||
return jraph.GraphsTuple(
|
||||
n_node=np.concatenate([g.n_node for g in graphs]),
|
||||
|
||||
@@ -487,8 +487,8 @@ class Experiment(experiment.AbstractExperiment):
|
||||
))
|
||||
ema_fn = (lambda x, y: # pylint:disable=g-long-lambda
|
||||
schedules.apply_ema_decay(x, y, ema_rate))
|
||||
ema_params = jax.tree_multimap(ema_fn, ema_params, params)
|
||||
ema_network_state = jax.tree_multimap(
|
||||
ema_params = jax.tree_map(ema_fn, ema_params, params)
|
||||
ema_network_state = jax.tree_map(
|
||||
ema_fn,
|
||||
ema_network_state,
|
||||
network_state,
|
||||
|
||||
@@ -107,7 +107,7 @@ def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple:
|
||||
|
||||
def _map_concat(nests):
|
||||
concat = lambda *args: np.concatenate(args)
|
||||
return tree.tree_multimap(concat, *nests)
|
||||
return tree.tree_map(concat, *nests)
|
||||
|
||||
return jraph.GraphsTuple(
|
||||
n_node=np.concatenate([g.n_node for g in graphs]),
|
||||
|
||||
@@ -229,8 +229,8 @@ class Experiment(experiment.AbstractExperiment):
|
||||
params = optax.apply_updates(params, updates)
|
||||
if ema_params is not None:
|
||||
ema = lambda x, y: tf1_ema(x, y, self.config.ema_decay, global_step)
|
||||
ema_params = jax.tree_multimap(ema, ema_params, params)
|
||||
ema_network_state = jax.tree_multimap(ema, ema_network_state,
|
||||
ema_params = jax.tree_map(ema, ema_params, params)
|
||||
ema_network_state = jax.tree_map(ema, ema_network_state,
|
||||
network_state)
|
||||
return params, ema_params, network_state, ema_network_state, opt_state, scalars
|
||||
|
||||
|
||||
@@ -527,7 +527,7 @@ class Experiment(experiment.AbstractExperiment):
|
||||
if summed_scalars is None:
|
||||
summed_scalars = scalars
|
||||
else:
|
||||
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
|
||||
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
|
||||
|
||||
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
|
||||
return mean_scalars
|
||||
|
||||
@@ -192,7 +192,7 @@ def add_weight_decay(
|
||||
|
||||
u_ex, u_in = hk.data_structures.partition(exclude, updates)
|
||||
_, p_in = hk.data_structures.partition(exclude, params)
|
||||
u_in = jax.tree_multimap(lambda g, p: g + weight_decay * p, u_in, p_in)
|
||||
u_in = jax.tree_map(lambda g, p: g + weight_decay * p, u_in, p_in)
|
||||
updates = hk.data_structures.merge(u_ex, u_in)
|
||||
return updates, state
|
||||
|
||||
|
||||
@@ -204,7 +204,7 @@ def solve_ivp_dt(
|
||||
for t_and_dt_i in zip(t, dt):
|
||||
y.append(loop_body(y[-1], t_and_dt_i)[0])
|
||||
# Note that we do not return the initial point
|
||||
return t_eval, jax.tree_multimap(lambda *args: jnp.stack(args, axis=0),
|
||||
return t_eval, jax.tree_map(lambda *args: jnp.stack(args, axis=0),
|
||||
*y[1:])
|
||||
|
||||
|
||||
@@ -252,7 +252,7 @@ def solve_ivp_dt_two_directions(
|
||||
)[1]
|
||||
yt.append(yt_fwd)
|
||||
if len(yt) > 1:
|
||||
return jax.tree_multimap(lambda *a: jnp.concatenate(a, axis=0), *yt)
|
||||
return jax.tree_map(lambda *a: jnp.concatenate(a, axis=0), *yt)
|
||||
else:
|
||||
return yt[0]
|
||||
|
||||
|
||||
@@ -192,7 +192,7 @@ class HGNExperiment(experiment.AbstractExperiment):
|
||||
new_state = utils.pmean_if_pmap(new_state, axis_name="i")
|
||||
new_state = hk.data_structures.to_mutable_dict(new_state)
|
||||
new_state = hk.data_structures.to_immutable_dict(new_state)
|
||||
return jax.tree_multimap(jnp.add, new_state, state)
|
||||
return jax.tree_map(jnp.add, new_state, state)
|
||||
|
||||
# _
|
||||
# _____ ____ _| |
|
||||
|
||||
@@ -829,7 +829,7 @@ class DiscreteDynamicsNetwork(hk.Module):
|
||||
if len(yt) == 1:
|
||||
yt = yt[0][:, None]
|
||||
else:
|
||||
yt = jax.tree_multimap(lambda args: jnp.stack(args, 1), yt)
|
||||
yt = jax.tree_map(lambda args: jnp.stack(args, 1), yt)
|
||||
if return_stats:
|
||||
return yt, dict()
|
||||
else:
|
||||
|
||||
@@ -220,11 +220,11 @@ class MultiBatchAccumulator(object):
|
||||
self._obj = jax.tree_map(lambda y: y * num_samples, averaged_values)
|
||||
self._num_samples = num_samples
|
||||
else:
|
||||
self._obj_max = jax.tree_multimap(jnp.maximum, self._obj_max,
|
||||
self._obj_max = jax.tree_map(jnp.maximum, self._obj_max,
|
||||
averaged_values)
|
||||
self._obj_min = jax.tree_multimap(jnp.minimum, self._obj_min,
|
||||
self._obj_min = jax.tree_map(jnp.minimum, self._obj_min,
|
||||
averaged_values)
|
||||
self._obj = jax.tree_multimap(lambda x, y: x + y * num_samples, self._obj,
|
||||
self._obj = jax.tree_map(lambda x, y: x + y * num_samples, self._obj,
|
||||
averaged_values)
|
||||
self._num_samples += num_samples
|
||||
|
||||
@@ -249,7 +249,7 @@ register_pytree_node(
|
||||
|
||||
|
||||
def inner_product(x: Any, y: Any) -> jnp.ndarray:
|
||||
products = jax.tree_multimap(lambda x_, y_: jnp.sum(x_ * y_), x, y)
|
||||
products = jax.tree_map(lambda x_, y_: jnp.sum(x_ * y_), x, y)
|
||||
return sum(jax.tree_leaves(products))
|
||||
|
||||
|
||||
|
||||
@@ -277,7 +277,7 @@ class UnbiasedExponentialWeightedAverageAgentTracker:
|
||||
# 0 * NaN == NaN just replace self._statistics on the first step.
|
||||
self._statistics = dict(agent.statistics)
|
||||
else:
|
||||
self._statistics = jax.tree_multimap(
|
||||
self._statistics = jax.tree_map(
|
||||
lambda s, x: (1 - final_step_size) * s + final_step_size * x,
|
||||
self._statistics, agent.statistics)
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from absl import logging
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
from jax.tree_util import tree_multimap
|
||||
from jax.tree_util import tree_map
|
||||
import numpy as np
|
||||
import optax
|
||||
|
||||
@@ -185,7 +185,7 @@ class Updater:
|
||||
'state', 'params', 'rng', 'replicated_step', 'opt_state']))
|
||||
state.update(extra_state)
|
||||
state['step'] += 1
|
||||
return state, tree_multimap(lambda x: x[0], out)
|
||||
return state, tree_map(lambda x: x[0], out)
|
||||
|
||||
def _eval(self, state, data, max_graph_size=None):
|
||||
"""Evaluates the current state on the given data."""
|
||||
@@ -209,7 +209,7 @@ class Updater:
|
||||
self._eval_fn, state, [data], keys=set([
|
||||
'state', 'params', 'rng', 'replicated_step']))
|
||||
state.update(extra_state)
|
||||
return state, tree_multimap(lambda x: x[0], out)
|
||||
return state, tree_map(lambda x: x[0], out)
|
||||
|
||||
def eval(self, state, data):
|
||||
"""Returns metrics without updating the model."""
|
||||
@@ -230,20 +230,20 @@ class Updater:
|
||||
prefix = (self._num_devices, x.shape[0] // self._num_devices)
|
||||
return np.reshape(x, prefix + x.shape[1:])
|
||||
|
||||
multi_inputs = tree_multimap(add_core_dimension, multi_inputs)
|
||||
multi_inputs = tree_map(add_core_dimension, multi_inputs)
|
||||
return multi_inputs
|
||||
|
||||
def params(self, state):
|
||||
"""Returns model parameters."""
|
||||
return tree_multimap(lambda x: x[0], state['params'])
|
||||
return tree_map(lambda x: x[0], state['params'])
|
||||
|
||||
def opt_state(self, state):
|
||||
"""Returns the state of the optimiser."""
|
||||
return tree_multimap(lambda x: x[0], state['opt_state'])
|
||||
return tree_map(lambda x: x[0], state['opt_state'])
|
||||
|
||||
def network_state(self, state):
|
||||
"""Returns the model's state."""
|
||||
return tree_multimap(lambda x: x[0], state['state'])
|
||||
return tree_map(lambda x: x[0], state['state'])
|
||||
|
||||
def to_checkpoint_state(self, state):
|
||||
"""Transforms the updater state into a checkpointable state."""
|
||||
@@ -251,14 +251,14 @@ class Updater:
|
||||
# Wrapper around checkpoint_state['step'] so we can get [0].
|
||||
checkpoint_state['step'] = checkpoint_state['step'][np.newaxis]
|
||||
# Unstack the replicated contents.
|
||||
checkpoint_state = tree_multimap(lambda x: x[0], checkpoint_state)
|
||||
checkpoint_state = tree_map(lambda x: x[0], checkpoint_state)
|
||||
return checkpoint_state
|
||||
|
||||
def from_checkpoint_state(self, checkpoint_state):
|
||||
"""Initializes the updater state from the checkpointed state."""
|
||||
# Expand the checkpoint so we have a copy for each device.
|
||||
state = tree_multimap(lambda x: np.stack(jax.local_device_count() * [x]),
|
||||
checkpoint_state)
|
||||
state = tree_map(lambda x: np.stack(jax.local_device_count() * [x]),
|
||||
checkpoint_state)
|
||||
state['step'] = state['step'][0] # Undo stacking for step.
|
||||
return state
|
||||
|
||||
|
||||
Reference in New Issue
Block a user