Enable the checkpoint path containing the DM21 weights to be manually specified.

PiperOrigin-RevId: 424835062
This commit is contained in:
James Spencer
2022-01-28 12:07:39 +00:00
committed by Diego de Las Casas
parent 7b427b5161
commit bc869b25d2
@@ -167,15 +167,24 @@ class NeuralNumInt(numint.NumInt):
mf.kernel() mf.kernel()
""" """
def __init__(self, functional: Functional): def __init__(self,
functional: Functional,
*,
checkpoint_path: Optional[str] = None):
"""Constructs a NeuralNumInt object. """Constructs a NeuralNumInt object.
Args: Args:
functional: member of Functional enum giving the name of the functional: member of Functional enum giving the name of the
functional. functional.
checkpoint_path: Optional path to specify the directory containing the
checkpoints of the DM21 family of functionals. If not specified, attempt
to find the checkpoints using a path relative to the source code.
""" """
self._functional_name = functional.name self._functional_name = functional.name
if checkpoint_path:
self._model_path = os.path.join(checkpoint_path, self._functional_name)
else:
self._model_path = os.path.join( self._model_path = os.path.join(
os.path.dirname(__file__), 'checkpoints', self._functional_name) os.path.dirname(__file__), 'checkpoints', self._functional_name)