mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-28 19:31:14 +08:00
[JAX] Remove obsolete unit type declarations in jax.core.
Remove obsolete unit test in host_callback. PiperOrigin-RevId: 507473737
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
cb555c241b
commit
797ea3c71d
@@ -164,11 +164,19 @@ def abstract_args(args):
|
|||||||
return jax.tree_map(abstract_single_value, args)
|
return jax.tree_map(abstract_single_value, args)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_call_jaxpr(primitive, params):
|
||||||
|
if not (primitive.call_primitive or primitive.map_primitive):
|
||||||
|
return None, params
|
||||||
|
else:
|
||||||
|
params = dict(params)
|
||||||
|
return params.pop("call_jaxpr"), params
|
||||||
|
|
||||||
|
|
||||||
def evaluate_eqn(eqn, in_values, write_func):
|
def evaluate_eqn(eqn, in_values, write_func):
|
||||||
"""Evaluate a single Jax equation and writes the outputs."""
|
"""Evaluate a single Jax equation and writes the outputs."""
|
||||||
in_values = list(in_values)
|
in_values = list(in_values)
|
||||||
# This is logic specifically to handle `xla_call`
|
# This is logic specifically to handle `xla_call`
|
||||||
call_jaxpr, params = jax.core.extract_call_jaxpr(eqn.primitive, eqn.params)
|
call_jaxpr, params = _extract_call_jaxpr(eqn.primitive, eqn.params)
|
||||||
if call_jaxpr:
|
if call_jaxpr:
|
||||||
subfuns = [
|
subfuns = [
|
||||||
jax.core.lu.wrap_init(
|
jax.core.lu.wrap_init(
|
||||||
@@ -224,12 +232,10 @@ def broadcast_merger(f):
|
|||||||
|
|
||||||
# Bind args and consts to environment
|
# Bind args and consts to environment
|
||||||
flat_args = jax.tree_flatten(func_args)[0]
|
flat_args = jax.tree_flatten(func_args)[0]
|
||||||
write(jax.core.unitvar, jax.core.unit)
|
|
||||||
jax_util.safe_map(write, jaxpr.invars, flat_args)
|
jax_util.safe_map(write, jaxpr.invars, flat_args)
|
||||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||||
|
|
||||||
# Bind args and consts to environment
|
# Bind args and consts to environment
|
||||||
write(jax.core.unitvar, jax.core.unit)
|
|
||||||
jax_util.safe_map(write, jaxpr.invars, flat_args)
|
jax_util.safe_map(write, jaxpr.invars, flat_args)
|
||||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||||
|
|
||||||
@@ -354,8 +360,6 @@ def var_to_str(var):
|
|||||||
"""Returns a string representation of the variable of a Jax expression."""
|
"""Returns a string representation of the variable of a Jax expression."""
|
||||||
if isinstance(var, jax.core.Literal):
|
if isinstance(var, jax.core.Literal):
|
||||||
return str(var)
|
return str(var)
|
||||||
elif isinstance(var, jax.core.UnitVar):
|
|
||||||
return "*"
|
|
||||||
elif not isinstance(var, jax.core.Var):
|
elif not isinstance(var, jax.core.Var):
|
||||||
raise ValueError(f"Idk what to do with this {type(var)}?")
|
raise ValueError(f"Idk what to do with this {type(var)}?")
|
||||||
c = int(var.count)
|
c = int(var.count)
|
||||||
@@ -388,7 +392,7 @@ def fill_jaxpr_to_graph(graph, jaxpr, in_vars=None, out_vars=None):
|
|||||||
for eqn in jaxpr.eqns:
|
for eqn in jaxpr.eqns:
|
||||||
in_vars = []
|
in_vars = []
|
||||||
for v in eqn.invars:
|
for v in eqn.invars:
|
||||||
if isinstance(v, (jax.core.Literal, jax.core.UnitVar)):
|
if isinstance(v, jax.core.Literal):
|
||||||
in_vars.append(var_to_str(v))
|
in_vars.append(var_to_str(v))
|
||||||
else:
|
else:
|
||||||
in_vars.append(in_map.get(v, var_to_str(v)))
|
in_vars.append(in_map.get(v, var_to_str(v)))
|
||||||
@@ -406,8 +410,7 @@ def fill_jaxpr_to_graph(graph, jaxpr, in_vars=None, out_vars=None):
|
|||||||
|
|
||||||
# Create incoming edges
|
# Create incoming edges
|
||||||
for v, name in zip(eqn.invars, in_vars):
|
for v, name in zip(eqn.invars, in_vars):
|
||||||
if (not isinstance(v, jax.core.Literal) and
|
if not isinstance(v, jax.core.Literal):
|
||||||
not isinstance(v, jax.core.UnitVar)):
|
|
||||||
graph.add_edge(name, node_c)
|
graph.add_edge(name, node_c)
|
||||||
|
|
||||||
# Create output nodes and edges
|
# Create output nodes and edges
|
||||||
@@ -573,7 +576,6 @@ def auto_register_tags(func,
|
|||||||
env[var] = val
|
env[var] = val
|
||||||
|
|
||||||
# Bind args and consts to environment
|
# Bind args and consts to environment
|
||||||
write(jax.core.unitvar, jax.core.unit)
|
|
||||||
jax_util.safe_map(write, graph.jaxpr.invars, flat_args)
|
jax_util.safe_map(write, graph.jaxpr.invars, flat_args)
|
||||||
jax_util.safe_map(write, graph.jaxpr.constvars, graph.consts)
|
jax_util.safe_map(write, graph.jaxpr.constvars, graph.consts)
|
||||||
|
|
||||||
|
|||||||
@@ -61,10 +61,8 @@ class TestGraphMatcher(unittest.TestCase):
|
|||||||
self.assertEqual(len(jaxpr.constvars), len(tagged_jaxpr.constvars))
|
self.assertEqual(len(jaxpr.constvars), len(tagged_jaxpr.constvars))
|
||||||
self.assertEqual(len(jaxpr.outvars), len(tagged_jaxpr.outvars))
|
self.assertEqual(len(jaxpr.outvars), len(tagged_jaxpr.outvars))
|
||||||
for eq, tagged_eq in zip(jaxpr.eqns, tagged_jaxpr.eqns):
|
for eq, tagged_eq in zip(jaxpr.eqns, tagged_jaxpr.eqns):
|
||||||
eq_in_vars = [v for v in eq.invars if not isinstance(v, jax.core.UnitVar)]
|
eq_in_vars = [v for v in eq.invars]
|
||||||
tagged_in_vars = [
|
tagged_in_vars = [v for v in tagged_eq.invars]
|
||||||
v for v in tagged_eq.invars if not isinstance(v, jax.core.UnitVar)
|
|
||||||
]
|
|
||||||
self.assertEqual(len(eq_in_vars), len(tagged_in_vars))
|
self.assertEqual(len(eq_in_vars), len(tagged_in_vars))
|
||||||
self.assertEqual(len(eq.outvars), len(tagged_eq.outvars))
|
self.assertEqual(len(eq.outvars), len(tagged_eq.outvars))
|
||||||
self.assertEqual(eq.primitive, tagged_eq.primitive)
|
self.assertEqual(eq.primitive, tagged_eq.primitive)
|
||||||
|
|||||||
@@ -64,7 +64,6 @@ def construct_compute_losses_inputs(
|
|||||||
write = functools.partial(tgm.write_env, env)
|
write = functools.partial(tgm.write_env, env)
|
||||||
|
|
||||||
# Bind args and consts to environment
|
# Bind args and consts to environment
|
||||||
write(jax.core.unitvar, jax.core.unit)
|
|
||||||
jax_util.safe_map(write, jaxpr.invars, flat_args)
|
jax_util.safe_map(write, jaxpr.invars, flat_args)
|
||||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||||
|
|
||||||
@@ -230,7 +229,6 @@ def trace_estimator_vjp(tagged_func: _Function) -> _Function:
|
|||||||
write = functools.partial(tgm.write_env, env)
|
write = functools.partial(tgm.write_env, env)
|
||||||
|
|
||||||
# Bind args and consts to environment
|
# Bind args and consts to environment
|
||||||
write(jax.core.unitvar, jax.core.unit)
|
|
||||||
jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0])
|
jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0])
|
||||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||||
|
|
||||||
@@ -253,12 +251,11 @@ def trace_estimator_vjp(tagged_func: _Function) -> _Function:
|
|||||||
env = dict()
|
env = dict()
|
||||||
read = functools.partial(tgm.read_env, env)
|
read = functools.partial(tgm.read_env, env)
|
||||||
def write(var, val):
|
def write(var, val):
|
||||||
if not isinstance(var, (jax.core.Literal, jax.core.UnitVar)):
|
if not isinstance(var, jax.core.Literal):
|
||||||
val = val + aux[var] if var in aux else val
|
val = val + aux[var] if var in aux else val
|
||||||
env[var] = val
|
env[var] = val
|
||||||
|
|
||||||
# Bind args and consts to environment
|
# Bind args and consts to environment
|
||||||
write(jax.core.unitvar, jax.core.unit)
|
|
||||||
jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0])
|
jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0])
|
||||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user