Internal changes.

PiperOrigin-RevId: 389919323
This commit is contained in:
Diego de Las Casas
2021-08-10 18:30:14 +01:00
parent 30799687ed
commit 5e35a17cf4
7 changed files with 1196 additions and 0 deletions
+67
View File
@@ -0,0 +1,67 @@
# Towards mental time travel: A hierarchical memory for RL agents
This provides an implementation of two components of the paper "Towards mental
time travel: A hierarchical memory for reinforcement learning agents." The
article can be found on arXiv at [https://arxiv.org/abs/2105.14039](https://arxiv.org/abs/2105.14039)
Specifically, this repository contains:
1) A JAX/Haiku implementation of hierarchical transformer attention over memory.
2) An implementation of the Ballet environment used in the paper.
We have also released the Rapid Word Learning tasks from the paper, but to
simplify dependencies they are located in the `dm_fast_mapping` repository:
[deepmind/dm_fast_mapping](https://github.com/deepmind/dm_fast_mapping) see the
[documentation](https://github.com/deepmind/dm_fast_mapping/blob/master/docs/index.md)
for that repository for further details about using those tasks.
## Setup
For easy installation, run:
```shell
python3 -m venv htm_env
source htm_env/bin/activate
pip install --upgrade pip
pip install -r requirements.txt
```
Note that this installs the components needed for both the attention module and
the environment. If you only wish to use the environment, you do not need to
install JAX, Haiku, or Chex.
## Using the hierarchical attention module:
Please see `hierarchical_attention/htm_attention_test.py` for some examples of
the expected inputs for this module.
## Running the ballet environment
The ballet environment is contained in the `pycolab_ballet/` subfolder. To load
a simple ballet environment with 2 dancers and short delays, and watch a few
steps of the dances, you can do:
```
from pycolab_ballet import ballet_environment
env = ballet_environment.simple_builder(level_name='2_delay16')
timestep = env.reset()
for _ in range(5):
action = 0
timestep = env.step(action)
```
## Citing this work
If you use this code, please cite the associated paper:
```
@article{lampinen2021towards,
title={Towards mental time travel:
a hierarchical memory for reinforcement learning agents},
author={Lampinen, Andrew Kyle and Chan, Stephanie CY and Banino, Andrea and
Hill, Felix},
journal={arXiv preprint arXiv:2105.14039},
year={2021}
}
```
@@ -0,0 +1,218 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Haiku module implementing hierarchical attention over memory."""
from typing import Optional, NamedTuple
import chex
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
_EPSILON = 1e-3
class HierarchicalMemory(NamedTuple):
"""Structure of the hierarchical memory.
Where 'B' is batch size, 'M' is number of memories, 'C' is chunk size, and 'D'
is memory dimension.
"""
keys: jnp.ndarray # [B, M, D]
contents: jnp.ndarray # [B, M, C, D]
steps_since_last_write: jnp.ndarray # [B], steps since last memory write
accumulator: jnp.ndarray # [B, C, D], accumulates experiences before write
def sinusoid_position_encoding(
sequence_length: int,
hidden_size: int,
min_timescale: float = 2.,
max_timescale: float = 1e4,
) -> jnp.ndarray:
"""Creates sinusoidal encodings.
Args:
sequence_length: length [L] of sequence to be position encoded.
hidden_size: dimension [D] of the positional encoding vectors.
min_timescale: minimum timescale for the frequency.
max_timescale: maximum timescale for the frequency.
Returns:
An array of shape [L, D]
"""
freqs = np.arange(0, hidden_size, min_timescale)
inv_freq = max_timescale**(-freqs / hidden_size)
pos_seq = np.arange(sequence_length - 1, -1, -1.0)
sinusoid_inp = np.einsum("i,j->ij", pos_seq, inv_freq)
pos_emb = np.concatenate(
[np.sin(sinusoid_inp), np.cos(sinusoid_inp)], axis=-1)
return pos_emb
class HierarchicalMemoryAttention(hk.Module):
"""Multi-head attention over hierarchical memory."""
def __init__(self,
feature_size: int,
k: int,
num_heads: int = 1,
memory_position_encoding: bool = True,
init_scale: float = 2.,
name: Optional[str] = None) -> None:
"""Constructor.
Args:
feature_size: size of feature dimension of attention-over-memories
embedding.
k: number of memories to sample.
num_heads: number of attention heads.
memory_position_encoding: whether to add positional encodings to memories
during within memory attention.
init_scale: scale factor for Variance weight initializers.
name: module name.
"""
super().__init__(name=name)
self._size = feature_size
self._k = k
self._num_heads = num_heads
self._weights = None
self._memory_position_encoding = memory_position_encoding
self._init_scale = init_scale
@property
def num_heads(self):
return self._num_heads
@hk.transparent
def _singlehead_linear(self,
inputs: jnp.ndarray,
hidden_size: int,
name: str):
linear = hk.Linear(
hidden_size,
with_bias=False,
w_init=hk.initializers.VarianceScaling(scale=self._init_scale),
name=name)
out = linear(inputs)
return out
def __call__(
self,
queries: jnp.ndarray,
hm_memory: HierarchicalMemory,
hm_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
"""Do hierarchical attention over the stored memories.
Args:
queries: Tensor [B, Q, E] Query(ies) in, for batch size B, query length
Q, and embedding dimension E.
hm_memory: Hierarchical Memory.
hm_mask: Optional boolean mask tensor of shape [B, Q, M]. Where false,
the corresponding query timepoints cannot attend to the corresponding
memory chunks. This can be used for enforcing causal attention on the
learner, not attending to memories from prior episodes, etc.
Returns:
Value updates for each query slot: [B, Q, D]
"""
# some shape checks
batch_size, query_length, _ = queries.shape
(memory_batch_size, num_memories,
memory_chunk_size, mem_embbedding_size) = hm_memory.contents.shape
assert batch_size == memory_batch_size
chex.assert_shape(hm_memory.keys,
(batch_size, num_memories, mem_embbedding_size))
chex.assert_shape(hm_memory.accumulator,
(memory_batch_size, memory_chunk_size,
mem_embbedding_size))
chex.assert_shape(hm_memory.steps_since_last_write,
(memory_batch_size,))
if hm_mask is not None:
chex.assert_type(hm_mask, bool)
chex.assert_shape(hm_mask,
(batch_size, query_length, num_memories))
query_head = self._singlehead_linear(queries, self._size, "query")
key_head = self._singlehead_linear(
jax.lax.stop_gradient(hm_memory.keys), self._size, "key")
# What times in the input [t] attend to what times in the memories [T].
logits = jnp.einsum("btd,bTd->btT", query_head, key_head)
scaled_logits = logits / np.sqrt(self._size)
# Mask last dimension, replacing invalid logits with large negative values.
# This allows e.g. enforcing causal attention on learner, or blocking
# attention across episodes
if hm_mask is not None:
masked_logits = jnp.where(hm_mask, scaled_logits, -1e6)
else:
masked_logits = scaled_logits
# identify the top-k memories and their relevance weights
top_k_logits, top_k_indices = jax.lax.top_k(masked_logits, self._k)
weights = jax.nn.softmax(top_k_logits)
# set up the within-memory attention
assert self._size % self._num_heads == 0
mha_key_size = self._size // self._num_heads
attention_layer = hk.MultiHeadAttention(
key_size=mha_key_size,
model_size=self._size,
num_heads=self._num_heads,
w_init_scale=self._init_scale,
name="within_mem_attn")
# position encodings
augmented_contents = hm_memory.contents
if self._memory_position_encoding:
position_embs = sinusoid_position_encoding(
memory_chunk_size, mem_embbedding_size)
augmented_contents += position_embs[None, None, :, :]
def _within_memory_attention(sub_inputs, sub_memory_contents, sub_weights,
sub_top_k_indices):
top_k_contents = sub_memory_contents[sub_top_k_indices, :, :]
# Now we go deeper, with another vmap over **tokens**, because each token
# can each attend to different memories.
def do_attention(sub_sub_inputs, sub_sub_top_k_contents):
tiled_inputs = jnp.tile(sub_sub_inputs[None, None, :],
reps=(self._k, 1, 1))
sub_attention_results = attention_layer(
query=tiled_inputs,
key=sub_sub_top_k_contents,
value=sub_sub_top_k_contents)
return sub_attention_results
do_attention = jax.vmap(do_attention, in_axes=0)
attention_results = do_attention(sub_inputs, top_k_contents)
attention_results = jnp.squeeze(attention_results, axis=2)
# Now collapse results across k memories
attention_results = sub_weights[:, :, None] * attention_results
attention_results = jnp.sum(attention_results, axis=1)
return attention_results
# vmap across batch
batch_within_memory_attention = jax.vmap(_within_memory_attention,
in_axes=0)
outputs = batch_within_memory_attention(
queries,
jax.lax.stop_gradient(augmented_contents),
weights,
top_k_indices)
return outputs
@@ -0,0 +1,106 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for hierarchical_attention.htm_attention."""
from absl.testing import absltest
from absl.testing import parameterized
import haiku as hk
import numpy as np
from hierarchical_transformer_memory.hierarchical_attention import htm_attention
def _build_queries_and_memory(query_length, num_memories, mem_chunk_size,
batch_size=2, embedding_size=12):
"""Builds dummy queries + memory contents for tests."""
queries = np.random.random([batch_size, query_length, embedding_size])
memory_contents = np.random.random(
[batch_size, num_memories, mem_chunk_size, embedding_size])
# summary key = average across chunk
memory_keys = np.mean(memory_contents, axis=2)
# to accumulate newest memories before writing
memory_accumulator = np.zeros_like(memory_contents[:, -1, :, :])
memory = htm_attention.HierarchicalMemory(
keys=memory_keys,
contents=memory_contents,
accumulator=memory_accumulator,
steps_since_last_write=np.zeros([batch_size,], dtype=np.int32))
return queries, memory
class HierarchicalAttentionTest(parameterized.TestCase):
@parameterized.parameters([
{
'query_length': 1,
'num_memories': 7,
'mem_chunk_size': 5,
'mem_k': 4,
},
{
'query_length': 9,
'num_memories': 7,
'mem_chunk_size': 5,
'mem_k': 4,
},
])
@hk.testing.transform_and_run
def test_output_shapes(self, query_length, num_memories, mem_chunk_size,
mem_k):
np.random.seed(0)
batch_size = 2
embedding_size = 12
num_heads = 3
queries, memory = _build_queries_and_memory(
query_length=query_length, num_memories=num_memories,
mem_chunk_size=mem_chunk_size, embedding_size=embedding_size)
hm_att = htm_attention.HierarchicalMemoryAttention(
feature_size=embedding_size,
k=mem_k,
num_heads=num_heads)
results = hm_att(queries, memory)
self.assertEqual(results.shape,
(batch_size, query_length, embedding_size))
self.assertTrue(np.all(np.isfinite(results)))
@hk.testing.transform_and_run
def test_masking(self):
np.random.seed(0)
batch_size = 2
embedding_size = 12
num_heads = 3
query_length = 5
num_memories = 7
mem_chunk_size = 6
mem_k = 4
queries, memory = _build_queries_and_memory(
query_length=query_length, num_memories=num_memories,
mem_chunk_size=mem_chunk_size, embedding_size=embedding_size)
hm_att = htm_attention.HierarchicalMemoryAttention(
feature_size=embedding_size,
k=mem_k,
num_heads=num_heads)
# get a random boolean mask
mask = np.random.binomial(
1, 0.5, [batch_size, query_length, num_memories]).astype(bool)
results = hm_att(queries, memory, hm_mask=mask)
self.assertEqual(results.shape,
(batch_size, query_length, embedding_size))
self.assertTrue(np.all(np.isfinite(results)))
if __name__ == '__main__':
absltest.main()
@@ -0,0 +1,376 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A pycolab environment for going to the ballet.
A pycolab-based environment for testing memory for sequences of events. The
environment contains some number of "dancer" characters in (implicit) 3 x 3
squares within a larger 9 x 9 room. The agent starts in the center of the room.
At the beginning of an episode, the dancers each do a dance solo of a fixed
length, separated by empty time of a fixed length. The agent's actions do
nothing during the dances. After the last dance ends, the agent must go up to a
dancer, identified using language describing the dance. The agent is rewarded +1
for approaching the correct dancer, 0 otherwise.
The room is upsampled at a size of 9 pixels per square to render a view for the
agent, which is cropped in egocentric perspective, i.e. the agent is always in
the center of its view (see https://arxiv.org/abs/1910.00571).
"""
from absl import app
from absl import flags
from absl import logging
import dm_env
import numpy as np
from pycolab import cropping
from hierarchical_transformer_memory.pycolab_ballet import ballet_environment_core as ballet_core
FLAGS = flags.FLAGS
UPSAMPLE_SIZE = 9 # pixels per game square
SCROLL_CROP_SIZE = 11 # in game squares
DANCER_SHAPES = [
"triangle", "empty_square", "plus", "inverse_plus", "ex", "inverse_ex",
"circle", "empty_circle", "tee", "upside_down_tee",
"h", "u", "upside_down_u", "vertical_stripes", "horizontal_stripes"
]
COLORS = {
"red": np.array([255, 0, 0]),
"green": np.array([0, 255, 0]),
"blue": np.array([0, 0, 255]),
"purple": np.array([128, 0, 128]),
"orange": np.array([255, 165, 0]),
"yellow": np.array([255, 255, 0]),
"brown": np.array([128, 64, 0]),
"pink": np.array([255, 64, 255]),
"cyan": np.array([0, 255, 255]),
"dark_green": np.array([0, 100, 0]),
"dark_red": np.array([100, 0, 0]),
"dark_blue": np.array([0, 0, 100]),
"olive": np.array([100, 100, 0]),
"teal": np.array([0, 100, 100]),
"lavender": np.array([215, 200, 255]),
"peach": np.array([255, 210, 170]),
"rose": np.array([255, 205, 230]),
"light_green": np.array([200, 255, 200]),
"light_yellow": np.array([255, 255, 200]),
}
def _generate_template(object_name):
"""Generates a template object image, given a name with color and shape."""
object_color, object_type = object_name.split()
template = np.zeros((UPSAMPLE_SIZE, UPSAMPLE_SIZE))
half = UPSAMPLE_SIZE // 2
if object_type == "triangle":
for i in range(UPSAMPLE_SIZE):
for j in range(UPSAMPLE_SIZE):
if (j <= half and i >= 2 * (half - j)) or (j > half and i >= 2 *
(j - half)):
template[i, j] = 1.
elif object_type == "square":
template[:, :] = 1.
elif object_type == "empty_square":
template[:2, :] = 1.
template[-2:, :] = 1.
template[:, :2] = 1.
template[:, -2:] = 1.
elif object_type == "plus":
template[:, half - 1:half + 2] = 1.
template[half - 1:half + 2, :] = 1.
elif object_type == "inverse_plus":
template[:, :] = 1.
template[:, half - 1:half + 2] = 0.
template[half - 1:half + 2, :] = 0.
elif object_type == "ex":
for i in range(UPSAMPLE_SIZE):
for j in range(UPSAMPLE_SIZE):
if abs(i - j) <= 1 or abs(UPSAMPLE_SIZE - 1 - j - i) <= 1:
template[i, j] = 1.
elif object_type == "inverse_ex":
for i in range(UPSAMPLE_SIZE):
for j in range(UPSAMPLE_SIZE):
if not (abs(i - j) <= 1 or abs(UPSAMPLE_SIZE - 1 - j - i) <= 1):
template[i, j] = 1.
elif object_type == "circle":
for i in range(UPSAMPLE_SIZE):
for j in range(UPSAMPLE_SIZE):
if (i - half)**2 + (j - half)**2 <= half**2:
template[i, j] = 1.
elif object_type == "empty_circle":
for i in range(UPSAMPLE_SIZE):
for j in range(UPSAMPLE_SIZE):
if abs((i - half)**2 + (j - half)**2 - half**2) < 6:
template[i, j] = 1.
elif object_type == "tee":
template[:, half - 1:half + 2] = 1.
template[:3, :] = 1.
elif object_type == "upside_down_tee":
template[:, half - 1:half + 2] = 1.
template[-3:, :] = 1.
elif object_type == "h":
template[:, :3] = 1.
template[:, -3:] = 1.
template[half - 1:half + 2, :] = 1.
elif object_type == "u":
template[:, :3] = 1.
template[:, -3:] = 1.
template[-3:, :] = 1.
elif object_type == "upside_down_u":
template[:, :3] = 1.
template[:, -3:] = 1.
template[:3, :] = 1.
elif object_type == "vertical_stripes":
for j in range(half + UPSAMPLE_SIZE % 2):
template[:, 2*j] = 1.
elif object_type == "horizontal_stripes":
for i in range(half + UPSAMPLE_SIZE % 2):
template[2*i, :] = 1.
else:
raise ValueError("Unknown object: {}".format(object_type))
if object_color not in COLORS:
raise ValueError("Unknown color: {}".format(object_color))
template = np.tensordot(template, COLORS[object_color], axes=0)
return template
# Agent and wall templates
_CHAR_TO_TEMPLATE_BASE = {
ballet_core.AGENT_CHAR:
np.tensordot(
np.ones([UPSAMPLE_SIZE, UPSAMPLE_SIZE]),
np.array([255, 255, 255]),
axes=0),
ballet_core.WALL_CHAR:
np.tensordot(
np.ones([UPSAMPLE_SIZE, UPSAMPLE_SIZE]),
np.array([40, 40, 40]),
axes=0),
}
def get_scrolling_cropper(rows=9, cols=9, crop_pad_char=" "):
return cropping.ScrollingCropper(rows=rows, cols=cols,
to_track=[ballet_core.AGENT_CHAR],
pad_char=crop_pad_char,
scroll_margins=(None, None))
class BalletEnvironment(dm_env.Environment):
"""A Python environment API for pycolab ballet tasks."""
def __init__(self, num_dancers, dance_delay, max_steps, rng=None):
"""Construct a BalletEnvironment that wraps pycolab games for agent use.
This class inherits from dm_env and has all the expected methods and specs.
Args:
num_dancers: The number of dancers to use, between 1 and 8 (inclusive).
dance_delay: How long to delay between the dances.
max_steps: The maximum number of steps to allow in an episode, after which
it will terminate.
rng: An optional numpy Random Generator, to set a fixed seed use e.g.
`rng=np.random.default_rng(seed=...)`
"""
self._num_dancers = num_dancers
self._dance_delay = dance_delay
self._max_steps = max_steps
# internal state
if rng is None:
rng = np.random.default_rng()
self._rng = rng
self._current_game = None # Current pycolab game instance.
self._state = None # Current game step state.
self._game_over = None # Whether the game has ended.
self._char_to_template = None # Mapping of chars to sprite images.
# rendering tools
self._cropper = get_scrolling_cropper(SCROLL_CROP_SIZE, SCROLL_CROP_SIZE,
" ")
def _game_factory(self):
"""Samples dancers and positions, returns a pycolab core game engine."""
target_dancer_index = self._rng.integers(self._num_dancers)
motions = list(ballet_core.DANCE_SEQUENCES.keys())
positions = ballet_core.DANCER_POSITIONS.copy()
colors = list(COLORS.keys())
shapes = DANCER_SHAPES.copy()
self._rng.shuffle(positions)
self._rng.shuffle(motions)
self._rng.shuffle(colors)
self._rng.shuffle(shapes)
dancers_and_properties = []
for dancer_i in range(self._num_dancers):
if dancer_i == target_dancer_index:
value = 1.
else:
value = 0.
dancers_and_properties.append(
(ballet_core.POSSIBLE_DANCER_CHARS[dancer_i],
positions[dancer_i],
motions[dancer_i],
shapes[dancer_i],
colors[dancer_i],
value))
logging.info("Making level with dancers_and_properties: %s",
dancers_and_properties)
return ballet_core.make_game(
dancers_and_properties=dancers_and_properties,
dance_delay=self._dance_delay)
def _render_observation(self, observation):
"""Renders from raw pycolab image observation to agent-usable ones."""
observation = self._cropper.crop(observation)
obs_rows, obs_cols = observation.board.shape
image = np.zeros([obs_rows * UPSAMPLE_SIZE, obs_cols * UPSAMPLE_SIZE, 3],
dtype=np.float32)
for i in range(obs_rows):
for j in range(obs_cols):
this_char = chr(observation.board[i, j])
if this_char != ballet_core.FLOOR_CHAR:
image[
i * UPSAMPLE_SIZE:(i + 1) * UPSAMPLE_SIZE, j *
UPSAMPLE_SIZE:(j + 1) * UPSAMPLE_SIZE] = self._char_to_template[
this_char]
image /= 255.
language = np.array(self._current_game.the_plot["instruction_string"])
full_observation = (image, language)
return full_observation
def reset(self):
"""Start a new episode."""
# Build a new game and retrieve its first set of state/reward/discount.
self._current_game = self._game_factory()
# set up rendering, cropping, and state for current game
self._char_to_template = {
k: _generate_template(v) for k, v in self._current_game.the_plot[
"char_to_color_shape"]}
self._char_to_template.update(_CHAR_TO_TEMPLATE_BASE)
self._cropper.set_engine(self._current_game)
self._state = dm_env.StepType.FIRST
# let's go!
observation, _, _ = self._current_game.its_showtime()
observation = self._render_observation(observation)
return dm_env.TimeStep(
step_type=self._state,
reward=None,
discount=None,
observation=observation)
def step(self, action):
"""Apply action, step the world forward, and return observations."""
# If needed, reset and start new episode.
if self._state == dm_env.StepType.LAST:
self._clear_state()
if self._current_game is None:
return self.reset()
# Execute the action in pycolab.
observation, reward, discount = self._current_game.play(action)
self._game_over = self._is_game_over()
reward = reward if reward is not None else 0.
observation = self._render_observation(observation)
# Check the current status of the game.
if self._game_over:
self._state = dm_env.StepType.LAST
else:
self._state = dm_env.StepType.MID
return dm_env.TimeStep(
step_type=self._state,
reward=reward,
discount=discount,
observation=observation)
@property
def observation_spec(self):
image_shape = (SCROLL_CROP_SIZE * UPSAMPLE_SIZE,
SCROLL_CROP_SIZE * UPSAMPLE_SIZE,
3)
return (
# vision
dm_env.specs.Array(
shape=image_shape, dtype=np.float32, name="image"),
# language
dm_env.specs.Array(
shape=[], dtype=str, name="language"),
)
@property
def action_spec(self):
return dm_env.specs.BoundedArray(
shape=[], dtype="int32",
minimum=0, maximum=7,
name="grid_actions")
def _is_game_over(self):
"""Returns whether it is game over, either from the engine or timeout."""
return (self._current_game.game_over or
(self._current_game.the_plot.frame >= self._max_steps))
def _clear_state(self):
"""Clear all the internal information about the game."""
self._state = None
self._current_game = None
self._char_to_template = None
self._game_over = None
def simple_builder(level_name):
"""Simplifies building from fixed defs.
Args:
level_name: '{num_dancers}_delay{delay_length}', where each variable is an
integer. The levels used in the paper were:
['2_delay16', '4_delay16', '8_delay16',
'2_delay48', '4_delay48', '8_delay48']
Returns:
A BalletEnvironment with the requested settings.
"""
num_dancers, dance_delay = level_name.split("_")
num_dancers = int(num_dancers)
dance_delay = int(dance_delay[5:])
max_steps = 320 if dance_delay == 16 else 1024
level_args = dict(
num_dancers=num_dancers,
dance_delay=dance_delay,
max_steps=max_steps)
return BalletEnvironment(**level_args)
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
env = simple_builder("4_delay16")
for _ in range(3):
obs = env.reset().observation
for _ in range(300):
obs = env.step(0).observation
print(obs)
if __name__ == "__main__":
app.run(main)
@@ -0,0 +1,337 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The pycolab core of the environment for going to the ballet.
This builds the text-based (non-graphical) engine of the environment, and offers
a UI which a human can play (for a fixed level). However, the logic of level
creation, the graphics, and anything that is external to the pycolab engine
itself is contained in ballet_environment.py.
"""
import curses
import enum
from absl import app
from absl import flags
from pycolab import ascii_art
from pycolab import human_ui
from pycolab.prefab_parts import sprites as prefab_sprites
FLAGS = flags.FLAGS
ROOM_SIZE = (11, 11) # one square around edge will be wall.
DANCER_POSITIONS = [(2, 2), (2, 5), (2, 8),
(5, 2), (5, 8), # space in center for agent
(8, 2), (8, 5), (8, 8)]
AGENT_START = (5, 5)
AGENT_CHAR = "A"
WALL_CHAR = "#"
FLOOR_CHAR = " "
RESERVED_CHARS = [AGENT_CHAR, WALL_CHAR, FLOOR_CHAR]
POSSIBLE_DANCER_CHARS = [
chr(i) for i in range(65, 91) if chr(i) not in RESERVED_CHARS
]
DANCE_SEQUENCE_LENGTHS = 16
# movement directions for dancers / actions for agent
class DIRECTIONS(enum.IntEnum):
N = 0
NE = 1
E = 2
SE = 3
S = 4
SW = 5
W = 6
NW = 7
DANCE_SEQUENCES = {
"circle_cw": [
DIRECTIONS.N, DIRECTIONS.E, DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.W,
DIRECTIONS.W, DIRECTIONS.N, DIRECTIONS.N, DIRECTIONS.E, DIRECTIONS.E,
DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.N,
DIRECTIONS.E
],
"circle_ccw": [
DIRECTIONS.N, DIRECTIONS.W, DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.E,
DIRECTIONS.E, DIRECTIONS.N, DIRECTIONS.N, DIRECTIONS.W, DIRECTIONS.W,
DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.E, DIRECTIONS.E, DIRECTIONS.N,
DIRECTIONS.W
],
"up_and_down": [
DIRECTIONS.N, DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.N,
DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.N, DIRECTIONS.S,
DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.N, DIRECTIONS.S, DIRECTIONS.S,
DIRECTIONS.N
],
"left_and_right": [
DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.E,
DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.E, DIRECTIONS.W,
DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.W,
DIRECTIONS.E
],
"diagonal_uldr": [
DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SE, DIRECTIONS.NW,
DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SE, DIRECTIONS.NW,
DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SE, DIRECTIONS.NW,
DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SE, DIRECTIONS.NW
],
"diagonal_urdl": [
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, DIRECTIONS.NE,
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, DIRECTIONS.NE,
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, DIRECTIONS.NE,
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, DIRECTIONS.NE
],
"plus_cw": [
DIRECTIONS.N, DIRECTIONS.S, DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.S,
DIRECTIONS.N, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.N, DIRECTIONS.S,
DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.W,
DIRECTIONS.E
],
"plus_ccw": [
DIRECTIONS.N, DIRECTIONS.S, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.S,
DIRECTIONS.N, DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.N, DIRECTIONS.S,
DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.E,
DIRECTIONS.W
],
"times_cw": [
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SE, DIRECTIONS.NW,
DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.NW, DIRECTIONS.SE,
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SE, DIRECTIONS.NW,
DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.NW, DIRECTIONS.SE
],
"times_ccw": [
DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SW, DIRECTIONS.NE,
DIRECTIONS.SE, DIRECTIONS.NW, DIRECTIONS.NE, DIRECTIONS.SW,
DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SW, DIRECTIONS.NE,
DIRECTIONS.SE, DIRECTIONS.NW, DIRECTIONS.NE, DIRECTIONS.SW
],
"zee": [
DIRECTIONS.NE, DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.E,
DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW,
DIRECTIONS.E, DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.NE,
DIRECTIONS.SW, DIRECTIONS.NE
],
"chevron_down": [
DIRECTIONS.NW, DIRECTIONS.S, DIRECTIONS.SE, DIRECTIONS.NE, DIRECTIONS.N,
DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.NE,
DIRECTIONS.S, DIRECTIONS.SW, DIRECTIONS.NW, DIRECTIONS.N, DIRECTIONS.SE,
DIRECTIONS.NW, DIRECTIONS.SE
],
"chevron_up": [
DIRECTIONS.SE, DIRECTIONS.N, DIRECTIONS.NW, DIRECTIONS.SW, DIRECTIONS.S,
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.SW,
DIRECTIONS.N, DIRECTIONS.NE, DIRECTIONS.SE, DIRECTIONS.S, DIRECTIONS.NW,
DIRECTIONS.SE, DIRECTIONS.NW
],
}
class DancerSprite(prefab_sprites.MazeWalker):
"""A `Sprite` for dancers."""
def __init__(self, corner, position, character, motion, color, shape,
value=0.):
super(DancerSprite, self).__init__(
corner, position, character, impassable="#")
self.motion = motion
self.dance_sequence = DANCE_SEQUENCES[motion].copy()
self.color = color
self.shape = shape
self.value = value
self.is_dancing = False
def update(self, actions, board, layers, backdrop, things, the_plot):
if the_plot["task_phase"] == "dance" and self.is_dancing:
if not self.dance_sequence:
raise ValueError(
"Dance sequence is empty! Was this dancer repeated in the order?")
dance_move = self.dance_sequence.pop(0)
if dance_move == DIRECTIONS.N:
self._north(board, the_plot)
elif dance_move == DIRECTIONS.NE:
self._northeast(board, the_plot)
elif dance_move == DIRECTIONS.E:
self._east(board, the_plot)
elif dance_move == DIRECTIONS.SE:
self._southeast(board, the_plot)
elif dance_move == DIRECTIONS.S:
self._south(board, the_plot)
elif dance_move == DIRECTIONS.SW:
self._southwest(board, the_plot)
elif dance_move == DIRECTIONS.W:
self._west(board, the_plot)
elif dance_move == DIRECTIONS.NW:
self._northwest(board, the_plot)
if not self.dance_sequence: # done!
self.is_dancing = False
the_plot["time_until_next_dance"] = the_plot["dance_delay"]
else:
if self.position == things[AGENT_CHAR].position:
# Award the player the appropriate amount of reward, and end episode.
the_plot.add_reward(self.value)
the_plot.terminate_episode()
class PlayerSprite(prefab_sprites.MazeWalker):
"""The player / agent character.
MazeWalker class methods handle basic movement and collision detection.
"""
def __init__(self, corner, position, character):
super(PlayerSprite, self).__init__(
corner, position, character, impassable="#")
def update(self, actions, board, layers, backdrop, things, the_plot):
if the_plot["task_phase"] == "dance":
# agent's actions are ignored, this logic updates the dance phases.
if the_plot["time_until_next_dance"] > 0:
the_plot["time_until_next_dance"] -= 1
if the_plot["time_until_next_dance"] == 0: # next phase time!
if the_plot["dance_order"]: # start the next dance!
next_dancer = the_plot["dance_order"].pop(0)
things[next_dancer].is_dancing = True
else: # choice time!
the_plot["task_phase"] = "choice"
the_plot["instruction_string"] = the_plot[
"choice_instruction_string"]
elif the_plot["task_phase"] == "choice":
# agent can now move and make its choice
if actions == DIRECTIONS.N:
self._north(board, the_plot)
elif actions == DIRECTIONS.NE:
self._northeast(board, the_plot)
elif actions == DIRECTIONS.E:
self._east(board, the_plot)
elif actions == DIRECTIONS.SE:
self._southeast(board, the_plot)
elif actions == DIRECTIONS.S:
self._south(board, the_plot)
elif actions == DIRECTIONS.SW:
self._southwest(board, the_plot)
elif actions == DIRECTIONS.W:
self._west(board, the_plot)
elif actions == DIRECTIONS.NW:
self._northwest(board, the_plot)
def make_game(dancers_and_properties, dance_delay=16):
"""Constructs an ascii map, then uses pycolab to make it a game.
Args:
dancers_and_properties: list of (character, (row, column), motion, shape,
color, value), for placing objects in the world.
dance_delay: how long to wait between dances.
Returns:
this_game: Pycolab engine running the specified game.
"""
num_rows, num_cols = ROOM_SIZE
level_layout = []
# upper wall
level_layout.append("".join([WALL_CHAR] * num_cols))
# room
middle_string = "".join([WALL_CHAR] + [" "] * (num_cols - 2) + [WALL_CHAR])
level_layout.extend([middle_string for _ in range(num_rows - 2)])
# lower wall
level_layout.append("".join([WALL_CHAR] * num_cols))
def _add_to_map(obj, loc):
"""Adds an ascii character to the level at the requested position."""
obj_row = level_layout[loc[0]]
pre_string = obj_row[:loc[1]]
post_string = obj_row[loc[1] + 1:]
level_layout[loc[0]] = pre_string + obj + post_string
_add_to_map(AGENT_CHAR, AGENT_START)
sprites = {AGENT_CHAR: PlayerSprite}
dance_order = []
char_to_color_shape = []
# add dancers to level
for obj, loc, motion, shape, color, value in dancers_and_properties:
_add_to_map(obj, loc)
sprites[obj] = ascii_art.Partial(
DancerSprite, motion=motion, color=color, shape=shape, value=value)
char_to_color_shape.append((obj, color + " " + shape))
dance_order += obj
if value > 0.:
choice_instruction_string = motion
this_game = ascii_art.ascii_art_to_game(
art=level_layout,
what_lies_beneath=" ",
sprites=sprites,
update_schedule=[[AGENT_CHAR],
dance_order])
this_game.the_plot["task_phase"] = "dance"
this_game.the_plot["instruction_string"] = "watch"
this_game.the_plot["choice_instruction_string"] = choice_instruction_string
this_game.the_plot["dance_order"] = dance_order
this_game.the_plot["dance_delay"] = dance_delay
this_game.the_plot["time_until_next_dance"] = 1
this_game.the_plot["char_to_color_shape"] = char_to_color_shape
return this_game
def main(argv):
del argv # unused
these_dancers_and_properties = [
(POSSIBLE_DANCER_CHARS[1], (2, 2), "chevron_up", "triangle", "red", 1),
(POSSIBLE_DANCER_CHARS[2], (2, 5), "circle_ccw", "triangle", "red", 0),
(POSSIBLE_DANCER_CHARS[3], (2, 8), "plus_cw", "triangle", "red", 0),
(POSSIBLE_DANCER_CHARS[4], (5, 2), "plus_ccw", "triangle", "red", 0),
(POSSIBLE_DANCER_CHARS[5], (5, 8), "times_cw", "triangle", "red", 0),
(POSSIBLE_DANCER_CHARS[6], (8, 2), "up_and_down", "plus", "blue", 0),
(POSSIBLE_DANCER_CHARS[7], (8, 5), "left_and_right", "plus", "blue", 0),
(POSSIBLE_DANCER_CHARS[8], (8, 8), "zee", "plus", "blue", 0),
]
game = make_game(dancers_and_properties=these_dancers_and_properties)
# Note that these colors are only for human UI
fg_colours = {
AGENT_CHAR: (999, 999, 999), # Agent is white
WALL_CHAR: (300, 300, 300), # Wall, dark grey
FLOOR_CHAR: (0, 0, 0), # Floor
}
for (c, _, _, _, col, _) in these_dancers_and_properties:
fg_colours[c] = (999, 0, 0) if col == "red" else (0, 0, 999)
bg_colours = {
c: (0, 0, 0) for c in RESERVED_CHARS + POSSIBLE_DANCER_CHARS[1:8]
}
ui = human_ui.CursesUi(
keys_to_actions={
# Basic movement.
curses.KEY_UP: DIRECTIONS.N,
curses.KEY_DOWN: DIRECTIONS.S,
curses.KEY_LEFT: DIRECTIONS.W,
curses.KEY_RIGHT: DIRECTIONS.E,
-1: 8, # Do nothing.
},
delay=500,
colour_fg=fg_colours,
colour_bg=bg_colours)
ui.play(game)
if __name__ == "__main__":
app.run(main)
@@ -0,0 +1,85 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pycolab_ballet.ballet_environment_wrapper."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from hierarchical_transformer_memory.pycolab_ballet import ballet_environment
from hierarchical_transformer_memory.pycolab_ballet import ballet_environment_core
class BalletEnvironmentTest(parameterized.TestCase):
def test_full_wrapper(self):
env = ballet_environment.BalletEnvironment(
num_dancers=1, dance_delay=16, max_steps=200,
rng=np.random.default_rng(seed=0))
result = env.reset()
self.assertIsNone(result.reward)
level_size = ballet_environment_core.ROOM_SIZE
upsample_size = ballet_environment.UPSAMPLE_SIZE
# wait for dance to complete
for i in range(30):
result = env.step(0).observation
self.assertEqual(result[0].shape,
(level_size[0] * upsample_size,
level_size[1] * upsample_size,
3))
self.assertEqual(str(result[1])[:5],
np.array("watch"))
for i in [1, 1, 1, 1]: # first gets eaten before agent can move
result = env.step(i)
self.assertEqual(result.observation[0].shape,
(level_size[0] * upsample_size,
level_size[1] * upsample_size,
3))
self.assertEqual(str(result.observation[1])[:11],
np.array("up_and_down"))
self.assertEqual(result.reward, 1.)
# check egocentric scrolling is working, by checking object is in center
np.testing.assert_array_almost_equal(
result.observation[0][45:54, 45:54],
ballet_environment._generate_template("orange plus") / 255.)
@parameterized.parameters(
"2_delay16",
"4_delay16",
"8_delay48",
)
def test_simple_builder(self, level_name):
dance_delay = int(level_name[-2:])
np.random.seed(0)
env = ballet_environment.simple_builder(level_name)
# check max steps are set to match paper settings
self.assertEqual(env._max_steps,
320 if dance_delay == 16 else 1024)
# test running a few steps of each
env.reset()
level_size = ballet_environment_core.ROOM_SIZE
upsample_size = ballet_environment.UPSAMPLE_SIZE
for i in range(8):
result = env.step(i) # check all 8 movements work
self.assertEqual(result.observation[0].shape,
(level_size[0] * upsample_size,
level_size[1] * upsample_size,
3))
self.assertEqual(str(result.observation[1])[:5],
np.array("watch"))
self.assertEqual(result.reward, 0.)
if __name__ == "__main__":
absltest.main()
@@ -0,0 +1,7 @@
absl-py
chex>=0.0.8
dm-env
dm-haiku>=0.0.5.dev
jax>=0.2.17
numpy
pycolab