Migrate away from using JaxTestCase in tests

Why? JaxTestCase is deprecated for use outside the JAX project as of version 0.3.1; see https://jax.readthedocs.io/en/latest/changelog.html#jax-0-3-1-feb-18-2022

PiperOrigin-RevId: 432390267
This commit is contained in:
Jake VanderPlas
2022-03-04 09:44:03 +00:00
committed by alimuldal
parent f30fae8dbc
commit 255e4e1256
2 changed files with 10 additions and 8 deletions
@@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from absl.testing import absltest
import jax
import jax.numpy as jnp
import jax.random as jnr
import jax.test_util as jtu
from kfac_ferminet_alpha import layers_and_loss_tags
from kfac_ferminet_alpha import loss_functions
@@ -44,8 +45,7 @@ def tagged_autoencoder(all_params, x_in):
return [[h1, t2], [h2, t2]]
@jtu.with_config(jax_numpy_rank_promotion="allow")
class TestGraphMatcher(jtu.JaxTestCase):
class TestGraphMatcher(unittest.TestCase):
"""Class for running all of the tests for integrating the systems."""
def _test_jaxpr(self, init_func, model_func, tagged_model, data_shape):
+7 -5
View File
@@ -12,11 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from absl.testing import absltest
import jax
import jax.numpy as jnp
import jax.random as jnr
import jax.test_util as jtu
import numpy as np
from kfac_ferminet_alpha import loss_functions
from kfac_ferminet_alpha import tag_graph_matcher as tgm
@@ -44,8 +46,7 @@ def autoencoder_aux(all_aux, all_params, x_in):
return [l1, l2 * 0.1], layers_values
@jtu.with_config(jax_numpy_rank_promotion="allow")
class TestTracer(jtu.JaxTestCase):
class TestTracer(unittest.TestCase):
"""Class for running all of the tests for integrating the systems."""
@staticmethod
@@ -66,13 +67,14 @@ class TestTracer(jtu.JaxTestCase):
return params, data, p_tangents, h_tangents
def assertStructureAllClose(self, x, y, **kwargs):
def assertStructureAllClose(self, x, y, rtol=1E-5, atol=1E-5, **kwargs):
x_v, x_tree = jax.tree_flatten(x)
y_v, y_tree = jax.tree_flatten(y)
self.assertEqual(x_tree, y_tree)
for xi, yi in zip(x_v, y_v):
self.assertEqual(xi.shape, yi.shape)
self.assertAllClose(xi, yi, check_dtypes=True, **kwargs)
self.assertEqual(xi.dtype, yi.dtype)
np.testing.assert_allclose(xi, yi, rtol=rtol, atol=atol, **kwargs)
def test_tacer_jvp(self):
init_func = common.init_autoencoder