mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-19 19:01:30 +08:00
Migrate away from using JaxTestCase in tests
Why? JaxTestCase is deprecated for use outside the JAX project as of version 0.3.1; see https://jax.readthedocs.io/en/latest/changelog.html#jax-0-3-1-feb-18-2022 PiperOrigin-RevId: 432390267
This commit is contained in:
committed by
alimuldal
parent
f30fae8dbc
commit
255e4e1256
@@ -12,11 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jnr
|
||||
import jax.test_util as jtu
|
||||
|
||||
from kfac_ferminet_alpha import layers_and_loss_tags
|
||||
from kfac_ferminet_alpha import loss_functions
|
||||
@@ -44,8 +45,7 @@ def tagged_autoencoder(all_params, x_in):
|
||||
return [[h1, t2], [h2, t2]]
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="allow")
|
||||
class TestGraphMatcher(jtu.JaxTestCase):
|
||||
class TestGraphMatcher(unittest.TestCase):
|
||||
"""Class for running all of the tests for integrating the systems."""
|
||||
|
||||
def _test_jaxpr(self, init_func, model_func, tagged_model, data_shape):
|
||||
|
||||
@@ -12,11 +12,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jnr
|
||||
import jax.test_util as jtu
|
||||
import numpy as np
|
||||
|
||||
from kfac_ferminet_alpha import loss_functions
|
||||
from kfac_ferminet_alpha import tag_graph_matcher as tgm
|
||||
@@ -44,8 +46,7 @@ def autoencoder_aux(all_aux, all_params, x_in):
|
||||
return [l1, l2 * 0.1], layers_values
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="allow")
|
||||
class TestTracer(jtu.JaxTestCase):
|
||||
class TestTracer(unittest.TestCase):
|
||||
"""Class for running all of the tests for integrating the systems."""
|
||||
|
||||
@staticmethod
|
||||
@@ -66,13 +67,14 @@ class TestTracer(jtu.JaxTestCase):
|
||||
|
||||
return params, data, p_tangents, h_tangents
|
||||
|
||||
def assertStructureAllClose(self, x, y, **kwargs):
|
||||
def assertStructureAllClose(self, x, y, rtol=1E-5, atol=1E-5, **kwargs):
|
||||
x_v, x_tree = jax.tree_flatten(x)
|
||||
y_v, y_tree = jax.tree_flatten(y)
|
||||
self.assertEqual(x_tree, y_tree)
|
||||
for xi, yi in zip(x_v, y_v):
|
||||
self.assertEqual(xi.shape, yi.shape)
|
||||
self.assertAllClose(xi, yi, check_dtypes=True, **kwargs)
|
||||
self.assertEqual(xi.dtype, yi.dtype)
|
||||
np.testing.assert_allclose(xi, yi, rtol=rtol, atol=atol, **kwargs)
|
||||
|
||||
def test_tacer_jvp(self):
|
||||
init_func = common.init_autoencoder
|
||||
|
||||
Reference in New Issue
Block a user