mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-31 21:15:21 +08:00
Enable the checkpoint path containing the DM21 weights to be manually specified.
PiperOrigin-RevId: 424835062
This commit is contained in:
committed by
Diego de Las Casas
parent
7b427b5161
commit
bc869b25d2
+12
-3
@@ -167,17 +167,26 @@ class NeuralNumInt(numint.NumInt):
|
|||||||
mf.kernel()
|
mf.kernel()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, functional: Functional):
|
def __init__(self,
|
||||||
|
functional: Functional,
|
||||||
|
*,
|
||||||
|
checkpoint_path: Optional[str] = None):
|
||||||
"""Constructs a NeuralNumInt object.
|
"""Constructs a NeuralNumInt object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
functional: member of Functional enum giving the name of the
|
functional: member of Functional enum giving the name of the
|
||||||
functional.
|
functional.
|
||||||
|
checkpoint_path: Optional path to specify the directory containing the
|
||||||
|
checkpoints of the DM21 family of functionals. If not specified, attempt
|
||||||
|
to find the checkpoints using a path relative to the source code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._functional_name = functional.name
|
self._functional_name = functional.name
|
||||||
self._model_path = os.path.join(
|
if checkpoint_path:
|
||||||
os.path.dirname(__file__), 'checkpoints', self._functional_name)
|
self._model_path = os.path.join(checkpoint_path, self._functional_name)
|
||||||
|
else:
|
||||||
|
self._model_path = os.path.join(
|
||||||
|
os.path.dirname(__file__), 'checkpoints', self._functional_name)
|
||||||
|
|
||||||
# All DM21 functionals use local Hartree-Fock features with a non-range
|
# All DM21 functionals use local Hartree-Fock features with a non-range
|
||||||
# separated 1/r kernel and a range-seperated kernel with \omega = 0.4.
|
# separated 1/r kernel and a range-seperated kernel with \omega = 0.4.
|
||||||
|
|||||||
Reference in New Issue
Block a user