Copybara import of the project:

--
ef2021392eedb9242636241d42625eed51c696d4 by Sharad Vikram <sharad.vikram@gmail.com>:

Adds simple effect types to jaxprs

PiperOrigin-RevId: 441083960
This commit is contained in:
Sharad Vikram
2022-04-12 06:04:39 +01:00
committed by alimuldal
parent 34b2486902
commit 81dedafa1d
@@ -80,6 +80,12 @@ class LossTag(jax_core.Primitive):
return self.get_outputs(*args, weight=weight, return_loss=return_loss)
def abstract_eval(self, *args, weight: float, return_loss: bool, **kwargs):
jax_version = (
jax.__version_info__ if hasattr(jax, "__version_info__")
else tuple(map(int, jax.__version__.split("."))))
if jax_version > (0, 3, 4):
return (self.get_outputs(*args, weight=weight, return_loss=return_loss),
jax_core.no_effects)
return self.get_outputs(*args, weight=weight, return_loss=return_loss)
def xla_translation(
@@ -176,6 +182,11 @@ class LayerTag(jax_core.Primitive):
return self.get_outputs(*operands, **kwargs)
def abstract_eval(self, *abstract_operands, **kwargs):
jax_version = (
jax.__version_info__ if hasattr(jax, "__version_info__")
else tuple(map(int, jax.__version__.split("."))))
if jax_version > (0, 3, 4):
return self.get_outputs(*abstract_operands, **kwargs), jax_core.no_effects
return self.get_outputs(*abstract_operands, **kwargs)
def batching(self, batched_operands, batched_dims, **kwargs):