diff --git a/kfac_ferminet_alpha/tests/graph_matcher_test.py b/kfac_ferminet_alpha/tests/graph_matcher_test.py index 4dbf35c..a79d487 100644 --- a/kfac_ferminet_alpha/tests/graph_matcher_test.py +++ b/kfac_ferminet_alpha/tests/graph_matcher_test.py @@ -12,11 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import unittest + from absl.testing import absltest import jax import jax.numpy as jnp import jax.random as jnr -import jax.test_util as jtu from kfac_ferminet_alpha import layers_and_loss_tags from kfac_ferminet_alpha import loss_functions @@ -44,8 +45,7 @@ def tagged_autoencoder(all_params, x_in): return [[h1, t2], [h2, t2]] -@jtu.with_config(jax_numpy_rank_promotion="allow") -class TestGraphMatcher(jtu.JaxTestCase): +class TestGraphMatcher(unittest.TestCase): """Class for running all of the tests for integrating the systems.""" def _test_jaxpr(self, init_func, model_func, tagged_model, data_shape): diff --git a/kfac_ferminet_alpha/tests/tracer_test.py b/kfac_ferminet_alpha/tests/tracer_test.py index 51a739c..87e9065 100644 --- a/kfac_ferminet_alpha/tests/tracer_test.py +++ b/kfac_ferminet_alpha/tests/tracer_test.py @@ -12,11 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import unittest + from absl.testing import absltest import jax import jax.numpy as jnp import jax.random as jnr -import jax.test_util as jtu +import numpy as np from kfac_ferminet_alpha import loss_functions from kfac_ferminet_alpha import tag_graph_matcher as tgm @@ -44,8 +46,7 @@ def autoencoder_aux(all_aux, all_params, x_in): return [l1, l2 * 0.1], layers_values -@jtu.with_config(jax_numpy_rank_promotion="allow") -class TestTracer(jtu.JaxTestCase): +class TestTracer(unittest.TestCase): """Class for running all of the tests for integrating the systems.""" @staticmethod @@ -66,13 +67,14 @@ class TestTracer(jtu.JaxTestCase): return params, data, p_tangents, h_tangents - def assertStructureAllClose(self, x, y, **kwargs): + def assertStructureAllClose(self, x, y, rtol=1E-5, atol=1E-5, **kwargs): x_v, x_tree = jax.tree_flatten(x) y_v, y_tree = jax.tree_flatten(y) self.assertEqual(x_tree, y_tree) for xi, yi in zip(x_v, y_v): self.assertEqual(xi.shape, yi.shape) - self.assertAllClose(xi, yi, check_dtypes=True, **kwargs) + self.assertEqual(xi.dtype, yi.dtype) + np.testing.assert_allclose(xi, yi, rtol=rtol, atol=atol, **kwargs) def test_tacer_jvp(self): init_func = common.init_autoencoder