mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-10 13:27:17 +08:00
Release of option_keyboard code.
PiperOrigin-RevId: 310175317
This commit is contained in:
committed by
Diego de Las Casas
parent
6f14cb5983
commit
391bc47f3c
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
|
||||
|
||||
## Projects
|
||||
|
||||
* [The Option Keyboard: Combining Skills in Reinforcement Learning](option_keyboard), NeurIPS 2019
|
||||
* [VISR - Fast Task Inference with Variational Intrinsic Successor Features](visr), ICLR 2020
|
||||
* [Unveiling the predictive power of static structure in glassy systems](glassy_dynamics), Nature Physics 2020
|
||||
* [Multi-Object Representation Learning with Iterative Variational Inference (IODINE)](iodine)
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
# The Option Keyboard: Combining Skills in Reinforcement Learning
|
||||
|
||||
This directory contains an implementation of the Option Keyboard framework.
|
||||
|
||||
From the [abstract](http://papers.nips.cc/paper/9463-the-option-keyboard-combining-skills-in-reinforcement-learning):
|
||||
|
||||
> The ability to combine known skills to create new ones may be crucial in the
|
||||
solution of complex reinforcement learning problems that unfold over extended
|
||||
periods. We argue that a robust way of combining skills is to define and manipulate
|
||||
them in the space of pseudo-rewards (or “cumulants”). Based on this premise, we
|
||||
propose a framework for combining skills using the formalism of options. We show
|
||||
that every deterministic option can be unambiguously represented as a cumulant
|
||||
defined in an extended domain. Building on this insight and on previous results
|
||||
on transfer learning, we show how to approximate options whose cumulants are
|
||||
linear combinations of the cumulants of known options. This means that, once we
|
||||
have learned options associated with a set of cumulants, we can instantaneously
|
||||
synthesise options induced by any linear combination of them, without any learning
|
||||
involved. We describe how this framework provides a hierarchical interface to the
|
||||
environment whose abstract actions correspond to combinations of basic skills.
|
||||
We demonstrate the practical benefits of our approach in a resource management
|
||||
problem and a navigation task involving a quadrupedal simulated robot.
|
||||
|
||||
If you use the code here please cite this paper
|
||||
|
||||
> Andre Barreto, Diana Borsa, Shaobo Hou, Gheorghe Comanici, Eser Aygün, Philippe Hamel, Daniel Toyama, Jonathan hunt, Shibl Mourad, David Silver, Doina Precup. *The Option Keyboard: Combining Skills in Reinforcement Learning*. Neurips 2019. [\[paper\]](https://papers.nips.cc/paper/9463-the-option-keyboard-combining-skills-in-reinforcement-learning).
|
||||
|
||||
## Running the code
|
||||
|
||||
### Setup
|
||||
```
|
||||
python3 -m venv ok_venv
|
||||
source ok_venv/bin/activate
|
||||
pip install -r option_keyboard/requirements.txt
|
||||
```
|
||||
|
||||
### Scavenger Task
|
||||
All agents are trained on a simple grid-world resource collection task. There
|
||||
are two types of collectible objects in the world: if the agent collects the
|
||||
object that is less abundant of the two then it receives a reward of -1,
|
||||
otherwise it receives a reward of +1 when it collects the object. See section
|
||||
5.1 in the paper for more details.
|
||||
|
||||
### Train the DQN baseline
|
||||
```
|
||||
python3 -m option_keyboard.run_dqn
|
||||
```
|
||||
This trains a DQN agent on the scavenger task.
|
||||
|
||||
### Train the Option Keyboard and agent
|
||||
```
|
||||
python3 -m option_keyboard.run_ok
|
||||
```
|
||||
This first trains an Option Keyboard on the cumulants in the task environment.
|
||||
Then it trains a DQN agent on the true task reward using high level abstract
|
||||
actions provided by the keyboard.
|
||||
|
||||
## Disclaimer
|
||||
This is not an official Google or DeepMind product.
|
||||
@@ -0,0 +1,58 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Auto-resetting environment base class.
|
||||
|
||||
The environment API states that stepping an environment after a LAST timestep
|
||||
should return the first timestep of a new episode.
|
||||
|
||||
However, environment authors sometimes don't spot this part or find it awkward
|
||||
to implement. This module contains a class that helps implement the reset
|
||||
behaviour.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import dm_env
|
||||
|
||||
|
||||
class Base(dm_env.Environment):
|
||||
"""This class implements the required `step()` and `reset()` methods.
|
||||
|
||||
It instead requires users to implement `_step()` and `_reset()`. This class
|
||||
handles the reset behaviour automatically when it detects a LAST timestep.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._reset_next_step = True
|
||||
|
||||
@abc.abstractmethod
|
||||
def _reset(self):
|
||||
"""Returns a `timestep` namedtuple as per the regular `reset()` method."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def _step(self, action):
|
||||
"""Returns a `timestep` namedtuple as per the regular `step()` method."""
|
||||
|
||||
def reset(self):
|
||||
self._reset_next_step = False
|
||||
return self._reset()
|
||||
|
||||
def step(self, action):
|
||||
if self._reset_next_step:
|
||||
return self.reset()
|
||||
timestep = self._step(action)
|
||||
self._reset_next_step = timestep.last()
|
||||
return timestep
|
||||
@@ -0,0 +1,41 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Environment configurations."""
|
||||
|
||||
|
||||
def get_task_config():
|
||||
return dict(
|
||||
arena_size=11,
|
||||
num_channels=2,
|
||||
max_num_steps=50, # 5o for the actual task.
|
||||
num_init_objects=10,
|
||||
object_priors=[0.5, 0.5],
|
||||
egocentric=True,
|
||||
rewarder="BalancedCollectionRewarder",
|
||||
)
|
||||
|
||||
|
||||
def get_pretrain_config():
|
||||
return dict(
|
||||
arena_size=11,
|
||||
num_channels=2,
|
||||
max_num_steps=40, # 40 for pretraining.
|
||||
num_init_objects=10,
|
||||
object_priors=[0.5, 0.5],
|
||||
egocentric=True,
|
||||
default_w=(1, 1),
|
||||
)
|
||||
@@ -0,0 +1,160 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DQN agent."""
|
||||
|
||||
import numpy as np
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class Agent():
|
||||
"""A DQN Agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obs_spec,
|
||||
action_spec,
|
||||
network_kwargs,
|
||||
epsilon,
|
||||
additional_discount,
|
||||
batch_size,
|
||||
optimizer_name,
|
||||
optimizer_kwargs,
|
||||
):
|
||||
"""A simple DQN agent.
|
||||
|
||||
Args:
|
||||
obs_spec: The observation spec.
|
||||
action_spec: The action spec.
|
||||
network_kwargs: Keyword arguments for snt.nets.MLP
|
||||
epsilon: Exploration probability.
|
||||
additional_discount: Discount on returns used by the agent.
|
||||
batch_size: Size of update batch.
|
||||
optimizer_name: Name of an optimizer from tf.train
|
||||
optimizer_kwargs: Keyword arguments for the optimizer.
|
||||
"""
|
||||
|
||||
self._epsilon = epsilon
|
||||
self._additional_discount = additional_discount
|
||||
self._batch_size = batch_size
|
||||
|
||||
self._n_actions = action_spec.num_values
|
||||
self._network = ValueNet(self._n_actions, network_kwargs=network_kwargs)
|
||||
|
||||
self._replay = []
|
||||
|
||||
obs_spec = self._extract_observation(obs_spec)
|
||||
|
||||
# Placeholders for policy
|
||||
o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
|
||||
q = self._network(tf.expand_dims(o, axis=0))
|
||||
|
||||
# Placeholders for update.
|
||||
o_tm1 = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype)
|
||||
a_tm1 = tf.placeholder(shape=(None,), dtype=tf.int32)
|
||||
r_t = tf.placeholder(shape=(None,), dtype=tf.float32)
|
||||
d_t = tf.placeholder(shape=(None,), dtype=tf.float32)
|
||||
o_t = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype)
|
||||
|
||||
# Compute values over all options.
|
||||
q_tm1 = self._network(o_tm1)
|
||||
q_t = self._network(o_t)
|
||||
|
||||
a_t = tf.cast(tf.argmax(q_t, axis=-1), tf.int32)
|
||||
qa_tm1 = _batched_index(q_tm1, a_tm1)
|
||||
qa_t = _batched_index(q_t, a_t)
|
||||
|
||||
# TD error
|
||||
g = additional_discount * d_t
|
||||
td_error = tf.stop_gradient(r_t + g * qa_t) - qa_tm1
|
||||
loss = tf.reduce_sum(tf.square(td_error) / 2)
|
||||
|
||||
with tf.variable_scope("optimizer"):
|
||||
self._optimizer = getattr(tf.train, optimizer_name)(**optimizer_kwargs)
|
||||
train_op = self._optimizer.minimize(loss)
|
||||
|
||||
# Make session and callables.
|
||||
session = tf.Session()
|
||||
self._update_fn = session.make_callable(train_op,
|
||||
[o_tm1, a_tm1, r_t, d_t, o_t])
|
||||
self._value_fn = session.make_callable(q, [o])
|
||||
session.run(tf.global_variables_initializer())
|
||||
|
||||
def _extract_observation(self, obs):
|
||||
return obs["arena"]
|
||||
|
||||
def step(self, timestep, is_training=False):
|
||||
"""Select actions according to epsilon-greedy policy."""
|
||||
|
||||
if is_training and np.random.rand() < self._epsilon:
|
||||
return np.random.randint(self._n_actions)
|
||||
|
||||
q_values = self._value_fn(
|
||||
self._extract_observation(timestep.observation))
|
||||
return int(np.argmax(q_values))
|
||||
|
||||
def update(self, step_tm1, action, step_t):
|
||||
"""Takes in a transition from the environment."""
|
||||
|
||||
transition = [
|
||||
self._extract_observation(step_tm1.observation),
|
||||
action,
|
||||
step_t.reward,
|
||||
step_t.discount,
|
||||
self._extract_observation(step_t.observation),
|
||||
]
|
||||
self._replay.append(transition)
|
||||
|
||||
if len(self._replay) == self._batch_size:
|
||||
batch = list(zip(*self._replay))
|
||||
self._update_fn(*batch)
|
||||
self._replay = [] # Just a queue.
|
||||
|
||||
|
||||
class ValueNet(snt.AbstractModule):
|
||||
"""Value Network."""
|
||||
|
||||
def __init__(self,
|
||||
n_actions,
|
||||
network_kwargs,
|
||||
name="value_network"):
|
||||
"""Construct a value network sonnet module.
|
||||
|
||||
Args:
|
||||
n_actions: Number of actions.
|
||||
network_kwargs: Network arguments.
|
||||
name: Name
|
||||
"""
|
||||
super(ValueNet, self).__init__(name=name)
|
||||
self._n_actions = n_actions
|
||||
self._network_kwargs = network_kwargs
|
||||
|
||||
def _build(self, observation):
|
||||
flat_obs = snt.BatchFlatten()(observation)
|
||||
net = snt.nets.MLP(**self._network_kwargs)(flat_obs)
|
||||
net = snt.Linear(output_size=self._n_actions)(net)
|
||||
|
||||
return net
|
||||
|
||||
@property
|
||||
def num_actions(self):
|
||||
return self._n_actions
|
||||
|
||||
|
||||
def _batched_index(values, indices):
|
||||
one_hot_indices = tf.one_hot(indices, values.shape[-1], dtype=values.dtype)
|
||||
return tf.reduce_sum(values * one_hot_indices, axis=-1)
|
||||
@@ -0,0 +1,190 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Environment with keyboard."""
|
||||
|
||||
import itertools
|
||||
|
||||
from absl import logging
|
||||
|
||||
import dm_env
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class EnvironmentWithLogging(dm_env.Environment):
|
||||
"""Wraps an environment with additional logging."""
|
||||
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
self._episode_return = 0
|
||||
|
||||
def reset(self):
|
||||
self._episode_return = 0
|
||||
return self._env.reset()
|
||||
|
||||
def step(self, action):
|
||||
"""Take action in the environment and do some logging."""
|
||||
|
||||
step = self._env.step(action)
|
||||
if step.first():
|
||||
step = self._env.step(action)
|
||||
self._episode_return = 0
|
||||
|
||||
self._episode_return += step.reward
|
||||
return step
|
||||
|
||||
@property
|
||||
def episode_return(self):
|
||||
return self._episode_return
|
||||
|
||||
def action_spec(self):
|
||||
return self._env.action_spec()
|
||||
|
||||
def observation_spec(self):
|
||||
return self._env.observation_spec()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
|
||||
class EnvironmentWithKeyboard(dm_env.Environment):
|
||||
"""Wraps an environment with a keyboard."""
|
||||
|
||||
def __init__(self,
|
||||
env,
|
||||
keyboard,
|
||||
keyboard_ckpt_path,
|
||||
n_actions_per_dim,
|
||||
additional_discount,
|
||||
call_and_return=False):
|
||||
self._env = env
|
||||
self._keyboard = keyboard
|
||||
self._discount = additional_discount
|
||||
self._call_and_return = call_and_return
|
||||
|
||||
options = _discretize_actions(n_actions_per_dim, keyboard.num_cumulants)
|
||||
self._options_np = options
|
||||
options = tf.convert_to_tensor(options, dtype=tf.float32)
|
||||
self._options = options
|
||||
|
||||
obs_spec = self._extract_observation(env.observation_spec())
|
||||
obs_ph = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
|
||||
option_ph = tf.placeholder(shape=(), dtype=tf.int32)
|
||||
gpi_action = self._keyboard.gpi(obs_ph, options[option_ph])
|
||||
|
||||
session = tf.Session()
|
||||
self._gpi_action = session.make_callable(gpi_action, [obs_ph, option_ph])
|
||||
self._keyboard_action = session.make_callable(
|
||||
self._keyboard(tf.expand_dims(obs_ph, axis=0))[0], [obs_ph])
|
||||
session.run(tf.global_variables_initializer())
|
||||
|
||||
saver = tf.train.Saver(var_list=keyboard.variables)
|
||||
saver.restore(session, keyboard_ckpt_path)
|
||||
|
||||
def _compute_reward(self, option, obs):
|
||||
return np.sum(self._options_np[option] * obs["cumulants"])
|
||||
|
||||
def reset(self):
|
||||
return self._env.reset()
|
||||
|
||||
def step(self, option):
|
||||
"""Take a step in the keyboard, then the environment."""
|
||||
|
||||
step_count = 0
|
||||
option_step = None
|
||||
while True:
|
||||
obs = self._extract_observation(self._env.observation())
|
||||
action = self._gpi_action(obs, option)
|
||||
action_step = self._env.step(action)
|
||||
step_count += 1
|
||||
|
||||
if option_step is None:
|
||||
option_step = action_step
|
||||
else:
|
||||
new_discount = (
|
||||
option_step.discount * self._discount * action_step.discount)
|
||||
new_reward = (
|
||||
option_step.reward + new_discount * action_step.reward)
|
||||
option_step = option_step._replace(
|
||||
observation=action_step.observation,
|
||||
reward=new_reward,
|
||||
discount=new_discount,
|
||||
step_type=action_step.step_type)
|
||||
|
||||
if action_step.last():
|
||||
break
|
||||
|
||||
# Terminate option.
|
||||
if self._compute_reward(option, action_step.observation) > 0:
|
||||
break
|
||||
|
||||
if not self._call_and_return:
|
||||
break
|
||||
|
||||
return option_step
|
||||
|
||||
def action_spec(self):
|
||||
return dm_env.specs.DiscreteArray(
|
||||
num_values=self._options_np.shape[0], name="action")
|
||||
|
||||
def _extract_observation(self, obs):
|
||||
return obs["arena"]
|
||||
|
||||
def observation_spec(self):
|
||||
return self._env.observation_spec()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
|
||||
def _discretize_actions(num_actions_per_dim,
|
||||
action_space_dim,
|
||||
min_val=-1.0,
|
||||
max_val=1.0):
|
||||
"""Discrete action space."""
|
||||
if num_actions_per_dim > 1:
|
||||
discretized_dim_action = np.linspace(
|
||||
min_val, max_val, num_actions_per_dim, endpoint=True)
|
||||
discretized_actions = [discretized_dim_action] * action_space_dim
|
||||
discretized_actions = itertools.product(*discretized_actions)
|
||||
discretized_actions = list(discretized_actions)
|
||||
elif num_actions_per_dim == 1:
|
||||
discretized_actions = [
|
||||
max_val * np.eye(action_space_dim),
|
||||
min_val * np.eye(action_space_dim),
|
||||
]
|
||||
discretized_actions = np.concatenate(discretized_actions, axis=0)
|
||||
elif num_actions_per_dim == 0:
|
||||
discretized_actions = np.eye(action_space_dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported num_actions_per_dim {}".format(num_actions_per_dim))
|
||||
|
||||
discretized_actions = np.array(discretized_actions)
|
||||
|
||||
# Remove options with all zeros.
|
||||
non_zero_entries = np.sum(np.square(discretized_actions), axis=-1) != 0.0
|
||||
# Remove options with no positive elements.
|
||||
non_negative_entries = np.any(discretized_actions > 0, axis=-1)
|
||||
discretized_actions = discretized_actions[np.logical_and(
|
||||
non_zero_entries, non_negative_entries)]
|
||||
logging.info("Total number of discretized actions: %s",
|
||||
len(discretized_actions))
|
||||
logging.info("Discretized actions: %s", discretized_actions)
|
||||
|
||||
return discretized_actions
|
||||
@@ -0,0 +1,70 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""A simple training loop."""
|
||||
|
||||
from absl import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def run(environment, agent, num_episodes, report_every=200):
|
||||
"""Runs an agent on an environment.
|
||||
|
||||
Args:
|
||||
environment: The environment.
|
||||
agent: The agent.
|
||||
num_episodes: Number of episodes to train for.
|
||||
report_every: Frequency at which training progress are reported (episodes).
|
||||
"""
|
||||
|
||||
train_returns = []
|
||||
eval_returns = []
|
||||
for episode_id in range(num_episodes):
|
||||
# Run a training episode.
|
||||
train_episode_return = run_episode(environment, agent, is_training=True)
|
||||
train_returns.append(train_episode_return)
|
||||
|
||||
# Run an evaluation episode.
|
||||
eval_episode_return = run_episode(environment, agent, is_training=False)
|
||||
eval_returns.append(eval_episode_return)
|
||||
|
||||
if ((episode_id + 1) % report_every) == 0:
|
||||
logging.info(
|
||||
"Episode %s, avg train return %.3f, avg eval return %.3f",
|
||||
episode_id + 1,
|
||||
np.mean(train_returns[-report_every:]),
|
||||
np.mean(eval_returns[-report_every:]),
|
||||
)
|
||||
|
||||
|
||||
def run_episode(environment, agent, is_training=False):
|
||||
"""Run a single episode."""
|
||||
|
||||
timestep = environment.reset()
|
||||
|
||||
while not timestep.last():
|
||||
action = agent.step(timestep, is_training)
|
||||
new_timestep = environment.step(action)
|
||||
|
||||
if is_training:
|
||||
agent.update(timestep, action, new_timestep)
|
||||
|
||||
timestep = new_timestep
|
||||
|
||||
episode_return = environment.episode_return
|
||||
|
||||
return episode_return
|
||||
@@ -0,0 +1,222 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Keyboard agent."""
|
||||
|
||||
import numpy as np
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class Agent():
|
||||
"""An Option Keyboard Agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obs_spec,
|
||||
action_spec,
|
||||
policy_weights,
|
||||
network_kwargs,
|
||||
epsilon,
|
||||
additional_discount,
|
||||
batch_size,
|
||||
optimizer_name,
|
||||
optimizer_kwargs,
|
||||
):
|
||||
"""A simple DQN agent.
|
||||
|
||||
Args:
|
||||
obs_spec: The observation spec.
|
||||
action_spec: The action spec.
|
||||
policy_weights: A list of vectors each representing the cumulant weights
|
||||
for that particular option/policy.
|
||||
network_kwargs: Keyword arguments for snt.nets.MLP
|
||||
epsilon: Exploration probability.
|
||||
additional_discount: Discount on returns used by the agent.
|
||||
batch_size: Size of update batch.
|
||||
optimizer_name: Name of an optimizer from tf.train
|
||||
optimizer_kwargs: Keyword arguments for the optimizer.
|
||||
"""
|
||||
|
||||
self._policy_weights = tf.convert_to_tensor(
|
||||
policy_weights, dtype=tf.float32)
|
||||
self._current_policy = None
|
||||
|
||||
self._epsilon = epsilon
|
||||
self._additional_discount = additional_discount
|
||||
self._batch_size = batch_size
|
||||
|
||||
self._n_actions = action_spec.num_values
|
||||
self._n_policies, self._n_cumulants = policy_weights.shape
|
||||
self._network = OptionValueNet(
|
||||
self._n_policies,
|
||||
self._n_cumulants,
|
||||
self._n_actions,
|
||||
network_kwargs=network_kwargs,
|
||||
)
|
||||
|
||||
self._replay = []
|
||||
|
||||
obs_spec = self._extract_observation(obs_spec)
|
||||
|
||||
def option_values(values, policy):
|
||||
return tf.tensordot(
|
||||
values[:, policy, Ellipsis], self._policy_weights[policy], axes=[1, 0])
|
||||
|
||||
# Placeholders for policy.
|
||||
o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
|
||||
p = tf.placeholder(shape=(), dtype=tf.int32)
|
||||
q = self._network(tf.expand_dims(o, axis=0))
|
||||
qo = option_values(q, p)
|
||||
|
||||
# Placeholders for update.
|
||||
o_tm1 = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype)
|
||||
a_tm1 = tf.placeholder(shape=(None,), dtype=tf.int32)
|
||||
c_t = tf.placeholder(shape=(None, self._n_cumulants), dtype=tf.float32)
|
||||
d_t = tf.placeholder(shape=(None,), dtype=tf.float32)
|
||||
o_t = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype)
|
||||
|
||||
# Compute values over all options.
|
||||
q_tm1 = self._network(o_tm1)
|
||||
q_t = self._network(o_t)
|
||||
qo_t = option_values(q_t, p)
|
||||
|
||||
a_t = tf.cast(tf.argmax(qo_t, axis=-1), tf.int32)
|
||||
qa_tm1 = _batched_index(q_tm1[:, p, Ellipsis], a_tm1)
|
||||
qa_t = _batched_index(q_t[:, p, Ellipsis], a_t)
|
||||
|
||||
# TD error
|
||||
g = additional_discount * tf.expand_dims(d_t, axis=-1)
|
||||
td_error = tf.stop_gradient(c_t + g * qa_t) - qa_tm1
|
||||
loss = tf.reduce_sum(tf.square(td_error) / 2)
|
||||
|
||||
with tf.variable_scope("optimizer"):
|
||||
self._optimizer = getattr(tf.train, optimizer_name)(**optimizer_kwargs)
|
||||
train_op = self._optimizer.minimize(loss)
|
||||
|
||||
# Make session and callables.
|
||||
session = tf.Session()
|
||||
self._session = session
|
||||
self._update_fn = session.make_callable(
|
||||
train_op, [o_tm1, a_tm1, c_t, d_t, o_t, p])
|
||||
self._value_fn = session.make_callable(qo, [o, p])
|
||||
session.run(tf.global_variables_initializer())
|
||||
|
||||
self._saver = tf.train.Saver(var_list=self._network.variables)
|
||||
|
||||
@property
|
||||
def keyboard(self):
|
||||
return self._network
|
||||
|
||||
def _extract_observation(self, obs):
|
||||
return obs["arena"]
|
||||
|
||||
def step(self, timestep, is_training=False):
|
||||
"""Select actions according to epsilon-greedy policy."""
|
||||
if timestep.first():
|
||||
self._current_policy = np.random.randint(self._n_policies)
|
||||
|
||||
if is_training and np.random.rand() < self._epsilon:
|
||||
return np.random.randint(self._n_actions)
|
||||
|
||||
q_values = self._value_fn(
|
||||
self._extract_observation(timestep.observation), self._current_policy)
|
||||
return int(np.argmax(q_values))
|
||||
|
||||
def update(self, step_tm1, action, step_t):
|
||||
"""Takes in a transition from the environment."""
|
||||
|
||||
transition = [
|
||||
self._extract_observation(step_tm1.observation),
|
||||
action,
|
||||
step_t.observation["cumulants"],
|
||||
step_t.discount,
|
||||
self._extract_observation(step_t.observation),
|
||||
]
|
||||
self._replay.append(transition)
|
||||
|
||||
if len(self._replay) == self._batch_size:
|
||||
batch = list(zip(*self._replay)) + [self._current_policy]
|
||||
self._update_fn(*batch)
|
||||
self._replay = [] # Just a queue.
|
||||
|
||||
def export(self, path):
|
||||
tf.logging.info("Exporting keyboard to %s", path)
|
||||
self._saver.save(self._session, path)
|
||||
|
||||
|
||||
class OptionValueNet(snt.AbstractModule):
|
||||
"""Option Value net."""
|
||||
|
||||
def __init__(self,
|
||||
n_policies,
|
||||
n_cumulants,
|
||||
n_actions,
|
||||
network_kwargs,
|
||||
name="option_keyboard"):
|
||||
"""Construct an Option Value Net sonnet module.
|
||||
|
||||
Args:
|
||||
n_policies: Number of policies.
|
||||
n_cumulants: Number of cumulants.
|
||||
n_actions: Number of actions.
|
||||
network_kwargs: Network arguments.
|
||||
name: Name
|
||||
"""
|
||||
super(OptionValueNet, self).__init__(name=name)
|
||||
self._n_policies = n_policies
|
||||
self._n_cumulants = n_cumulants
|
||||
self._n_actions = n_actions
|
||||
self._network_kwargs = network_kwargs
|
||||
|
||||
def _build(self, observation):
|
||||
values = []
|
||||
|
||||
flat_obs = snt.BatchFlatten()(observation)
|
||||
for _ in range(self._n_cumulants):
|
||||
net = snt.nets.MLP(**self._network_kwargs)(flat_obs)
|
||||
net = snt.Linear(output_size=self._n_policies * self._n_actions)(net)
|
||||
net = snt.BatchReshape([self._n_policies, self._n_actions])(net)
|
||||
values.append(net)
|
||||
values = tf.stack(values, axis=2)
|
||||
return values
|
||||
|
||||
def gpi(self, observation, cumulant_weights):
|
||||
q_values = self.__call__(tf.expand_dims(observation, axis=0))[0]
|
||||
q_w = tf.tensordot(q_values, cumulant_weights, axes=[1, 0]) # [P,a]
|
||||
q_w_actions = tf.reduce_max(q_w, axis=0)
|
||||
|
||||
action = tf.cast(tf.argmax(q_w_actions), tf.int32)
|
||||
|
||||
return action
|
||||
|
||||
@property
|
||||
def num_cumulants(self):
|
||||
return self._n_cumulants
|
||||
|
||||
@property
|
||||
def num_policies(self):
|
||||
return self._n_policies
|
||||
|
||||
@property
|
||||
def num_actions(self):
|
||||
return self._n_actions
|
||||
|
||||
|
||||
def _batched_index(values, indices):
|
||||
one_hot_indices = tf.one_hot(indices, values.shape[-1], dtype=values.dtype)
|
||||
one_hot_indices = tf.expand_dims(one_hot_indices, axis=1)
|
||||
return tf.reduce_sum(values * one_hot_indices, axis=-1)
|
||||
@@ -0,0 +1,6 @@
|
||||
absl-py
|
||||
dm-env==1.2
|
||||
dm-sonnet==1.34
|
||||
numpy==1.16.4
|
||||
tensorflow==1.13.2
|
||||
tensorflow_probability==0.6.0
|
||||
Executable
+21
@@ -0,0 +1,21 @@
|
||||
#!/bin/sh
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
python3 -m venv ok_venv
|
||||
source ok_venv/bin/activate
|
||||
pip install -r option_keyboard/requirements.txt
|
||||
|
||||
python3 -m option_keyboard.run_dqn_test
|
||||
python3 -m option_keyboard.run_ok_test
|
||||
@@ -0,0 +1,61 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Run an experiment."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from option_keyboard import configs
|
||||
from option_keyboard import dqn_agent
|
||||
from option_keyboard import environment_wrappers
|
||||
from option_keyboard import experiment
|
||||
from option_keyboard import scavenger
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
# Create the task environment.
|
||||
env_config = configs.get_task_config()
|
||||
env = scavenger.Scavenger(**env_config)
|
||||
env = environment_wrappers.EnvironmentWithLogging(env)
|
||||
|
||||
# Create the flat agent.
|
||||
agent = dqn_agent.Agent(
|
||||
obs_spec=env.observation_spec(),
|
||||
action_spec=env.action_spec(),
|
||||
network_kwargs=dict(
|
||||
output_sizes=(64, 128),
|
||||
activate_final=True,
|
||||
),
|
||||
epsilon=0.1,
|
||||
additional_discount=0.9,
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||
|
||||
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
@@ -0,0 +1,38 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tests for running the simple DQN agent."""
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from option_keyboard import run_dqn
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class RunDQNTest(absltest.TestCase):
|
||||
|
||||
def test_run(self):
|
||||
FLAGS.num_episodes = 200
|
||||
run_dqn.main(None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.disable_v2_behavior()
|
||||
absltest.main()
|
||||
@@ -0,0 +1,108 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Run an experiment."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from option_keyboard import configs
|
||||
from option_keyboard import dqn_agent
|
||||
from option_keyboard import environment_wrappers
|
||||
from option_keyboard import experiment
|
||||
from option_keyboard import keyboard_agent
|
||||
from option_keyboard import scavenger
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||
flags.DEFINE_integer("num_pretrain_episodes", 20000,
|
||||
"Number of pretraining episodes.")
|
||||
|
||||
|
||||
def _train_keyboard(num_episodes):
|
||||
"""Train an option keyboard."""
|
||||
env_config = configs.get_pretrain_config()
|
||||
env = scavenger.Scavenger(**env_config)
|
||||
env = environment_wrappers.EnvironmentWithLogging(env)
|
||||
|
||||
agent = keyboard_agent.Agent(
|
||||
obs_spec=env.observation_spec(),
|
||||
action_spec=env.action_spec(),
|
||||
policy_weights=np.array([
|
||||
[1.0, 0.0],
|
||||
[0.0, 1.0],
|
||||
]),
|
||||
network_kwargs=dict(
|
||||
output_sizes=(64, 128),
|
||||
activate_final=True,
|
||||
),
|
||||
epsilon=0.1,
|
||||
additional_discount=0.9,
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||
|
||||
experiment.run(env, agent, num_episodes=num_episodes)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
# Pretrain the keyboard and save a checkpoint.
|
||||
pretrain_agent = _train_keyboard(num_episodes=FLAGS.num_pretrain_episodes)
|
||||
keyboard_ckpt_path = "/tmp/option_keyboard/keyboard.ckpt"
|
||||
pretrain_agent.export(keyboard_ckpt_path)
|
||||
|
||||
# Create the task environment.
|
||||
base_env_config = configs.get_task_config()
|
||||
base_env = scavenger.Scavenger(**base_env_config)
|
||||
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||
|
||||
# Wrap the task environment with the keyboard.
|
||||
additional_discount = 0.9
|
||||
env = environment_wrappers.EnvironmentWithKeyboard(
|
||||
env=base_env,
|
||||
keyboard=pretrain_agent.keyboard,
|
||||
keyboard_ckpt_path=keyboard_ckpt_path,
|
||||
n_actions_per_dim=3,
|
||||
additional_discount=additional_discount,
|
||||
call_and_return=True)
|
||||
|
||||
# Create the player agent.
|
||||
agent = dqn_agent.Agent(
|
||||
obs_spec=env.observation_spec(),
|
||||
action_spec=env.action_spec(),
|
||||
network_kwargs=dict(
|
||||
output_sizes=(64, 128),
|
||||
activate_final=True,
|
||||
),
|
||||
epsilon=0.1,
|
||||
additional_discount=additional_discount,
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||
|
||||
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
@@ -0,0 +1,39 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tests for training a keyboard and then running a DQN agent on top of it."""
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from option_keyboard import run_ok
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class RunDQNTest(absltest.TestCase):
|
||||
|
||||
def test_run(self):
|
||||
FLAGS.num_episodes = 200
|
||||
FLAGS.num_pretrain_episodes = 200
|
||||
run_ok.main(None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.disable_v2_behavior()
|
||||
absltest.main()
|
||||
@@ -0,0 +1,269 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Simple Scavenger environment."""
|
||||
|
||||
import copy
|
||||
import enum
|
||||
import sys
|
||||
|
||||
import dm_env
|
||||
|
||||
import numpy as np
|
||||
|
||||
from option_keyboard import auto_reset_environment
|
||||
|
||||
this_module = sys.modules[__name__]
|
||||
|
||||
|
||||
class Action(enum.IntEnum):
|
||||
"""Actions available to the player."""
|
||||
UP = 0
|
||||
DOWN = 1
|
||||
LEFT = 2
|
||||
RIGHT = 3
|
||||
|
||||
|
||||
def _one_hot(indices, depth):
|
||||
return np.eye(depth)[indices]
|
||||
|
||||
|
||||
def _random_pos(arena_size):
|
||||
return tuple(np.random.randint(0, arena_size, size=[2]).tolist())
|
||||
|
||||
|
||||
class Scavenger(auto_reset_environment.Base):
|
||||
"""Simple Scavenger."""
|
||||
|
||||
def __init__(self,
|
||||
arena_size,
|
||||
num_channels,
|
||||
max_num_steps,
|
||||
default_w=None,
|
||||
num_init_objects=15,
|
||||
object_priors=None,
|
||||
egocentric=True,
|
||||
rewarder=None):
|
||||
self._arena_size = arena_size
|
||||
self._num_channels = num_channels
|
||||
self._max_num_steps = max_num_steps
|
||||
self._num_init_objects = num_init_objects
|
||||
self._egocentric = egocentric
|
||||
self._rewarder = (
|
||||
getattr(this_module, rewarder)() if rewarder is not None else None)
|
||||
|
||||
if object_priors is None:
|
||||
self._object_priors = np.ones(num_channels) / num_channels
|
||||
else:
|
||||
assert len(object_priors) == num_channels
|
||||
self._object_priors = np.array(object_priors)
|
||||
|
||||
if default_w is None:
|
||||
self._default_w = np.ones(shape=(num_channels,))
|
||||
else:
|
||||
self._default_w = default_w
|
||||
|
||||
self._num_channels_all = self._num_channels + 2
|
||||
self._step_in_episode = None
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return copy.deepcopy([
|
||||
self._step_in_episode,
|
||||
self._walls,
|
||||
self._objects,
|
||||
self._player_pos,
|
||||
self._prev_collected,
|
||||
])
|
||||
|
||||
def set_state(self, state):
|
||||
state_ = copy.deepcopy(state)
|
||||
self._step_in_episode = state_[0]
|
||||
self._walls = state_[1]
|
||||
self._objects = state_[2]
|
||||
self._player_pos = state_[3]
|
||||
self._prev_collected = state_[4]
|
||||
|
||||
@property
|
||||
def player_pos(self):
|
||||
return self._player_pos
|
||||
|
||||
def _reset(self):
|
||||
self._step_in_episode = 0
|
||||
|
||||
# Walls.
|
||||
self._walls = []
|
||||
for col in range(self._arena_size):
|
||||
new_pos = (0, col)
|
||||
if new_pos not in self._walls:
|
||||
self._walls.append(new_pos)
|
||||
for row in range(self._arena_size):
|
||||
new_pos = (row, 0)
|
||||
if new_pos not in self._walls:
|
||||
self._walls.append(new_pos)
|
||||
|
||||
# Objects.
|
||||
self._objects = dict()
|
||||
for _ in range(self._num_init_objects):
|
||||
while True:
|
||||
new_pos = _random_pos(self._arena_size)
|
||||
if new_pos not in self._objects and new_pos not in self._walls:
|
||||
self._objects[new_pos] = np.random.multinomial(1, self._object_priors)
|
||||
break
|
||||
|
||||
# Player
|
||||
self._player_pos = _random_pos(self._arena_size)
|
||||
while self._player_pos in self._objects or self._player_pos in self._walls:
|
||||
self._player_pos = _random_pos(self._arena_size)
|
||||
|
||||
self._prev_collected = np.zeros(shape=(self._num_channels,))
|
||||
|
||||
obs = self.observation()
|
||||
|
||||
return dm_env.restart(obs)
|
||||
|
||||
def _step(self, action):
|
||||
self._step_in_episode += 1
|
||||
|
||||
if action == Action.UP:
|
||||
new_player_pos = (self._player_pos[0], self._player_pos[1] + 1)
|
||||
elif action == Action.DOWN:
|
||||
new_player_pos = (self._player_pos[0], self._player_pos[1] - 1)
|
||||
elif action == Action.LEFT:
|
||||
new_player_pos = (self._player_pos[0] - 1, self._player_pos[1])
|
||||
elif action == Action.RIGHT:
|
||||
new_player_pos = (self._player_pos[0] + 1, self._player_pos[1])
|
||||
else:
|
||||
raise ValueError("Invalid action `{}`".format(action))
|
||||
|
||||
# Toroidal.
|
||||
new_player_pos = (
|
||||
(new_player_pos[0] + self._arena_size) % self._arena_size,
|
||||
(new_player_pos[1] + self._arena_size) % self._arena_size,
|
||||
)
|
||||
|
||||
if new_player_pos not in self._walls:
|
||||
self._player_pos = new_player_pos
|
||||
|
||||
# Compute rewards.
|
||||
consumed = self._objects.pop(self._player_pos,
|
||||
np.zeros(shape=(self._num_channels,)))
|
||||
if self._rewarder is None:
|
||||
reward = np.dot(consumed, np.array(self._default_w))
|
||||
else:
|
||||
reward = self._rewarder.get_reward(self.state, consumed)
|
||||
self._prev_collected = np.copy(consumed)
|
||||
|
||||
assert self._player_pos not in self._objects
|
||||
assert self._player_pos not in self._walls
|
||||
|
||||
# Render everything.
|
||||
obs = self.observation()
|
||||
|
||||
if self._step_in_episode < self._max_num_steps:
|
||||
return dm_env.transition(reward=reward, observation=obs)
|
||||
else:
|
||||
# termination with discount=1.0
|
||||
return dm_env.truncation(reward=reward, observation=obs)
|
||||
|
||||
def observation(self, force_non_egocentric=False):
|
||||
arena_shape = [self._arena_size] * 2 + [self._num_channels_all]
|
||||
arena = np.zeros(shape=arena_shape, dtype=np.float32)
|
||||
|
||||
def offset_position(pos_):
|
||||
use_egocentric = self._egocentric and not force_non_egocentric
|
||||
offset = self._player_pos if use_egocentric else (0, 0)
|
||||
x = (pos_[0] - offset[0] + self._arena_size) % self._arena_size
|
||||
y = (pos_[1] - offset[1] + self._arena_size) % self._arena_size
|
||||
return (x, y)
|
||||
|
||||
player_pos = offset_position(self._player_pos)
|
||||
arena[player_pos] = _one_hot(self._num_channels, self._num_channels_all)
|
||||
|
||||
for pos, obj in self._objects.items():
|
||||
x, y = offset_position(pos)
|
||||
arena[x, y, :self._num_channels] = obj
|
||||
|
||||
for pos in self._walls:
|
||||
x, y = offset_position(pos)
|
||||
arena[x, y] = _one_hot(self._num_channels + 1, self._num_channels_all)
|
||||
|
||||
collected_resources = np.copy(self._prev_collected).astype(np.float32)
|
||||
|
||||
return dict(
|
||||
arena=arena,
|
||||
cumulants=collected_resources,
|
||||
)
|
||||
|
||||
def observation_spec(self):
|
||||
arena = dm_env.specs.BoundedArray(
|
||||
shape=(self._arena_size, self._arena_size, self._num_channels_all),
|
||||
dtype=np.float32,
|
||||
minimum=0.,
|
||||
maximum=1.,
|
||||
name="arena")
|
||||
collected_resources = dm_env.specs.BoundedArray(
|
||||
shape=(self._num_channels,),
|
||||
dtype=np.float32,
|
||||
minimum=-1e9,
|
||||
maximum=1e9,
|
||||
name="collected_resources")
|
||||
|
||||
return dict(
|
||||
arena=arena,
|
||||
cumulants=collected_resources,
|
||||
)
|
||||
|
||||
def action_spec(self):
|
||||
return dm_env.specs.DiscreteArray(num_values=len(Action), name="action")
|
||||
|
||||
|
||||
class SequentialCollectionRewarder(object):
|
||||
"""SequentialCollectionRewarder."""
|
||||
|
||||
def get_reward(self, state, consumed):
|
||||
"""Get reward."""
|
||||
|
||||
object_counts = sum(list(state[2].values()) + [np.zeros(len(consumed))])
|
||||
|
||||
reward = 0.0
|
||||
if np.sum(consumed) > 0:
|
||||
for i in range(len(consumed)):
|
||||
if np.all(object_counts[:i] <= object_counts[i]):
|
||||
reward += consumed[i]
|
||||
else:
|
||||
reward -= consumed[i]
|
||||
|
||||
return reward
|
||||
|
||||
|
||||
class BalancedCollectionRewarder(object):
|
||||
"""BalancedCollectionRewarder."""
|
||||
|
||||
def get_reward(self, state, consumed):
|
||||
"""Get reward."""
|
||||
|
||||
object_counts = sum(list(state[2].values()) + [np.zeros(len(consumed))])
|
||||
|
||||
reward = 0.0
|
||||
if np.sum(consumed) > 0:
|
||||
for i in range(len(consumed)):
|
||||
if (object_counts[i] + consumed[i]) >= np.max(object_counts):
|
||||
reward += consumed[i]
|
||||
else:
|
||||
reward -= consumed[i]
|
||||
|
||||
return reward
|
||||
Reference in New Issue
Block a user