mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-21 23:07:29 +08:00
[JAX] Remove obsolete unit type declarations in jax.core.
Remove obsolete unit test in host_callback. PiperOrigin-RevId: 507473737
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
cb555c241b
commit
797ea3c71d
@@ -164,11 +164,19 @@ def abstract_args(args):
|
||||
return jax.tree_map(abstract_single_value, args)
|
||||
|
||||
|
||||
def _extract_call_jaxpr(primitive, params):
|
||||
if not (primitive.call_primitive or primitive.map_primitive):
|
||||
return None, params
|
||||
else:
|
||||
params = dict(params)
|
||||
return params.pop("call_jaxpr"), params
|
||||
|
||||
|
||||
def evaluate_eqn(eqn, in_values, write_func):
|
||||
"""Evaluate a single Jax equation and writes the outputs."""
|
||||
in_values = list(in_values)
|
||||
# This is logic specifically to handle `xla_call`
|
||||
call_jaxpr, params = jax.core.extract_call_jaxpr(eqn.primitive, eqn.params)
|
||||
call_jaxpr, params = _extract_call_jaxpr(eqn.primitive, eqn.params)
|
||||
if call_jaxpr:
|
||||
subfuns = [
|
||||
jax.core.lu.wrap_init(
|
||||
@@ -224,12 +232,10 @@ def broadcast_merger(f):
|
||||
|
||||
# Bind args and consts to environment
|
||||
flat_args = jax.tree_flatten(func_args)[0]
|
||||
write(jax.core.unitvar, jax.core.unit)
|
||||
jax_util.safe_map(write, jaxpr.invars, flat_args)
|
||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||
|
||||
# Bind args and consts to environment
|
||||
write(jax.core.unitvar, jax.core.unit)
|
||||
jax_util.safe_map(write, jaxpr.invars, flat_args)
|
||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||
|
||||
@@ -354,8 +360,6 @@ def var_to_str(var):
|
||||
"""Returns a string representation of the variable of a Jax expression."""
|
||||
if isinstance(var, jax.core.Literal):
|
||||
return str(var)
|
||||
elif isinstance(var, jax.core.UnitVar):
|
||||
return "*"
|
||||
elif not isinstance(var, jax.core.Var):
|
||||
raise ValueError(f"Idk what to do with this {type(var)}?")
|
||||
c = int(var.count)
|
||||
@@ -388,7 +392,7 @@ def fill_jaxpr_to_graph(graph, jaxpr, in_vars=None, out_vars=None):
|
||||
for eqn in jaxpr.eqns:
|
||||
in_vars = []
|
||||
for v in eqn.invars:
|
||||
if isinstance(v, (jax.core.Literal, jax.core.UnitVar)):
|
||||
if isinstance(v, jax.core.Literal):
|
||||
in_vars.append(var_to_str(v))
|
||||
else:
|
||||
in_vars.append(in_map.get(v, var_to_str(v)))
|
||||
@@ -406,8 +410,7 @@ def fill_jaxpr_to_graph(graph, jaxpr, in_vars=None, out_vars=None):
|
||||
|
||||
# Create incoming edges
|
||||
for v, name in zip(eqn.invars, in_vars):
|
||||
if (not isinstance(v, jax.core.Literal) and
|
||||
not isinstance(v, jax.core.UnitVar)):
|
||||
if not isinstance(v, jax.core.Literal):
|
||||
graph.add_edge(name, node_c)
|
||||
|
||||
# Create output nodes and edges
|
||||
@@ -573,7 +576,6 @@ def auto_register_tags(func,
|
||||
env[var] = val
|
||||
|
||||
# Bind args and consts to environment
|
||||
write(jax.core.unitvar, jax.core.unit)
|
||||
jax_util.safe_map(write, graph.jaxpr.invars, flat_args)
|
||||
jax_util.safe_map(write, graph.jaxpr.constvars, graph.consts)
|
||||
|
||||
|
||||
@@ -61,10 +61,8 @@ class TestGraphMatcher(unittest.TestCase):
|
||||
self.assertEqual(len(jaxpr.constvars), len(tagged_jaxpr.constvars))
|
||||
self.assertEqual(len(jaxpr.outvars), len(tagged_jaxpr.outvars))
|
||||
for eq, tagged_eq in zip(jaxpr.eqns, tagged_jaxpr.eqns):
|
||||
eq_in_vars = [v for v in eq.invars if not isinstance(v, jax.core.UnitVar)]
|
||||
tagged_in_vars = [
|
||||
v for v in tagged_eq.invars if not isinstance(v, jax.core.UnitVar)
|
||||
]
|
||||
eq_in_vars = [v for v in eq.invars]
|
||||
tagged_in_vars = [v for v in tagged_eq.invars]
|
||||
self.assertEqual(len(eq_in_vars), len(tagged_in_vars))
|
||||
self.assertEqual(len(eq.outvars), len(tagged_eq.outvars))
|
||||
self.assertEqual(eq.primitive, tagged_eq.primitive)
|
||||
|
||||
@@ -64,7 +64,6 @@ def construct_compute_losses_inputs(
|
||||
write = functools.partial(tgm.write_env, env)
|
||||
|
||||
# Bind args and consts to environment
|
||||
write(jax.core.unitvar, jax.core.unit)
|
||||
jax_util.safe_map(write, jaxpr.invars, flat_args)
|
||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||
|
||||
@@ -230,7 +229,6 @@ def trace_estimator_vjp(tagged_func: _Function) -> _Function:
|
||||
write = functools.partial(tgm.write_env, env)
|
||||
|
||||
# Bind args and consts to environment
|
||||
write(jax.core.unitvar, jax.core.unit)
|
||||
jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0])
|
||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||
|
||||
@@ -253,12 +251,11 @@ def trace_estimator_vjp(tagged_func: _Function) -> _Function:
|
||||
env = dict()
|
||||
read = functools.partial(tgm.read_env, env)
|
||||
def write(var, val):
|
||||
if not isinstance(var, (jax.core.Literal, jax.core.UnitVar)):
|
||||
if not isinstance(var, jax.core.Literal):
|
||||
val = val + aux[var] if var in aux else val
|
||||
env[var] = val
|
||||
|
||||
# Bind args and consts to environment
|
||||
write(jax.core.unitvar, jax.core.unit)
|
||||
jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0])
|
||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user