mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-03 21:37:40 +08:00
Use jax.tree_util.tree_map in place of deprecated tree_multimap.
The latter is a simple alias of the former, so this change is a no-op. PiperOrigin-RevId: 461229165
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
956c4b5d9c
commit
6fcb84268e
+10
-10
@@ -37,7 +37,7 @@ from absl import logging
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
from jax.tree_util import tree_multimap
|
||||
from jax.tree_util import tree_map
|
||||
import numpy as np
|
||||
import optax
|
||||
|
||||
@@ -185,7 +185,7 @@ class Updater:
|
||||
'state', 'params', 'rng', 'replicated_step', 'opt_state']))
|
||||
state.update(extra_state)
|
||||
state['step'] += 1
|
||||
return state, tree_multimap(lambda x: x[0], out)
|
||||
return state, tree_map(lambda x: x[0], out)
|
||||
|
||||
def _eval(self, state, data, max_graph_size=None):
|
||||
"""Evaluates the current state on the given data."""
|
||||
@@ -209,7 +209,7 @@ class Updater:
|
||||
self._eval_fn, state, [data], keys=set([
|
||||
'state', 'params', 'rng', 'replicated_step']))
|
||||
state.update(extra_state)
|
||||
return state, tree_multimap(lambda x: x[0], out)
|
||||
return state, tree_map(lambda x: x[0], out)
|
||||
|
||||
def eval(self, state, data):
|
||||
"""Returns metrics without updating the model."""
|
||||
@@ -230,20 +230,20 @@ class Updater:
|
||||
prefix = (self._num_devices, x.shape[0] // self._num_devices)
|
||||
return np.reshape(x, prefix + x.shape[1:])
|
||||
|
||||
multi_inputs = tree_multimap(add_core_dimension, multi_inputs)
|
||||
multi_inputs = tree_map(add_core_dimension, multi_inputs)
|
||||
return multi_inputs
|
||||
|
||||
def params(self, state):
|
||||
"""Returns model parameters."""
|
||||
return tree_multimap(lambda x: x[0], state['params'])
|
||||
return tree_map(lambda x: x[0], state['params'])
|
||||
|
||||
def opt_state(self, state):
|
||||
"""Returns the state of the optimiser."""
|
||||
return tree_multimap(lambda x: x[0], state['opt_state'])
|
||||
return tree_map(lambda x: x[0], state['opt_state'])
|
||||
|
||||
def network_state(self, state):
|
||||
"""Returns the model's state."""
|
||||
return tree_multimap(lambda x: x[0], state['state'])
|
||||
return tree_map(lambda x: x[0], state['state'])
|
||||
|
||||
def to_checkpoint_state(self, state):
|
||||
"""Transforms the updater state into a checkpointable state."""
|
||||
@@ -251,14 +251,14 @@ class Updater:
|
||||
# Wrapper around checkpoint_state['step'] so we can get [0].
|
||||
checkpoint_state['step'] = checkpoint_state['step'][np.newaxis]
|
||||
# Unstack the replicated contents.
|
||||
checkpoint_state = tree_multimap(lambda x: x[0], checkpoint_state)
|
||||
checkpoint_state = tree_map(lambda x: x[0], checkpoint_state)
|
||||
return checkpoint_state
|
||||
|
||||
def from_checkpoint_state(self, checkpoint_state):
|
||||
"""Initializes the updater state from the checkpointed state."""
|
||||
# Expand the checkpoint so we have a copy for each device.
|
||||
state = tree_multimap(lambda x: np.stack(jax.local_device_count() * [x]),
|
||||
checkpoint_state)
|
||||
state = tree_map(lambda x: np.stack(jax.local_device_count() * [x]),
|
||||
checkpoint_state)
|
||||
state['step'] = state['step'][0] # Undo stacking for step.
|
||||
return state
|
||||
|
||||
|
||||
Reference in New Issue
Block a user