[JAX] Remove obsolete unit type declarations in jax.core.

Remove obsolete unit test in host_callback.

PiperOrigin-RevId: 507473737
This commit is contained in:
Peter Hawkins
2023-02-06 15:25:23 +00:00
committed by Saran Tunyasuvunakool
parent cb555c241b
commit 797ea3c71d
3 changed files with 14 additions and 17 deletions
+11 -9
View File
@@ -164,11 +164,19 @@ def abstract_args(args):
return jax.tree_map(abstract_single_value, args)
def _extract_call_jaxpr(primitive, params):
if not (primitive.call_primitive or primitive.map_primitive):
return None, params
else:
params = dict(params)
return params.pop("call_jaxpr"), params
def evaluate_eqn(eqn, in_values, write_func):
"""Evaluate a single Jax equation and writes the outputs."""
in_values = list(in_values)
# This is logic specifically to handle `xla_call`
call_jaxpr, params = jax.core.extract_call_jaxpr(eqn.primitive, eqn.params)
call_jaxpr, params = _extract_call_jaxpr(eqn.primitive, eqn.params)
if call_jaxpr:
subfuns = [
jax.core.lu.wrap_init(
@@ -224,12 +232,10 @@ def broadcast_merger(f):
# Bind args and consts to environment
flat_args = jax.tree_flatten(func_args)[0]
write(jax.core.unitvar, jax.core.unit)
jax_util.safe_map(write, jaxpr.invars, flat_args)
jax_util.safe_map(write, jaxpr.constvars, consts)
# Bind args and consts to environment
write(jax.core.unitvar, jax.core.unit)
jax_util.safe_map(write, jaxpr.invars, flat_args)
jax_util.safe_map(write, jaxpr.constvars, consts)
@@ -354,8 +360,6 @@ def var_to_str(var):
"""Returns a string representation of the variable of a Jax expression."""
if isinstance(var, jax.core.Literal):
return str(var)
elif isinstance(var, jax.core.UnitVar):
return "*"
elif not isinstance(var, jax.core.Var):
raise ValueError(f"Idk what to do with this {type(var)}?")
c = int(var.count)
@@ -388,7 +392,7 @@ def fill_jaxpr_to_graph(graph, jaxpr, in_vars=None, out_vars=None):
for eqn in jaxpr.eqns:
in_vars = []
for v in eqn.invars:
if isinstance(v, (jax.core.Literal, jax.core.UnitVar)):
if isinstance(v, jax.core.Literal):
in_vars.append(var_to_str(v))
else:
in_vars.append(in_map.get(v, var_to_str(v)))
@@ -406,8 +410,7 @@ def fill_jaxpr_to_graph(graph, jaxpr, in_vars=None, out_vars=None):
# Create incoming edges
for v, name in zip(eqn.invars, in_vars):
if (not isinstance(v, jax.core.Literal) and
not isinstance(v, jax.core.UnitVar)):
if not isinstance(v, jax.core.Literal):
graph.add_edge(name, node_c)
# Create output nodes and edges
@@ -573,7 +576,6 @@ def auto_register_tags(func,
env[var] = val
# Bind args and consts to environment
write(jax.core.unitvar, jax.core.unit)
jax_util.safe_map(write, graph.jaxpr.invars, flat_args)
jax_util.safe_map(write, graph.jaxpr.constvars, graph.consts)
@@ -61,10 +61,8 @@ class TestGraphMatcher(unittest.TestCase):
self.assertEqual(len(jaxpr.constvars), len(tagged_jaxpr.constvars))
self.assertEqual(len(jaxpr.outvars), len(tagged_jaxpr.outvars))
for eq, tagged_eq in zip(jaxpr.eqns, tagged_jaxpr.eqns):
eq_in_vars = [v for v in eq.invars if not isinstance(v, jax.core.UnitVar)]
tagged_in_vars = [
v for v in tagged_eq.invars if not isinstance(v, jax.core.UnitVar)
]
eq_in_vars = [v for v in eq.invars]
tagged_in_vars = [v for v in tagged_eq.invars]
self.assertEqual(len(eq_in_vars), len(tagged_in_vars))
self.assertEqual(len(eq.outvars), len(tagged_eq.outvars))
self.assertEqual(eq.primitive, tagged_eq.primitive)
+1 -4
View File
@@ -64,7 +64,6 @@ def construct_compute_losses_inputs(
write = functools.partial(tgm.write_env, env)
# Bind args and consts to environment
write(jax.core.unitvar, jax.core.unit)
jax_util.safe_map(write, jaxpr.invars, flat_args)
jax_util.safe_map(write, jaxpr.constvars, consts)
@@ -230,7 +229,6 @@ def trace_estimator_vjp(tagged_func: _Function) -> _Function:
write = functools.partial(tgm.write_env, env)
# Bind args and consts to environment
write(jax.core.unitvar, jax.core.unit)
jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0])
jax_util.safe_map(write, jaxpr.constvars, consts)
@@ -253,12 +251,11 @@ def trace_estimator_vjp(tagged_func: _Function) -> _Function:
env = dict()
read = functools.partial(tgm.read_env, env)
def write(var, val):
if not isinstance(var, (jax.core.Literal, jax.core.UnitVar)):
if not isinstance(var, jax.core.Literal):
val = val + aux[var] if var in aux else val
env[var] = val
# Bind args and consts to environment
write(jax.core.unitvar, jax.core.unit)
jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0])
jax_util.safe_map(write, jaxpr.constvars, consts)