369 Commits

Author SHA1 Message Date
David Pfau f5de0ede84 Add code for generating s3o4d data
PiperOrigin-RevId: 535584807
2023-06-02 18:04:50 +01:00
Rebecca Chen 9176a9f23c Silence some pytype errors.
PiperOrigin-RevId: 527929776
2023-06-02 18:04:36 +01:00
Alvaro Sanchez-Gonzalez fb1d757863 pytype fix
PiperOrigin-RevId: 525432152
2023-06-02 18:04:22 +01:00
Jake VanderPlas f905943c13 Remove references to deprecated jax.ShapedArray
This is deprecated as of https://github.com/google/jax/pull/15263: most users will never need to use ShapedArray directly, and so having it exposed in the top-level public namespace causes undue confusion.

PiperOrigin-RevId: 520846838
2023-06-02 18:04:08 +01:00
Jake VanderPlas 9d01171d43 Replace references to deprecated jax.curry function.
This is deprecated as of https://github.com/google/jax/pull/15263

PiperOrigin-RevId: 520269385
2023-06-02 18:03:54 +01:00
Peter Hawkins 82a347438f Fix pytype failures related to teaching pytype about NumPy scalar types.
PiperOrigin-RevId: 519205179
2023-06-02 18:03:40 +01:00
John Cater e2d21540e8 Internal change
PiperOrigin-RevId: 517935523
2023-06-02 18:03:27 +01:00
Peter Hawkins d988ff1bf2 Replaces references to jax.numpy.DeviceArray with jax.Array.
PiperOrigin-RevId: 515678285
2023-06-02 18:03:13 +01:00
Rebecca Chen 0824c28deb Silence some pytype errors.
PiperOrigin-RevId: 515579955
2023-06-02 18:02:59 +01:00
Peter Hawkins 784f67565e Suppress some pytype errors related to jnp.DeviceArray == jax.Array.
PiperOrigin-RevId: 514986169
2023-06-02 18:02:45 +01:00
Peter Hawkins 6f0ddef7da Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
PiperOrigin-RevId: 512349622
2023-06-02 18:02:30 +01:00
Peter Hawkins c051e6a51d Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
PiperOrigin-RevId: 511294746
2023-06-02 18:02:16 +01:00
Peter Hawkins 797ea3c71d [JAX] Remove obsolete unit type declarations in jax.core.
Remove obsolete unit test in host_callback.

PiperOrigin-RevId: 507473737
2023-06-02 18:01:57 +01:00
James Spencer cb555c241b Update DM21 pinned requirements and bazel config for compiling functional to C++.
PiperOrigin-RevId: 505176617
2023-01-31 17:17:39 +00:00
Peter Hawkins c7e2ef28be [NumPy] Remove references to deprecated NumPy type aliases.
This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str).

NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy.

PiperOrigin-RevId: 501824490
2023-01-31 17:17:28 +00:00
Rebecca Chen 3c1aa70723 Silence some pytype errors.
PiperOrigin-RevId: 496746673
2023-01-31 17:17:16 +00:00
Peter Hawkins 4caaad8b0a [NumPy] Remove references to deprecated NumPy type aliases.
This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str).

NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy.

PiperOrigin-RevId: 495851585
2023-01-31 17:17:04 +00:00
Rebecca Chen ea772958de Silence some pytype errors.
PiperOrigin-RevId: 491299111
2023-01-31 17:16:51 +00:00
James Spencer e1e065d63e Update bazel configuration and dependencies for DM21
Resolves https://github.com/deepmind/deepmind-research/issues/391

PiperOrigin-RevId: 483933786
2023-01-31 17:16:31 +00:00
Alvaro Sanchez-Gonzalez a318c69018 Fixing bug in open source implementation of OGB-LSC/PCQ.
PiperOrigin-RevId: 483319710
2022-10-24 11:16:37 +00:00
James Spencer a748a7c817 Set small_rho_cutoff in DM21 tests to match pyscf 2.0 behaviour.
PiperOrigin-RevId: 480303500
2022-10-12 10:32:48 +00:00
James Spencer a8dab395c8 Lift restriction on h5py now pyscf 2 is widely available.
PiperOrigin-RevId: 480303233
2022-10-12 10:32:38 +00:00
James Spencer affa162ac7 Return exc, vxc as float64 arrays.
PiperOrigin-RevId: 480303057
2022-10-12 10:32:27 +00:00
James Spencer 789bc38a2e Add unused omega argument to NeuralNumint.eval_xc to match pyscf interface
PiperOrigin-RevId: 480302784
2022-10-12 10:32:17 +00:00
Yash Katariya da0f2de14d Replace jnp.DeviceArray with the new public type jax.Array.
PiperOrigin-RevId: 477558995
2022-10-12 10:32:06 +00:00
Yilei Yang 58fb45db7e Make this code compatible with Python 3.10.
PiperOrigin-RevId: 473700513
2022-10-12 10:31:51 +00:00
Yilei Yang 3af71dd9a7 Make this code compatible with Python 3.10.
PiperOrigin-RevId: 473457191
2022-10-12 10:31:41 +00:00
Victoria Krakovna 586a1c55de fix side effects baseline test flakiness due to corner case
PiperOrigin-RevId: 463419361
2022-10-12 10:31:17 +00:00
Jake VanderPlas 6fcb84268e Use jax.tree_util.tree_map in place of deprecated tree_multimap.
The latter is a simple alias of the former, so this change is a no-op.

