absl-py==0.10.0
dm-haiku
jax>=0.2.13
jraph
nltk>=3.6.2
numpy>=1.19.5
optax>=0.0.6
scikit-learn>=0.24.2
