mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-21 23:07:29 +08:00
Explicitly import estimator from tensorflow as a separate import instead of accessing it via tf.estimator and depend on the tensorflow estimator target.
PiperOrigin-RevId: 436950450
This commit is contained in:
@@ -40,6 +40,7 @@ from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1 import estimator as tf_estimator
|
||||
import tree
|
||||
|
||||
|
||||
@@ -372,7 +373,7 @@ def get_one_step_estimator_fn(data_path,
|
||||
'one_step_position_mse': tf.metrics.mean_squared_error(
|
||||
predicted_next_position, target_next_position)
|
||||
}
|
||||
return tf.estimator.EstimatorSpec(
|
||||
return tf_estimator.EstimatorSpec(
|
||||
mode=mode,
|
||||
train_op=train_op,
|
||||
loss=loss,
|
||||
@@ -414,7 +415,7 @@ def get_rollout_estimator_fn(data_path,
|
||||
# Add a leading axis, since Estimator's predict method insists that all
|
||||
# tensors have a shared leading batch axis fo the same dims.
|
||||
rollout_op = tree.map_structure(lambda x: x[tf.newaxis], rollout_op)
|
||||
return tf.estimator.EstimatorSpec(
|
||||
return tf_estimator.EstimatorSpec(
|
||||
mode=mode,
|
||||
train_op=None,
|
||||
loss=loss,
|
||||
@@ -433,7 +434,7 @@ def main(_):
|
||||
"""Train or evaluates the model."""
|
||||
|
||||
if FLAGS.mode in ['train', 'eval']:
|
||||
estimator = tf.estimator.Estimator(
|
||||
estimator = tf_estimator.Estimator(
|
||||
get_one_step_estimator_fn(FLAGS.data_path, FLAGS.noise_std),
|
||||
model_dir=FLAGS.model_path)
|
||||
if FLAGS.mode == 'train':
|
||||
@@ -452,7 +453,7 @@ def main(_):
|
||||
elif FLAGS.mode == 'eval_rollout':
|
||||
if not FLAGS.output_path:
|
||||
raise ValueError('A rollout path must be provided.')
|
||||
rollout_estimator = tf.estimator.Estimator(
|
||||
rollout_estimator = tf_estimator.Estimator(
|
||||
get_rollout_estimator_fn(FLAGS.data_path, FLAGS.noise_std),
|
||||
model_dir=FLAGS.model_path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user