diff --git a/hierarchical_transformer_memory/README.md b/hierarchical_transformer_memory/README.md new file mode 100644 index 0000000..df54a18 --- /dev/null +++ b/hierarchical_transformer_memory/README.md @@ -0,0 +1,67 @@ +# Towards mental time travel: A hierarchical memory for RL agents + +This provides an implementation of two components of the paper "Towards mental +time travel: A hierarchical memory for reinforcement learning agents." The +article can be found on arXiv at [https://arxiv.org/abs/2105.14039](https://arxiv.org/abs/2105.14039) + +Specifically, this repository contains: + +1) A JAX/Haiku implementation of hierarchical transformer attention over memory. +2) An implementation of the Ballet environment used in the paper. + +We have also released the Rapid Word Learning tasks from the paper, but to +simplify dependencies they are located in the `dm_fast_mapping` repository: +[deepmind/dm_fast_mapping](https://github.com/deepmind/dm_fast_mapping) see the +[documentation](https://github.com/deepmind/dm_fast_mapping/blob/master/docs/index.md) +for that repository for further details about using those tasks. + +## Setup + +For easy installation, run: + +```shell +python3 -m venv htm_env +source htm_env/bin/activate +pip install --upgrade pip +pip install -r requirements.txt +``` + +Note that this installs the components needed for both the attention module and +the environment. If you only wish to use the environment, you do not need to +install JAX, Haiku, or Chex. + +## Using the hierarchical attention module: + +Please see `hierarchical_attention/htm_attention_test.py` for some examples of +the expected inputs for this module. + +## Running the ballet environment + +The ballet environment is contained in the `pycolab_ballet/` subfolder. To load +a simple ballet environment with 2 dancers and short delays, and watch a few +steps of the dances, you can do: + +``` +from pycolab_ballet import ballet_environment + +env = ballet_environment.simple_builder(level_name='2_delay16') +timestep = env.reset() +for _ in range(5): + action = 0 + timestep = env.step(action) +``` + +## Citing this work + +If you use this code, please cite the associated paper: + +``` +@article{lampinen2021towards, + title={Towards mental time travel: + a hierarchical memory for reinforcement learning agents}, + author={Lampinen, Andrew Kyle and Chan, Stephanie CY and Banino, Andrea and + Hill, Felix}, + journal={arXiv preprint arXiv:2105.14039}, + year={2021} +} +``` diff --git a/hierarchical_transformer_memory/hierarchical_attention/htm_attention.py b/hierarchical_transformer_memory/hierarchical_attention/htm_attention.py new file mode 100644 index 0000000..82505e1 --- /dev/null +++ b/hierarchical_transformer_memory/hierarchical_attention/htm_attention.py @@ -0,0 +1,218 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Haiku module implementing hierarchical attention over memory.""" + +from typing import Optional, NamedTuple + +import chex +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + + +_EPSILON = 1e-3 + + +class HierarchicalMemory(NamedTuple): + """Structure of the hierarchical memory. + + Where 'B' is batch size, 'M' is number of memories, 'C' is chunk size, and 'D' + is memory dimension. + """ + keys: jnp.ndarray # [B, M, D] + contents: jnp.ndarray # [B, M, C, D] + steps_since_last_write: jnp.ndarray # [B], steps since last memory write + accumulator: jnp.ndarray # [B, C, D], accumulates experiences before write + + +def sinusoid_position_encoding( + sequence_length: int, + hidden_size: int, + min_timescale: float = 2., + max_timescale: float = 1e4, +) -> jnp.ndarray: + """Creates sinusoidal encodings. + + Args: + sequence_length: length [L] of sequence to be position encoded. + hidden_size: dimension [D] of the positional encoding vectors. + min_timescale: minimum timescale for the frequency. + max_timescale: maximum timescale for the frequency. + + Returns: + An array of shape [L, D] + """ + freqs = np.arange(0, hidden_size, min_timescale) + inv_freq = max_timescale**(-freqs / hidden_size) + pos_seq = np.arange(sequence_length - 1, -1, -1.0) + sinusoid_inp = np.einsum("i,j->ij", pos_seq, inv_freq) + pos_emb = np.concatenate( + [np.sin(sinusoid_inp), np.cos(sinusoid_inp)], axis=-1) + return pos_emb + + +class HierarchicalMemoryAttention(hk.Module): + """Multi-head attention over hierarchical memory.""" + + def __init__(self, + feature_size: int, + k: int, + num_heads: int = 1, + memory_position_encoding: bool = True, + init_scale: float = 2., + name: Optional[str] = None) -> None: + """Constructor. + + Args: + feature_size: size of feature dimension of attention-over-memories + embedding. + k: number of memories to sample. + num_heads: number of attention heads. + memory_position_encoding: whether to add positional encodings to memories + during within memory attention. + init_scale: scale factor for Variance weight initializers. + name: module name. + """ + super().__init__(name=name) + self._size = feature_size + self._k = k + self._num_heads = num_heads + self._weights = None + self._memory_position_encoding = memory_position_encoding + self._init_scale = init_scale + + @property + def num_heads(self): + return self._num_heads + + @hk.transparent + def _singlehead_linear(self, + inputs: jnp.ndarray, + hidden_size: int, + name: str): + linear = hk.Linear( + hidden_size, + with_bias=False, + w_init=hk.initializers.VarianceScaling(scale=self._init_scale), + name=name) + out = linear(inputs) + return out + + def __call__( + self, + queries: jnp.ndarray, + hm_memory: HierarchicalMemory, + hm_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: + """Do hierarchical attention over the stored memories. + + Args: + queries: Tensor [B, Q, E] Query(ies) in, for batch size B, query length + Q, and embedding dimension E. + hm_memory: Hierarchical Memory. + hm_mask: Optional boolean mask tensor of shape [B, Q, M]. Where false, + the corresponding query timepoints cannot attend to the corresponding + memory chunks. This can be used for enforcing causal attention on the + learner, not attending to memories from prior episodes, etc. + + Returns: + Value updates for each query slot: [B, Q, D] + """ + # some shape checks + batch_size, query_length, _ = queries.shape + (memory_batch_size, num_memories, + memory_chunk_size, mem_embbedding_size) = hm_memory.contents.shape + assert batch_size == memory_batch_size + chex.assert_shape(hm_memory.keys, + (batch_size, num_memories, mem_embbedding_size)) + chex.assert_shape(hm_memory.accumulator, + (memory_batch_size, memory_chunk_size, + mem_embbedding_size)) + chex.assert_shape(hm_memory.steps_since_last_write, + (memory_batch_size,)) + if hm_mask is not None: + chex.assert_type(hm_mask, bool) + chex.assert_shape(hm_mask, + (batch_size, query_length, num_memories)) + query_head = self._singlehead_linear(queries, self._size, "query") + key_head = self._singlehead_linear( + jax.lax.stop_gradient(hm_memory.keys), self._size, "key") + + # What times in the input [t] attend to what times in the memories [T]. + logits = jnp.einsum("btd,bTd->btT", query_head, key_head) + + scaled_logits = logits / np.sqrt(self._size) + + # Mask last dimension, replacing invalid logits with large negative values. + # This allows e.g. enforcing causal attention on learner, or blocking + # attention across episodes + if hm_mask is not None: + masked_logits = jnp.where(hm_mask, scaled_logits, -1e6) + else: + masked_logits = scaled_logits + + # identify the top-k memories and their relevance weights + top_k_logits, top_k_indices = jax.lax.top_k(masked_logits, self._k) + weights = jax.nn.softmax(top_k_logits) + + # set up the within-memory attention + assert self._size % self._num_heads == 0 + mha_key_size = self._size // self._num_heads + attention_layer = hk.MultiHeadAttention( + key_size=mha_key_size, + model_size=self._size, + num_heads=self._num_heads, + w_init_scale=self._init_scale, + name="within_mem_attn") + + # position encodings + augmented_contents = hm_memory.contents + if self._memory_position_encoding: + position_embs = sinusoid_position_encoding( + memory_chunk_size, mem_embbedding_size) + augmented_contents += position_embs[None, None, :, :] + + def _within_memory_attention(sub_inputs, sub_memory_contents, sub_weights, + sub_top_k_indices): + top_k_contents = sub_memory_contents[sub_top_k_indices, :, :] + + # Now we go deeper, with another vmap over **tokens**, because each token + # can each attend to different memories. + def do_attention(sub_sub_inputs, sub_sub_top_k_contents): + tiled_inputs = jnp.tile(sub_sub_inputs[None, None, :], + reps=(self._k, 1, 1)) + sub_attention_results = attention_layer( + query=tiled_inputs, + key=sub_sub_top_k_contents, + value=sub_sub_top_k_contents) + return sub_attention_results + do_attention = jax.vmap(do_attention, in_axes=0) + attention_results = do_attention(sub_inputs, top_k_contents) + attention_results = jnp.squeeze(attention_results, axis=2) + # Now collapse results across k memories + attention_results = sub_weights[:, :, None] * attention_results + attention_results = jnp.sum(attention_results, axis=1) + return attention_results + + # vmap across batch + batch_within_memory_attention = jax.vmap(_within_memory_attention, + in_axes=0) + outputs = batch_within_memory_attention( + queries, + jax.lax.stop_gradient(augmented_contents), + weights, + top_k_indices) + + return outputs diff --git a/hierarchical_transformer_memory/hierarchical_attention/htm_attention_test.py b/hierarchical_transformer_memory/hierarchical_attention/htm_attention_test.py new file mode 100644 index 0000000..fda05df --- /dev/null +++ b/hierarchical_transformer_memory/hierarchical_attention/htm_attention_test.py @@ -0,0 +1,106 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for hierarchical_attention.htm_attention.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import haiku as hk +import numpy as np + +from hierarchical_transformer_memory.hierarchical_attention import htm_attention + + +def _build_queries_and_memory(query_length, num_memories, mem_chunk_size, + batch_size=2, embedding_size=12): + """Builds dummy queries + memory contents for tests.""" + queries = np.random.random([batch_size, query_length, embedding_size]) + memory_contents = np.random.random( + [batch_size, num_memories, mem_chunk_size, embedding_size]) + # summary key = average across chunk + memory_keys = np.mean(memory_contents, axis=2) + # to accumulate newest memories before writing + memory_accumulator = np.zeros_like(memory_contents[:, -1, :, :]) + memory = htm_attention.HierarchicalMemory( + keys=memory_keys, + contents=memory_contents, + accumulator=memory_accumulator, + steps_since_last_write=np.zeros([batch_size,], dtype=np.int32)) + return queries, memory + + +class HierarchicalAttentionTest(parameterized.TestCase): + + @parameterized.parameters([ + { + 'query_length': 1, + 'num_memories': 7, + 'mem_chunk_size': 5, + 'mem_k': 4, + }, + { + 'query_length': 9, + 'num_memories': 7, + 'mem_chunk_size': 5, + 'mem_k': 4, + }, + ]) + @hk.testing.transform_and_run + def test_output_shapes(self, query_length, num_memories, mem_chunk_size, + mem_k): + np.random.seed(0) + batch_size = 2 + embedding_size = 12 + num_heads = 3 + queries, memory = _build_queries_and_memory( + query_length=query_length, num_memories=num_memories, + mem_chunk_size=mem_chunk_size, embedding_size=embedding_size) + hm_att = htm_attention.HierarchicalMemoryAttention( + feature_size=embedding_size, + k=mem_k, + num_heads=num_heads) + results = hm_att(queries, memory) + self.assertEqual(results.shape, + (batch_size, query_length, embedding_size)) + self.assertTrue(np.all(np.isfinite(results))) + + @hk.testing.transform_and_run + def test_masking(self): + np.random.seed(0) + batch_size = 2 + embedding_size = 12 + num_heads = 3 + query_length = 5 + num_memories = 7 + mem_chunk_size = 6 + mem_k = 4 + queries, memory = _build_queries_and_memory( + query_length=query_length, num_memories=num_memories, + mem_chunk_size=mem_chunk_size, embedding_size=embedding_size) + hm_att = htm_attention.HierarchicalMemoryAttention( + feature_size=embedding_size, + k=mem_k, + num_heads=num_heads) + # get a random boolean mask + mask = np.random.binomial( + 1, 0.5, [batch_size, query_length, num_memories]).astype(bool) + results = hm_att(queries, memory, hm_mask=mask) + self.assertEqual(results.shape, + (batch_size, query_length, embedding_size)) + self.assertTrue(np.all(np.isfinite(results))) + + +if __name__ == '__main__': + absltest.main() diff --git a/hierarchical_transformer_memory/pycolab_ballet/ballet_environment.py b/hierarchical_transformer_memory/pycolab_ballet/ballet_environment.py new file mode 100644 index 0000000..036db1d --- /dev/null +++ b/hierarchical_transformer_memory/pycolab_ballet/ballet_environment.py @@ -0,0 +1,376 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A pycolab environment for going to the ballet. + +A pycolab-based environment for testing memory for sequences of events. The +environment contains some number of "dancer" characters in (implicit) 3 x 3 +squares within a larger 9 x 9 room. The agent starts in the center of the room. +At the beginning of an episode, the dancers each do a dance solo of a fixed +length, separated by empty time of a fixed length. The agent's actions do +nothing during the dances. After the last dance ends, the agent must go up to a +dancer, identified using language describing the dance. The agent is rewarded +1 +for approaching the correct dancer, 0 otherwise. + +The room is upsampled at a size of 9 pixels per square to render a view for the +agent, which is cropped in egocentric perspective, i.e. the agent is always in +the center of its view (see https://arxiv.org/abs/1910.00571). +""" +from absl import app +from absl import flags +from absl import logging + +import dm_env + +import numpy as np +from pycolab import cropping + +from hierarchical_transformer_memory.pycolab_ballet import ballet_environment_core as ballet_core + +FLAGS = flags.FLAGS + +UPSAMPLE_SIZE = 9 # pixels per game square +SCROLL_CROP_SIZE = 11 # in game squares + +DANCER_SHAPES = [ + "triangle", "empty_square", "plus", "inverse_plus", "ex", "inverse_ex", + "circle", "empty_circle", "tee", "upside_down_tee", + "h", "u", "upside_down_u", "vertical_stripes", "horizontal_stripes" +] + +COLORS = { + "red": np.array([255, 0, 0]), + "green": np.array([0, 255, 0]), + "blue": np.array([0, 0, 255]), + "purple": np.array([128, 0, 128]), + "orange": np.array([255, 165, 0]), + "yellow": np.array([255, 255, 0]), + "brown": np.array([128, 64, 0]), + "pink": np.array([255, 64, 255]), + "cyan": np.array([0, 255, 255]), + "dark_green": np.array([0, 100, 0]), + "dark_red": np.array([100, 0, 0]), + "dark_blue": np.array([0, 0, 100]), + "olive": np.array([100, 100, 0]), + "teal": np.array([0, 100, 100]), + "lavender": np.array([215, 200, 255]), + "peach": np.array([255, 210, 170]), + "rose": np.array([255, 205, 230]), + "light_green": np.array([200, 255, 200]), + "light_yellow": np.array([255, 255, 200]), +} + + +def _generate_template(object_name): + """Generates a template object image, given a name with color and shape.""" + object_color, object_type = object_name.split() + template = np.zeros((UPSAMPLE_SIZE, UPSAMPLE_SIZE)) + half = UPSAMPLE_SIZE // 2 + if object_type == "triangle": + for i in range(UPSAMPLE_SIZE): + for j in range(UPSAMPLE_SIZE): + if (j <= half and i >= 2 * (half - j)) or (j > half and i >= 2 * + (j - half)): + template[i, j] = 1. + elif object_type == "square": + template[:, :] = 1. + elif object_type == "empty_square": + template[:2, :] = 1. + template[-2:, :] = 1. + template[:, :2] = 1. + template[:, -2:] = 1. + elif object_type == "plus": + template[:, half - 1:half + 2] = 1. + template[half - 1:half + 2, :] = 1. + elif object_type == "inverse_plus": + template[:, :] = 1. + template[:, half - 1:half + 2] = 0. + template[half - 1:half + 2, :] = 0. + elif object_type == "ex": + for i in range(UPSAMPLE_SIZE): + for j in range(UPSAMPLE_SIZE): + if abs(i - j) <= 1 or abs(UPSAMPLE_SIZE - 1 - j - i) <= 1: + template[i, j] = 1. + elif object_type == "inverse_ex": + for i in range(UPSAMPLE_SIZE): + for j in range(UPSAMPLE_SIZE): + if not (abs(i - j) <= 1 or abs(UPSAMPLE_SIZE - 1 - j - i) <= 1): + template[i, j] = 1. + elif object_type == "circle": + for i in range(UPSAMPLE_SIZE): + for j in range(UPSAMPLE_SIZE): + if (i - half)**2 + (j - half)**2 <= half**2: + template[i, j] = 1. + elif object_type == "empty_circle": + for i in range(UPSAMPLE_SIZE): + for j in range(UPSAMPLE_SIZE): + if abs((i - half)**2 + (j - half)**2 - half**2) < 6: + template[i, j] = 1. + elif object_type == "tee": + template[:, half - 1:half + 2] = 1. + template[:3, :] = 1. + elif object_type == "upside_down_tee": + template[:, half - 1:half + 2] = 1. + template[-3:, :] = 1. + elif object_type == "h": + template[:, :3] = 1. + template[:, -3:] = 1. + template[half - 1:half + 2, :] = 1. + elif object_type == "u": + template[:, :3] = 1. + template[:, -3:] = 1. + template[-3:, :] = 1. + elif object_type == "upside_down_u": + template[:, :3] = 1. + template[:, -3:] = 1. + template[:3, :] = 1. + elif object_type == "vertical_stripes": + for j in range(half + UPSAMPLE_SIZE % 2): + template[:, 2*j] = 1. + elif object_type == "horizontal_stripes": + for i in range(half + UPSAMPLE_SIZE % 2): + template[2*i, :] = 1. + else: + raise ValueError("Unknown object: {}".format(object_type)) + + if object_color not in COLORS: + raise ValueError("Unknown color: {}".format(object_color)) + + template = np.tensordot(template, COLORS[object_color], axes=0) + + return template + + +# Agent and wall templates +_CHAR_TO_TEMPLATE_BASE = { + ballet_core.AGENT_CHAR: + np.tensordot( + np.ones([UPSAMPLE_SIZE, UPSAMPLE_SIZE]), + np.array([255, 255, 255]), + axes=0), + ballet_core.WALL_CHAR: + np.tensordot( + np.ones([UPSAMPLE_SIZE, UPSAMPLE_SIZE]), + np.array([40, 40, 40]), + axes=0), +} + + +def get_scrolling_cropper(rows=9, cols=9, crop_pad_char=" "): + return cropping.ScrollingCropper(rows=rows, cols=cols, + to_track=[ballet_core.AGENT_CHAR], + pad_char=crop_pad_char, + scroll_margins=(None, None)) + + +class BalletEnvironment(dm_env.Environment): + """A Python environment API for pycolab ballet tasks.""" + + def __init__(self, num_dancers, dance_delay, max_steps, rng=None): + """Construct a BalletEnvironment that wraps pycolab games for agent use. + + This class inherits from dm_env and has all the expected methods and specs. + + Args: + num_dancers: The number of dancers to use, between 1 and 8 (inclusive). + dance_delay: How long to delay between the dances. + max_steps: The maximum number of steps to allow in an episode, after which + it will terminate. + rng: An optional numpy Random Generator, to set a fixed seed use e.g. + `rng=np.random.default_rng(seed=...)` + """ + self._num_dancers = num_dancers + self._dance_delay = dance_delay + self._max_steps = max_steps + + # internal state + if rng is None: + rng = np.random.default_rng() + self._rng = rng + self._current_game = None # Current pycolab game instance. + self._state = None # Current game step state. + self._game_over = None # Whether the game has ended. + self._char_to_template = None # Mapping of chars to sprite images. + + # rendering tools + self._cropper = get_scrolling_cropper(SCROLL_CROP_SIZE, SCROLL_CROP_SIZE, + " ") + + def _game_factory(self): + """Samples dancers and positions, returns a pycolab core game engine.""" + target_dancer_index = self._rng.integers(self._num_dancers) + motions = list(ballet_core.DANCE_SEQUENCES.keys()) + positions = ballet_core.DANCER_POSITIONS.copy() + colors = list(COLORS.keys()) + shapes = DANCER_SHAPES.copy() + self._rng.shuffle(positions) + self._rng.shuffle(motions) + self._rng.shuffle(colors) + self._rng.shuffle(shapes) + dancers_and_properties = [] + for dancer_i in range(self._num_dancers): + if dancer_i == target_dancer_index: + value = 1. + else: + value = 0. + dancers_and_properties.append( + (ballet_core.POSSIBLE_DANCER_CHARS[dancer_i], + positions[dancer_i], + motions[dancer_i], + shapes[dancer_i], + colors[dancer_i], + value)) + + logging.info("Making level with dancers_and_properties: %s", + dancers_and_properties) + + return ballet_core.make_game( + dancers_and_properties=dancers_and_properties, + dance_delay=self._dance_delay) + + def _render_observation(self, observation): + """Renders from raw pycolab image observation to agent-usable ones.""" + observation = self._cropper.crop(observation) + obs_rows, obs_cols = observation.board.shape + image = np.zeros([obs_rows * UPSAMPLE_SIZE, obs_cols * UPSAMPLE_SIZE, 3], + dtype=np.float32) + for i in range(obs_rows): + for j in range(obs_cols): + this_char = chr(observation.board[i, j]) + if this_char != ballet_core.FLOOR_CHAR: + image[ + i * UPSAMPLE_SIZE:(i + 1) * UPSAMPLE_SIZE, j * + UPSAMPLE_SIZE:(j + 1) * UPSAMPLE_SIZE] = self._char_to_template[ + this_char] + image /= 255. + language = np.array(self._current_game.the_plot["instruction_string"]) + full_observation = (image, language) + return full_observation + + def reset(self): + """Start a new episode.""" + # Build a new game and retrieve its first set of state/reward/discount. + self._current_game = self._game_factory() + # set up rendering, cropping, and state for current game + self._char_to_template = { + k: _generate_template(v) for k, v in self._current_game.the_plot[ + "char_to_color_shape"]} + self._char_to_template.update(_CHAR_TO_TEMPLATE_BASE) + self._cropper.set_engine(self._current_game) + self._state = dm_env.StepType.FIRST + # let's go! + observation, _, _ = self._current_game.its_showtime() + observation = self._render_observation(observation) + return dm_env.TimeStep( + step_type=self._state, + reward=None, + discount=None, + observation=observation) + + def step(self, action): + """Apply action, step the world forward, and return observations.""" + # If needed, reset and start new episode. + if self._state == dm_env.StepType.LAST: + self._clear_state() + if self._current_game is None: + return self.reset() + + # Execute the action in pycolab. + observation, reward, discount = self._current_game.play(action) + + self._game_over = self._is_game_over() + reward = reward if reward is not None else 0. + observation = self._render_observation(observation) + + # Check the current status of the game. + if self._game_over: + self._state = dm_env.StepType.LAST + else: + self._state = dm_env.StepType.MID + + return dm_env.TimeStep( + step_type=self._state, + reward=reward, + discount=discount, + observation=observation) + + @property + def observation_spec(self): + image_shape = (SCROLL_CROP_SIZE * UPSAMPLE_SIZE, + SCROLL_CROP_SIZE * UPSAMPLE_SIZE, + 3) + return ( + # vision + dm_env.specs.Array( + shape=image_shape, dtype=np.float32, name="image"), + # language + dm_env.specs.Array( + shape=[], dtype=str, name="language"), + ) + + @property + def action_spec(self): + return dm_env.specs.BoundedArray( + shape=[], dtype="int32", + minimum=0, maximum=7, + name="grid_actions") + + def _is_game_over(self): + """Returns whether it is game over, either from the engine or timeout.""" + return (self._current_game.game_over or + (self._current_game.the_plot.frame >= self._max_steps)) + + def _clear_state(self): + """Clear all the internal information about the game.""" + self._state = None + self._current_game = None + self._char_to_template = None + self._game_over = None + + +def simple_builder(level_name): + """Simplifies building from fixed defs. + + Args: + level_name: '{num_dancers}_delay{delay_length}', where each variable is an + integer. The levels used in the paper were: + ['2_delay16', '4_delay16', '8_delay16', + '2_delay48', '4_delay48', '8_delay48'] + + Returns: + A BalletEnvironment with the requested settings. + """ + num_dancers, dance_delay = level_name.split("_") + num_dancers = int(num_dancers) + dance_delay = int(dance_delay[5:]) + max_steps = 320 if dance_delay == 16 else 1024 + level_args = dict( + num_dancers=num_dancers, + dance_delay=dance_delay, + max_steps=max_steps) + return BalletEnvironment(**level_args) + + +def main(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + env = simple_builder("4_delay16") + for _ in range(3): + obs = env.reset().observation + for _ in range(300): + obs = env.step(0).observation + print(obs) + +if __name__ == "__main__": + app.run(main) diff --git a/hierarchical_transformer_memory/pycolab_ballet/ballet_environment_core.py b/hierarchical_transformer_memory/pycolab_ballet/ballet_environment_core.py new file mode 100644 index 0000000..0dd5ca5 --- /dev/null +++ b/hierarchical_transformer_memory/pycolab_ballet/ballet_environment_core.py @@ -0,0 +1,337 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The pycolab core of the environment for going to the ballet. + +This builds the text-based (non-graphical) engine of the environment, and offers +a UI which a human can play (for a fixed level). However, the logic of level +creation, the graphics, and anything that is external to the pycolab engine +itself is contained in ballet_environment.py. +""" +import curses +import enum + +from absl import app +from absl import flags + +from pycolab import ascii_art +from pycolab import human_ui +from pycolab.prefab_parts import sprites as prefab_sprites + +FLAGS = flags.FLAGS + +ROOM_SIZE = (11, 11) # one square around edge will be wall. +DANCER_POSITIONS = [(2, 2), (2, 5), (2, 8), + (5, 2), (5, 8), # space in center for agent + (8, 2), (8, 5), (8, 8)] +AGENT_START = (5, 5) +AGENT_CHAR = "A" +WALL_CHAR = "#" +FLOOR_CHAR = " " +RESERVED_CHARS = [AGENT_CHAR, WALL_CHAR, FLOOR_CHAR] +POSSIBLE_DANCER_CHARS = [ + chr(i) for i in range(65, 91) if chr(i) not in RESERVED_CHARS +] + +DANCE_SEQUENCE_LENGTHS = 16 + + +# movement directions for dancers / actions for agent +class DIRECTIONS(enum.IntEnum): + N = 0 + NE = 1 + E = 2 + SE = 3 + S = 4 + SW = 5 + W = 6 + NW = 7 + +DANCE_SEQUENCES = { + "circle_cw": [ + DIRECTIONS.N, DIRECTIONS.E, DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.W, + DIRECTIONS.W, DIRECTIONS.N, DIRECTIONS.N, DIRECTIONS.E, DIRECTIONS.E, + DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.N, + DIRECTIONS.E + ], + "circle_ccw": [ + DIRECTIONS.N, DIRECTIONS.W, DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.E, + DIRECTIONS.E, DIRECTIONS.N, DIRECTIONS.N, DIRECTIONS.W, DIRECTIONS.W, + DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.E, DIRECTIONS.E, DIRECTIONS.N, + DIRECTIONS.W + ], + "up_and_down": [ + DIRECTIONS.N, DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.N, + DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.N, DIRECTIONS.S, + DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.N, DIRECTIONS.S, DIRECTIONS.S, + DIRECTIONS.N + ], + "left_and_right": [ + DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.E, + DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.E, DIRECTIONS.W, + DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.W, + DIRECTIONS.E + ], + "diagonal_uldr": [ + DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SE, DIRECTIONS.NW, + DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SE, DIRECTIONS.NW, + DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SE, DIRECTIONS.NW, + DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SE, DIRECTIONS.NW + ], + "diagonal_urdl": [ + DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, DIRECTIONS.NE, + DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, DIRECTIONS.NE, + DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, DIRECTIONS.NE, + DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, DIRECTIONS.NE + ], + "plus_cw": [ + DIRECTIONS.N, DIRECTIONS.S, DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.S, + DIRECTIONS.N, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.N, DIRECTIONS.S, + DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.W, + DIRECTIONS.E + ], + "plus_ccw": [ + DIRECTIONS.N, DIRECTIONS.S, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.S, + DIRECTIONS.N, DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.N, DIRECTIONS.S, + DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.E, + DIRECTIONS.W + ], + "times_cw": [ + DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SE, DIRECTIONS.NW, + DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.NW, DIRECTIONS.SE, + DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SE, DIRECTIONS.NW, + DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.NW, DIRECTIONS.SE + ], + "times_ccw": [ + DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SW, DIRECTIONS.NE, + DIRECTIONS.SE, DIRECTIONS.NW, DIRECTIONS.NE, DIRECTIONS.SW, + DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SW, DIRECTIONS.NE, + DIRECTIONS.SE, DIRECTIONS.NW, DIRECTIONS.NE, DIRECTIONS.SW + ], + "zee": [ + DIRECTIONS.NE, DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.E, + DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, + DIRECTIONS.E, DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.NE, + DIRECTIONS.SW, DIRECTIONS.NE + ], + "chevron_down": [ + DIRECTIONS.NW, DIRECTIONS.S, DIRECTIONS.SE, DIRECTIONS.NE, DIRECTIONS.N, + DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.NE, + DIRECTIONS.S, DIRECTIONS.SW, DIRECTIONS.NW, DIRECTIONS.N, DIRECTIONS.SE, + DIRECTIONS.NW, DIRECTIONS.SE + ], + "chevron_up": [ + DIRECTIONS.SE, DIRECTIONS.N, DIRECTIONS.NW, DIRECTIONS.SW, DIRECTIONS.S, + DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.SW, + DIRECTIONS.N, DIRECTIONS.NE, DIRECTIONS.SE, DIRECTIONS.S, DIRECTIONS.NW, + DIRECTIONS.SE, DIRECTIONS.NW + ], +} + + +class DancerSprite(prefab_sprites.MazeWalker): + """A `Sprite` for dancers.""" + + def __init__(self, corner, position, character, motion, color, shape, + value=0.): + super(DancerSprite, self).__init__( + corner, position, character, impassable="#") + self.motion = motion + self.dance_sequence = DANCE_SEQUENCES[motion].copy() + self.color = color + self.shape = shape + self.value = value + self.is_dancing = False + + def update(self, actions, board, layers, backdrop, things, the_plot): + if the_plot["task_phase"] == "dance" and self.is_dancing: + if not self.dance_sequence: + raise ValueError( + "Dance sequence is empty! Was this dancer repeated in the order?") + dance_move = self.dance_sequence.pop(0) + if dance_move == DIRECTIONS.N: + self._north(board, the_plot) + elif dance_move == DIRECTIONS.NE: + self._northeast(board, the_plot) + elif dance_move == DIRECTIONS.E: + self._east(board, the_plot) + elif dance_move == DIRECTIONS.SE: + self._southeast(board, the_plot) + elif dance_move == DIRECTIONS.S: + self._south(board, the_plot) + elif dance_move == DIRECTIONS.SW: + self._southwest(board, the_plot) + elif dance_move == DIRECTIONS.W: + self._west(board, the_plot) + elif dance_move == DIRECTIONS.NW: + self._northwest(board, the_plot) + + if not self.dance_sequence: # done! + self.is_dancing = False + the_plot["time_until_next_dance"] = the_plot["dance_delay"] + else: + if self.position == things[AGENT_CHAR].position: + # Award the player the appropriate amount of reward, and end episode. + the_plot.add_reward(self.value) + the_plot.terminate_episode() + + +class PlayerSprite(prefab_sprites.MazeWalker): + """The player / agent character. + + MazeWalker class methods handle basic movement and collision detection. + """ + + def __init__(self, corner, position, character): + super(PlayerSprite, self).__init__( + corner, position, character, impassable="#") + + def update(self, actions, board, layers, backdrop, things, the_plot): + if the_plot["task_phase"] == "dance": + # agent's actions are ignored, this logic updates the dance phases. + if the_plot["time_until_next_dance"] > 0: + the_plot["time_until_next_dance"] -= 1 + if the_plot["time_until_next_dance"] == 0: # next phase time! + if the_plot["dance_order"]: # start the next dance! + next_dancer = the_plot["dance_order"].pop(0) + things[next_dancer].is_dancing = True + else: # choice time! + the_plot["task_phase"] = "choice" + the_plot["instruction_string"] = the_plot[ + "choice_instruction_string"] + elif the_plot["task_phase"] == "choice": + # agent can now move and make its choice + if actions == DIRECTIONS.N: + self._north(board, the_plot) + elif actions == DIRECTIONS.NE: + self._northeast(board, the_plot) + elif actions == DIRECTIONS.E: + self._east(board, the_plot) + elif actions == DIRECTIONS.SE: + self._southeast(board, the_plot) + elif actions == DIRECTIONS.S: + self._south(board, the_plot) + elif actions == DIRECTIONS.SW: + self._southwest(board, the_plot) + elif actions == DIRECTIONS.W: + self._west(board, the_plot) + elif actions == DIRECTIONS.NW: + self._northwest(board, the_plot) + + +def make_game(dancers_and_properties, dance_delay=16): + """Constructs an ascii map, then uses pycolab to make it a game. + + Args: + dancers_and_properties: list of (character, (row, column), motion, shape, + color, value), for placing objects in the world. + dance_delay: how long to wait between dances. + + Returns: + this_game: Pycolab engine running the specified game. + """ + num_rows, num_cols = ROOM_SIZE + level_layout = [] + # upper wall + level_layout.append("".join([WALL_CHAR] * num_cols)) + # room + middle_string = "".join([WALL_CHAR] + [" "] * (num_cols - 2) + [WALL_CHAR]) + level_layout.extend([middle_string for _ in range(num_rows - 2)]) + # lower wall + level_layout.append("".join([WALL_CHAR] * num_cols)) + + def _add_to_map(obj, loc): + """Adds an ascii character to the level at the requested position.""" + obj_row = level_layout[loc[0]] + pre_string = obj_row[:loc[1]] + post_string = obj_row[loc[1] + 1:] + level_layout[loc[0]] = pre_string + obj + post_string + + _add_to_map(AGENT_CHAR, AGENT_START) + sprites = {AGENT_CHAR: PlayerSprite} + dance_order = [] + char_to_color_shape = [] + # add dancers to level + for obj, loc, motion, shape, color, value in dancers_and_properties: + _add_to_map(obj, loc) + sprites[obj] = ascii_art.Partial( + DancerSprite, motion=motion, color=color, shape=shape, value=value) + char_to_color_shape.append((obj, color + " " + shape)) + dance_order += obj + if value > 0.: + choice_instruction_string = motion + + this_game = ascii_art.ascii_art_to_game( + art=level_layout, + what_lies_beneath=" ", + sprites=sprites, + update_schedule=[[AGENT_CHAR], + dance_order]) + + this_game.the_plot["task_phase"] = "dance" + this_game.the_plot["instruction_string"] = "watch" + this_game.the_plot["choice_instruction_string"] = choice_instruction_string + this_game.the_plot["dance_order"] = dance_order + this_game.the_plot["dance_delay"] = dance_delay + this_game.the_plot["time_until_next_dance"] = 1 + this_game.the_plot["char_to_color_shape"] = char_to_color_shape + return this_game + + +def main(argv): + del argv # unused + these_dancers_and_properties = [ + (POSSIBLE_DANCER_CHARS[1], (2, 2), "chevron_up", "triangle", "red", 1), + (POSSIBLE_DANCER_CHARS[2], (2, 5), "circle_ccw", "triangle", "red", 0), + (POSSIBLE_DANCER_CHARS[3], (2, 8), "plus_cw", "triangle", "red", 0), + (POSSIBLE_DANCER_CHARS[4], (5, 2), "plus_ccw", "triangle", "red", 0), + (POSSIBLE_DANCER_CHARS[5], (5, 8), "times_cw", "triangle", "red", 0), + (POSSIBLE_DANCER_CHARS[6], (8, 2), "up_and_down", "plus", "blue", 0), + (POSSIBLE_DANCER_CHARS[7], (8, 5), "left_and_right", "plus", "blue", 0), + (POSSIBLE_DANCER_CHARS[8], (8, 8), "zee", "plus", "blue", 0), + ] + + game = make_game(dancers_and_properties=these_dancers_and_properties) + + # Note that these colors are only for human UI + fg_colours = { + AGENT_CHAR: (999, 999, 999), # Agent is white + WALL_CHAR: (300, 300, 300), # Wall, dark grey + FLOOR_CHAR: (0, 0, 0), # Floor + } + for (c, _, _, _, col, _) in these_dancers_and_properties: + fg_colours[c] = (999, 0, 0) if col == "red" else (0, 0, 999) + + bg_colours = { + c: (0, 0, 0) for c in RESERVED_CHARS + POSSIBLE_DANCER_CHARS[1:8] + } + + ui = human_ui.CursesUi( + keys_to_actions={ + # Basic movement. + curses.KEY_UP: DIRECTIONS.N, + curses.KEY_DOWN: DIRECTIONS.S, + curses.KEY_LEFT: DIRECTIONS.W, + curses.KEY_RIGHT: DIRECTIONS.E, + -1: 8, # Do nothing. + }, + delay=500, + colour_fg=fg_colours, + colour_bg=bg_colours) + + ui.play(game) + + +if __name__ == "__main__": + app.run(main) diff --git a/hierarchical_transformer_memory/pycolab_ballet/ballet_environment_test.py b/hierarchical_transformer_memory/pycolab_ballet/ballet_environment_test.py new file mode 100644 index 0000000..f14240b --- /dev/null +++ b/hierarchical_transformer_memory/pycolab_ballet/ballet_environment_test.py @@ -0,0 +1,85 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for pycolab_ballet.ballet_environment_wrapper.""" +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np + +from hierarchical_transformer_memory.pycolab_ballet import ballet_environment +from hierarchical_transformer_memory.pycolab_ballet import ballet_environment_core + + +class BalletEnvironmentTest(parameterized.TestCase): + + def test_full_wrapper(self): + env = ballet_environment.BalletEnvironment( + num_dancers=1, dance_delay=16, max_steps=200, + rng=np.random.default_rng(seed=0)) + result = env.reset() + self.assertIsNone(result.reward) + level_size = ballet_environment_core.ROOM_SIZE + upsample_size = ballet_environment.UPSAMPLE_SIZE + # wait for dance to complete + for i in range(30): + result = env.step(0).observation + self.assertEqual(result[0].shape, + (level_size[0] * upsample_size, + level_size[1] * upsample_size, + 3)) + self.assertEqual(str(result[1])[:5], + np.array("watch")) + for i in [1, 1, 1, 1]: # first gets eaten before agent can move + result = env.step(i) + self.assertEqual(result.observation[0].shape, + (level_size[0] * upsample_size, + level_size[1] * upsample_size, + 3)) + self.assertEqual(str(result.observation[1])[:11], + np.array("up_and_down")) + self.assertEqual(result.reward, 1.) + # check egocentric scrolling is working, by checking object is in center + np.testing.assert_array_almost_equal( + result.observation[0][45:54, 45:54], + ballet_environment._generate_template("orange plus") / 255.) + + @parameterized.parameters( + "2_delay16", + "4_delay16", + "8_delay48", + ) + def test_simple_builder(self, level_name): + dance_delay = int(level_name[-2:]) + np.random.seed(0) + env = ballet_environment.simple_builder(level_name) + # check max steps are set to match paper settings + self.assertEqual(env._max_steps, + 320 if dance_delay == 16 else 1024) + # test running a few steps of each + env.reset() + level_size = ballet_environment_core.ROOM_SIZE + upsample_size = ballet_environment.UPSAMPLE_SIZE + for i in range(8): + result = env.step(i) # check all 8 movements work + self.assertEqual(result.observation[0].shape, + (level_size[0] * upsample_size, + level_size[1] * upsample_size, + 3)) + self.assertEqual(str(result.observation[1])[:5], + np.array("watch")) + self.assertEqual(result.reward, 0.) + +if __name__ == "__main__": + absltest.main() diff --git a/hierarchical_transformer_memory/requirements.txt b/hierarchical_transformer_memory/requirements.txt new file mode 100644 index 0000000..796c272 --- /dev/null +++ b/hierarchical_transformer_memory/requirements.txt @@ -0,0 +1,7 @@ +absl-py +chex>=0.0.8 +dm-env +dm-haiku>=0.0.5.dev +jax>=0.2.17 +numpy +pycolab