diff --git a/README.md b/README.md index ea3abbf..5953ff4 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ https://deepmind.com/research/publications/ ## Projects +* [The Option Keyboard: Combining Skills in Reinforcement Learning](option_keyboard), NeurIPS 2019 * [VISR - Fast Task Inference with Variational Intrinsic Successor Features](visr), ICLR 2020 * [Unveiling the predictive power of static structure in glassy systems](glassy_dynamics), Nature Physics 2020 * [Multi-Object Representation Learning with Iterative Variational Inference (IODINE)](iodine) diff --git a/option_keyboard/README.md b/option_keyboard/README.md new file mode 100644 index 0000000..804a36c --- /dev/null +++ b/option_keyboard/README.md @@ -0,0 +1,58 @@ +# The Option Keyboard: Combining Skills in Reinforcement Learning + +This directory contains an implementation of the Option Keyboard framework. + +From the [abstract](http://papers.nips.cc/paper/9463-the-option-keyboard-combining-skills-in-reinforcement-learning): + +> The ability to combine known skills to create new ones may be crucial in the +solution of complex reinforcement learning problems that unfold over extended +periods. We argue that a robust way of combining skills is to define and manipulate +them in the space of pseudo-rewards (or “cumulants”). Based on this premise, we +propose a framework for combining skills using the formalism of options. We show +that every deterministic option can be unambiguously represented as a cumulant +defined in an extended domain. Building on this insight and on previous results +on transfer learning, we show how to approximate options whose cumulants are +linear combinations of the cumulants of known options. This means that, once we +have learned options associated with a set of cumulants, we can instantaneously +synthesise options induced by any linear combination of them, without any learning +involved. We describe how this framework provides a hierarchical interface to the +environment whose abstract actions correspond to combinations of basic skills. +We demonstrate the practical benefits of our approach in a resource management +problem and a navigation task involving a quadrupedal simulated robot. + +If you use the code here please cite this paper + +> Andre Barreto, Diana Borsa, Shaobo Hou, Gheorghe Comanici, Eser Aygün, Philippe Hamel, Daniel Toyama, Jonathan hunt, Shibl Mourad, David Silver, Doina Precup. *The Option Keyboard: Combining Skills in Reinforcement Learning*. Neurips 2019. [\[paper\]](https://papers.nips.cc/paper/9463-the-option-keyboard-combining-skills-in-reinforcement-learning). + +## Running the code + +### Setup +``` +python3 -m venv ok_venv +source ok_venv/bin/activate +pip install -r option_keyboard/requirements.txt +``` + +### Scavenger Task +All agents are trained on a simple grid-world resource collection task. There +are two types of collectible objects in the world: if the agent collects the +object that is less abundant of the two then it receives a reward of -1, +otherwise it receives a reward of +1 when it collects the object. See section +5.1 in the paper for more details. + +### Train the DQN baseline +``` +python3 -m option_keyboard.run_dqn +``` +This trains a DQN agent on the scavenger task. + +### Train the Option Keyboard and agent +``` +python3 -m option_keyboard.run_ok +``` +This first trains an Option Keyboard on the cumulants in the task environment. +Then it trains a DQN agent on the true task reward using high level abstract +actions provided by the keyboard. + +## Disclaimer +This is not an official Google or DeepMind product. diff --git a/option_keyboard/auto_reset_environment.py b/option_keyboard/auto_reset_environment.py new file mode 100644 index 0000000..f83de76 --- /dev/null +++ b/option_keyboard/auto_reset_environment.py @@ -0,0 +1,58 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Auto-resetting environment base class. + +The environment API states that stepping an environment after a LAST timestep +should return the first timestep of a new episode. + +However, environment authors sometimes don't spot this part or find it awkward +to implement. This module contains a class that helps implement the reset +behaviour. +""" + +import abc +import dm_env + + +class Base(dm_env.Environment): + """This class implements the required `step()` and `reset()` methods. + + It instead requires users to implement `_step()` and `_reset()`. This class + handles the reset behaviour automatically when it detects a LAST timestep. + """ + + def __init__(self): + self._reset_next_step = True + + @abc.abstractmethod + def _reset(self): + """Returns a `timestep` namedtuple as per the regular `reset()` method.""" + + @abc.abstractmethod + def _step(self, action): + """Returns a `timestep` namedtuple as per the regular `step()` method.""" + + def reset(self): + self._reset_next_step = False + return self._reset() + + def step(self, action): + if self._reset_next_step: + return self.reset() + timestep = self._step(action) + self._reset_next_step = timestep.last() + return timestep diff --git a/option_keyboard/configs.py b/option_keyboard/configs.py new file mode 100644 index 0000000..77ad229 --- /dev/null +++ b/option_keyboard/configs.py @@ -0,0 +1,41 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Environment configurations.""" + + +def get_task_config(): + return dict( + arena_size=11, + num_channels=2, + max_num_steps=50, # 5o for the actual task. + num_init_objects=10, + object_priors=[0.5, 0.5], + egocentric=True, + rewarder="BalancedCollectionRewarder", + ) + + +def get_pretrain_config(): + return dict( + arena_size=11, + num_channels=2, + max_num_steps=40, # 40 for pretraining. + num_init_objects=10, + object_priors=[0.5, 0.5], + egocentric=True, + default_w=(1, 1), + ) diff --git a/option_keyboard/dqn_agent.py b/option_keyboard/dqn_agent.py new file mode 100644 index 0000000..4240f8e --- /dev/null +++ b/option_keyboard/dqn_agent.py @@ -0,0 +1,160 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""DQN agent.""" + +import numpy as np +import sonnet as snt +import tensorflow.compat.v1 as tf + + +class Agent(): + """A DQN Agent.""" + + def __init__( + self, + obs_spec, + action_spec, + network_kwargs, + epsilon, + additional_discount, + batch_size, + optimizer_name, + optimizer_kwargs, + ): + """A simple DQN agent. + + Args: + obs_spec: The observation spec. + action_spec: The action spec. + network_kwargs: Keyword arguments for snt.nets.MLP + epsilon: Exploration probability. + additional_discount: Discount on returns used by the agent. + batch_size: Size of update batch. + optimizer_name: Name of an optimizer from tf.train + optimizer_kwargs: Keyword arguments for the optimizer. + """ + + self._epsilon = epsilon + self._additional_discount = additional_discount + self._batch_size = batch_size + + self._n_actions = action_spec.num_values + self._network = ValueNet(self._n_actions, network_kwargs=network_kwargs) + + self._replay = [] + + obs_spec = self._extract_observation(obs_spec) + + # Placeholders for policy + o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype) + q = self._network(tf.expand_dims(o, axis=0)) + + # Placeholders for update. + o_tm1 = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype) + a_tm1 = tf.placeholder(shape=(None,), dtype=tf.int32) + r_t = tf.placeholder(shape=(None,), dtype=tf.float32) + d_t = tf.placeholder(shape=(None,), dtype=tf.float32) + o_t = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype) + + # Compute values over all options. + q_tm1 = self._network(o_tm1) + q_t = self._network(o_t) + + a_t = tf.cast(tf.argmax(q_t, axis=-1), tf.int32) + qa_tm1 = _batched_index(q_tm1, a_tm1) + qa_t = _batched_index(q_t, a_t) + + # TD error + g = additional_discount * d_t + td_error = tf.stop_gradient(r_t + g * qa_t) - qa_tm1 + loss = tf.reduce_sum(tf.square(td_error) / 2) + + with tf.variable_scope("optimizer"): + self._optimizer = getattr(tf.train, optimizer_name)(**optimizer_kwargs) + train_op = self._optimizer.minimize(loss) + + # Make session and callables. + session = tf.Session() + self._update_fn = session.make_callable(train_op, + [o_tm1, a_tm1, r_t, d_t, o_t]) + self._value_fn = session.make_callable(q, [o]) + session.run(tf.global_variables_initializer()) + + def _extract_observation(self, obs): + return obs["arena"] + + def step(self, timestep, is_training=False): + """Select actions according to epsilon-greedy policy.""" + + if is_training and np.random.rand() < self._epsilon: + return np.random.randint(self._n_actions) + + q_values = self._value_fn( + self._extract_observation(timestep.observation)) + return int(np.argmax(q_values)) + + def update(self, step_tm1, action, step_t): + """Takes in a transition from the environment.""" + + transition = [ + self._extract_observation(step_tm1.observation), + action, + step_t.reward, + step_t.discount, + self._extract_observation(step_t.observation), + ] + self._replay.append(transition) + + if len(self._replay) == self._batch_size: + batch = list(zip(*self._replay)) + self._update_fn(*batch) + self._replay = [] # Just a queue. + + +class ValueNet(snt.AbstractModule): + """Value Network.""" + + def __init__(self, + n_actions, + network_kwargs, + name="value_network"): + """Construct a value network sonnet module. + + Args: + n_actions: Number of actions. + network_kwargs: Network arguments. + name: Name + """ + super(ValueNet, self).__init__(name=name) + self._n_actions = n_actions + self._network_kwargs = network_kwargs + + def _build(self, observation): + flat_obs = snt.BatchFlatten()(observation) + net = snt.nets.MLP(**self._network_kwargs)(flat_obs) + net = snt.Linear(output_size=self._n_actions)(net) + + return net + + @property + def num_actions(self): + return self._n_actions + + +def _batched_index(values, indices): + one_hot_indices = tf.one_hot(indices, values.shape[-1], dtype=values.dtype) + return tf.reduce_sum(values * one_hot_indices, axis=-1) diff --git a/option_keyboard/environment_wrappers.py b/option_keyboard/environment_wrappers.py new file mode 100644 index 0000000..908fc28 --- /dev/null +++ b/option_keyboard/environment_wrappers.py @@ -0,0 +1,190 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Environment with keyboard.""" + +import itertools + +from absl import logging + +import dm_env + +import numpy as np +import tensorflow.compat.v1 as tf + + +class EnvironmentWithLogging(dm_env.Environment): + """Wraps an environment with additional logging.""" + + def __init__(self, env): + self._env = env + self._episode_return = 0 + + def reset(self): + self._episode_return = 0 + return self._env.reset() + + def step(self, action): + """Take action in the environment and do some logging.""" + + step = self._env.step(action) + if step.first(): + step = self._env.step(action) + self._episode_return = 0 + + self._episode_return += step.reward + return step + + @property + def episode_return(self): + return self._episode_return + + def action_spec(self): + return self._env.action_spec() + + def observation_spec(self): + return self._env.observation_spec() + + def __getattr__(self, name): + return getattr(self._env, name) + + +class EnvironmentWithKeyboard(dm_env.Environment): + """Wraps an environment with a keyboard.""" + + def __init__(self, + env, + keyboard, + keyboard_ckpt_path, + n_actions_per_dim, + additional_discount, + call_and_return=False): + self._env = env + self._keyboard = keyboard + self._discount = additional_discount + self._call_and_return = call_and_return + + options = _discretize_actions(n_actions_per_dim, keyboard.num_cumulants) + self._options_np = options + options = tf.convert_to_tensor(options, dtype=tf.float32) + self._options = options + + obs_spec = self._extract_observation(env.observation_spec()) + obs_ph = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype) + option_ph = tf.placeholder(shape=(), dtype=tf.int32) + gpi_action = self._keyboard.gpi(obs_ph, options[option_ph]) + + session = tf.Session() + self._gpi_action = session.make_callable(gpi_action, [obs_ph, option_ph]) + self._keyboard_action = session.make_callable( + self._keyboard(tf.expand_dims(obs_ph, axis=0))[0], [obs_ph]) + session.run(tf.global_variables_initializer()) + + saver = tf.train.Saver(var_list=keyboard.variables) + saver.restore(session, keyboard_ckpt_path) + + def _compute_reward(self, option, obs): + return np.sum(self._options_np[option] * obs["cumulants"]) + + def reset(self): + return self._env.reset() + + def step(self, option): + """Take a step in the keyboard, then the environment.""" + + step_count = 0 + option_step = None + while True: + obs = self._extract_observation(self._env.observation()) + action = self._gpi_action(obs, option) + action_step = self._env.step(action) + step_count += 1 + + if option_step is None: + option_step = action_step + else: + new_discount = ( + option_step.discount * self._discount * action_step.discount) + new_reward = ( + option_step.reward + new_discount * action_step.reward) + option_step = option_step._replace( + observation=action_step.observation, + reward=new_reward, + discount=new_discount, + step_type=action_step.step_type) + + if action_step.last(): + break + + # Terminate option. + if self._compute_reward(option, action_step.observation) > 0: + break + + if not self._call_and_return: + break + + return option_step + + def action_spec(self): + return dm_env.specs.DiscreteArray( + num_values=self._options_np.shape[0], name="action") + + def _extract_observation(self, obs): + return obs["arena"] + + def observation_spec(self): + return self._env.observation_spec() + + def __getattr__(self, name): + return getattr(self._env, name) + + +def _discretize_actions(num_actions_per_dim, + action_space_dim, + min_val=-1.0, + max_val=1.0): + """Discrete action space.""" + if num_actions_per_dim > 1: + discretized_dim_action = np.linspace( + min_val, max_val, num_actions_per_dim, endpoint=True) + discretized_actions = [discretized_dim_action] * action_space_dim + discretized_actions = itertools.product(*discretized_actions) + discretized_actions = list(discretized_actions) + elif num_actions_per_dim == 1: + discretized_actions = [ + max_val * np.eye(action_space_dim), + min_val * np.eye(action_space_dim), + ] + discretized_actions = np.concatenate(discretized_actions, axis=0) + elif num_actions_per_dim == 0: + discretized_actions = np.eye(action_space_dim) + else: + raise ValueError( + "Unsupported num_actions_per_dim {}".format(num_actions_per_dim)) + + discretized_actions = np.array(discretized_actions) + + # Remove options with all zeros. + non_zero_entries = np.sum(np.square(discretized_actions), axis=-1) != 0.0 + # Remove options with no positive elements. + non_negative_entries = np.any(discretized_actions > 0, axis=-1) + discretized_actions = discretized_actions[np.logical_and( + non_zero_entries, non_negative_entries)] + logging.info("Total number of discretized actions: %s", + len(discretized_actions)) + logging.info("Discretized actions: %s", discretized_actions) + + return discretized_actions diff --git a/option_keyboard/experiment.py b/option_keyboard/experiment.py new file mode 100644 index 0000000..703f998 --- /dev/null +++ b/option_keyboard/experiment.py @@ -0,0 +1,70 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""A simple training loop.""" + +from absl import logging + +import numpy as np + + +def run(environment, agent, num_episodes, report_every=200): + """Runs an agent on an environment. + + Args: + environment: The environment. + agent: The agent. + num_episodes: Number of episodes to train for. + report_every: Frequency at which training progress are reported (episodes). + """ + + train_returns = [] + eval_returns = [] + for episode_id in range(num_episodes): + # Run a training episode. + train_episode_return = run_episode(environment, agent, is_training=True) + train_returns.append(train_episode_return) + + # Run an evaluation episode. + eval_episode_return = run_episode(environment, agent, is_training=False) + eval_returns.append(eval_episode_return) + + if ((episode_id + 1) % report_every) == 0: + logging.info( + "Episode %s, avg train return %.3f, avg eval return %.3f", + episode_id + 1, + np.mean(train_returns[-report_every:]), + np.mean(eval_returns[-report_every:]), + ) + + +def run_episode(environment, agent, is_training=False): + """Run a single episode.""" + + timestep = environment.reset() + + while not timestep.last(): + action = agent.step(timestep, is_training) + new_timestep = environment.step(action) + + if is_training: + agent.update(timestep, action, new_timestep) + + timestep = new_timestep + + episode_return = environment.episode_return + + return episode_return diff --git a/option_keyboard/keyboard_agent.py b/option_keyboard/keyboard_agent.py new file mode 100644 index 0000000..f698df1 --- /dev/null +++ b/option_keyboard/keyboard_agent.py @@ -0,0 +1,222 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Keyboard agent.""" + +import numpy as np +import sonnet as snt +import tensorflow.compat.v1 as tf + + +class Agent(): + """An Option Keyboard Agent.""" + + def __init__( + self, + obs_spec, + action_spec, + policy_weights, + network_kwargs, + epsilon, + additional_discount, + batch_size, + optimizer_name, + optimizer_kwargs, + ): + """A simple DQN agent. + + Args: + obs_spec: The observation spec. + action_spec: The action spec. + policy_weights: A list of vectors each representing the cumulant weights + for that particular option/policy. + network_kwargs: Keyword arguments for snt.nets.MLP + epsilon: Exploration probability. + additional_discount: Discount on returns used by the agent. + batch_size: Size of update batch. + optimizer_name: Name of an optimizer from tf.train + optimizer_kwargs: Keyword arguments for the optimizer. + """ + + self._policy_weights = tf.convert_to_tensor( + policy_weights, dtype=tf.float32) + self._current_policy = None + + self._epsilon = epsilon + self._additional_discount = additional_discount + self._batch_size = batch_size + + self._n_actions = action_spec.num_values + self._n_policies, self._n_cumulants = policy_weights.shape + self._network = OptionValueNet( + self._n_policies, + self._n_cumulants, + self._n_actions, + network_kwargs=network_kwargs, + ) + + self._replay = [] + + obs_spec = self._extract_observation(obs_spec) + + def option_values(values, policy): + return tf.tensordot( + values[:, policy, Ellipsis], self._policy_weights[policy], axes=[1, 0]) + + # Placeholders for policy. + o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype) + p = tf.placeholder(shape=(), dtype=tf.int32) + q = self._network(tf.expand_dims(o, axis=0)) + qo = option_values(q, p) + + # Placeholders for update. + o_tm1 = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype) + a_tm1 = tf.placeholder(shape=(None,), dtype=tf.int32) + c_t = tf.placeholder(shape=(None, self._n_cumulants), dtype=tf.float32) + d_t = tf.placeholder(shape=(None,), dtype=tf.float32) + o_t = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype) + + # Compute values over all options. + q_tm1 = self._network(o_tm1) + q_t = self._network(o_t) + qo_t = option_values(q_t, p) + + a_t = tf.cast(tf.argmax(qo_t, axis=-1), tf.int32) + qa_tm1 = _batched_index(q_tm1[:, p, Ellipsis], a_tm1) + qa_t = _batched_index(q_t[:, p, Ellipsis], a_t) + + # TD error + g = additional_discount * tf.expand_dims(d_t, axis=-1) + td_error = tf.stop_gradient(c_t + g * qa_t) - qa_tm1 + loss = tf.reduce_sum(tf.square(td_error) / 2) + + with tf.variable_scope("optimizer"): + self._optimizer = getattr(tf.train, optimizer_name)(**optimizer_kwargs) + train_op = self._optimizer.minimize(loss) + + # Make session and callables. + session = tf.Session() + self._session = session + self._update_fn = session.make_callable( + train_op, [o_tm1, a_tm1, c_t, d_t, o_t, p]) + self._value_fn = session.make_callable(qo, [o, p]) + session.run(tf.global_variables_initializer()) + + self._saver = tf.train.Saver(var_list=self._network.variables) + + @property + def keyboard(self): + return self._network + + def _extract_observation(self, obs): + return obs["arena"] + + def step(self, timestep, is_training=False): + """Select actions according to epsilon-greedy policy.""" + if timestep.first(): + self._current_policy = np.random.randint(self._n_policies) + + if is_training and np.random.rand() < self._epsilon: + return np.random.randint(self._n_actions) + + q_values = self._value_fn( + self._extract_observation(timestep.observation), self._current_policy) + return int(np.argmax(q_values)) + + def update(self, step_tm1, action, step_t): + """Takes in a transition from the environment.""" + + transition = [ + self._extract_observation(step_tm1.observation), + action, + step_t.observation["cumulants"], + step_t.discount, + self._extract_observation(step_t.observation), + ] + self._replay.append(transition) + + if len(self._replay) == self._batch_size: + batch = list(zip(*self._replay)) + [self._current_policy] + self._update_fn(*batch) + self._replay = [] # Just a queue. + + def export(self, path): + tf.logging.info("Exporting keyboard to %s", path) + self._saver.save(self._session, path) + + +class OptionValueNet(snt.AbstractModule): + """Option Value net.""" + + def __init__(self, + n_policies, + n_cumulants, + n_actions, + network_kwargs, + name="option_keyboard"): + """Construct an Option Value Net sonnet module. + + Args: + n_policies: Number of policies. + n_cumulants: Number of cumulants. + n_actions: Number of actions. + network_kwargs: Network arguments. + name: Name + """ + super(OptionValueNet, self).__init__(name=name) + self._n_policies = n_policies + self._n_cumulants = n_cumulants + self._n_actions = n_actions + self._network_kwargs = network_kwargs + + def _build(self, observation): + values = [] + + flat_obs = snt.BatchFlatten()(observation) + for _ in range(self._n_cumulants): + net = snt.nets.MLP(**self._network_kwargs)(flat_obs) + net = snt.Linear(output_size=self._n_policies * self._n_actions)(net) + net = snt.BatchReshape([self._n_policies, self._n_actions])(net) + values.append(net) + values = tf.stack(values, axis=2) + return values + + def gpi(self, observation, cumulant_weights): + q_values = self.__call__(tf.expand_dims(observation, axis=0))[0] + q_w = tf.tensordot(q_values, cumulant_weights, axes=[1, 0]) # [P,a] + q_w_actions = tf.reduce_max(q_w, axis=0) + + action = tf.cast(tf.argmax(q_w_actions), tf.int32) + + return action + + @property + def num_cumulants(self): + return self._n_cumulants + + @property + def num_policies(self): + return self._n_policies + + @property + def num_actions(self): + return self._n_actions + + +def _batched_index(values, indices): + one_hot_indices = tf.one_hot(indices, values.shape[-1], dtype=values.dtype) + one_hot_indices = tf.expand_dims(one_hot_indices, axis=1) + return tf.reduce_sum(values * one_hot_indices, axis=-1) diff --git a/option_keyboard/requirements.txt b/option_keyboard/requirements.txt new file mode 100644 index 0000000..12f5f22 --- /dev/null +++ b/option_keyboard/requirements.txt @@ -0,0 +1,6 @@ +absl-py +dm-env==1.2 +dm-sonnet==1.34 +numpy==1.16.4 +tensorflow==1.13.2 +tensorflow_probability==0.6.0 diff --git a/option_keyboard/run.sh b/option_keyboard/run.sh new file mode 100755 index 0000000..f27d231 --- /dev/null +++ b/option_keyboard/run.sh @@ -0,0 +1,21 @@ +#!/bin/sh +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +python3 -m venv ok_venv +source ok_venv/bin/activate +pip install -r option_keyboard/requirements.txt + +python3 -m option_keyboard.run_dqn_test +python3 -m option_keyboard.run_ok_test diff --git a/option_keyboard/run_dqn.py b/option_keyboard/run_dqn.py new file mode 100644 index 0000000..d6b1dee --- /dev/null +++ b/option_keyboard/run_dqn.py @@ -0,0 +1,61 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Run an experiment.""" + +from absl import app +from absl import flags + +import tensorflow.compat.v1 as tf + +from option_keyboard import configs +from option_keyboard import dqn_agent +from option_keyboard import environment_wrappers +from option_keyboard import experiment +from option_keyboard import scavenger + +FLAGS = flags.FLAGS +flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.") + + +def main(argv): + del argv + + # Create the task environment. + env_config = configs.get_task_config() + env = scavenger.Scavenger(**env_config) + env = environment_wrappers.EnvironmentWithLogging(env) + + # Create the flat agent. + agent = dqn_agent.Agent( + obs_spec=env.observation_spec(), + action_spec=env.action_spec(), + network_kwargs=dict( + output_sizes=(64, 128), + activate_final=True, + ), + epsilon=0.1, + additional_discount=0.9, + batch_size=10, + optimizer_name="AdamOptimizer", + optimizer_kwargs=dict(learning_rate=3e-4,)) + + experiment.run(env, agent, num_episodes=FLAGS.num_episodes) + + +if __name__ == "__main__": + tf.disable_v2_behavior() + app.run(main) diff --git a/option_keyboard/run_dqn_test.py b/option_keyboard/run_dqn_test.py new file mode 100644 index 0000000..263608e --- /dev/null +++ b/option_keyboard/run_dqn_test.py @@ -0,0 +1,38 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for running the simple DQN agent.""" + +from absl import flags +from absl.testing import absltest + +import tensorflow.compat.v1 as tf + +from option_keyboard import run_dqn + +FLAGS = flags.FLAGS + + +class RunDQNTest(absltest.TestCase): + + def test_run(self): + FLAGS.num_episodes = 200 + run_dqn.main(None) + + +if __name__ == '__main__': + tf.disable_v2_behavior() + absltest.main() diff --git a/option_keyboard/run_ok.py b/option_keyboard/run_ok.py new file mode 100644 index 0000000..7ff1414 --- /dev/null +++ b/option_keyboard/run_ok.py @@ -0,0 +1,108 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Run an experiment.""" + +from absl import app +from absl import flags + +import numpy as np +import tensorflow.compat.v1 as tf + +from option_keyboard import configs +from option_keyboard import dqn_agent +from option_keyboard import environment_wrappers +from option_keyboard import experiment +from option_keyboard import keyboard_agent +from option_keyboard import scavenger + +FLAGS = flags.FLAGS +flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.") +flags.DEFINE_integer("num_pretrain_episodes", 20000, + "Number of pretraining episodes.") + + +def _train_keyboard(num_episodes): + """Train an option keyboard.""" + env_config = configs.get_pretrain_config() + env = scavenger.Scavenger(**env_config) + env = environment_wrappers.EnvironmentWithLogging(env) + + agent = keyboard_agent.Agent( + obs_spec=env.observation_spec(), + action_spec=env.action_spec(), + policy_weights=np.array([ + [1.0, 0.0], + [0.0, 1.0], + ]), + network_kwargs=dict( + output_sizes=(64, 128), + activate_final=True, + ), + epsilon=0.1, + additional_discount=0.9, + batch_size=10, + optimizer_name="AdamOptimizer", + optimizer_kwargs=dict(learning_rate=3e-4,)) + + experiment.run(env, agent, num_episodes=num_episodes) + + return agent + + +def main(argv): + del argv + + # Pretrain the keyboard and save a checkpoint. + pretrain_agent = _train_keyboard(num_episodes=FLAGS.num_pretrain_episodes) + keyboard_ckpt_path = "/tmp/option_keyboard/keyboard.ckpt" + pretrain_agent.export(keyboard_ckpt_path) + + # Create the task environment. + base_env_config = configs.get_task_config() + base_env = scavenger.Scavenger(**base_env_config) + base_env = environment_wrappers.EnvironmentWithLogging(base_env) + + # Wrap the task environment with the keyboard. + additional_discount = 0.9 + env = environment_wrappers.EnvironmentWithKeyboard( + env=base_env, + keyboard=pretrain_agent.keyboard, + keyboard_ckpt_path=keyboard_ckpt_path, + n_actions_per_dim=3, + additional_discount=additional_discount, + call_and_return=True) + + # Create the player agent. + agent = dqn_agent.Agent( + obs_spec=env.observation_spec(), + action_spec=env.action_spec(), + network_kwargs=dict( + output_sizes=(64, 128), + activate_final=True, + ), + epsilon=0.1, + additional_discount=additional_discount, + batch_size=10, + optimizer_name="AdamOptimizer", + optimizer_kwargs=dict(learning_rate=3e-4,)) + + experiment.run(env, agent, num_episodes=FLAGS.num_episodes) + + +if __name__ == "__main__": + tf.disable_v2_behavior() + app.run(main) diff --git a/option_keyboard/run_ok_test.py b/option_keyboard/run_ok_test.py new file mode 100644 index 0000000..63c484f --- /dev/null +++ b/option_keyboard/run_ok_test.py @@ -0,0 +1,39 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for training a keyboard and then running a DQN agent on top of it.""" + +from absl import flags +from absl.testing import absltest + +import tensorflow.compat.v1 as tf + +from option_keyboard import run_ok + +FLAGS = flags.FLAGS + + +class RunDQNTest(absltest.TestCase): + + def test_run(self): + FLAGS.num_episodes = 200 + FLAGS.num_pretrain_episodes = 200 + run_ok.main(None) + + +if __name__ == '__main__': + tf.disable_v2_behavior() + absltest.main() diff --git a/option_keyboard/scavenger.py b/option_keyboard/scavenger.py new file mode 100644 index 0000000..647e949 --- /dev/null +++ b/option_keyboard/scavenger.py @@ -0,0 +1,269 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Simple Scavenger environment.""" + +import copy +import enum +import sys + +import dm_env + +import numpy as np + +from option_keyboard import auto_reset_environment + +this_module = sys.modules[__name__] + + +class Action(enum.IntEnum): + """Actions available to the player.""" + UP = 0 + DOWN = 1 + LEFT = 2 + RIGHT = 3 + + +def _one_hot(indices, depth): + return np.eye(depth)[indices] + + +def _random_pos(arena_size): + return tuple(np.random.randint(0, arena_size, size=[2]).tolist()) + + +class Scavenger(auto_reset_environment.Base): + """Simple Scavenger.""" + + def __init__(self, + arena_size, + num_channels, + max_num_steps, + default_w=None, + num_init_objects=15, + object_priors=None, + egocentric=True, + rewarder=None): + self._arena_size = arena_size + self._num_channels = num_channels + self._max_num_steps = max_num_steps + self._num_init_objects = num_init_objects + self._egocentric = egocentric + self._rewarder = ( + getattr(this_module, rewarder)() if rewarder is not None else None) + + if object_priors is None: + self._object_priors = np.ones(num_channels) / num_channels + else: + assert len(object_priors) == num_channels + self._object_priors = np.array(object_priors) + + if default_w is None: + self._default_w = np.ones(shape=(num_channels,)) + else: + self._default_w = default_w + + self._num_channels_all = self._num_channels + 2 + self._step_in_episode = None + + @property + def state(self): + return copy.deepcopy([ + self._step_in_episode, + self._walls, + self._objects, + self._player_pos, + self._prev_collected, + ]) + + def set_state(self, state): + state_ = copy.deepcopy(state) + self._step_in_episode = state_[0] + self._walls = state_[1] + self._objects = state_[2] + self._player_pos = state_[3] + self._prev_collected = state_[4] + + @property + def player_pos(self): + return self._player_pos + + def _reset(self): + self._step_in_episode = 0 + + # Walls. + self._walls = [] + for col in range(self._arena_size): + new_pos = (0, col) + if new_pos not in self._walls: + self._walls.append(new_pos) + for row in range(self._arena_size): + new_pos = (row, 0) + if new_pos not in self._walls: + self._walls.append(new_pos) + + # Objects. + self._objects = dict() + for _ in range(self._num_init_objects): + while True: + new_pos = _random_pos(self._arena_size) + if new_pos not in self._objects and new_pos not in self._walls: + self._objects[new_pos] = np.random.multinomial(1, self._object_priors) + break + + # Player + self._player_pos = _random_pos(self._arena_size) + while self._player_pos in self._objects or self._player_pos in self._walls: + self._player_pos = _random_pos(self._arena_size) + + self._prev_collected = np.zeros(shape=(self._num_channels,)) + + obs = self.observation() + + return dm_env.restart(obs) + + def _step(self, action): + self._step_in_episode += 1 + + if action == Action.UP: + new_player_pos = (self._player_pos[0], self._player_pos[1] + 1) + elif action == Action.DOWN: + new_player_pos = (self._player_pos[0], self._player_pos[1] - 1) + elif action == Action.LEFT: + new_player_pos = (self._player_pos[0] - 1, self._player_pos[1]) + elif action == Action.RIGHT: + new_player_pos = (self._player_pos[0] + 1, self._player_pos[1]) + else: + raise ValueError("Invalid action `{}`".format(action)) + + # Toroidal. + new_player_pos = ( + (new_player_pos[0] + self._arena_size) % self._arena_size, + (new_player_pos[1] + self._arena_size) % self._arena_size, + ) + + if new_player_pos not in self._walls: + self._player_pos = new_player_pos + + # Compute rewards. + consumed = self._objects.pop(self._player_pos, + np.zeros(shape=(self._num_channels,))) + if self._rewarder is None: + reward = np.dot(consumed, np.array(self._default_w)) + else: + reward = self._rewarder.get_reward(self.state, consumed) + self._prev_collected = np.copy(consumed) + + assert self._player_pos not in self._objects + assert self._player_pos not in self._walls + + # Render everything. + obs = self.observation() + + if self._step_in_episode < self._max_num_steps: + return dm_env.transition(reward=reward, observation=obs) + else: + # termination with discount=1.0 + return dm_env.truncation(reward=reward, observation=obs) + + def observation(self, force_non_egocentric=False): + arena_shape = [self._arena_size] * 2 + [self._num_channels_all] + arena = np.zeros(shape=arena_shape, dtype=np.float32) + + def offset_position(pos_): + use_egocentric = self._egocentric and not force_non_egocentric + offset = self._player_pos if use_egocentric else (0, 0) + x = (pos_[0] - offset[0] + self._arena_size) % self._arena_size + y = (pos_[1] - offset[1] + self._arena_size) % self._arena_size + return (x, y) + + player_pos = offset_position(self._player_pos) + arena[player_pos] = _one_hot(self._num_channels, self._num_channels_all) + + for pos, obj in self._objects.items(): + x, y = offset_position(pos) + arena[x, y, :self._num_channels] = obj + + for pos in self._walls: + x, y = offset_position(pos) + arena[x, y] = _one_hot(self._num_channels + 1, self._num_channels_all) + + collected_resources = np.copy(self._prev_collected).astype(np.float32) + + return dict( + arena=arena, + cumulants=collected_resources, + ) + + def observation_spec(self): + arena = dm_env.specs.BoundedArray( + shape=(self._arena_size, self._arena_size, self._num_channels_all), + dtype=np.float32, + minimum=0., + maximum=1., + name="arena") + collected_resources = dm_env.specs.BoundedArray( + shape=(self._num_channels,), + dtype=np.float32, + minimum=-1e9, + maximum=1e9, + name="collected_resources") + + return dict( + arena=arena, + cumulants=collected_resources, + ) + + def action_spec(self): + return dm_env.specs.DiscreteArray(num_values=len(Action), name="action") + + +class SequentialCollectionRewarder(object): + """SequentialCollectionRewarder.""" + + def get_reward(self, state, consumed): + """Get reward.""" + + object_counts = sum(list(state[2].values()) + [np.zeros(len(consumed))]) + + reward = 0.0 + if np.sum(consumed) > 0: + for i in range(len(consumed)): + if np.all(object_counts[:i] <= object_counts[i]): + reward += consumed[i] + else: + reward -= consumed[i] + + return reward + + +class BalancedCollectionRewarder(object): + """BalancedCollectionRewarder.""" + + def get_reward(self, state, consumed): + """Get reward.""" + + object_counts = sum(list(state[2].values()) + [np.zeros(len(consumed))]) + + reward = 0.0 + if np.sum(consumed) > 0: + for i in range(len(consumed)): + if (object_counts[i] + consumed[i]) >= np.max(object_counts): + reward += consumed[i] + else: + reward -= consumed[i] + + return reward