Enable the checkpoint path containing the DM21 weights to be manually specified.

PiperOrigin-RevId: 424835062
This commit is contained in:
James Spencer
2022-01-28 12:07:39 +00:00
committed by Diego de Las Casas
parent 7b427b5161
commit bc869b25d2
@@ -167,17 +167,26 @@ class NeuralNumInt(numint.NumInt):
mf.kernel()
"""
def __init__(self, functional: Functional):
def __init__(self,
functional: Functional,
*,
checkpoint_path: Optional[str] = None):
"""Constructs a NeuralNumInt object.
Args:
functional: member of Functional enum giving the name of the
functional.
checkpoint_path: Optional path to specify the directory containing the
checkpoints of the DM21 family of functionals. If not specified, attempt
to find the checkpoints using a path relative to the source code.
"""
self._functional_name = functional.name
self._model_path = os.path.join(
os.path.dirname(__file__), 'checkpoints', self._functional_name)
if checkpoint_path:
self._model_path = os.path.join(checkpoint_path, self._functional_name)
else:
self._model_path = os.path.join(
os.path.dirname(__file__), 'checkpoints', self._functional_name)
# All DM21 functionals use local Hartree-Fock features with a non-range
# separated 1/r kernel and a range-seperated kernel with \omega = 0.4.