[JAX] Relax test tolerances.

These tests fail with an upcoming change to JAX that adds `jit` decorators around a number of JAX standard library functions. This appears to cause some very small numerical changes, most likely related to slightly different optimization opportunities available to the compiler.

PiperOrigin-RevId: 389039057
This commit is contained in:
Peter Hawkins
2021-08-05 23:30:48 +01:00
committed by Diego de Las Casas
parent 3eddada361
commit ac7a19a9b2
+2 -2
View File
@@ -122,7 +122,7 @@ class TestTracer(jtu.JaxTestCase):
tracer_p_tangents = tracer_vjp_func(h_tangents)
self.assertStructureAllClose(loss_vals, tracer_losses)
self.assertStructureAllClose(p_tangents, tracer_p_tangents)
self.assertStructureAllClose(p_tangents, tracer_p_tangents, atol=3e-6)
def test_tracer_hvp(self):
init_func = common.init_autoencoder
@@ -191,7 +191,7 @@ class TestTracer(jtu.JaxTestCase):
tracer_outputs = tracer_vjp_func((h_tangents[:1], h_tangents[1:]))
self.assertStructureAllClose(loss_vals, tracer_losses)
self.assertStructureAllClose(tracer_outputs, layers_info)
self.assertStructureAllClose(tracer_outputs, layers_info, atol=3e-6)
if __name__ == "__main__":