Use jax.tree_util.tree_map in place of deprecated tree_multimap.

The latter is a simple alias of the former, so this change is a no-op.

PiperOrigin-RevId: 461229165
This commit is contained in:
Jake VanderPlas
2022-07-15 21:14:58 +01:00
committed by Saran Tunyasuvunakool
parent 956c4b5d9c
commit 6fcb84268e
19 changed files with 42 additions and 42 deletions
+10 -10
View File
@@ -37,7 +37,7 @@ from absl import logging
import haiku as hk
import jax
from jax.tree_util import tree_multimap
from jax.tree_util import tree_map
import numpy as np
import optax
@@ -185,7 +185,7 @@ class Updater:
'state', 'params', 'rng', 'replicated_step', 'opt_state']))
state.update(extra_state)
state['step'] += 1
return state, tree_multimap(lambda x: x[0], out)
return state, tree_map(lambda x: x[0], out)
def _eval(self, state, data, max_graph_size=None):
"""Evaluates the current state on the given data."""
@@ -209,7 +209,7 @@ class Updater:
self._eval_fn, state, [data], keys=set([
'state', 'params', 'rng', 'replicated_step']))
state.update(extra_state)
return state, tree_multimap(lambda x: x[0], out)
return state, tree_map(lambda x: x[0], out)
def eval(self, state, data):
"""Returns metrics without updating the model."""
@@ -230,20 +230,20 @@ class Updater:
prefix = (self._num_devices, x.shape[0] // self._num_devices)
return np.reshape(x, prefix + x.shape[1:])
multi_inputs = tree_multimap(add_core_dimension, multi_inputs)
multi_inputs = tree_map(add_core_dimension, multi_inputs)
return multi_inputs
def params(self, state):
"""Returns model parameters."""
return tree_multimap(lambda x: x[0], state['params'])
return tree_map(lambda x: x[0], state['params'])
def opt_state(self, state):
"""Returns the state of the optimiser."""
return tree_multimap(lambda x: x[0], state['opt_state'])
return tree_map(lambda x: x[0], state['opt_state'])
def network_state(self, state):
"""Returns the model's state."""
return tree_multimap(lambda x: x[0], state['state'])
return tree_map(lambda x: x[0], state['state'])
def to_checkpoint_state(self, state):
"""Transforms the updater state into a checkpointable state."""
@@ -251,14 +251,14 @@ class Updater:
# Wrapper around checkpoint_state['step'] so we can get [0].
checkpoint_state['step'] = checkpoint_state['step'][np.newaxis]
# Unstack the replicated contents.
checkpoint_state = tree_multimap(lambda x: x[0], checkpoint_state)
checkpoint_state = tree_map(lambda x: x[0], checkpoint_state)
return checkpoint_state
def from_checkpoint_state(self, checkpoint_state):
"""Initializes the updater state from the checkpointed state."""
# Expand the checkpoint so we have a copy for each device.
state = tree_multimap(lambda x: np.stack(jax.local_device_count() * [x]),
checkpoint_state)
state = tree_map(lambda x: np.stack(jax.local_device_count() * [x]),
checkpoint_state)
state['step'] = state['step'][0] # Undo stacking for step.
return state