Use jax.tree_util.tree_map in place of deprecated tree_multimap.

The latter is a simple alias of the former, so this change is a no-op.

PiperOrigin-RevId: 461045645
This commit is contained in:
Jake VanderPlas
2022-07-14 22:03:46 +01:00
committed by Saran Tunyasuvunakool
parent 11c2ab53e8
commit 956c4b5d9c
4 changed files with 6 additions and 6 deletions
+2 -2
View File
@@ -372,9 +372,9 @@ class ScaleAndShiftDiagonal(CurvatureBlock):
raise ValueError("Neither `has_scale` nor `has_shift`.")
factors = jax.tree_map(lambda x: x + diagonal_weight, factors)
if exp == 1:
return jax.tree_multimap(jnp.multiply, vec, factors)
return jax.tree_map(jnp.multiply, vec, factors)
elif exp == -1:
return jax.tree_multimap(jnp.divide, vec, factors)
return jax.tree_map(jnp.divide, vec, factors)
else:
raise NotImplementedError()
+1 -1
View File
@@ -180,7 +180,7 @@ class CurvatureEstimator(utils.Stateful):
"""Executes func for each approximation block on vectors."""
per_block_vectors = self.vectors_to_blocks(parameter_structured_vector)
assert len(per_block_vectors) == len(self.blocks)
results = jax.tree_multimap(func, tuple(self.blocks.values()),
results = jax.tree_map(func, tuple(self.blocks.values()),
per_block_vectors)
parameter_structured_result = self.blocks_to_vectors(results)
utils.check_structure_shapes_and_dtype(parameter_structured_vector,
+2 -2
View File
@@ -462,7 +462,7 @@ class Optimizer(utils.Stateful):
)
# Update parameters: params = params + delta
params = jax.tree_multimap(jnp.add, params, delta)
params = jax.tree_map(jnp.add, params, delta)
# Optionally compute the reduction ratio and update the damping
self.estimator.damping = None
@@ -607,5 +607,5 @@ class Optimizer(utils.Stateful):
assert len(vectors) == len(coefficients)
delta = utils.scalar_mul(vectors[0], coefficients[0])
for vi, wi in zip(vectors[1:], coefficients[1:]):
delta = jax.tree_multimap(jnp.add, delta, utils.scalar_mul(vi, wi))
delta = jax.tree_map(jnp.add, delta, utils.scalar_mul(vi, wi))
return delta, delta
+1 -1
View File
@@ -120,7 +120,7 @@ def extract_func_outputs(
def inner_product(obj1: T, obj2: T) -> jnp.ndarray:
if jax.tree_structure(obj1) != jax.tree_structure(obj2):
raise ValueError("The two structures are not identical.")
elements_product = jax.tree_multimap(lambda x, y: jnp.sum(x * y), obj1, obj2)
elements_product = jax.tree_map(lambda x, y: jnp.sum(x * y), obj1, obj2)
return sum(jax.tree_flatten(elements_product)[0])