mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-31 21:15:21 +08:00
Add a colab for generating figures.
Export training curves to file and fix some inconsistencies. PiperOrigin-RevId: 324825810
This commit is contained in:
committed by
Diego de Las Casas
parent
99aaa6930a
commit
60550a5bc6
@@ -135,7 +135,7 @@ class EnvironmentWithKeyboard(dm_env.Environment):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Terminate option.
|
# Terminate option.
|
||||||
if self._compute_reward(option, action_step.observation) > 0:
|
if self._should_terminate(option, action_step.observation):
|
||||||
break
|
break
|
||||||
|
|
||||||
if not self._call_and_return:
|
if not self._call_and_return:
|
||||||
@@ -143,6 +143,16 @@ class EnvironmentWithKeyboard(dm_env.Environment):
|
|||||||
|
|
||||||
return option_step
|
return option_step
|
||||||
|
|
||||||
|
def _should_terminate(self, option, obs):
|
||||||
|
if self._compute_reward(option, obs) > 0:
|
||||||
|
return True
|
||||||
|
elif np.all(self._options_np[option] <= 0):
|
||||||
|
# TODO(shaobohou) A hack ensure option with non-positive weights
|
||||||
|
# terminates after one step
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def action_spec(self):
|
def action_spec(self):
|
||||||
return dm_env.specs.DiscreteArray(
|
return dm_env.specs.DiscreteArray(
|
||||||
num_values=self._options_np.shape[0], name="action")
|
num_values=self._options_np.shape[0], name="action")
|
||||||
@@ -228,7 +238,7 @@ class EnvironmentWithKeyboardDirect(dm_env.Environment):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Terminate option.
|
# Terminate option.
|
||||||
if self._compute_reward(option, action_step.observation) > 0:
|
if self._should_terminate(option, action_step.observation):
|
||||||
break
|
break
|
||||||
|
|
||||||
if not self._call_and_return:
|
if not self._call_and_return:
|
||||||
@@ -236,6 +246,16 @@ class EnvironmentWithKeyboardDirect(dm_env.Environment):
|
|||||||
|
|
||||||
return option_step
|
return option_step
|
||||||
|
|
||||||
|
def _should_terminate(self, option, obs):
|
||||||
|
if self._compute_reward(option, obs) > 0:
|
||||||
|
return True
|
||||||
|
elif np.all(option <= 0):
|
||||||
|
# TODO(shaobohou) A hack ensure option with non-positive weights
|
||||||
|
# terminates after one step
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def action_spec(self):
|
def action_spec(self):
|
||||||
return dm_env.specs.BoundedArray(shape=(self._keyboard.num_cumulants,),
|
return dm_env.specs.BoundedArray(shape=(self._keyboard.num_cumulants,),
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
@@ -280,10 +300,7 @@ def _discretize_actions(num_actions_per_dim,
|
|||||||
|
|
||||||
# Remove options with all zeros.
|
# Remove options with all zeros.
|
||||||
non_zero_entries = np.sum(np.square(discretized_actions), axis=-1) != 0.0
|
non_zero_entries = np.sum(np.square(discretized_actions), axis=-1) != 0.0
|
||||||
# Remove options with no positive elements.
|
discretized_actions = discretized_actions[non_zero_entries]
|
||||||
non_negative_entries = np.any(discretized_actions > 0, axis=-1)
|
|
||||||
discretized_actions = discretized_actions[np.logical_and(
|
|
||||||
non_zero_entries, non_negative_entries)]
|
|
||||||
logging.info("Total number of discretized actions: %s",
|
logging.info("Total number of discretized actions: %s",
|
||||||
len(discretized_actions))
|
len(discretized_actions))
|
||||||
logging.info("Discretized actions: %s", discretized_actions)
|
logging.info("Discretized actions: %s", discretized_actions)
|
||||||
|
|||||||
@@ -16,7 +16,10 @@
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""A simple training loop."""
|
"""A simple training loop."""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
|
from tensorflow.compat.v1.io import gfile
|
||||||
|
|
||||||
|
|
||||||
def _ema(base, val, decay=0.995):
|
def _ema(base, val, decay=0.995):
|
||||||
@@ -32,31 +35,42 @@ def run(env, agent, num_episodes, report_every=200, num_eval_reps=1):
|
|||||||
num_episodes: Number of episodes to train for.
|
num_episodes: Number of episodes to train for.
|
||||||
report_every: Frequency at which training progress are reported (episodes).
|
report_every: Frequency at which training progress are reported (episodes).
|
||||||
num_eval_reps: Number of eval episodes to run per training episode.
|
num_eval_reps: Number of eval episodes to run per training episode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of dicts containing training and evaluation returns, and a list of
|
||||||
|
reported returns smoothed by EMA.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
train_returns = []
|
returns = []
|
||||||
|
logged_returns = []
|
||||||
train_return_ema = 0.
|
train_return_ema = 0.
|
||||||
eval_returns = []
|
|
||||||
eval_return_ema = 0.
|
eval_return_ema = 0.
|
||||||
for episode_id in range(num_episodes):
|
for episode in range(num_episodes):
|
||||||
|
returns.append(dict(episode=episode))
|
||||||
|
|
||||||
# Run a training episode.
|
# Run a training episode.
|
||||||
train_episode_return = run_episode(env, agent, is_training=True)
|
train_episode_return = run_episode(env, agent, is_training=True)
|
||||||
train_returns.append(train_episode_return)
|
|
||||||
train_return_ema = _ema(train_return_ema, train_episode_return)
|
train_return_ema = _ema(train_return_ema, train_episode_return)
|
||||||
|
returns[-1]["train"] = train_episode_return
|
||||||
|
|
||||||
# Run an evaluation episode.
|
# Run an evaluation episode.
|
||||||
|
returns[-1]["eval"] = []
|
||||||
for _ in range(num_eval_reps):
|
for _ in range(num_eval_reps):
|
||||||
eval_episode_return = run_episode(env, agent, is_training=False)
|
eval_episode_return = run_episode(env, agent, is_training=False)
|
||||||
eval_returns.append(eval_episode_return)
|
|
||||||
eval_return_ema = _ema(eval_return_ema, eval_episode_return)
|
eval_return_ema = _ema(eval_return_ema, eval_episode_return)
|
||||||
|
returns[-1]["eval"].append(eval_episode_return)
|
||||||
|
|
||||||
if ((episode_id + 1) % report_every) == 0:
|
if ((episode + 1) % report_every) == 0 or episode == 0:
|
||||||
|
logged_returns.append(
|
||||||
|
dict(episode=episode, train=train_return_ema, eval=[eval_return_ema]))
|
||||||
logging.info("Episode %s, avg train return %.3f, avg eval return %.3f",
|
logging.info("Episode %s, avg train return %.3f, avg eval return %.3f",
|
||||||
episode_id + 1, train_return_ema, eval_return_ema)
|
episode + 1, train_return_ema, eval_return_ema)
|
||||||
if hasattr(agent, "get_logs"):
|
if hasattr(agent, "get_logs"):
|
||||||
logging.info("Episode %s, agent logs: %s", episode_id + 1,
|
logging.info("Episode %s, agent logs: %s", episode + 1,
|
||||||
agent.get_logs())
|
agent.get_logs())
|
||||||
|
|
||||||
|
return returns, logged_returns
|
||||||
|
|
||||||
|
|
||||||
def run_episode(environment, agent, is_training=False):
|
def run_episode(environment, agent, is_training=False):
|
||||||
"""Run a single episode."""
|
"""Run a single episode."""
|
||||||
@@ -75,3 +89,14 @@ def run_episode(environment, agent, is_training=False):
|
|||||||
episode_return = environment.episode_return
|
episode_return = environment.episode_return
|
||||||
|
|
||||||
return episode_return
|
return episode_return
|
||||||
|
|
||||||
|
|
||||||
|
def write_returns_to_file(path, returns):
|
||||||
|
"""Write returns to file."""
|
||||||
|
|
||||||
|
with gfile.GFile(path, "w") as file:
|
||||||
|
writer = csv.writer(file, delimiter=" ", quoting=csv.QUOTE_MINIMAL)
|
||||||
|
writer.writerow(["episode", "train"] +
|
||||||
|
[f"eval_{idx}" for idx in range(len(returns[0]["eval"]))])
|
||||||
|
for row in returns:
|
||||||
|
writer.writerow([row["episode"], row["train"]] + row["eval"])
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ python3 train_keyboard.py --logtostderr --policy_weights_name=5
|
|||||||
|
|
||||||
Then generate the polar plot data as follows:
|
Then generate the polar plot data as follows:
|
||||||
|
|
||||||
python3 eval_keyboard_fig5a.py --logtostderr \
|
python3 eval_keyboard_fig5.py --logtostderr \
|
||||||
--keyboard_paths=/tmp/option_keyboard/keyboard_12/tfhub,/tmp/option_keyboard/keyboard_34/tfhub,/tmp/option_keyboard/keyboard_5/tfhub \
|
--keyboard_paths=/tmp/option_keyboard/keyboard_12/tfhub,/tmp/option_keyboard/keyboard_34/tfhub,/tmp/option_keyboard/keyboard_5/tfhub \
|
||||||
--num_episodes=1000
|
--num_episodes=1000
|
||||||
|
|
||||||
@@ -57,11 +57,14 @@ Example outout:
|
|||||||
[ 0.099 0.349 0.055 ]]
|
[ 0.099 0.349 0.055 ]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
from tensorflow.compat.v1.io import gfile
|
||||||
import tensorflow_hub as hub
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
from option_keyboard import configs
|
from option_keyboard import configs
|
||||||
@@ -75,20 +78,12 @@ from option_keyboard.gpe_gpi_experiments import regressed_agent
|
|||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||||
flags.DEFINE_list("keyboard_paths", [], "Path to keyboard model.")
|
flags.DEFINE_list("keyboard_paths", [], "Path to keyboard model.")
|
||||||
|
flags.DEFINE_string("output_path", None, "Path to write out returns.")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_keyboard(keyboard_path):
|
def evaluate_keyboard(keyboard_path, weights_to_sweep):
|
||||||
"""Evaluate a keyboard."""
|
"""Evaluate a keyboard."""
|
||||||
|
|
||||||
angles_to_sweep = np.deg2rad(np.linspace(-90, 180, num=19, endpoint=True))
|
|
||||||
weights_to_sweep = np.stack(
|
|
||||||
[np.cos(angles_to_sweep),
|
|
||||||
np.sin(angles_to_sweep)], axis=-1)
|
|
||||||
weights_to_sweep /= np.sum(
|
|
||||||
np.maximum(weights_to_sweep, 0.0), axis=-1, keepdims=True)
|
|
||||||
weights_to_sweep = np.clip(weights_to_sweep, -1000, 1000)
|
|
||||||
tf.logging.info(weights_to_sweep)
|
|
||||||
|
|
||||||
# Load the keyboard.
|
# Load the keyboard.
|
||||||
keyboard = smart_module.SmartModuleImport(hub.Module(keyboard_path))
|
keyboard = smart_module.SmartModuleImport(hub.Module(keyboard_path))
|
||||||
|
|
||||||
@@ -124,20 +119,41 @@ def evaluate_keyboard(keyboard_path):
|
|||||||
f"{FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
f"{FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
||||||
all_returns.append(returns)
|
all_returns.append(returns)
|
||||||
|
|
||||||
return all_returns, weights_to_sweep
|
return all_returns
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
del argv
|
del argv
|
||||||
|
|
||||||
|
angles_to_sweep = np.deg2rad(np.linspace(-90, 180, num=19, endpoint=True))
|
||||||
|
weights_to_sweep = np.stack(
|
||||||
|
[np.sin(angles_to_sweep),
|
||||||
|
np.cos(angles_to_sweep)], axis=-1)
|
||||||
|
weights_to_sweep /= np.sum(
|
||||||
|
np.maximum(weights_to_sweep, 0.0), axis=-1, keepdims=True)
|
||||||
|
weights_to_sweep = np.clip(weights_to_sweep, -1000, 1000)
|
||||||
|
tf.logging.info(weights_to_sweep)
|
||||||
|
|
||||||
all_returns = []
|
all_returns = []
|
||||||
for keyboard_path in FLAGS.keyboard_paths:
|
for keyboard_path in FLAGS.keyboard_paths:
|
||||||
returns, _ = evaluate_keyboard(keyboard_path)
|
returns = evaluate_keyboard(keyboard_path, weights_to_sweep)
|
||||||
all_returns.append(returns)
|
all_returns.append(returns)
|
||||||
|
|
||||||
print("Results:")
|
print("Results:")
|
||||||
print(np.mean(all_returns, axis=-1).T)
|
print(np.mean(all_returns, axis=-1).T)
|
||||||
|
|
||||||
|
if FLAGS.output_path:
|
||||||
|
with gfile.GFile(FLAGS.output_path, "w") as file:
|
||||||
|
writer = csv.writer(file, delimiter=" ", quoting=csv.QUOTE_MINIMAL)
|
||||||
|
writer.writerow(["angle", "return", "idx"])
|
||||||
|
for idx, returns in enumerate(all_returns):
|
||||||
|
for row in np.array(returns).T.tolist():
|
||||||
|
assert len(angles_to_sweep) == len(row)
|
||||||
|
for ang, val in zip(angles_to_sweep, row):
|
||||||
|
ang = "{:.4g}".format(ang)
|
||||||
|
val = "{:.4g}".format(val)
|
||||||
|
writer.writerow([ang, val, idx])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.disable_v2_behavior()
|
tf.disable_v2_behavior()
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -32,6 +32,9 @@ from option_keyboard import scavenger
|
|||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||||
|
flags.DEFINE_integer("report_every", 5,
|
||||||
|
"Frequency at which metrics are reported.")
|
||||||
|
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
@@ -56,7 +59,13 @@ def main(argv):
|
|||||||
optimizer_name="AdamOptimizer",
|
optimizer_name="AdamOptimizer",
|
||||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||||
|
|
||||||
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
_, ema_returns = experiment.run(
|
||||||
|
env,
|
||||||
|
agent,
|
||||||
|
num_episodes=FLAGS.num_episodes,
|
||||||
|
report_every=FLAGS.report_every)
|
||||||
|
if FLAGS.output_path:
|
||||||
|
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -32,7 +32,10 @@ from option_keyboard import scavenger
|
|||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||||
flags.DEFINE_list("test_w", [], "The w to test.")
|
flags.DEFINE_list("test_w", None, "The w to test.")
|
||||||
|
flags.DEFINE_integer("report_every", 200,
|
||||||
|
"Frequency at which metrics are reported.")
|
||||||
|
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
@@ -58,7 +61,13 @@ def main(argv):
|
|||||||
optimizer_name="AdamOptimizer",
|
optimizer_name="AdamOptimizer",
|
||||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||||
|
|
||||||
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
_, ema_returns = experiment.run(
|
||||||
|
env,
|
||||||
|
agent,
|
||||||
|
num_episodes=FLAGS.num_episodes,
|
||||||
|
report_every=FLAGS.report_every)
|
||||||
|
if FLAGS.output_path:
|
||||||
|
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -0,0 +1,97 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
r"""Run an experiment.
|
||||||
|
|
||||||
|
Run GPE/GPI on task (1, -1) with w obtained by regression.
|
||||||
|
|
||||||
|
|
||||||
|
For example, first train a keyboard:
|
||||||
|
|
||||||
|
python3 train_keyboard.py -- --logtostderr --policy_weights_name=12 \
|
||||||
|
--export_path=/tmp/option_keyboard/keyboard
|
||||||
|
|
||||||
|
|
||||||
|
Then, evaluate the keyboard with w by regression.
|
||||||
|
|
||||||
|
python3 run_regressed_w_fig4b.py -- --logtostderr \
|
||||||
|
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||||
|
"""
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
|
from option_keyboard import configs
|
||||||
|
from option_keyboard import environment_wrappers
|
||||||
|
from option_keyboard import experiment
|
||||||
|
from option_keyboard import scavenger
|
||||||
|
from option_keyboard import smart_module
|
||||||
|
|
||||||
|
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_integer("num_episodes", 4000, "Number of training episodes.")
|
||||||
|
flags.DEFINE_integer("report_every", 5,
|
||||||
|
"Frequency at which metrics are reported.")
|
||||||
|
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||||
|
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
# Load the keyboard.
|
||||||
|
keyboard = smart_module.SmartModuleImport(hub.Module(FLAGS.keyboard_path))
|
||||||
|
|
||||||
|
# Create the task environment.
|
||||||
|
base_env_config = configs.get_fig4_task_config()
|
||||||
|
base_env = scavenger.Scavenger(**base_env_config)
|
||||||
|
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||||
|
|
||||||
|
# Wrap the task environment with the keyboard.
|
||||||
|
additional_discount = 0.9
|
||||||
|
env = environment_wrappers.EnvironmentWithKeyboardDirect(
|
||||||
|
env=base_env,
|
||||||
|
keyboard=keyboard,
|
||||||
|
keyboard_ckpt_path=None,
|
||||||
|
additional_discount=additional_discount,
|
||||||
|
call_and_return=False)
|
||||||
|
|
||||||
|
# Create the player agent.
|
||||||
|
agent = regressed_agent.Agent(
|
||||||
|
batch_size=10,
|
||||||
|
optimizer_name="AdamOptimizer",
|
||||||
|
optimizer_kwargs=dict(learning_rate=3e-2,),
|
||||||
|
init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, ema_returns = experiment.run(
|
||||||
|
env,
|
||||||
|
agent,
|
||||||
|
num_episodes=FLAGS.num_episodes,
|
||||||
|
report_every=FLAGS.report_every,
|
||||||
|
num_eval_reps=20)
|
||||||
|
if FLAGS.output_path:
|
||||||
|
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.disable_v2_behavior()
|
||||||
|
app.run(main)
|
||||||
+10
-5
@@ -27,7 +27,7 @@ python3 train_keyboard.py -- --logtostderr --policy_weights_name=12 \
|
|||||||
|
|
||||||
Then, evaluate the keyboard with w by regression.
|
Then, evaluate the keyboard with w by regression.
|
||||||
|
|
||||||
python3 run_regressed_w_fig4.py -- --logtostderr \
|
python3 run_regressed_w_fig4c.py -- --logtostderr \
|
||||||
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -47,8 +47,11 @@ from option_keyboard import smart_module
|
|||||||
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
flags.DEFINE_integer("num_episodes", 100, "Number of training episodes.")
|
||||||
|
flags.DEFINE_integer("report_every", 1,
|
||||||
|
"Frequency at which metrics are reported.")
|
||||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||||
|
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
@@ -75,16 +78,18 @@ def main(argv):
|
|||||||
agent = regressed_agent.Agent(
|
agent = regressed_agent.Agent(
|
||||||
batch_size=10,
|
batch_size=10,
|
||||||
optimizer_name="AdamOptimizer",
|
optimizer_name="AdamOptimizer",
|
||||||
optimizer_kwargs=dict(learning_rate=1e-1,),
|
optimizer_kwargs=dict(learning_rate=3e-2,),
|
||||||
init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
|
init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
experiment.run(
|
_, ema_returns = experiment.run(
|
||||||
env,
|
env,
|
||||||
agent,
|
agent,
|
||||||
num_episodes=FLAGS.num_episodes,
|
num_episodes=FLAGS.num_episodes,
|
||||||
report_every=2,
|
report_every=FLAGS.report_every,
|
||||||
num_eval_reps=100)
|
num_eval_reps=100)
|
||||||
|
if FLAGS.output_path:
|
||||||
|
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -35,7 +35,7 @@ python3 train_keyboard_with_phi.py -- --logtostderr \
|
|||||||
|
|
||||||
Finally, evaluate the keyboard with w by regression.
|
Finally, evaluate the keyboard with w by regression.
|
||||||
|
|
||||||
python3 run_regressed_w_with_phi_fig4b.py -- --logtostderr \
|
python3 run_regressed_w_with_phi_fig4c.py -- --logtostderr \
|
||||||
--phi_model_path=/tmp/option_keyboard/phi_model_3d \
|
--phi_model_path=/tmp/option_keyboard/phi_model_3d \
|
||||||
--keyboard_path=/tmp/option_keyboard/keyboard_3d/tfhub
|
--keyboard_path=/tmp/option_keyboard/keyboard_3d/tfhub
|
||||||
"""
|
"""
|
||||||
@@ -56,9 +56,12 @@ from option_keyboard import smart_module
|
|||||||
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
flags.DEFINE_integer("num_episodes", 100, "Number of training episodes.")
|
||||||
|
flags.DEFINE_integer("report_every", 1,
|
||||||
|
"Frequency at which metrics are reported.")
|
||||||
flags.DEFINE_string("phi_model_path", None, "Path to phi model.")
|
flags.DEFINE_string("phi_model_path", None, "Path to phi model.")
|
||||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||||
|
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
@@ -88,16 +91,18 @@ def main(argv):
|
|||||||
agent = regressed_agent.Agent(
|
agent = regressed_agent.Agent(
|
||||||
batch_size=10,
|
batch_size=10,
|
||||||
optimizer_name="AdamOptimizer",
|
optimizer_name="AdamOptimizer",
|
||||||
optimizer_kwargs=dict(learning_rate=1e-1,),
|
optimizer_kwargs=dict(learning_rate=3e-2,),
|
||||||
init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
|
init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
experiment.run(
|
_, ema_returns = experiment.run(
|
||||||
env,
|
env,
|
||||||
agent,
|
agent,
|
||||||
num_episodes=FLAGS.num_episodes,
|
num_episodes=FLAGS.num_episodes,
|
||||||
report_every=2,
|
report_every=FLAGS.report_every,
|
||||||
num_eval_reps=100)
|
num_eval_reps=100)
|
||||||
|
if FLAGS.output_path:
|
||||||
|
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -30,11 +30,14 @@ python3 run_true_w_fig4.py -- --logtostderr \
|
|||||||
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
from tensorflow.compat.v1.io import gfile
|
||||||
import tensorflow_hub as hub
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
from option_keyboard import configs
|
from option_keyboard import configs
|
||||||
@@ -48,6 +51,7 @@ from option_keyboard.gpe_gpi_experiments import regressed_agent
|
|||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||||
|
flags.DEFINE_string("output_path", None, "Path to write out returns.")
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
@@ -86,6 +90,13 @@ def main(argv):
|
|||||||
f"Avg. return over {FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
f"Avg. return over {FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
||||||
tf.logging.info("#" * 80)
|
tf.logging.info("#" * 80)
|
||||||
|
|
||||||
|
if FLAGS.output_path:
|
||||||
|
with gfile.GFile(FLAGS.output_path, "w") as file:
|
||||||
|
writer = csv.writer(file, delimiter=" ", quoting=csv.QUOTE_MINIMAL)
|
||||||
|
writer.writerow(["return"])
|
||||||
|
for val in returns:
|
||||||
|
writer.writerow([val])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.disable_v2_behavior()
|
tf.disable_v2_behavior()
|
||||||
|
|||||||
@@ -26,15 +26,18 @@ python3 train_keyboard.py -- --logtostderr --policy_weights_name=12
|
|||||||
|
|
||||||
Then, evaluate the keyboard with a fixed w.
|
Then, evaluate the keyboard with a fixed w.
|
||||||
|
|
||||||
python3 run_true_w_fig4.py -- --logtostderr \
|
python3 run_true_w_fig6.py -- --logtostderr \
|
||||||
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
from tensorflow.compat.v1.io import gfile
|
||||||
import tensorflow_hub as hub
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
from option_keyboard import configs
|
from option_keyboard import configs
|
||||||
@@ -48,7 +51,8 @@ from option_keyboard.gpe_gpi_experiments import regressed_agent
|
|||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||||
flags.DEFINE_list("test_w", [], "The w to test.")
|
flags.DEFINE_list("test_w", None, "The w to test.")
|
||||||
|
flags.DEFINE_string("output_path", None, "Path to write out returns.")
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
@@ -87,6 +91,13 @@ def main(argv):
|
|||||||
f"Avg. return over {FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
f"Avg. return over {FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
||||||
tf.logging.info("#" * 80)
|
tf.logging.info("#" * 80)
|
||||||
|
|
||||||
|
if FLAGS.output_path:
|
||||||
|
with gfile.GFile(FLAGS.output_path, "w") as file:
|
||||||
|
writer = csv.writer(file, delimiter=" ", quoting=csv.QUOTE_MINIMAL)
|
||||||
|
writer.writerow(["return"])
|
||||||
|
for val in returns:
|
||||||
|
writer.writerow([val])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.disable_v2_behavior()
|
tf.disable_v2_behavior()
|
||||||
|
|||||||
@@ -139,10 +139,8 @@ def main(argv):
|
|||||||
tasks = [
|
tasks = [
|
||||||
[1.0, 0.0],
|
[1.0, 0.0],
|
||||||
[0.0, 1.0],
|
[0.0, 1.0],
|
||||||
[-1.0, 0.0],
|
[1.0, 1.0],
|
||||||
[0.0, -1.0],
|
[-1.0, 1.0],
|
||||||
[0.7, 0.3],
|
|
||||||
[-0.3, -0.7],
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if FLAGS.normalisation == "L1":
|
if FLAGS.normalisation == "L1":
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ from option_keyboard import scavenger
|
|||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||||
|
flags.DEFINE_integer("report_every", 200,
|
||||||
|
"Frequency at which metrics are reported.")
|
||||||
|
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
@@ -53,7 +56,13 @@ def main(argv):
|
|||||||
optimizer_name="AdamOptimizer",
|
optimizer_name="AdamOptimizer",
|
||||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||||
|
|
||||||
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
_, ema_returns = experiment.run(
|
||||||
|
env,
|
||||||
|
agent,
|
||||||
|
num_episodes=FLAGS.num_episodes,
|
||||||
|
report_every=FLAGS.report_every)
|
||||||
|
if FLAGS.output_path:
|
||||||
|
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -36,7 +36,10 @@ FLAGS = flags.FLAGS
|
|||||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||||
flags.DEFINE_integer("num_pretrain_episodes", 20000,
|
flags.DEFINE_integer("num_pretrain_episodes", 20000,
|
||||||
"Number of pretraining episodes.")
|
"Number of pretraining episodes.")
|
||||||
|
flags.DEFINE_integer("report_every", 200,
|
||||||
|
"Frequency at which metrics are reported.")
|
||||||
flags.DEFINE_string("keyboard_path", None, "Path to pretrained keyboard model.")
|
flags.DEFINE_string("keyboard_path", None, "Path to pretrained keyboard model.")
|
||||||
|
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
@@ -84,7 +87,13 @@ def main(argv):
|
|||||||
optimizer_name="AdamOptimizer",
|
optimizer_name="AdamOptimizer",
|
||||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||||
|
|
||||||
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
_, ema_returns = experiment.run(
|
||||||
|
env,
|
||||||
|
agent,
|
||||||
|
num_episodes=FLAGS.num_episodes,
|
||||||
|
report_every=FLAGS.report_every)
|
||||||
|
if FLAGS.output_path:
|
||||||
|
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -17,11 +17,10 @@
|
|||||||
"""Smart module export/import utilities."""
|
"""Smart module export/import utilities."""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
|
||||||
import pickle
|
import pickle
|
||||||
import shutil
|
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
from tensorflow.compat.v1.io import gfile
|
||||||
import tensorflow_hub as hub
|
import tensorflow_hub as hub
|
||||||
import tree as nest
|
import tree as nest
|
||||||
import wrapt
|
import wrapt
|
||||||
@@ -164,9 +163,9 @@ class SmartModuleExport(object):
|
|||||||
module_session.run(
|
module_session.run(
|
||||||
assign_ops, feed_dict=dict(zip(assign_phs, module_weights)))
|
assign_ops, feed_dict=dict(zip(assign_phs, module_weights)))
|
||||||
|
|
||||||
if overwrite and os.path.exists(path):
|
if overwrite and gfile.exists(path):
|
||||||
shutil.rmtree(path)
|
gfile.rmtree(path)
|
||||||
os.makedirs(path)
|
gfile.makedirs(path)
|
||||||
hub_module.export(path, module_session)
|
hub_module.export(path, module_session)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user