mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-30 20:35:25 +08:00
Copybara import of the project:
-- ef2021392eedb9242636241d42625eed51c696d4 by Sharad Vikram <sharad.vikram@gmail.com>: Adds simple effect types to jaxprs PiperOrigin-RevId: 441083960
This commit is contained in:
@@ -80,6 +80,12 @@ class LossTag(jax_core.Primitive):
|
|||||||
return self.get_outputs(*args, weight=weight, return_loss=return_loss)
|
return self.get_outputs(*args, weight=weight, return_loss=return_loss)
|
||||||
|
|
||||||
def abstract_eval(self, *args, weight: float, return_loss: bool, **kwargs):
|
def abstract_eval(self, *args, weight: float, return_loss: bool, **kwargs):
|
||||||
|
jax_version = (
|
||||||
|
jax.__version_info__ if hasattr(jax, "__version_info__")
|
||||||
|
else tuple(map(int, jax.__version__.split("."))))
|
||||||
|
if jax_version > (0, 3, 4):
|
||||||
|
return (self.get_outputs(*args, weight=weight, return_loss=return_loss),
|
||||||
|
jax_core.no_effects)
|
||||||
return self.get_outputs(*args, weight=weight, return_loss=return_loss)
|
return self.get_outputs(*args, weight=weight, return_loss=return_loss)
|
||||||
|
|
||||||
def xla_translation(
|
def xla_translation(
|
||||||
@@ -176,6 +182,11 @@ class LayerTag(jax_core.Primitive):
|
|||||||
return self.get_outputs(*operands, **kwargs)
|
return self.get_outputs(*operands, **kwargs)
|
||||||
|
|
||||||
def abstract_eval(self, *abstract_operands, **kwargs):
|
def abstract_eval(self, *abstract_operands, **kwargs):
|
||||||
|
jax_version = (
|
||||||
|
jax.__version_info__ if hasattr(jax, "__version_info__")
|
||||||
|
else tuple(map(int, jax.__version__.split("."))))
|
||||||
|
if jax_version > (0, 3, 4):
|
||||||
|
return self.get_outputs(*abstract_operands, **kwargs), jax_core.no_effects
|
||||||
return self.get_outputs(*abstract_operands, **kwargs)
|
return self.get_outputs(*abstract_operands, **kwargs)
|
||||||
|
|
||||||
def batching(self, batched_operands, batched_dims, **kwargs):
|
def batching(self, batched_operands, batched_dims, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user