PiperOrigin-RevId: 461229165
2022-07-24 17:53:28 +01:00
Jake VanderPlas 956c4b5d9c Use jax.tree_util.tree_map in place of deprecated tree_multimap.
The latter is a simple alias of the former, so this change is a no-op.

PiperOrigin-RevId: 461045645
2022-07-24 17:53:14 +01:00
Kyle Taylor 11c2ab53e8 Update enformer-training.ipynb to show how to restore from checkpoint.
PiperOrigin-RevId: 450894644
2022-05-26 17:45:33 +01:00
Kyle Taylor 57456a001d Minor changes to match trained SOTA model.
PiperOrigin-RevId: 450716647
2022-05-26 17:45:21 +01:00
Saran Tunyasuvunakool d436681054 Remove dependency on kernel-pruner.
PiperOrigin-RevId: 449460339
2022-05-26 17:45:11 +01:00
James Spencer a97d8b1807 Remove conda install instructions for DM21
PiperOrigin-RevId: 447710907
2022-05-26 17:45:01 +01:00
Yilei Yang f4916dadab Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 443298684
2022-05-26 17:44:51 +01:00
Yilei Yang 94702704c8 Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 441399237
2022-05-26 17:44:40 +01:00
Yilei Yang 5501e3237b Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 441399134
2022-05-26 17:44:30 +01:00
Sharad Vikram 81dedafa1d Copybara import of the project:
--
ef2021392eedb9242636241d42625eed51c696d4 by Sharad Vikram <sharad.vikram@gmail.com>:

Adds simple effect types to jaxprs

PiperOrigin-RevId: 441083960
2022-05-26 17:44:18 +01:00
Yilei Yang 34b2486902 Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 440907738
2022-05-26 17:44:09 +01:00
Yilei Yang 152ac280f2 Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 440896914
2022-05-26 17:43:58 +01:00
Yilei Yang 9a526bd35b Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 440880484
2022-05-26 17:43:48 +01:00
DeepMind Team 92a307a920 Explicitly import estimator from tensorflow as a separate import instead of accessing it via tf.estimator and depend on the tensorflow estimator target.
PiperOrigin-RevId: 436950450
2022-05-26 17:43:38 +01:00
Rebecca Chen 464939ede1 Silence some pytype errors.
PiperOrigin-RevId: 435606303
2022-05-26 17:43:25 +01:00
Sven Gowal 8e24fbbb29 Added new models.
PiperOrigin-RevId: 434780632
2022-05-26 17:43:12 +01:00
Sven Gowal b2fa23c838 Allow alternative functions to load extra data.
PiperOrigin-RevId: 434495304
2022-05-26 17:43:03 +01:00
Nimrod Gileadi 3cb13ea0b9 Use new MuJoCo python bindings to implement dm_control/mujoco.
PiperOrigin-RevId: 432958654
2022-05-26 17:42:53 +01:00
Jake VanderPlas 255e4e1256 Migrate away from using JaxTestCase in tests
Why? JaxTestCase is deprecated for use outside the JAX project as of version 0.3.1; see https://jax.readthedocs.io/en/latest/changelog.html#jax-0-3-1-feb-18-2022

PiperOrigin-RevId: 432390267
2022-05-26 17:42:43 +01:00
Piotr Stanczyk f30fae8dbc Internal change
PiperOrigin-RevId: 431903167
2022-05-26 17:42:33 +01:00
Albin Cassirer ded41df440 Expose times_sampled in SampleInfo.
PiperOrigin-RevId: 430170716
2022-05-26 17:42:11 +01:00
Diego de Las Casas 1642ae3499 Release fusion_tcv
PiperOrigin-RevId: 429049702
2022-02-16 16:11:04 +00:00