mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-21 14:33:42 +08:00
[JAX] Relax test tolerances.
These tests fail with an upcoming change to JAX that adds `jit` decorators around a number of JAX standard library functions. This appears to cause some very small numerical changes, most likely related to slightly different optimization opportunities available to the compiler. PiperOrigin-RevId: 389039057
This commit is contained in:
committed by
Diego de Las Casas
parent
3eddada361
commit
ac7a19a9b2
@@ -122,7 +122,7 @@ class TestTracer(jtu.JaxTestCase):
|
||||
tracer_p_tangents = tracer_vjp_func(h_tangents)
|
||||
|
||||
self.assertStructureAllClose(loss_vals, tracer_losses)
|
||||
self.assertStructureAllClose(p_tangents, tracer_p_tangents)
|
||||
self.assertStructureAllClose(p_tangents, tracer_p_tangents, atol=3e-6)
|
||||
|
||||
def test_tracer_hvp(self):
|
||||
init_func = common.init_autoencoder
|
||||
@@ -191,7 +191,7 @@ class TestTracer(jtu.JaxTestCase):
|
||||
tracer_outputs = tracer_vjp_func((h_tangents[:1], h_tangents[1:]))
|
||||
|
||||
self.assertStructureAllClose(loss_vals, tracer_losses)
|
||||
self.assertStructureAllClose(tracer_outputs, layers_info)
|
||||
self.assertStructureAllClose(tracer_outputs, layers_info, atol=3e-6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user