mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-26 01:15:26 +08:00
Enable the checkpoint path containing the DM21 weights to be manually specified.
PiperOrigin-RevId: 424835062
This commit is contained in:
committed by
Diego de Las Casas
parent
7b427b5161
commit
bc869b25d2
+12
-3
@@ -167,17 +167,26 @@ class NeuralNumInt(numint.NumInt):
|
||||
mf.kernel()
|
||||
"""
|
||||
|
||||
def __init__(self, functional: Functional):
|
||||
def __init__(self,
|
||||
functional: Functional,
|
||||
*,
|
||||
checkpoint_path: Optional[str] = None):
|
||||
"""Constructs a NeuralNumInt object.
|
||||
|
||||
Args:
|
||||
functional: member of Functional enum giving the name of the
|
||||
functional.
|
||||
checkpoint_path: Optional path to specify the directory containing the
|
||||
checkpoints of the DM21 family of functionals. If not specified, attempt
|
||||
to find the checkpoints using a path relative to the source code.
|
||||
"""
|
||||
|
||||
self._functional_name = functional.name
|
||||
self._model_path = os.path.join(
|
||||
os.path.dirname(__file__), 'checkpoints', self._functional_name)
|
||||
if checkpoint_path:
|
||||
self._model_path = os.path.join(checkpoint_path, self._functional_name)
|
||||
else:
|
||||
self._model_path = os.path.join(
|
||||
os.path.dirname(__file__), 'checkpoints', self._functional_name)
|
||||
|
||||
# All DM21 functionals use local Hartree-Fock features with a non-range
|
||||
# separated 1/r kernel and a range-seperated kernel with \omega = 0.4.
|
||||
|
||||
Reference in New Issue
Block a user