mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-21 23:07:29 +08:00
Use jax.tree_util.tree_map in place of deprecated tree_multimap.
The latter is a simple alias of the former, so this change is a no-op. PiperOrigin-RevId: 461045645
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
11c2ab53e8
commit
956c4b5d9c
@@ -372,9 +372,9 @@ class ScaleAndShiftDiagonal(CurvatureBlock):
|
||||
raise ValueError("Neither `has_scale` nor `has_shift`.")
|
||||
factors = jax.tree_map(lambda x: x + diagonal_weight, factors)
|
||||
if exp == 1:
|
||||
return jax.tree_multimap(jnp.multiply, vec, factors)
|
||||
return jax.tree_map(jnp.multiply, vec, factors)
|
||||
elif exp == -1:
|
||||
return jax.tree_multimap(jnp.divide, vec, factors)
|
||||
return jax.tree_map(jnp.divide, vec, factors)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ class CurvatureEstimator(utils.Stateful):
|
||||
"""Executes func for each approximation block on vectors."""
|
||||
per_block_vectors = self.vectors_to_blocks(parameter_structured_vector)
|
||||
assert len(per_block_vectors) == len(self.blocks)
|
||||
results = jax.tree_multimap(func, tuple(self.blocks.values()),
|
||||
results = jax.tree_map(func, tuple(self.blocks.values()),
|
||||
per_block_vectors)
|
||||
parameter_structured_result = self.blocks_to_vectors(results)
|
||||
utils.check_structure_shapes_and_dtype(parameter_structured_vector,
|
||||
|
||||
@@ -462,7 +462,7 @@ class Optimizer(utils.Stateful):
|
||||
)
|
||||
|
||||
# Update parameters: params = params + delta
|
||||
params = jax.tree_multimap(jnp.add, params, delta)
|
||||
params = jax.tree_map(jnp.add, params, delta)
|
||||
|
||||
# Optionally compute the reduction ratio and update the damping
|
||||
self.estimator.damping = None
|
||||
@@ -607,5 +607,5 @@ class Optimizer(utils.Stateful):
|
||||
assert len(vectors) == len(coefficients)
|
||||
delta = utils.scalar_mul(vectors[0], coefficients[0])
|
||||
for vi, wi in zip(vectors[1:], coefficients[1:]):
|
||||
delta = jax.tree_multimap(jnp.add, delta, utils.scalar_mul(vi, wi))
|
||||
delta = jax.tree_map(jnp.add, delta, utils.scalar_mul(vi, wi))
|
||||
return delta, delta
|
||||
|
||||
@@ -120,7 +120,7 @@ def extract_func_outputs(
|
||||
def inner_product(obj1: T, obj2: T) -> jnp.ndarray:
|
||||
if jax.tree_structure(obj1) != jax.tree_structure(obj2):
|
||||
raise ValueError("The two structures are not identical.")
|
||||
elements_product = jax.tree_multimap(lambda x, y: jnp.sum(x * y), obj1, obj2)
|
||||
elements_product = jax.tree_map(lambda x, y: jnp.sum(x * y), obj1, obj2)
|
||||
return sum(jax.tree_flatten(elements_product)[0])
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user