Explicitly import estimator from tensorflow as a separate import instead of accessing it via tf.estimator and depend on the tensorflow estimator target.

PiperOrigin-RevId: 436950450
This commit is contained in:
DeepMind Team
2022-03-24 10:07:15 +00:00
committed by alimuldal
parent 464939ede1
commit 92a307a920
+5 -4
View File
@@ -40,6 +40,7 @@ from absl import flags
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import tree
@@ -372,7 +373,7 @@ def get_one_step_estimator_fn(data_path,
'one_step_position_mse': tf.metrics.mean_squared_error(
predicted_next_position, target_next_position)
}
return tf.estimator.EstimatorSpec(
return tf_estimator.EstimatorSpec(
mode=mode,
train_op=train_op,
loss=loss,
@@ -414,7 +415,7 @@ def get_rollout_estimator_fn(data_path,
# Add a leading axis, since Estimator's predict method insists that all
# tensors have a shared leading batch axis fo the same dims.
rollout_op = tree.map_structure(lambda x: x[tf.newaxis], rollout_op)
return tf.estimator.EstimatorSpec(
return tf_estimator.EstimatorSpec(
mode=mode,
train_op=None,
loss=loss,
@@ -433,7 +434,7 @@ def main(_):
"""Train or evaluates the model."""
if FLAGS.mode in ['train', 'eval']:
estimator = tf.estimator.Estimator(
estimator = tf_estimator.Estimator(
get_one_step_estimator_fn(FLAGS.data_path, FLAGS.noise_std),
model_dir=FLAGS.model_path)
if FLAGS.mode == 'train':
@@ -452,7 +453,7 @@ def main(_):
elif FLAGS.mode == 'eval_rollout':
if not FLAGS.output_path:
raise ValueError('A rollout path must be provided.')
rollout_estimator = tf.estimator.Estimator(
rollout_estimator = tf_estimator.Estimator(
get_rollout_estimator_fn(FLAGS.data_path, FLAGS.noise_std),
model_dir=FLAGS.model_path)