diff --git a/kfac_ferminet_alpha/tests/tracer_test.py b/kfac_ferminet_alpha/tests/tracer_test.py index 8236d4d..72d472b 100644 --- a/kfac_ferminet_alpha/tests/tracer_test.py +++ b/kfac_ferminet_alpha/tests/tracer_test.py @@ -122,7 +122,7 @@ class TestTracer(jtu.JaxTestCase): tracer_p_tangents = tracer_vjp_func(h_tangents) self.assertStructureAllClose(loss_vals, tracer_losses) - self.assertStructureAllClose(p_tangents, tracer_p_tangents) + self.assertStructureAllClose(p_tangents, tracer_p_tangents, atol=3e-6) def test_tracer_hvp(self): init_func = common.init_autoencoder @@ -191,7 +191,7 @@ class TestTracer(jtu.JaxTestCase): tracer_outputs = tracer_vjp_func((h_tangents[:1], h_tangents[1:])) self.assertStructureAllClose(loss_vals, tracer_losses) - self.assertStructureAllClose(tracer_outputs, layers_info) + self.assertStructureAllClose(tracer_outputs, layers_info, atol=3e-6) if __name__ == "__main__":