mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-01 05:24:03 +08:00
Add GPE/GPI experiments.
PiperOrigin-RevId: 323750949
This commit is contained in:
committed by
Diego de Las Casas
parent
59c0cf5044
commit
a24bda5ed0
@@ -21,7 +21,7 @@ def get_task_config():
|
|||||||
return dict(
|
return dict(
|
||||||
arena_size=11,
|
arena_size=11,
|
||||||
num_channels=2,
|
num_channels=2,
|
||||||
max_num_steps=50, # 5o for the actual task.
|
max_num_steps=50, # 50 for the actual task.
|
||||||
num_init_objects=10,
|
num_init_objects=10,
|
||||||
object_priors=[0.5, 0.5],
|
object_priors=[0.5, 0.5],
|
||||||
egocentric=True,
|
egocentric=True,
|
||||||
@@ -39,3 +39,27 @@ def get_pretrain_config():
|
|||||||
egocentric=True,
|
egocentric=True,
|
||||||
default_w=(1, 1),
|
default_w=(1, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_fig4_task_config():
|
||||||
|
return dict(
|
||||||
|
arena_size=11,
|
||||||
|
num_channels=2,
|
||||||
|
max_num_steps=50, # 50 for the actual task.
|
||||||
|
num_init_objects=10,
|
||||||
|
object_priors=[0.5, 0.5],
|
||||||
|
egocentric=True,
|
||||||
|
default_w=(1, -1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_fig5_task_config(default_w):
|
||||||
|
return dict(
|
||||||
|
arena_size=11,
|
||||||
|
num_channels=2,
|
||||||
|
max_num_steps=50, # 50 for the actual task.
|
||||||
|
num_init_objects=10,
|
||||||
|
object_priors=[0.5, 0.5],
|
||||||
|
egocentric=True,
|
||||||
|
default_w=default_w,
|
||||||
|
)
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ import dm_env
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
import tree
|
||||||
|
|
||||||
|
from option_keyboard import smart_module
|
||||||
|
|
||||||
|
|
||||||
class EnvironmentWithLogging(dm_env.Environment):
|
class EnvironmentWithLogging(dm_env.Environment):
|
||||||
@@ -93,6 +97,7 @@ class EnvironmentWithKeyboard(dm_env.Environment):
|
|||||||
self._keyboard(tf.expand_dims(obs_ph, axis=0))[0], [obs_ph])
|
self._keyboard(tf.expand_dims(obs_ph, axis=0))[0], [obs_ph])
|
||||||
session.run(tf.global_variables_initializer())
|
session.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
if keyboard_ckpt_path:
|
||||||
saver = tf.train.Saver(var_list=keyboard.variables)
|
saver = tf.train.Saver(var_list=keyboard.variables)
|
||||||
saver.restore(session, keyboard_ckpt_path)
|
saver.restore(session, keyboard_ckpt_path)
|
||||||
|
|
||||||
@@ -152,6 +157,102 @@ class EnvironmentWithKeyboard(dm_env.Environment):
|
|||||||
return getattr(self._env, name)
|
return getattr(self._env, name)
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentWithKeyboardDirect(dm_env.Environment):
|
||||||
|
"""Wraps an environment with a keyboard.
|
||||||
|
|
||||||
|
This is different from EnvironmentWithKeyboard as the actions space is not
|
||||||
|
discretized.
|
||||||
|
|
||||||
|
TODO(shaobohou) Merge the two implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
env,
|
||||||
|
keyboard,
|
||||||
|
keyboard_ckpt_path,
|
||||||
|
additional_discount,
|
||||||
|
call_and_return=False):
|
||||||
|
self._env = env
|
||||||
|
self._keyboard = keyboard
|
||||||
|
self._discount = additional_discount
|
||||||
|
self._call_and_return = call_and_return
|
||||||
|
|
||||||
|
obs_spec = self._extract_observation(env.observation_spec())
|
||||||
|
obs_ph = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
|
||||||
|
option_ph = tf.placeholder(
|
||||||
|
shape=(keyboard.num_cumulants,), dtype=tf.float32)
|
||||||
|
gpi_action = self._keyboard.gpi(obs_ph, option_ph)
|
||||||
|
|
||||||
|
session = tf.Session()
|
||||||
|
self._gpi_action = session.make_callable(gpi_action, [obs_ph, option_ph])
|
||||||
|
self._keyboard_action = session.make_callable(
|
||||||
|
self._keyboard(tf.expand_dims(obs_ph, axis=0))[0], [obs_ph])
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
if keyboard_ckpt_path:
|
||||||
|
saver = tf.train.Saver(var_list=keyboard.variables)
|
||||||
|
saver.restore(session, keyboard_ckpt_path)
|
||||||
|
|
||||||
|
def _compute_reward(self, option, obs):
|
||||||
|
assert option.shape == obs["cumulants"].shape
|
||||||
|
return np.sum(option * obs["cumulants"])
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return self._env.reset()
|
||||||
|
|
||||||
|
def step(self, option):
|
||||||
|
"""Take a step in the keyboard, then the environment."""
|
||||||
|
|
||||||
|
step_count = 0
|
||||||
|
option_step = None
|
||||||
|
while True:
|
||||||
|
obs = self._extract_observation(self._env.observation())
|
||||||
|
action = self._gpi_action(obs, option)
|
||||||
|
action_step = self._env.step(action)
|
||||||
|
step_count += 1
|
||||||
|
|
||||||
|
if option_step is None:
|
||||||
|
option_step = action_step
|
||||||
|
else:
|
||||||
|
new_discount = (
|
||||||
|
option_step.discount * self._discount * action_step.discount)
|
||||||
|
new_reward = (
|
||||||
|
option_step.reward + new_discount * action_step.reward)
|
||||||
|
option_step = option_step._replace(
|
||||||
|
observation=action_step.observation,
|
||||||
|
reward=new_reward,
|
||||||
|
discount=new_discount,
|
||||||
|
step_type=action_step.step_type)
|
||||||
|
|
||||||
|
if action_step.last():
|
||||||
|
break
|
||||||
|
|
||||||
|
# Terminate option.
|
||||||
|
if self._compute_reward(option, action_step.observation) > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not self._call_and_return:
|
||||||
|
break
|
||||||
|
|
||||||
|
return option_step
|
||||||
|
|
||||||
|
def action_spec(self):
|
||||||
|
return dm_env.specs.BoundedArray(shape=(self._keyboard.num_cumulants,),
|
||||||
|
dtype=np.float32,
|
||||||
|
minimum=-1.0,
|
||||||
|
maximum=1.0,
|
||||||
|
name="action")
|
||||||
|
|
||||||
|
def _extract_observation(self, obs):
|
||||||
|
return obs["arena"]
|
||||||
|
|
||||||
|
def observation_spec(self):
|
||||||
|
return self._env.observation_spec()
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._env, name)
|
||||||
|
|
||||||
|
|
||||||
def _discretize_actions(num_actions_per_dim,
|
def _discretize_actions(num_actions_per_dim,
|
||||||
action_space_dim,
|
action_space_dim,
|
||||||
min_val=-1.0,
|
min_val=-1.0,
|
||||||
@@ -188,3 +289,71 @@ def _discretize_actions(num_actions_per_dim,
|
|||||||
logging.info("Discretized actions: %s", discretized_actions)
|
logging.info("Discretized actions: %s", discretized_actions)
|
||||||
|
|
||||||
return discretized_actions
|
return discretized_actions
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentWithLearnedPhi(dm_env.Environment):
|
||||||
|
"""Wraps an environment with learned phi model."""
|
||||||
|
|
||||||
|
def __init__(self, env, model_path):
|
||||||
|
self._env = env
|
||||||
|
|
||||||
|
create_ph = lambda x: tf.placeholder(shape=x.shape, dtype=x.dtype)
|
||||||
|
add_batch = lambda x: tf.expand_dims(x, axis=0)
|
||||||
|
|
||||||
|
# Make session and callables.
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
model = smart_module.SmartModuleImport(hub.Module(model_path))
|
||||||
|
|
||||||
|
obs_spec = env.observation_spec()
|
||||||
|
obs_ph = tree.map_structure(create_ph, obs_spec)
|
||||||
|
action_ph = tf.placeholder(shape=(), dtype=tf.int32)
|
||||||
|
phis = model(tree.map_structure(add_batch, obs_ph), add_batch(action_ph))
|
||||||
|
|
||||||
|
self.num_phis = phis.shape.as_list()[-1]
|
||||||
|
self._last_phis = np.zeros((self.num_phis,), dtype=np.float32)
|
||||||
|
|
||||||
|
session = tf.Session()
|
||||||
|
self._session = session
|
||||||
|
self._phis_fn = session.make_callable(
|
||||||
|
phis[0], tree.flatten([obs_ph, action_ph]))
|
||||||
|
self._session.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._last_phis = np.zeros((self.num_phis,), dtype=np.float32)
|
||||||
|
return self._env.reset()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
"""Take action in the environment and do some logging."""
|
||||||
|
|
||||||
|
phis = self._phis_fn(*tree.flatten([self._env.observation(), action]))
|
||||||
|
step = self._env.step(action)
|
||||||
|
|
||||||
|
if step.first():
|
||||||
|
phis = self._phis_fn(*tree.flatten([self._env.observation(), action]))
|
||||||
|
step = self._env.step(action)
|
||||||
|
|
||||||
|
step.observation["cumulants"] = phis
|
||||||
|
self._last_phis = phis
|
||||||
|
|
||||||
|
return step
|
||||||
|
|
||||||
|
def action_spec(self):
|
||||||
|
return self._env.action_spec()
|
||||||
|
|
||||||
|
def observation(self):
|
||||||
|
obs = self._env.observation()
|
||||||
|
obs["cumulants"] = self._last_phis
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def observation_spec(self):
|
||||||
|
obs_spec = self._env.observation_spec()
|
||||||
|
obs_spec["cumulants"] = dm_env.specs.BoundedArray(
|
||||||
|
shape=(self.num_phis,),
|
||||||
|
dtype=np.float32,
|
||||||
|
minimum=-1e9,
|
||||||
|
maximum=1e9,
|
||||||
|
name="collected_resources")
|
||||||
|
return obs_spec
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._env, name)
|
||||||
|
|||||||
@@ -18,37 +18,44 @@
|
|||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
def _ema(base, val, decay=0.995):
|
||||||
|
return base * decay + (1 - decay) * val
|
||||||
|
|
||||||
|
|
||||||
def run(environment, agent, num_episodes, report_every=200):
|
def run(env, agent, num_episodes, report_every=200, num_eval_reps=1):
|
||||||
"""Runs an agent on an environment.
|
"""Runs an agent on an environment.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
environment: The environment.
|
env: The environment.
|
||||||
agent: The agent.
|
agent: The agent.
|
||||||
num_episodes: Number of episodes to train for.
|
num_episodes: Number of episodes to train for.
|
||||||
report_every: Frequency at which training progress are reported (episodes).
|
report_every: Frequency at which training progress are reported (episodes).
|
||||||
|
num_eval_reps: Number of eval episodes to run per training episode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
train_returns = []
|
train_returns = []
|
||||||
|
train_return_ema = 0.
|
||||||
eval_returns = []
|
eval_returns = []
|
||||||
|
eval_return_ema = 0.
|
||||||
for episode_id in range(num_episodes):
|
for episode_id in range(num_episodes):
|
||||||
# Run a training episode.
|
# Run a training episode.
|
||||||
train_episode_return = run_episode(environment, agent, is_training=True)
|
train_episode_return = run_episode(env, agent, is_training=True)
|
||||||
train_returns.append(train_episode_return)
|
train_returns.append(train_episode_return)
|
||||||
|
train_return_ema = _ema(train_return_ema, train_episode_return)
|
||||||
|
|
||||||
# Run an evaluation episode.
|
# Run an evaluation episode.
|
||||||
eval_episode_return = run_episode(environment, agent, is_training=False)
|
for _ in range(num_eval_reps):
|
||||||
|
eval_episode_return = run_episode(env, agent, is_training=False)
|
||||||
eval_returns.append(eval_episode_return)
|
eval_returns.append(eval_episode_return)
|
||||||
|
eval_return_ema = _ema(eval_return_ema, eval_episode_return)
|
||||||
|
|
||||||
if ((episode_id + 1) % report_every) == 0:
|
if ((episode_id + 1) % report_every) == 0:
|
||||||
logging.info(
|
logging.info("Episode %s, avg train return %.3f, avg eval return %.3f",
|
||||||
"Episode %s, avg train return %.3f, avg eval return %.3f",
|
episode_id + 1, train_return_ema, eval_return_ema)
|
||||||
episode_id + 1,
|
if hasattr(agent, "get_logs"):
|
||||||
np.mean(train_returns[-report_every:]),
|
logging.info("Episode %s, agent logs: %s", episode_id + 1,
|
||||||
np.mean(eval_returns[-report_every:]),
|
agent.get_logs())
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_episode(environment, agent, is_training=False):
|
def run_episode(environment, agent, is_training=False):
|
||||||
|
|||||||
@@ -0,0 +1,144 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# pylint: disable=line-too-long
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
r"""Run an experiment.
|
||||||
|
|
||||||
|
This script generates the raw data for the polar plots used to visualise how
|
||||||
|
well a trained keyboard covers the space of w.
|
||||||
|
|
||||||
|
|
||||||
|
For example, train 3 separate keyboards with different base policies:
|
||||||
|
|
||||||
|
python3 train_keyboard.py --logtostderr --policy_weights_name=12
|
||||||
|
python3 train_keyboard.py --logtostderr --policy_weights_name=34
|
||||||
|
python3 train_keyboard.py --logtostderr --policy_weights_name=5
|
||||||
|
|
||||||
|
|
||||||
|
Then generate the polar plot data as follows:
|
||||||
|
|
||||||
|
python3 eval_keyboard_fig5a.py --logtostderr \
|
||||||
|
--keyboard_paths=/tmp/option_keyboard/keyboard_12/tfhub,/tmp/option_keyboard/keyboard_34/tfhub,/tmp/option_keyboard/keyboard_5/tfhub \
|
||||||
|
--num_episodes=1000
|
||||||
|
|
||||||
|
|
||||||
|
Example outout:
|
||||||
|
[[ 0.11 0.261 -0.933 ]
|
||||||
|
[ 1.302 3.955 0.54 ]
|
||||||
|
[ 2.398 4.434 1.2105359 ]
|
||||||
|
[ 3.459 4.606 2.087 ]
|
||||||
|
[ 4.09026795 4.60911325 3.06106882]
|
||||||
|
[ 4.55499485 4.71947818 3.8123229 ]
|
||||||
|
[ 4.715 4.835 4.395 ]
|
||||||
|
[ 4.75743564 4.64095528 4.46330207]
|
||||||
|
[ 4.82518207 4.71232378 4.56190708]
|
||||||
|
[ 4.831 4.7155 4.5735 ]
|
||||||
|
[ 4.78074425 4.6754641 4.58312762]
|
||||||
|
[ 4.70154374 4.5416429 4.47850417]
|
||||||
|
[ 4.694 4.631 4.427 ]
|
||||||
|
[ 4.25085125 4.56606664 3.68157677]
|
||||||
|
[ 3.61726795 4.4838453 2.68154403]
|
||||||
|
[ 2.714 4.43 1.554 ]
|
||||||
|
[ 1.69 4.505 0.9635359 ]
|
||||||
|
[ 0.894 4.043 0.424 ]
|
||||||
|
[ 0.099 0.349 0.055 ]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
|
from option_keyboard import configs
|
||||||
|
from option_keyboard import environment_wrappers
|
||||||
|
from option_keyboard import experiment
|
||||||
|
from option_keyboard import scavenger
|
||||||
|
from option_keyboard import smart_module
|
||||||
|
|
||||||
|
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||||
|
flags.DEFINE_list("keyboard_paths", [], "Path to keyboard model.")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_keyboard(keyboard_path):
|
||||||
|
"""Evaluate a keyboard."""
|
||||||
|
|
||||||
|
angles_to_sweep = np.deg2rad(np.linspace(-90, 180, num=19, endpoint=True))
|
||||||
|
weights_to_sweep = np.stack(
|
||||||
|
[np.cos(angles_to_sweep),
|
||||||
|
np.sin(angles_to_sweep)], axis=-1)
|
||||||
|
weights_to_sweep /= np.sum(
|
||||||
|
np.maximum(weights_to_sweep, 0.0), axis=-1, keepdims=True)
|
||||||
|
weights_to_sweep = np.clip(weights_to_sweep, -1000, 1000)
|
||||||
|
tf.logging.info(weights_to_sweep)
|
||||||
|
|
||||||
|
# Load the keyboard.
|
||||||
|
keyboard = smart_module.SmartModuleImport(hub.Module(keyboard_path))
|
||||||
|
|
||||||
|
# Create the task environment.
|
||||||
|
all_returns = []
|
||||||
|
for w_to_sweep in weights_to_sweep.tolist():
|
||||||
|
base_env_config = configs.get_fig5_task_config(w_to_sweep)
|
||||||
|
base_env = scavenger.Scavenger(**base_env_config)
|
||||||
|
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||||
|
|
||||||
|
# Wrap the task environment with the keyboard.
|
||||||
|
with tf.variable_scope(None, default_name="inner_loop"):
|
||||||
|
additional_discount = 0.9
|
||||||
|
env = environment_wrappers.EnvironmentWithKeyboardDirect(
|
||||||
|
env=base_env,
|
||||||
|
keyboard=keyboard,
|
||||||
|
keyboard_ckpt_path=None,
|
||||||
|
additional_discount=additional_discount,
|
||||||
|
call_and_return=False)
|
||||||
|
|
||||||
|
# Create the player agent.
|
||||||
|
agent = regressed_agent.Agent(
|
||||||
|
batch_size=10,
|
||||||
|
optimizer_name="AdamOptimizer",
|
||||||
|
# Disable training.
|
||||||
|
optimizer_kwargs=dict(learning_rate=0.0,),
|
||||||
|
init_w=w_to_sweep)
|
||||||
|
|
||||||
|
returns = []
|
||||||
|
for _ in range(FLAGS.num_episodes):
|
||||||
|
returns.append(experiment.run_episode(env, agent))
|
||||||
|
tf.logging.info(f"Task: {w_to_sweep}, mean returns over "
|
||||||
|
f"{FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
||||||
|
all_returns.append(returns)
|
||||||
|
|
||||||
|
return all_returns, weights_to_sweep
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
all_returns = []
|
||||||
|
for keyboard_path in FLAGS.keyboard_paths:
|
||||||
|
returns, _ = evaluate_keyboard(keyboard_path)
|
||||||
|
all_returns.append(returns)
|
||||||
|
|
||||||
|
print("Results:")
|
||||||
|
print(np.mean(all_returns, axis=-1).T)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.disable_v2_behavior()
|
||||||
|
app.run(main)
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Regressed agent."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
|
||||||
|
class Agent():
|
||||||
|
"""A DQN Agent."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
optimizer_name,
|
||||||
|
optimizer_kwargs,
|
||||||
|
init_w,
|
||||||
|
):
|
||||||
|
"""A simple DQN agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Size of update batch.
|
||||||
|
optimizer_name: Name of an optimizer from tf.train
|
||||||
|
optimizer_kwargs: Keyword arguments for the optimizer.
|
||||||
|
init_w: The initial cumulant weight.
|
||||||
|
"""
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self._init_w = np.array(init_w)
|
||||||
|
self._replay = []
|
||||||
|
|
||||||
|
# Regress w by gradient descent, could also use closed-form solution.
|
||||||
|
self._n_cumulants = len(init_w)
|
||||||
|
self._regressed_w = tf.get_variable(
|
||||||
|
"regressed_w",
|
||||||
|
dtype=tf.float32,
|
||||||
|
initializer=lambda: tf.to_float(init_w))
|
||||||
|
cumulants_ph = tf.placeholder(
|
||||||
|
shape=(None, self._n_cumulants), dtype=tf.float32)
|
||||||
|
rewards_ph = tf.placeholder(shape=(None,), dtype=tf.float32)
|
||||||
|
predicted_rewards = tf.reduce_sum(
|
||||||
|
tf.multiply(self._regressed_w, cumulants_ph), axis=-1)
|
||||||
|
loss = tf.reduce_sum(tf.square(predicted_rewards - rewards_ph))
|
||||||
|
|
||||||
|
with tf.variable_scope("optimizer"):
|
||||||
|
self._optimizer = getattr(tf.train, optimizer_name)(**optimizer_kwargs)
|
||||||
|
train_op = self._optimizer.minimize(loss)
|
||||||
|
|
||||||
|
# Make session and callables.
|
||||||
|
session = tf.Session()
|
||||||
|
self._update_fn = session.make_callable(train_op,
|
||||||
|
[cumulants_ph, rewards_ph])
|
||||||
|
self._action = session.make_callable(self._regressed_w.read_value(), [])
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
def step(self, timestep, is_training=False):
|
||||||
|
"""Select actions according to epsilon-greedy policy."""
|
||||||
|
del timestep
|
||||||
|
|
||||||
|
if is_training:
|
||||||
|
# Can also just use random actions at environment level.
|
||||||
|
return np.random.uniform(low=-1.0, high=1.0, size=(self._n_cumulants,))
|
||||||
|
|
||||||
|
return self._action()
|
||||||
|
|
||||||
|
def update(self, step_tm1, action, step_t):
|
||||||
|
"""Takes in a transition from the environment."""
|
||||||
|
del step_tm1, action
|
||||||
|
|
||||||
|
transition = [
|
||||||
|
step_t.observation["cumulants"],
|
||||||
|
step_t.reward,
|
||||||
|
]
|
||||||
|
self._replay.append(transition)
|
||||||
|
|
||||||
|
if len(self._replay) == self._batch_size:
|
||||||
|
batch = list(zip(*self._replay))
|
||||||
|
self._update_fn(*batch)
|
||||||
|
self._replay = [] # Just a queue.
|
||||||
|
|
||||||
|
def get_logs(self):
|
||||||
|
return dict(regressed=self._action())
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Run an experiment.
|
||||||
|
|
||||||
|
Run a q-learning agent on task (1, -1).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from option_keyboard import configs
|
||||||
|
from option_keyboard import dqn_agent
|
||||||
|
from option_keyboard import environment_wrappers
|
||||||
|
from option_keyboard import experiment
|
||||||
|
from option_keyboard import scavenger
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
# Create the task environment.
|
||||||
|
env_config = configs.get_fig4_task_config()
|
||||||
|
env = scavenger.Scavenger(**env_config)
|
||||||
|
env = environment_wrappers.EnvironmentWithLogging(env)
|
||||||
|
|
||||||
|
# Create the flat agent.
|
||||||
|
agent = dqn_agent.Agent(
|
||||||
|
obs_spec=env.observation_spec(),
|
||||||
|
action_spec=env.action_spec(),
|
||||||
|
network_kwargs=dict(
|
||||||
|
output_sizes=(64, 128),
|
||||||
|
activate_final=True,
|
||||||
|
),
|
||||||
|
epsilon=0.1,
|
||||||
|
additional_discount=0.9,
|
||||||
|
batch_size=10,
|
||||||
|
optimizer_name="AdamOptimizer",
|
||||||
|
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||||
|
|
||||||
|
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.disable_v2_behavior()
|
||||||
|
app.run(main)
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Run an experiment.
|
||||||
|
|
||||||
|
Run a q-learning agent on a task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from option_keyboard import configs
|
||||||
|
from option_keyboard import dqn_agent
|
||||||
|
from option_keyboard import environment_wrappers
|
||||||
|
from option_keyboard import experiment
|
||||||
|
from option_keyboard import scavenger
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||||
|
flags.DEFINE_list("test_w", [], "The w to test.")
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
# Create the task environment.
|
||||||
|
test_w = [float(x) for x in FLAGS.test_w]
|
||||||
|
env_config = configs.get_fig5_task_config(test_w)
|
||||||
|
env = scavenger.Scavenger(**env_config)
|
||||||
|
env = environment_wrappers.EnvironmentWithLogging(env)
|
||||||
|
|
||||||
|
# Create the flat agent.
|
||||||
|
agent = dqn_agent.Agent(
|
||||||
|
obs_spec=env.observation_spec(),
|
||||||
|
action_spec=env.action_spec(),
|
||||||
|
network_kwargs=dict(
|
||||||
|
output_sizes=(64, 128),
|
||||||
|
activate_final=True,
|
||||||
|
),
|
||||||
|
epsilon=0.1,
|
||||||
|
additional_discount=0.9,
|
||||||
|
batch_size=10,
|
||||||
|
optimizer_name="AdamOptimizer",
|
||||||
|
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||||
|
|
||||||
|
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.disable_v2_behavior()
|
||||||
|
app.run(main)
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
r"""Run an experiment.
|
||||||
|
|
||||||
|
Run GPE/GPI on task (1, -1) with w obtained by regression.
|
||||||
|
|
||||||
|
|
||||||
|
For example, first train a keyboard:
|
||||||
|
|
||||||
|
python3 train_keyboard.py -- --logtostderr --policy_weights_name=12 \
|
||||||
|
--export_path=/tmp/option_keyboard/keyboard
|
||||||
|
|
||||||
|
|
||||||
|
Then, evaluate the keyboard with w by regression.
|
||||||
|
|
||||||
|
python3 run_regressed_w_fig4.py -- --logtostderr \
|
||||||
|
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||||
|
"""
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
|
from option_keyboard import configs
|
||||||
|
from option_keyboard import environment_wrappers
|
||||||
|
from option_keyboard import experiment
|
||||||
|
from option_keyboard import scavenger
|
||||||
|
from option_keyboard import smart_module
|
||||||
|
|
||||||
|
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||||
|
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
# Load the keyboard.
|
||||||
|
keyboard = smart_module.SmartModuleImport(hub.Module(FLAGS.keyboard_path))
|
||||||
|
|
||||||
|
# Create the task environment.
|
||||||
|
base_env_config = configs.get_fig4_task_config()
|
||||||
|
base_env = scavenger.Scavenger(**base_env_config)
|
||||||
|
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||||
|
|
||||||
|
# Wrap the task environment with the keyboard.
|
||||||
|
additional_discount = 0.9
|
||||||
|
env = environment_wrappers.EnvironmentWithKeyboardDirect(
|
||||||
|
env=base_env,
|
||||||
|
keyboard=keyboard,
|
||||||
|
keyboard_ckpt_path=None,
|
||||||
|
additional_discount=additional_discount,
|
||||||
|
call_and_return=False)
|
||||||
|
|
||||||
|
# Create the player agent.
|
||||||
|
agent = regressed_agent.Agent(
|
||||||
|
batch_size=10,
|
||||||
|
optimizer_name="AdamOptimizer",
|
||||||
|
optimizer_kwargs=dict(learning_rate=1e-1,),
|
||||||
|
init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
experiment.run(
|
||||||
|
env,
|
||||||
|
agent,
|
||||||
|
num_episodes=FLAGS.num_episodes,
|
||||||
|
report_every=2,
|
||||||
|
num_eval_reps=100)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.disable_v2_behavior()
|
||||||
|
app.run(main)
|
||||||
@@ -0,0 +1,105 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
r"""Run an experiment.
|
||||||
|
|
||||||
|
Run GPE/GPI on task (1, -1) with a learned phi model and w by regression.
|
||||||
|
|
||||||
|
|
||||||
|
For example, first train a phi model with 3 dimenional phi:
|
||||||
|
|
||||||
|
python3 train_phi_model.py -- --logtostderr --use_random_tasks \
|
||||||
|
--export_path=/tmp/option_keyboard/phi_model_3d --num_phis=3
|
||||||
|
|
||||||
|
|
||||||
|
Then train a keyboard:
|
||||||
|
|
||||||
|
python3 train_keyboard_with_phi.py -- --logtostderr \
|
||||||
|
--export_path=/tmp/option_keyboard/keyboard_3d \
|
||||||
|
--phi_model_path=/tmp/option_keyboard/phi_model_3d \
|
||||||
|
--num_phis=2
|
||||||
|
|
||||||
|
|
||||||
|
Finally, evaluate the keyboard with w by regression.
|
||||||
|
|
||||||
|
python3 run_regressed_w_with_phi_fig4b.py -- --logtostderr \
|
||||||
|
--phi_model_path=/tmp/option_keyboard/phi_model_3d \
|
||||||
|
--keyboard_path=/tmp/option_keyboard/keyboard_3d/tfhub
|
||||||
|
"""
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
|
from option_keyboard import configs
|
||||||
|
from option_keyboard import environment_wrappers
|
||||||
|
from option_keyboard import experiment
|
||||||
|
from option_keyboard import scavenger
|
||||||
|
from option_keyboard import smart_module
|
||||||
|
|
||||||
|
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||||
|
flags.DEFINE_string("phi_model_path", None, "Path to phi model.")
|
||||||
|
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
# Load the keyboard.
|
||||||
|
keyboard = smart_module.SmartModuleImport(hub.Module(FLAGS.keyboard_path))
|
||||||
|
|
||||||
|
# Create the task environment.
|
||||||
|
base_env_config = configs.get_fig4_task_config()
|
||||||
|
base_env = scavenger.Scavenger(**base_env_config)
|
||||||
|
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||||
|
|
||||||
|
base_env = environment_wrappers.EnvironmentWithLearnedPhi(
|
||||||
|
base_env, FLAGS.phi_model_path)
|
||||||
|
|
||||||
|
# Wrap the task environment with the keyboard.
|
||||||
|
additional_discount = 0.9
|
||||||
|
env = environment_wrappers.EnvironmentWithKeyboardDirect(
|
||||||
|
env=base_env,
|
||||||
|
keyboard=keyboard,
|
||||||
|
keyboard_ckpt_path=None,
|
||||||
|
additional_discount=additional_discount,
|
||||||
|
call_and_return=False)
|
||||||
|
|
||||||
|
# Create the player agent.
|
||||||
|
agent = regressed_agent.Agent(
|
||||||
|
batch_size=10,
|
||||||
|
optimizer_name="AdamOptimizer",
|
||||||
|
optimizer_kwargs=dict(learning_rate=1e-1,),
|
||||||
|
init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
experiment.run(
|
||||||
|
env,
|
||||||
|
agent,
|
||||||
|
num_episodes=FLAGS.num_episodes,
|
||||||
|
report_every=2,
|
||||||
|
num_eval_reps=100)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.disable_v2_behavior()
|
||||||
|
app.run(main)
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
r"""Run an experiment.
|
||||||
|
|
||||||
|
Run GPE/GPI on task (1, -1) with the groundtruth w.
|
||||||
|
|
||||||
|
|
||||||
|
For example, first train a keyboard:
|
||||||
|
|
||||||
|
python3 train_keyboard.py -- --logtostderr --policy_weights_name=12
|
||||||
|
|
||||||
|
|
||||||
|
Then, evaluate the keyboard with groundtruth w.
|
||||||
|
|
||||||
|
python3 run_true_w_fig4.py -- --logtostderr \
|
||||||
|
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||||
|
"""
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
|
from option_keyboard import configs
|
||||||
|
from option_keyboard import environment_wrappers
|
||||||
|
from option_keyboard import experiment
|
||||||
|
from option_keyboard import scavenger
|
||||||
|
from option_keyboard import smart_module
|
||||||
|
|
||||||
|
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||||
|
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
# Load the keyboard.
|
||||||
|
keyboard = smart_module.SmartModuleImport(hub.Module(FLAGS.keyboard_path))
|
||||||
|
|
||||||
|
# Create the task environment.
|
||||||
|
base_env_config = configs.get_fig4_task_config()
|
||||||
|
base_env = scavenger.Scavenger(**base_env_config)
|
||||||
|
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||||
|
|
||||||
|
# Wrap the task environment with the keyboard.
|
||||||
|
additional_discount = 0.9
|
||||||
|
env = environment_wrappers.EnvironmentWithKeyboardDirect(
|
||||||
|
env=base_env,
|
||||||
|
keyboard=keyboard,
|
||||||
|
keyboard_ckpt_path=None,
|
||||||
|
additional_discount=additional_discount,
|
||||||
|
call_and_return=False)
|
||||||
|
|
||||||
|
# Create the player agent.
|
||||||
|
agent = regressed_agent.Agent(
|
||||||
|
batch_size=10,
|
||||||
|
optimizer_name="AdamOptimizer",
|
||||||
|
# Disable training.
|
||||||
|
optimizer_kwargs=dict(learning_rate=0.0,),
|
||||||
|
init_w=[1., -1.])
|
||||||
|
|
||||||
|
returns = []
|
||||||
|
for _ in range(FLAGS.num_episodes):
|
||||||
|
returns.append(experiment.run_episode(env, agent))
|
||||||
|
tf.logging.info("#" * 80)
|
||||||
|
tf.logging.info(
|
||||||
|
f"Avg. return over {FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
||||||
|
tf.logging.info("#" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.disable_v2_behavior()
|
||||||
|
app.run(main)
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
r"""Run an experiment.
|
||||||
|
|
||||||
|
Run GPE/GPI on the "balancing" task with a fixed w
|
||||||
|
|
||||||
|
|
||||||
|
For example, first train a keyboard:
|
||||||
|
|
||||||
|
python3 train_keyboard.py -- --logtostderr --policy_weights_name=12
|
||||||
|
|
||||||
|
|
||||||
|
Then, evaluate the keyboard with a fixed w.
|
||||||
|
|
||||||
|
python3 run_true_w_fig4.py -- --logtostderr \
|
||||||
|
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||||
|
"""
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
|
from option_keyboard import configs
|
||||||
|
from option_keyboard import environment_wrappers
|
||||||
|
from option_keyboard import experiment
|
||||||
|
from option_keyboard import scavenger
|
||||||
|
from option_keyboard import smart_module
|
||||||
|
|
||||||
|
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||||
|
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||||
|
flags.DEFINE_list("test_w", [], "The w to test.")
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
# Load the keyboard.
|
||||||
|
keyboard = smart_module.SmartModuleImport(hub.Module(FLAGS.keyboard_path))
|
||||||
|
|
||||||
|
# Create the task environment.
|
||||||
|
base_env_config = configs.get_task_config()
|
||||||
|
base_env = scavenger.Scavenger(**base_env_config)
|
||||||
|
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||||
|
|
||||||
|
# Wrap the task environment with the keyboard.
|
||||||
|
additional_discount = 0.9
|
||||||
|
env = environment_wrappers.EnvironmentWithKeyboardDirect(
|
||||||
|
env=base_env,
|
||||||
|
keyboard=keyboard,
|
||||||
|
keyboard_ckpt_path=None,
|
||||||
|
additional_discount=additional_discount,
|
||||||
|
call_and_return=False)
|
||||||
|
|
||||||
|
# Create the player agent.
|
||||||
|
agent = regressed_agent.Agent(
|
||||||
|
batch_size=10,
|
||||||
|
optimizer_name="AdamOptimizer",
|
||||||
|
# Disable training.
|
||||||
|
optimizer_kwargs=dict(learning_rate=0.0,),
|
||||||
|
init_w=[float(x) for x in FLAGS.test_w])
|
||||||
|
|
||||||
|
returns = []
|
||||||
|
for _ in range(FLAGS.num_episodes):
|
||||||
|
returns.append(experiment.run_episode(env, agent))
|
||||||
|
tf.logging.info("#" * 80)
|
||||||
|
tf.logging.info(
|
||||||
|
f"Avg. return over {FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
||||||
|
tf.logging.info("#" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.disable_v2_behavior()
|
||||||
|
app.run(main)
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Train a keyboard."""
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from option_keyboard import keyboard_utils
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_integer("num_pretrain_episodes", 20000,
|
||||||
|
"Number of pretraining episodes.")
|
||||||
|
flags.DEFINE_string("export_path", None,
|
||||||
|
"Where to save the keyboard checkpoints.")
|
||||||
|
flags.DEFINE_string("policy_weights_name", None,
|
||||||
|
"A string repsenting the policy weights.")
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
all_policy_weights = {
|
||||||
|
"1": [1., 0.],
|
||||||
|
"2": [0., 1.],
|
||||||
|
"3": [1., -1.],
|
||||||
|
"4": [-1., 1.],
|
||||||
|
"5": [1., 1.],
|
||||||
|
}
|
||||||
|
if FLAGS.policy_weights_name:
|
||||||
|
policy_weights = np.array(
|
||||||
|
[all_policy_weights[v] for v in FLAGS.policy_weights_name])
|
||||||
|
num_episodes = ((FLAGS.num_pretrain_episodes // 2) *
|
||||||
|
max(2, len(policy_weights)))
|
||||||
|
export_path = FLAGS.export_path + "_" + FLAGS.policy_weights_name
|
||||||
|
else:
|
||||||
|
policy_weights = None
|
||||||
|
num_episodes = FLAGS.num_pretrain_episodes
|
||||||
|
export_path = FLAGS.export_path
|
||||||
|
|
||||||
|
keyboard_utils.create_and_train_keyboard(
|
||||||
|
num_episodes=num_episodes,
|
||||||
|
policy_weights=policy_weights,
|
||||||
|
export_path=export_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.disable_v2_behavior()
|
||||||
|
app.run(main)
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Train a keyboard."""
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from option_keyboard import keyboard_utils
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_integer("num_pretrain_episodes", 20000,
|
||||||
|
"Number of pretraining episodes.")
|
||||||
|
flags.DEFINE_integer("num_phis", None, "Size of phi")
|
||||||
|
flags.DEFINE_string("phi_model_path", None,
|
||||||
|
"Where to load the phi model checkpoints.")
|
||||||
|
flags.DEFINE_string("export_path", None,
|
||||||
|
"Where to save the keyboard checkpoints.")
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
keyboard_utils.create_and_train_keyboard_with_phi(
|
||||||
|
num_episodes=FLAGS.num_pretrain_episodes,
|
||||||
|
phi_model_path=FLAGS.phi_model_path,
|
||||||
|
policy_weights=np.eye(FLAGS.num_phis, dtype=np.float32),
|
||||||
|
export_path=FLAGS.export_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.disable_v2_behavior()
|
||||||
|
app.run(main)
|
||||||
@@ -0,0 +1,270 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Train simple phi model."""
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import random
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import sonnet as snt
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tree
|
||||||
|
|
||||||
|
from option_keyboard import scavenger
|
||||||
|
from option_keyboard import smart_module
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_integer("num_phis", 2, "Dimensionality of phis.")
|
||||||
|
flags.DEFINE_integer("num_train_steps", 2000, "Number of training steps.")
|
||||||
|
flags.DEFINE_integer("num_replay_steps", 500, "Number of replay steps.")
|
||||||
|
flags.DEFINE_integer("min_replay_size", 1000,
|
||||||
|
"Minimum replay size before starting training.")
|
||||||
|
flags.DEFINE_integer("num_train_repeats", 10, "Number of training repeats.")
|
||||||
|
flags.DEFINE_float("learning_rate", 3e-3, "Learning rate.")
|
||||||
|
flags.DEFINE_bool("use_random_tasks", False, "Use random tasks.")
|
||||||
|
flags.DEFINE_string("normalisation", "L2",
|
||||||
|
"Normalisation method for cumulant weights.")
|
||||||
|
flags.DEFINE_string("export_path", None, "Export path.")
|
||||||
|
|
||||||
|
|
||||||
|
StepOutput = collections.namedtuple("StepOutput",
|
||||||
|
["obs", "actions", "rewards", "next_obs"])
|
||||||
|
|
||||||
|
|
||||||
|
def collect_experience(env, num_episodes, verbose=False):
|
||||||
|
"""Collect experience."""
|
||||||
|
|
||||||
|
num_actions = env.action_spec().maximum + 1
|
||||||
|
|
||||||
|
observations = []
|
||||||
|
actions = []
|
||||||
|
rewards = []
|
||||||
|
next_observations = []
|
||||||
|
|
||||||
|
for _ in range(num_episodes):
|
||||||
|
timestep = env.reset()
|
||||||
|
episode_return = 0
|
||||||
|
while not timestep.last():
|
||||||
|
action = np.random.randint(num_actions)
|
||||||
|
observations.append(timestep.observation)
|
||||||
|
actions.append(action)
|
||||||
|
|
||||||
|
timestep = env.step(action)
|
||||||
|
rewards.append(timestep.observation["aux_tasks_reward"])
|
||||||
|
episode_return += timestep.reward
|
||||||
|
|
||||||
|
next_observations.append(timestep.observation)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logging.info("Total return for episode: %f", episode_return)
|
||||||
|
|
||||||
|
observation_spec = tree.map_structure(lambda _: None, observations[0])
|
||||||
|
|
||||||
|
def stack_observations(obs_list):
|
||||||
|
obs_list = [
|
||||||
|
np.stack(obs) for obs in zip(*[tree.flatten(obs) for obs in obs_list])
|
||||||
|
]
|
||||||
|
obs_dict = tree.unflatten_as(observation_spec, obs_list)
|
||||||
|
obs_dict.pop("aux_tasks_reward")
|
||||||
|
return obs_dict
|
||||||
|
|
||||||
|
observations = stack_observations(observations)
|
||||||
|
actions = np.array(actions, dtype=np.int32)
|
||||||
|
rewards = np.stack(rewards)
|
||||||
|
next_observations = stack_observations(next_observations)
|
||||||
|
|
||||||
|
return StepOutput(observations, actions, rewards, next_observations)
|
||||||
|
|
||||||
|
|
||||||
|
class PhiModel(snt.AbstractModule):
|
||||||
|
"""A model for learning phi."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n_actions,
|
||||||
|
n_phis,
|
||||||
|
network_kwargs,
|
||||||
|
final_activation="sigmoid",
|
||||||
|
name="PhiModel"):
|
||||||
|
super(PhiModel, self).__init__(name=name)
|
||||||
|
self._n_actions = n_actions
|
||||||
|
self._n_phis = n_phis
|
||||||
|
self._network_kwargs = network_kwargs
|
||||||
|
self._final_activation = final_activation
|
||||||
|
|
||||||
|
def _build(self, observation, actions):
|
||||||
|
obs = observation["arena"]
|
||||||
|
|
||||||
|
n_outputs = self._n_actions * self._n_phis
|
||||||
|
flat_obs = snt.BatchFlatten()(obs)
|
||||||
|
net = snt.nets.MLP(**self._network_kwargs)(flat_obs)
|
||||||
|
net = snt.Linear(output_size=n_outputs)(net)
|
||||||
|
net = snt.BatchReshape((self._n_actions, self._n_phis))(net)
|
||||||
|
|
||||||
|
indices = tf.stack([tf.range(tf.shape(actions)[0]), actions], axis=1)
|
||||||
|
values = tf.gather_nd(net, indices)
|
||||||
|
if self._final_activation:
|
||||||
|
values = getattr(tf.nn, self._final_activation)(values)
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def create_ph(tensor):
|
||||||
|
return tf.placeholder(shape=(None,) + tensor.shape[1:], dtype=tensor.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
if FLAGS.use_random_tasks:
|
||||||
|
tasks = np.random.normal(size=(8, 2))
|
||||||
|
else:
|
||||||
|
tasks = [
|
||||||
|
[1.0, 0.0],
|
||||||
|
[0.0, 1.0],
|
||||||
|
[-1.0, 0.0],
|
||||||
|
[0.0, -1.0],
|
||||||
|
[0.7, 0.3],
|
||||||
|
[-0.3, -0.7],
|
||||||
|
]
|
||||||
|
|
||||||
|
if FLAGS.normalisation == "L1":
|
||||||
|
tasks /= np.sum(np.abs(tasks), axis=-1, keepdims=True)
|
||||||
|
elif FLAGS.normalisation == "L2":
|
||||||
|
tasks /= np.linalg.norm(tasks, axis=-1, keepdims=True)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown normlisation_method {}".format(
|
||||||
|
FLAGS.normalisation))
|
||||||
|
|
||||||
|
logging.info("Tasks: %s", tasks)
|
||||||
|
|
||||||
|
env_config = dict(
|
||||||
|
arena_size=11,
|
||||||
|
num_channels=2,
|
||||||
|
max_num_steps=100,
|
||||||
|
num_init_objects=10,
|
||||||
|
object_priors=[1.0, 1.0],
|
||||||
|
egocentric=True,
|
||||||
|
default_w=None,
|
||||||
|
aux_tasks_w=tasks)
|
||||||
|
env = scavenger.Scavenger(**env_config)
|
||||||
|
num_actions = env.action_spec().maximum + 1
|
||||||
|
|
||||||
|
model_config = dict(
|
||||||
|
n_actions=num_actions,
|
||||||
|
n_phis=FLAGS.num_phis,
|
||||||
|
network_kwargs=dict(
|
||||||
|
output_sizes=(64, 128),
|
||||||
|
activate_final=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model = smart_module.SmartModuleExport(lambda: PhiModel(**model_config))
|
||||||
|
|
||||||
|
dummy_steps = collect_experience(env, num_episodes=10, verbose=True)
|
||||||
|
num_rewards = dummy_steps.rewards.shape[-1]
|
||||||
|
|
||||||
|
# Placeholders
|
||||||
|
steps_ph = tree.map_structure(create_ph, dummy_steps)
|
||||||
|
|
||||||
|
phis = model(steps_ph.obs, steps_ph.actions)
|
||||||
|
phis_to_rewards = snt.Linear(
|
||||||
|
num_rewards, initializers=dict(w=tf.zeros), use_bias=False)
|
||||||
|
preds = phis_to_rewards(phis)
|
||||||
|
loss_per_batch = tf.square(preds - steps_ph.rewards)
|
||||||
|
loss_op = tf.reduce_mean(loss_per_batch)
|
||||||
|
|
||||||
|
replay = []
|
||||||
|
|
||||||
|
# Optimizer and train op.
|
||||||
|
with tf.variable_scope("optimizer"):
|
||||||
|
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
|
||||||
|
train_op = optimizer.minimize(loss_op)
|
||||||
|
# Add normalisation of weights in phis_to_rewards
|
||||||
|
if FLAGS.normalisation == "L1":
|
||||||
|
w_norm = tf.reduce_sum(tf.abs(phis_to_rewards.w), axis=0, keepdims=True)
|
||||||
|
elif FLAGS.normalisation == "L2":
|
||||||
|
w_norm = tf.norm(phis_to_rewards.w, axis=0, keepdims=True)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown normlisation_method {}".format(
|
||||||
|
FLAGS.normalisation))
|
||||||
|
|
||||||
|
normalise_w = tf.assign(phis_to_rewards.w,
|
||||||
|
phis_to_rewards.w / tf.maximum(w_norm, 1e-6))
|
||||||
|
|
||||||
|
def filter_steps(steps):
|
||||||
|
mask = np.sum(np.abs(steps.rewards), axis=-1) > 0.1
|
||||||
|
nonzero_inds = np.where(mask)[0]
|
||||||
|
zero_inds = np.where(np.logical_not(mask))[0]
|
||||||
|
zero_inds = np.random.choice(
|
||||||
|
zero_inds, size=len(nonzero_inds), replace=False)
|
||||||
|
selected_inds = np.concatenate([nonzero_inds, zero_inds])
|
||||||
|
selected_steps = tree.map_structure(lambda x: x[selected_inds], steps)
|
||||||
|
return selected_steps, selected_inds
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
while step < FLAGS.num_train_steps:
|
||||||
|
step += 1
|
||||||
|
steps_output = collect_experience(env, num_episodes=10)
|
||||||
|
selected_step_outputs, selected_inds = filter_steps(steps_output)
|
||||||
|
|
||||||
|
if len(replay) > FLAGS.min_replay_size:
|
||||||
|
# Do training.
|
||||||
|
for _ in range(FLAGS.num_train_repeats):
|
||||||
|
train_samples = random.choices(replay, k=128)
|
||||||
|
train_samples = tree.map_structure(
|
||||||
|
lambda *x: np.stack(x, axis=0), *train_samples)
|
||||||
|
train_samples = tree.unflatten_as(steps_ph, train_samples)
|
||||||
|
feed_dict = dict(
|
||||||
|
zip(tree.flatten(steps_ph), tree.flatten(train_samples)))
|
||||||
|
_, train_loss = sess.run([train_op, loss_op], feed_dict=feed_dict)
|
||||||
|
sess.run(normalise_w)
|
||||||
|
|
||||||
|
# Do evaluation.
|
||||||
|
if step % 50 == 0:
|
||||||
|
feed_dict = dict(
|
||||||
|
zip(tree.flatten(steps_ph), tree.flatten(selected_step_outputs)))
|
||||||
|
eval_loss = sess.run(loss_op, feed_dict=feed_dict)
|
||||||
|
logging.info("Step %d, train loss %f, eval loss %f, replay %s",
|
||||||
|
step, train_loss, eval_loss, len(replay))
|
||||||
|
print(sess.run(phis_to_rewards.get_variables())[0].T)
|
||||||
|
|
||||||
|
values = dict(step=step, train_loss=train_loss, eval_loss=eval_loss)
|
||||||
|
logging.info(values)
|
||||||
|
|
||||||
|
# Add to replay.
|
||||||
|
if step <= FLAGS.num_replay_steps:
|
||||||
|
def select_fn(ind):
|
||||||
|
return lambda x: x[ind]
|
||||||
|
for idx in range(len(selected_inds)):
|
||||||
|
replay.append(
|
||||||
|
tree.flatten(
|
||||||
|
tree.map_structure(select_fn(idx), selected_step_outputs)))
|
||||||
|
|
||||||
|
# Export trained model.
|
||||||
|
if FLAGS.export_path:
|
||||||
|
model.export(FLAGS.export_path, sess, overwrite=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.disable_v2_behavior()
|
||||||
|
app.run(main)
|
||||||
@@ -16,10 +16,14 @@
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Keyboard agent."""
|
"""Keyboard agent."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sonnet as snt
|
import sonnet as snt
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from option_keyboard import smart_module
|
||||||
|
|
||||||
|
|
||||||
class Agent():
|
class Agent():
|
||||||
"""An Option Keyboard Agent."""
|
"""An Option Keyboard Agent."""
|
||||||
@@ -51,6 +55,7 @@ class Agent():
|
|||||||
optimizer_kwargs: Keyword arguments for the optimizer.
|
optimizer_kwargs: Keyword arguments for the optimizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
tf.logging.info(policy_weights)
|
||||||
self._policy_weights = tf.convert_to_tensor(
|
self._policy_weights = tf.convert_to_tensor(
|
||||||
policy_weights, dtype=tf.float32)
|
policy_weights, dtype=tf.float32)
|
||||||
self._current_policy = None
|
self._current_policy = None
|
||||||
@@ -61,13 +66,16 @@ class Agent():
|
|||||||
|
|
||||||
self._n_actions = action_spec.num_values
|
self._n_actions = action_spec.num_values
|
||||||
self._n_policies, self._n_cumulants = policy_weights.shape
|
self._n_policies, self._n_cumulants = policy_weights.shape
|
||||||
self._network = OptionValueNet(
|
|
||||||
|
def create_network():
|
||||||
|
return OptionValueNet(
|
||||||
self._n_policies,
|
self._n_policies,
|
||||||
self._n_cumulants,
|
self._n_cumulants,
|
||||||
self._n_actions,
|
self._n_actions,
|
||||||
network_kwargs=network_kwargs,
|
network_kwargs=network_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._network = smart_module.SmartModuleExport(create_network)
|
||||||
self._replay = []
|
self._replay = []
|
||||||
|
|
||||||
obs_spec = self._extract_observation(obs_spec)
|
obs_spec = self._extract_observation(obs_spec)
|
||||||
@@ -103,6 +111,12 @@ class Agent():
|
|||||||
td_error = tf.stop_gradient(c_t + g * qa_t) - qa_tm1
|
td_error = tf.stop_gradient(c_t + g * qa_t) - qa_tm1
|
||||||
loss = tf.reduce_sum(tf.square(td_error) / 2)
|
loss = tf.reduce_sum(tf.square(td_error) / 2)
|
||||||
|
|
||||||
|
# Dummy calls to keyboard for SmartModule
|
||||||
|
_ = self._network.gpi(o_tm1[0], c_t[0])
|
||||||
|
_ = self._network.num_cumulants
|
||||||
|
_ = self._network.num_policies
|
||||||
|
_ = self._network.num_actions
|
||||||
|
|
||||||
with tf.variable_scope("optimizer"):
|
with tf.variable_scope("optimizer"):
|
||||||
self._optimizer = getattr(tf.train, optimizer_name)(**optimizer_kwargs)
|
self._optimizer = getattr(tf.train, optimizer_name)(**optimizer_kwargs)
|
||||||
train_op = self._optimizer.minimize(loss)
|
train_op = self._optimizer.minimize(loss)
|
||||||
@@ -155,7 +169,9 @@ class Agent():
|
|||||||
|
|
||||||
def export(self, path):
|
def export(self, path):
|
||||||
tf.logging.info("Exporting keyboard to %s", path)
|
tf.logging.info("Exporting keyboard to %s", path)
|
||||||
self._saver.save(self._session, path)
|
self._network.export(
|
||||||
|
os.path.join(path, "tfhub"), self._session, overwrite=True)
|
||||||
|
self._saver.save(self._session, os.path.join(path, "checkpoints"))
|
||||||
|
|
||||||
|
|
||||||
class OptionValueNet(snt.AbstractModule):
|
class OptionValueNet(snt.AbstractModule):
|
||||||
|
|||||||
@@ -0,0 +1,88 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Keyboard utils."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from option_keyboard import configs
|
||||||
|
from option_keyboard import environment_wrappers
|
||||||
|
from option_keyboard import experiment
|
||||||
|
from option_keyboard import keyboard_agent
|
||||||
|
from option_keyboard import scavenger
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_train_keyboard(num_episodes,
|
||||||
|
policy_weights=None,
|
||||||
|
export_path=None):
|
||||||
|
"""Train an option keyboard."""
|
||||||
|
if policy_weights is None:
|
||||||
|
policy_weights = np.eye(2, dtype=np.float32)
|
||||||
|
|
||||||
|
env_config = configs.get_pretrain_config()
|
||||||
|
env = scavenger.Scavenger(**env_config)
|
||||||
|
env = environment_wrappers.EnvironmentWithLogging(env)
|
||||||
|
|
||||||
|
agent = keyboard_agent.Agent(
|
||||||
|
obs_spec=env.observation_spec(),
|
||||||
|
action_spec=env.action_spec(),
|
||||||
|
policy_weights=policy_weights,
|
||||||
|
network_kwargs=dict(
|
||||||
|
output_sizes=(64, 128),
|
||||||
|
activate_final=True,
|
||||||
|
),
|
||||||
|
epsilon=0.1,
|
||||||
|
additional_discount=0.9,
|
||||||
|
batch_size=10,
|
||||||
|
optimizer_name="AdamOptimizer",
|
||||||
|
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||||
|
|
||||||
|
if num_episodes:
|
||||||
|
experiment.run(env, agent, num_episodes=num_episodes)
|
||||||
|
agent.export(export_path)
|
||||||
|
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_train_keyboard_with_phi(num_episodes,
|
||||||
|
phi_model_path,
|
||||||
|
policy_weights,
|
||||||
|
export_path=None):
|
||||||
|
"""Train an option keyboard."""
|
||||||
|
env_config = configs.get_pretrain_config()
|
||||||
|
env = scavenger.Scavenger(**env_config)
|
||||||
|
env = environment_wrappers.EnvironmentWithLogging(env)
|
||||||
|
env = environment_wrappers.EnvironmentWithLearnedPhi(env, phi_model_path)
|
||||||
|
|
||||||
|
agent = keyboard_agent.Agent(
|
||||||
|
obs_spec=env.observation_spec(),
|
||||||
|
action_spec=env.action_spec(),
|
||||||
|
policy_weights=policy_weights,
|
||||||
|
network_kwargs=dict(
|
||||||
|
output_sizes=(64, 128),
|
||||||
|
activate_final=True,
|
||||||
|
),
|
||||||
|
epsilon=0.1,
|
||||||
|
additional_discount=0.9,
|
||||||
|
batch_size=10,
|
||||||
|
optimizer_name="AdamOptimizer",
|
||||||
|
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||||
|
|
||||||
|
if num_episodes:
|
||||||
|
experiment.run(env, agent, num_episodes=num_episodes)
|
||||||
|
agent.export(export_path)
|
||||||
|
|
||||||
|
return agent
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
absl-py
|
absl-py
|
||||||
dm-env==1.2
|
dm-env==1.2
|
||||||
dm-sonnet==1.34
|
dm-sonnet==1.34
|
||||||
|
dm-tree
|
||||||
numpy==1.16.4
|
numpy==1.16.4
|
||||||
tensorflow==1.13.2
|
tensorflow==1.13.2
|
||||||
|
tensorflow_hub==0.7.0
|
||||||
tensorflow_probability==0.6.0
|
tensorflow_probability==0.6.0
|
||||||
|
wrapt
|
||||||
|
|||||||
+20
-36
@@ -16,60 +16,44 @@
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Run an experiment."""
|
"""Run an experiment."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
from option_keyboard import configs
|
from option_keyboard import configs
|
||||||
from option_keyboard import dqn_agent
|
from option_keyboard import dqn_agent
|
||||||
from option_keyboard import environment_wrappers
|
from option_keyboard import environment_wrappers
|
||||||
from option_keyboard import experiment
|
from option_keyboard import experiment
|
||||||
from option_keyboard import keyboard_agent
|
from option_keyboard import keyboard_utils
|
||||||
from option_keyboard import scavenger
|
from option_keyboard import scavenger
|
||||||
|
from option_keyboard import smart_module
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||||
flags.DEFINE_integer("num_pretrain_episodes", 20000,
|
flags.DEFINE_integer("num_pretrain_episodes", 20000,
|
||||||
"Number of pretraining episodes.")
|
"Number of pretraining episodes.")
|
||||||
|
flags.DEFINE_string("keyboard_path", None, "Path to pretrained keyboard model.")
|
||||||
|
|
||||||
def _train_keyboard(num_episodes):
|
|
||||||
"""Train an option keyboard."""
|
|
||||||
env_config = configs.get_pretrain_config()
|
|
||||||
env = scavenger.Scavenger(**env_config)
|
|
||||||
env = environment_wrappers.EnvironmentWithLogging(env)
|
|
||||||
|
|
||||||
agent = keyboard_agent.Agent(
|
|
||||||
obs_spec=env.observation_spec(),
|
|
||||||
action_spec=env.action_spec(),
|
|
||||||
policy_weights=np.array([
|
|
||||||
[1.0, 0.0],
|
|
||||||
[0.0, 1.0],
|
|
||||||
]),
|
|
||||||
network_kwargs=dict(
|
|
||||||
output_sizes=(64, 128),
|
|
||||||
activate_final=True,
|
|
||||||
),
|
|
||||||
epsilon=0.1,
|
|
||||||
additional_discount=0.9,
|
|
||||||
batch_size=10,
|
|
||||||
optimizer_name="AdamOptimizer",
|
|
||||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
|
||||||
|
|
||||||
experiment.run(env, agent, num_episodes=num_episodes)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
del argv
|
del argv
|
||||||
|
|
||||||
# Pretrain the keyboard and save a checkpoint.
|
# Pretrain the keyboard and save a checkpoint.
|
||||||
pretrain_agent = _train_keyboard(num_episodes=FLAGS.num_pretrain_episodes)
|
if FLAGS.keyboard_path:
|
||||||
keyboard_ckpt_path = "/tmp/option_keyboard/keyboard.ckpt"
|
keyboard_path = FLAGS.keyboard_path
|
||||||
pretrain_agent.export(keyboard_ckpt_path)
|
else:
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
export_path = "/tmp/option_keyboard/keyboard"
|
||||||
|
_ = keyboard_utils.create_and_train_keyboard(
|
||||||
|
num_episodes=FLAGS.num_pretrain_episodes, export_path=export_path)
|
||||||
|
keyboard_path = os.path.join(export_path, "tfhub")
|
||||||
|
|
||||||
|
# Load the keyboard.
|
||||||
|
keyboard = smart_module.SmartModuleImport(hub.Module(keyboard_path))
|
||||||
|
|
||||||
# Create the task environment.
|
# Create the task environment.
|
||||||
base_env_config = configs.get_task_config()
|
base_env_config = configs.get_task_config()
|
||||||
@@ -80,11 +64,11 @@ def main(argv):
|
|||||||
additional_discount = 0.9
|
additional_discount = 0.9
|
||||||
env = environment_wrappers.EnvironmentWithKeyboard(
|
env = environment_wrappers.EnvironmentWithKeyboard(
|
||||||
env=base_env,
|
env=base_env,
|
||||||
keyboard=pretrain_agent.keyboard,
|
keyboard=keyboard,
|
||||||
keyboard_ckpt_path=keyboard_ckpt_path,
|
keyboard_ckpt_path=None,
|
||||||
n_actions_per_dim=3,
|
n_actions_per_dim=3,
|
||||||
additional_discount=additional_discount,
|
additional_discount=additional_discount,
|
||||||
call_and_return=True)
|
call_and_return=False)
|
||||||
|
|
||||||
# Create the player agent.
|
# Create the player agent.
|
||||||
agent = dqn_agent.Agent(
|
agent = dqn_agent.Agent(
|
||||||
|
|||||||
@@ -56,7 +56,8 @@ class Scavenger(auto_reset_environment.Base):
|
|||||||
num_init_objects=15,
|
num_init_objects=15,
|
||||||
object_priors=None,
|
object_priors=None,
|
||||||
egocentric=True,
|
egocentric=True,
|
||||||
rewarder=None):
|
rewarder=None,
|
||||||
|
aux_tasks_w=None):
|
||||||
self._arena_size = arena_size
|
self._arena_size = arena_size
|
||||||
self._num_channels = num_channels
|
self._num_channels = num_channels
|
||||||
self._max_num_steps = max_num_steps
|
self._max_num_steps = max_num_steps
|
||||||
@@ -64,12 +65,13 @@ class Scavenger(auto_reset_environment.Base):
|
|||||||
self._egocentric = egocentric
|
self._egocentric = egocentric
|
||||||
self._rewarder = (
|
self._rewarder = (
|
||||||
getattr(this_module, rewarder)() if rewarder is not None else None)
|
getattr(this_module, rewarder)() if rewarder is not None else None)
|
||||||
|
self._aux_tasks_w = aux_tasks_w
|
||||||
|
|
||||||
if object_priors is None:
|
if object_priors is None:
|
||||||
self._object_priors = np.ones(num_channels) / num_channels
|
self._object_priors = np.ones(num_channels) / num_channels
|
||||||
else:
|
else:
|
||||||
assert len(object_priors) == num_channels
|
assert len(object_priors) == num_channels
|
||||||
self._object_priors = np.array(object_priors)
|
self._object_priors = np.array(object_priors) / np.sum(object_priors)
|
||||||
|
|
||||||
if default_w is None:
|
if default_w is None:
|
||||||
self._default_w = np.ones(shape=(num_channels,))
|
self._default_w = np.ones(shape=(num_channels,))
|
||||||
@@ -203,10 +205,15 @@ class Scavenger(auto_reset_environment.Base):
|
|||||||
|
|
||||||
collected_resources = np.copy(self._prev_collected).astype(np.float32)
|
collected_resources = np.copy(self._prev_collected).astype(np.float32)
|
||||||
|
|
||||||
return dict(
|
obs = dict(
|
||||||
arena=arena,
|
arena=arena,
|
||||||
cumulants=collected_resources,
|
cumulants=collected_resources,
|
||||||
)
|
)
|
||||||
|
if self._aux_tasks_w is not None:
|
||||||
|
obs["aux_tasks_reward"] = np.dot(
|
||||||
|
np.array(self._aux_tasks_w), self._prev_collected).astype(np.float32)
|
||||||
|
|
||||||
|
return obs
|
||||||
|
|
||||||
def observation_spec(self):
|
def observation_spec(self):
|
||||||
arena = dm_env.specs.BoundedArray(
|
arena = dm_env.specs.BoundedArray(
|
||||||
@@ -222,10 +229,19 @@ class Scavenger(auto_reset_environment.Base):
|
|||||||
maximum=1e9,
|
maximum=1e9,
|
||||||
name="collected_resources")
|
name="collected_resources")
|
||||||
|
|
||||||
return dict(
|
obs_spec = dict(
|
||||||
arena=arena,
|
arena=arena,
|
||||||
cumulants=collected_resources,
|
cumulants=collected_resources,
|
||||||
)
|
)
|
||||||
|
if self._aux_tasks_w is not None:
|
||||||
|
obs_spec["aux_tasks_reward"] = dm_env.specs.BoundedArray(
|
||||||
|
shape=(len(self._aux_tasks_w),),
|
||||||
|
dtype=np.float32,
|
||||||
|
minimum=-1e9,
|
||||||
|
maximum=1e9,
|
||||||
|
name="aux_tasks_reward")
|
||||||
|
|
||||||
|
return obs_spec
|
||||||
|
|
||||||
def action_spec(self):
|
def action_spec(self):
|
||||||
return dm_env.specs.DiscreteArray(num_values=len(Action), name="action")
|
return dm_env.specs.DiscreteArray(num_values=len(Action), name="action")
|
||||||
|
|||||||
@@ -0,0 +1,228 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# pylint: disable=g-bad-file-header
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Smart module export/import utilities."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
import tree as nest
|
||||||
|
import wrapt
|
||||||
|
|
||||||
|
|
||||||
|
_ALLOWED_TYPES = (bool, float, int, str)
|
||||||
|
|
||||||
|
|
||||||
|
def _getcallargs(signature, *args, **kwargs):
|
||||||
|
bound_args = signature.bind(*args, **kwargs)
|
||||||
|
bound_args.apply_defaults()
|
||||||
|
inputs = bound_args.arguments
|
||||||
|
inputs.pop("self", None)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def _to_placeholder(arg):
|
||||||
|
if arg is None or isinstance(arg, bool):
|
||||||
|
return arg
|
||||||
|
|
||||||
|
arg = tf.convert_to_tensor(arg)
|
||||||
|
return tf.placeholder(dtype=arg.dtype, shape=arg.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class SmartModuleExport(object):
|
||||||
|
"""Helper class for exporting TF-Hub modules."""
|
||||||
|
|
||||||
|
def __init__(self, object_factory):
|
||||||
|
self._object_factory = object_factory
|
||||||
|
self._wrapped_object = self._object_factory()
|
||||||
|
self._variable_scope = tf.get_variable_scope()
|
||||||
|
self._captured_calls = {}
|
||||||
|
self._captured_attrs = {}
|
||||||
|
|
||||||
|
def _create_captured_method(self, method_name):
|
||||||
|
"""Creates a wrapped method that captures its inputs."""
|
||||||
|
with tf.variable_scope(self._variable_scope):
|
||||||
|
method_ = getattr(self._wrapped_object, method_name)
|
||||||
|
|
||||||
|
@wrapt.decorator
|
||||||
|
def wrapper(method, instance, args, kwargs):
|
||||||
|
"""Wrapped method to capture inputs."""
|
||||||
|
del instance
|
||||||
|
|
||||||
|
specs = inspect.signature(method)
|
||||||
|
inputs = _getcallargs(specs, *args, **kwargs)
|
||||||
|
|
||||||
|
with tf.variable_scope(self._variable_scope):
|
||||||
|
output = method(*args, **kwargs)
|
||||||
|
|
||||||
|
self._captured_calls[method_name] = [inputs, specs]
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
return wrapper(method_) # pylint: disable=no-value-for-parameter
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
"""Helper method for accessing an attributes of the wrapped object."""
|
||||||
|
# if "_wrapped_object" not in self.__dict__:
|
||||||
|
# return super(ExportableModule, self).__getattr__(name)
|
||||||
|
|
||||||
|
with tf.variable_scope(self._variable_scope):
|
||||||
|
attr = getattr(self._wrapped_object, name)
|
||||||
|
|
||||||
|
if inspect.ismethod(attr) or inspect.isfunction(attr):
|
||||||
|
return self._create_captured_method(name)
|
||||||
|
else:
|
||||||
|
if all([isinstance(v, _ALLOWED_TYPES) for v in nest.flatten(attr)]):
|
||||||
|
self._captured_attrs[name] = attr
|
||||||
|
return attr
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self._create_captured_method("__call__")(*args, **kwargs)
|
||||||
|
|
||||||
|
def export(self, path, session, overwrite=False):
|
||||||
|
"""Build the TF-Hub spec, module and sync ops."""
|
||||||
|
|
||||||
|
method_specs = {}
|
||||||
|
|
||||||
|
def module_fn():
|
||||||
|
"""A module_fn for use with hub.create_module_spec()."""
|
||||||
|
# We will use a copy of the original object to build the graph.
|
||||||
|
wrapped_object = self._object_factory()
|
||||||
|
|
||||||
|
for method_name, method_info in self._captured_calls.items():
|
||||||
|
captured_inputs, captured_specs = method_info
|
||||||
|
tensor_inputs = nest.map_structure(_to_placeholder, captured_inputs)
|
||||||
|
method_to_call = getattr(wrapped_object, method_name)
|
||||||
|
tensor_outputs = method_to_call(**tensor_inputs)
|
||||||
|
|
||||||
|
flat_tensor_inputs = nest.flatten(tensor_inputs)
|
||||||
|
flat_tensor_inputs = {
|
||||||
|
str(k): v for k, v in zip(
|
||||||
|
range(len(flat_tensor_inputs)), flat_tensor_inputs)
|
||||||
|
}
|
||||||
|
flat_tensor_outputs = nest.flatten(tensor_outputs)
|
||||||
|
flat_tensor_outputs = {
|
||||||
|
str(k): v for k, v in zip(
|
||||||
|
range(len(flat_tensor_outputs)), flat_tensor_outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
method_specs[method_name] = dict(
|
||||||
|
specs=captured_specs,
|
||||||
|
inputs=nest.map_structure(lambda _: None, tensor_inputs),
|
||||||
|
outputs=nest.map_structure(lambda _: None, tensor_outputs))
|
||||||
|
|
||||||
|
signature_name = ("default"
|
||||||
|
if method_name == "__call__" else method_name)
|
||||||
|
hub.add_signature(signature_name, flat_tensor_inputs,
|
||||||
|
flat_tensor_outputs)
|
||||||
|
|
||||||
|
hub.attach_message(
|
||||||
|
"methods", tf.train.BytesList(value=[pickle.dumps(method_specs)]))
|
||||||
|
hub.attach_message(
|
||||||
|
"properties",
|
||||||
|
tf.train.BytesList(value=[pickle.dumps(self._captured_attrs)]))
|
||||||
|
|
||||||
|
# Create the spec that will be later used in export.
|
||||||
|
hub_spec = hub.create_module_spec(module_fn, drop_collections=["sonnet"])
|
||||||
|
|
||||||
|
# Get variables values
|
||||||
|
module_weights = [
|
||||||
|
session.run(v) for v in self._wrapped_object.get_all_variables()
|
||||||
|
]
|
||||||
|
|
||||||
|
# create the sync ops
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
hub_module = hub.Module(hub_spec, trainable=True, name="hub")
|
||||||
|
|
||||||
|
assign_ops = []
|
||||||
|
assign_phs = []
|
||||||
|
for _, v in sorted(hub_module.variable_map.items()):
|
||||||
|
ph = tf.placeholder(shape=v.shape, dtype=v.dtype)
|
||||||
|
assign_phs.append(ph)
|
||||||
|
assign_ops.append(tf.assign(v, ph))
|
||||||
|
|
||||||
|
with tf.Session() as module_session:
|
||||||
|
module_session.run(tf.local_variables_initializer())
|
||||||
|
module_session.run(tf.global_variables_initializer())
|
||||||
|
module_session.run(
|
||||||
|
assign_ops, feed_dict=dict(zip(assign_phs, module_weights)))
|
||||||
|
|
||||||
|
if overwrite and os.path.exists(path):
|
||||||
|
shutil.rmtree(path)
|
||||||
|
os.makedirs(path)
|
||||||
|
hub_module.export(path, module_session)
|
||||||
|
|
||||||
|
|
||||||
|
class SmartModuleImport(object):
|
||||||
|
"""A class for importing graph building objects from TF-Hub modules."""
|
||||||
|
|
||||||
|
def __init__(self, module):
|
||||||
|
self._module = module
|
||||||
|
self._method_specs = pickle.loads(
|
||||||
|
self._module.get_attached_message("methods",
|
||||||
|
tf.train.BytesList).value[0])
|
||||||
|
self._properties = pickle.loads(
|
||||||
|
self._module.get_attached_message("properties",
|
||||||
|
tf.train.BytesList).value[0])
|
||||||
|
|
||||||
|
def _create_wrapped_method(self, method):
|
||||||
|
"""Creates a wrapped method that converts nested inputs and outputs."""
|
||||||
|
|
||||||
|
def wrapped_method(*args, **kwargs):
|
||||||
|
"""A wrapped method around a TF-Hub module signature."""
|
||||||
|
|
||||||
|
inputs = _getcallargs(self._method_specs[method]["specs"], *args,
|
||||||
|
**kwargs)
|
||||||
|
nest.assert_same_structure(self._method_specs[method]["inputs"], inputs)
|
||||||
|
flat_inputs = nest.flatten(inputs)
|
||||||
|
flat_inputs = {
|
||||||
|
str(k): v for k, v in zip(range(len(flat_inputs)), flat_inputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
signature = "default" if method == "__call__" else method
|
||||||
|
flat_outputs = self._module(
|
||||||
|
flat_inputs, signature=signature, as_dict=True)
|
||||||
|
flat_outputs = [v for _, v in sorted(flat_outputs.items())]
|
||||||
|
|
||||||
|
output_spec = self._method_specs[method]["outputs"]
|
||||||
|
if output_spec is None:
|
||||||
|
if len(flat_outputs) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected output containing a single tensor, found {}".format(
|
||||||
|
flat_outputs))
|
||||||
|
outputs = flat_outputs[0]
|
||||||
|
else:
|
||||||
|
outputs = nest.unflatten_as(output_spec, flat_outputs)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return wrapped_method
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name in self._method_specs:
|
||||||
|
return self._create_wrapped_method(name)
|
||||||
|
|
||||||
|
if name in self._properties:
|
||||||
|
return self._properties[name]
|
||||||
|
|
||||||
|
return getattr(self._module, name)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self._create_wrapped_method("__call__")(*args, **kwargs)
|
||||||
Reference in New Issue
Block a user