From 81dedafa1d52ceb633defee389a7d73ce58f4c4e Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 12 Apr 2022 06:04:39 +0100 Subject: [PATCH] Copybara import of the project: -- ef2021392eedb9242636241d42625eed51c696d4 by Sharad Vikram : Adds simple effect types to jaxprs PiperOrigin-RevId: 441083960 --- kfac_ferminet_alpha/layers_and_loss_tags.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/kfac_ferminet_alpha/layers_and_loss_tags.py b/kfac_ferminet_alpha/layers_and_loss_tags.py index 92fcfdb..b396bab 100644 --- a/kfac_ferminet_alpha/layers_and_loss_tags.py +++ b/kfac_ferminet_alpha/layers_and_loss_tags.py @@ -80,6 +80,12 @@ class LossTag(jax_core.Primitive): return self.get_outputs(*args, weight=weight, return_loss=return_loss) def abstract_eval(self, *args, weight: float, return_loss: bool, **kwargs): + jax_version = ( + jax.__version_info__ if hasattr(jax, "__version_info__") + else tuple(map(int, jax.__version__.split(".")))) + if jax_version > (0, 3, 4): + return (self.get_outputs(*args, weight=weight, return_loss=return_loss), + jax_core.no_effects) return self.get_outputs(*args, weight=weight, return_loss=return_loss) def xla_translation( @@ -176,6 +182,11 @@ class LayerTag(jax_core.Primitive): return self.get_outputs(*operands, **kwargs) def abstract_eval(self, *abstract_operands, **kwargs): + jax_version = ( + jax.__version_info__ if hasattr(jax, "__version_info__") + else tuple(map(int, jax.__version__.split(".")))) + if jax_version > (0, 3, 4): + return self.get_outputs(*abstract_operands, **kwargs), jax_core.no_effects return self.get_outputs(*abstract_operands, **kwargs) def batching(self, batched_operands, batched_dims, **kwargs):