mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-11 22:27:15 +08:00
Copybara import of the project:
-- ef2021392eedb9242636241d42625eed51c696d4 by Sharad Vikram <sharad.vikram@gmail.com>: Adds simple effect types to jaxprs PiperOrigin-RevId: 441083960
This commit is contained in:
@@ -80,6 +80,12 @@ class LossTag(jax_core.Primitive):
|
||||
return self.get_outputs(*args, weight=weight, return_loss=return_loss)
|
||||
|
||||
def abstract_eval(self, *args, weight: float, return_loss: bool, **kwargs):
|
||||
jax_version = (
|
||||
jax.__version_info__ if hasattr(jax, "__version_info__")
|
||||
else tuple(map(int, jax.__version__.split("."))))
|
||||
if jax_version > (0, 3, 4):
|
||||
return (self.get_outputs(*args, weight=weight, return_loss=return_loss),
|
||||
jax_core.no_effects)
|
||||
return self.get_outputs(*args, weight=weight, return_loss=return_loss)
|
||||
|
||||
def xla_translation(
|
||||
@@ -176,6 +182,11 @@ class LayerTag(jax_core.Primitive):
|
||||
return self.get_outputs(*operands, **kwargs)
|
||||
|
||||
def abstract_eval(self, *abstract_operands, **kwargs):
|
||||
jax_version = (
|
||||
jax.__version_info__ if hasattr(jax, "__version_info__")
|
||||
else tuple(map(int, jax.__version__.split("."))))
|
||||
if jax_version > (0, 3, 4):
|
||||
return self.get_outputs(*abstract_operands, **kwargs), jax_core.no_effects
|
||||
return self.get_outputs(*abstract_operands, **kwargs)
|
||||
|
||||
def batching(self, batched_operands, batched_dims, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user