mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-21 23:07:29 +08:00
Internal changes.
PiperOrigin-RevId: 389919323
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
# Towards mental time travel: A hierarchical memory for RL agents
|
||||
|
||||
This provides an implementation of two components of the paper "Towards mental
|
||||
time travel: A hierarchical memory for reinforcement learning agents." The
|
||||
article can be found on arXiv at [https://arxiv.org/abs/2105.14039](https://arxiv.org/abs/2105.14039)
|
||||
|
||||
Specifically, this repository contains:
|
||||
|
||||
1) A JAX/Haiku implementation of hierarchical transformer attention over memory.
|
||||
2) An implementation of the Ballet environment used in the paper.
|
||||
|
||||
We have also released the Rapid Word Learning tasks from the paper, but to
|
||||
simplify dependencies they are located in the `dm_fast_mapping` repository:
|
||||
[deepmind/dm_fast_mapping](https://github.com/deepmind/dm_fast_mapping) see the
|
||||
[documentation](https://github.com/deepmind/dm_fast_mapping/blob/master/docs/index.md)
|
||||
for that repository for further details about using those tasks.
|
||||
|
||||
## Setup
|
||||
|
||||
For easy installation, run:
|
||||
|
||||
```shell
|
||||
python3 -m venv htm_env
|
||||
source htm_env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Note that this installs the components needed for both the attention module and
|
||||
the environment. If you only wish to use the environment, you do not need to
|
||||
install JAX, Haiku, or Chex.
|
||||
|
||||
## Using the hierarchical attention module:
|
||||
|
||||
Please see `hierarchical_attention/htm_attention_test.py` for some examples of
|
||||
the expected inputs for this module.
|
||||
|
||||
## Running the ballet environment
|
||||
|
||||
The ballet environment is contained in the `pycolab_ballet/` subfolder. To load
|
||||
a simple ballet environment with 2 dancers and short delays, and watch a few
|
||||
steps of the dances, you can do:
|
||||
|
||||
```
|
||||
from pycolab_ballet import ballet_environment
|
||||
|
||||
env = ballet_environment.simple_builder(level_name='2_delay16')
|
||||
timestep = env.reset()
|
||||
for _ in range(5):
|
||||
action = 0
|
||||
timestep = env.step(action)
|
||||
```
|
||||
|
||||
## Citing this work
|
||||
|
||||
If you use this code, please cite the associated paper:
|
||||
|
||||
```
|
||||
@article{lampinen2021towards,
|
||||
title={Towards mental time travel:
|
||||
a hierarchical memory for reinforcement learning agents},
|
||||
author={Lampinen, Andrew Kyle and Chan, Stephanie CY and Banino, Andrea and
|
||||
Hill, Felix},
|
||||
journal={arXiv preprint arXiv:2105.14039},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,218 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Haiku module implementing hierarchical attention over memory."""
|
||||
|
||||
from typing import Optional, NamedTuple
|
||||
|
||||
import chex
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
_EPSILON = 1e-3
|
||||
|
||||
|
||||
class HierarchicalMemory(NamedTuple):
|
||||
"""Structure of the hierarchical memory.
|
||||
|
||||
Where 'B' is batch size, 'M' is number of memories, 'C' is chunk size, and 'D'
|
||||
is memory dimension.
|
||||
"""
|
||||
keys: jnp.ndarray # [B, M, D]
|
||||
contents: jnp.ndarray # [B, M, C, D]
|
||||
steps_since_last_write: jnp.ndarray # [B], steps since last memory write
|
||||
accumulator: jnp.ndarray # [B, C, D], accumulates experiences before write
|
||||
|
||||
|
||||
def sinusoid_position_encoding(
|
||||
sequence_length: int,
|
||||
hidden_size: int,
|
||||
min_timescale: float = 2.,
|
||||
max_timescale: float = 1e4,
|
||||
) -> jnp.ndarray:
|
||||
"""Creates sinusoidal encodings.
|
||||
|
||||
Args:
|
||||
sequence_length: length [L] of sequence to be position encoded.
|
||||
hidden_size: dimension [D] of the positional encoding vectors.
|
||||
min_timescale: minimum timescale for the frequency.
|
||||
max_timescale: maximum timescale for the frequency.
|
||||
|
||||
Returns:
|
||||
An array of shape [L, D]
|
||||
"""
|
||||
freqs = np.arange(0, hidden_size, min_timescale)
|
||||
inv_freq = max_timescale**(-freqs / hidden_size)
|
||||
pos_seq = np.arange(sequence_length - 1, -1, -1.0)
|
||||
sinusoid_inp = np.einsum("i,j->ij", pos_seq, inv_freq)
|
||||
pos_emb = np.concatenate(
|
||||
[np.sin(sinusoid_inp), np.cos(sinusoid_inp)], axis=-1)
|
||||
return pos_emb
|
||||
|
||||
|
||||
class HierarchicalMemoryAttention(hk.Module):
|
||||
"""Multi-head attention over hierarchical memory."""
|
||||
|
||||
def __init__(self,
|
||||
feature_size: int,
|
||||
k: int,
|
||||
num_heads: int = 1,
|
||||
memory_position_encoding: bool = True,
|
||||
init_scale: float = 2.,
|
||||
name: Optional[str] = None) -> None:
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
feature_size: size of feature dimension of attention-over-memories
|
||||
embedding.
|
||||
k: number of memories to sample.
|
||||
num_heads: number of attention heads.
|
||||
memory_position_encoding: whether to add positional encodings to memories
|
||||
during within memory attention.
|
||||
init_scale: scale factor for Variance weight initializers.
|
||||
name: module name.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
self._size = feature_size
|
||||
self._k = k
|
||||
self._num_heads = num_heads
|
||||
self._weights = None
|
||||
self._memory_position_encoding = memory_position_encoding
|
||||
self._init_scale = init_scale
|
||||
|
||||
@property
|
||||
def num_heads(self):
|
||||
return self._num_heads
|
||||
|
||||
@hk.transparent
|
||||
def _singlehead_linear(self,
|
||||
inputs: jnp.ndarray,
|
||||
hidden_size: int,
|
||||
name: str):
|
||||
linear = hk.Linear(
|
||||
hidden_size,
|
||||
with_bias=False,
|
||||
w_init=hk.initializers.VarianceScaling(scale=self._init_scale),
|
||||
name=name)
|
||||
out = linear(inputs)
|
||||
return out
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
queries: jnp.ndarray,
|
||||
hm_memory: HierarchicalMemory,
|
||||
hm_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
|
||||
"""Do hierarchical attention over the stored memories.
|
||||
|
||||
Args:
|
||||
queries: Tensor [B, Q, E] Query(ies) in, for batch size B, query length
|
||||
Q, and embedding dimension E.
|
||||
hm_memory: Hierarchical Memory.
|
||||
hm_mask: Optional boolean mask tensor of shape [B, Q, M]. Where false,
|
||||
the corresponding query timepoints cannot attend to the corresponding
|
||||
memory chunks. This can be used for enforcing causal attention on the
|
||||
learner, not attending to memories from prior episodes, etc.
|
||||
|
||||
Returns:
|
||||
Value updates for each query slot: [B, Q, D]
|
||||
"""
|
||||
# some shape checks
|
||||
batch_size, query_length, _ = queries.shape
|
||||
(memory_batch_size, num_memories,
|
||||
memory_chunk_size, mem_embbedding_size) = hm_memory.contents.shape
|
||||
assert batch_size == memory_batch_size
|
||||
chex.assert_shape(hm_memory.keys,
|
||||
(batch_size, num_memories, mem_embbedding_size))
|
||||
chex.assert_shape(hm_memory.accumulator,
|
||||
(memory_batch_size, memory_chunk_size,
|
||||
mem_embbedding_size))
|
||||
chex.assert_shape(hm_memory.steps_since_last_write,
|
||||
(memory_batch_size,))
|
||||
if hm_mask is not None:
|
||||
chex.assert_type(hm_mask, bool)
|
||||
chex.assert_shape(hm_mask,
|
||||
(batch_size, query_length, num_memories))
|
||||
query_head = self._singlehead_linear(queries, self._size, "query")
|
||||
key_head = self._singlehead_linear(
|
||||
jax.lax.stop_gradient(hm_memory.keys), self._size, "key")
|
||||
|
||||
# What times in the input [t] attend to what times in the memories [T].
|
||||
logits = jnp.einsum("btd,bTd->btT", query_head, key_head)
|
||||
|
||||
scaled_logits = logits / np.sqrt(self._size)
|
||||
|
||||
# Mask last dimension, replacing invalid logits with large negative values.
|
||||
# This allows e.g. enforcing causal attention on learner, or blocking
|
||||
# attention across episodes
|
||||
if hm_mask is not None:
|
||||
masked_logits = jnp.where(hm_mask, scaled_logits, -1e6)
|
||||
else:
|
||||
masked_logits = scaled_logits
|
||||
|
||||
# identify the top-k memories and their relevance weights
|
||||
top_k_logits, top_k_indices = jax.lax.top_k(masked_logits, self._k)
|
||||
weights = jax.nn.softmax(top_k_logits)
|
||||
|
||||
# set up the within-memory attention
|
||||
assert self._size % self._num_heads == 0
|
||||
mha_key_size = self._size // self._num_heads
|
||||
attention_layer = hk.MultiHeadAttention(
|
||||
key_size=mha_key_size,
|
||||
model_size=self._size,
|
||||
num_heads=self._num_heads,
|
||||
w_init_scale=self._init_scale,
|
||||
name="within_mem_attn")
|
||||
|
||||
# position encodings
|
||||
augmented_contents = hm_memory.contents
|
||||
if self._memory_position_encoding:
|
||||
position_embs = sinusoid_position_encoding(
|
||||
memory_chunk_size, mem_embbedding_size)
|
||||
augmented_contents += position_embs[None, None, :, :]
|
||||
|
||||
def _within_memory_attention(sub_inputs, sub_memory_contents, sub_weights,
|
||||
sub_top_k_indices):
|
||||
top_k_contents = sub_memory_contents[sub_top_k_indices, :, :]
|
||||
|
||||
# Now we go deeper, with another vmap over **tokens**, because each token
|
||||
# can each attend to different memories.
|
||||
def do_attention(sub_sub_inputs, sub_sub_top_k_contents):
|
||||
tiled_inputs = jnp.tile(sub_sub_inputs[None, None, :],
|
||||
reps=(self._k, 1, 1))
|
||||
sub_attention_results = attention_layer(
|
||||
query=tiled_inputs,
|
||||
key=sub_sub_top_k_contents,
|
||||
value=sub_sub_top_k_contents)
|
||||
return sub_attention_results
|
||||
do_attention = jax.vmap(do_attention, in_axes=0)
|
||||
attention_results = do_attention(sub_inputs, top_k_contents)
|
||||
attention_results = jnp.squeeze(attention_results, axis=2)
|
||||
# Now collapse results across k memories
|
||||
attention_results = sub_weights[:, :, None] * attention_results
|
||||
attention_results = jnp.sum(attention_results, axis=1)
|
||||
return attention_results
|
||||
|
||||
# vmap across batch
|
||||
batch_within_memory_attention = jax.vmap(_within_memory_attention,
|
||||
in_axes=0)
|
||||
outputs = batch_within_memory_attention(
|
||||
queries,
|
||||
jax.lax.stop_gradient(augmented_contents),
|
||||
weights,
|
||||
top_k_indices)
|
||||
|
||||
return outputs
|
||||
@@ -0,0 +1,106 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for hierarchical_attention.htm_attention."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import haiku as hk
|
||||
import numpy as np
|
||||
|
||||
from hierarchical_transformer_memory.hierarchical_attention import htm_attention
|
||||
|
||||
|
||||
def _build_queries_and_memory(query_length, num_memories, mem_chunk_size,
|
||||
batch_size=2, embedding_size=12):
|
||||
"""Builds dummy queries + memory contents for tests."""
|
||||
queries = np.random.random([batch_size, query_length, embedding_size])
|
||||
memory_contents = np.random.random(
|
||||
[batch_size, num_memories, mem_chunk_size, embedding_size])
|
||||
# summary key = average across chunk
|
||||
memory_keys = np.mean(memory_contents, axis=2)
|
||||
# to accumulate newest memories before writing
|
||||
memory_accumulator = np.zeros_like(memory_contents[:, -1, :, :])
|
||||
memory = htm_attention.HierarchicalMemory(
|
||||
keys=memory_keys,
|
||||
contents=memory_contents,
|
||||
accumulator=memory_accumulator,
|
||||
steps_since_last_write=np.zeros([batch_size,], dtype=np.int32))
|
||||
return queries, memory
|
||||
|
||||
|
||||
class HierarchicalAttentionTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'query_length': 1,
|
||||
'num_memories': 7,
|
||||
'mem_chunk_size': 5,
|
||||
'mem_k': 4,
|
||||
},
|
||||
{
|
||||
'query_length': 9,
|
||||
'num_memories': 7,
|
||||
'mem_chunk_size': 5,
|
||||
'mem_k': 4,
|
||||
},
|
||||
])
|
||||
@hk.testing.transform_and_run
|
||||
def test_output_shapes(self, query_length, num_memories, mem_chunk_size,
|
||||
mem_k):
|
||||
np.random.seed(0)
|
||||
batch_size = 2
|
||||
embedding_size = 12
|
||||
num_heads = 3
|
||||
queries, memory = _build_queries_and_memory(
|
||||
query_length=query_length, num_memories=num_memories,
|
||||
mem_chunk_size=mem_chunk_size, embedding_size=embedding_size)
|
||||
hm_att = htm_attention.HierarchicalMemoryAttention(
|
||||
feature_size=embedding_size,
|
||||
k=mem_k,
|
||||
num_heads=num_heads)
|
||||
results = hm_att(queries, memory)
|
||||
self.assertEqual(results.shape,
|
||||
(batch_size, query_length, embedding_size))
|
||||
self.assertTrue(np.all(np.isfinite(results)))
|
||||
|
||||
@hk.testing.transform_and_run
|
||||
def test_masking(self):
|
||||
np.random.seed(0)
|
||||
batch_size = 2
|
||||
embedding_size = 12
|
||||
num_heads = 3
|
||||
query_length = 5
|
||||
num_memories = 7
|
||||
mem_chunk_size = 6
|
||||
mem_k = 4
|
||||
queries, memory = _build_queries_and_memory(
|
||||
query_length=query_length, num_memories=num_memories,
|
||||
mem_chunk_size=mem_chunk_size, embedding_size=embedding_size)
|
||||
hm_att = htm_attention.HierarchicalMemoryAttention(
|
||||
feature_size=embedding_size,
|
||||
k=mem_k,
|
||||
num_heads=num_heads)
|
||||
# get a random boolean mask
|
||||
mask = np.random.binomial(
|
||||
1, 0.5, [batch_size, query_length, num_memories]).astype(bool)
|
||||
results = hm_att(queries, memory, hm_mask=mask)
|
||||
self.assertEqual(results.shape,
|
||||
(batch_size, query_length, embedding_size))
|
||||
self.assertTrue(np.all(np.isfinite(results)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
@@ -0,0 +1,376 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A pycolab environment for going to the ballet.
|
||||
|
||||
A pycolab-based environment for testing memory for sequences of events. The
|
||||
environment contains some number of "dancer" characters in (implicit) 3 x 3
|
||||
squares within a larger 9 x 9 room. The agent starts in the center of the room.
|
||||
At the beginning of an episode, the dancers each do a dance solo of a fixed
|
||||
length, separated by empty time of a fixed length. The agent's actions do
|
||||
nothing during the dances. After the last dance ends, the agent must go up to a
|
||||
dancer, identified using language describing the dance. The agent is rewarded +1
|
||||
for approaching the correct dancer, 0 otherwise.
|
||||
|
||||
The room is upsampled at a size of 9 pixels per square to render a view for the
|
||||
agent, which is cropped in egocentric perspective, i.e. the agent is always in
|
||||
the center of its view (see https://arxiv.org/abs/1910.00571).
|
||||
"""
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
import dm_env
|
||||
|
||||
import numpy as np
|
||||
from pycolab import cropping
|
||||
|
||||
from hierarchical_transformer_memory.pycolab_ballet import ballet_environment_core as ballet_core
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
UPSAMPLE_SIZE = 9 # pixels per game square
|
||||
SCROLL_CROP_SIZE = 11 # in game squares
|
||||
|
||||
DANCER_SHAPES = [
|
||||
"triangle", "empty_square", "plus", "inverse_plus", "ex", "inverse_ex",
|
||||
"circle", "empty_circle", "tee", "upside_down_tee",
|
||||
"h", "u", "upside_down_u", "vertical_stripes", "horizontal_stripes"
|
||||
]
|
||||
|
||||
COLORS = {
|
||||
"red": np.array([255, 0, 0]),
|
||||
"green": np.array([0, 255, 0]),
|
||||
"blue": np.array([0, 0, 255]),
|
||||
"purple": np.array([128, 0, 128]),
|
||||
"orange": np.array([255, 165, 0]),
|
||||
"yellow": np.array([255, 255, 0]),
|
||||
"brown": np.array([128, 64, 0]),
|
||||
"pink": np.array([255, 64, 255]),
|
||||
"cyan": np.array([0, 255, 255]),
|
||||
"dark_green": np.array([0, 100, 0]),
|
||||
"dark_red": np.array([100, 0, 0]),
|
||||
"dark_blue": np.array([0, 0, 100]),
|
||||
"olive": np.array([100, 100, 0]),
|
||||
"teal": np.array([0, 100, 100]),
|
||||
"lavender": np.array([215, 200, 255]),
|
||||
"peach": np.array([255, 210, 170]),
|
||||
"rose": np.array([255, 205, 230]),
|
||||
"light_green": np.array([200, 255, 200]),
|
||||
"light_yellow": np.array([255, 255, 200]),
|
||||
}
|
||||
|
||||
|
||||
def _generate_template(object_name):
|
||||
"""Generates a template object image, given a name with color and shape."""
|
||||
object_color, object_type = object_name.split()
|
||||
template = np.zeros((UPSAMPLE_SIZE, UPSAMPLE_SIZE))
|
||||
half = UPSAMPLE_SIZE // 2
|
||||
if object_type == "triangle":
|
||||
for i in range(UPSAMPLE_SIZE):
|
||||
for j in range(UPSAMPLE_SIZE):
|
||||
if (j <= half and i >= 2 * (half - j)) or (j > half and i >= 2 *
|
||||
(j - half)):
|
||||
template[i, j] = 1.
|
||||
elif object_type == "square":
|
||||
template[:, :] = 1.
|
||||
elif object_type == "empty_square":
|
||||
template[:2, :] = 1.
|
||||
template[-2:, :] = 1.
|
||||
template[:, :2] = 1.
|
||||
template[:, -2:] = 1.
|
||||
elif object_type == "plus":
|
||||
template[:, half - 1:half + 2] = 1.
|
||||
template[half - 1:half + 2, :] = 1.
|
||||
elif object_type == "inverse_plus":
|
||||
template[:, :] = 1.
|
||||
template[:, half - 1:half + 2] = 0.
|
||||
template[half - 1:half + 2, :] = 0.
|
||||
elif object_type == "ex":
|
||||
for i in range(UPSAMPLE_SIZE):
|
||||
for j in range(UPSAMPLE_SIZE):
|
||||
if abs(i - j) <= 1 or abs(UPSAMPLE_SIZE - 1 - j - i) <= 1:
|
||||
template[i, j] = 1.
|
||||
elif object_type == "inverse_ex":
|
||||
for i in range(UPSAMPLE_SIZE):
|
||||
for j in range(UPSAMPLE_SIZE):
|
||||
if not (abs(i - j) <= 1 or abs(UPSAMPLE_SIZE - 1 - j - i) <= 1):
|
||||
template[i, j] = 1.
|
||||
elif object_type == "circle":
|
||||
for i in range(UPSAMPLE_SIZE):
|
||||
for j in range(UPSAMPLE_SIZE):
|
||||
if (i - half)**2 + (j - half)**2 <= half**2:
|
||||
template[i, j] = 1.
|
||||
elif object_type == "empty_circle":
|
||||
for i in range(UPSAMPLE_SIZE):
|
||||
for j in range(UPSAMPLE_SIZE):
|
||||
if abs((i - half)**2 + (j - half)**2 - half**2) < 6:
|
||||
template[i, j] = 1.
|
||||
elif object_type == "tee":
|
||||
template[:, half - 1:half + 2] = 1.
|
||||
template[:3, :] = 1.
|
||||
elif object_type == "upside_down_tee":
|
||||
template[:, half - 1:half + 2] = 1.
|
||||
template[-3:, :] = 1.
|
||||
elif object_type == "h":
|
||||
template[:, :3] = 1.
|
||||
template[:, -3:] = 1.
|
||||
template[half - 1:half + 2, :] = 1.
|
||||
elif object_type == "u":
|
||||
template[:, :3] = 1.
|
||||
template[:, -3:] = 1.
|
||||
template[-3:, :] = 1.
|
||||
elif object_type == "upside_down_u":
|
||||
template[:, :3] = 1.
|
||||
template[:, -3:] = 1.
|
||||
template[:3, :] = 1.
|
||||
elif object_type == "vertical_stripes":
|
||||
for j in range(half + UPSAMPLE_SIZE % 2):
|
||||
template[:, 2*j] = 1.
|
||||
elif object_type == "horizontal_stripes":
|
||||
for i in range(half + UPSAMPLE_SIZE % 2):
|
||||
template[2*i, :] = 1.
|
||||
else:
|
||||
raise ValueError("Unknown object: {}".format(object_type))
|
||||
|
||||
if object_color not in COLORS:
|
||||
raise ValueError("Unknown color: {}".format(object_color))
|
||||
|
||||
template = np.tensordot(template, COLORS[object_color], axes=0)
|
||||
|
||||
return template
|
||||
|
||||
|
||||
# Agent and wall templates
|
||||
_CHAR_TO_TEMPLATE_BASE = {
|
||||
ballet_core.AGENT_CHAR:
|
||||
np.tensordot(
|
||||
np.ones([UPSAMPLE_SIZE, UPSAMPLE_SIZE]),
|
||||
np.array([255, 255, 255]),
|
||||
axes=0),
|
||||
ballet_core.WALL_CHAR:
|
||||
np.tensordot(
|
||||
np.ones([UPSAMPLE_SIZE, UPSAMPLE_SIZE]),
|
||||
np.array([40, 40, 40]),
|
||||
axes=0),
|
||||
}
|
||||
|
||||
|
||||
def get_scrolling_cropper(rows=9, cols=9, crop_pad_char=" "):
|
||||
return cropping.ScrollingCropper(rows=rows, cols=cols,
|
||||
to_track=[ballet_core.AGENT_CHAR],
|
||||
pad_char=crop_pad_char,
|
||||
scroll_margins=(None, None))
|
||||
|
||||
|
||||
class BalletEnvironment(dm_env.Environment):
|
||||
"""A Python environment API for pycolab ballet tasks."""
|
||||
|
||||
def __init__(self, num_dancers, dance_delay, max_steps, rng=None):
|
||||
"""Construct a BalletEnvironment that wraps pycolab games for agent use.
|
||||
|
||||
This class inherits from dm_env and has all the expected methods and specs.
|
||||
|
||||
Args:
|
||||
num_dancers: The number of dancers to use, between 1 and 8 (inclusive).
|
||||
dance_delay: How long to delay between the dances.
|
||||
max_steps: The maximum number of steps to allow in an episode, after which
|
||||
it will terminate.
|
||||
rng: An optional numpy Random Generator, to set a fixed seed use e.g.
|
||||
`rng=np.random.default_rng(seed=...)`
|
||||
"""
|
||||
self._num_dancers = num_dancers
|
||||
self._dance_delay = dance_delay
|
||||
self._max_steps = max_steps
|
||||
|
||||
# internal state
|
||||
if rng is None:
|
||||
rng = np.random.default_rng()
|
||||
self._rng = rng
|
||||
self._current_game = None # Current pycolab game instance.
|
||||
self._state = None # Current game step state.
|
||||
self._game_over = None # Whether the game has ended.
|
||||
self._char_to_template = None # Mapping of chars to sprite images.
|
||||
|
||||
# rendering tools
|
||||
self._cropper = get_scrolling_cropper(SCROLL_CROP_SIZE, SCROLL_CROP_SIZE,
|
||||
" ")
|
||||
|
||||
def _game_factory(self):
|
||||
"""Samples dancers and positions, returns a pycolab core game engine."""
|
||||
target_dancer_index = self._rng.integers(self._num_dancers)
|
||||
motions = list(ballet_core.DANCE_SEQUENCES.keys())
|
||||
positions = ballet_core.DANCER_POSITIONS.copy()
|
||||
colors = list(COLORS.keys())
|
||||
shapes = DANCER_SHAPES.copy()
|
||||
self._rng.shuffle(positions)
|
||||
self._rng.shuffle(motions)
|
||||
self._rng.shuffle(colors)
|
||||
self._rng.shuffle(shapes)
|
||||
dancers_and_properties = []
|
||||
for dancer_i in range(self._num_dancers):
|
||||
if dancer_i == target_dancer_index:
|
||||
value = 1.
|
||||
else:
|
||||
value = 0.
|
||||
dancers_and_properties.append(
|
||||
(ballet_core.POSSIBLE_DANCER_CHARS[dancer_i],
|
||||
positions[dancer_i],
|
||||
motions[dancer_i],
|
||||
shapes[dancer_i],
|
||||
colors[dancer_i],
|
||||
value))
|
||||
|
||||
logging.info("Making level with dancers_and_properties: %s",
|
||||
dancers_and_properties)
|
||||
|
||||
return ballet_core.make_game(
|
||||
dancers_and_properties=dancers_and_properties,
|
||||
dance_delay=self._dance_delay)
|
||||
|
||||
def _render_observation(self, observation):
|
||||
"""Renders from raw pycolab image observation to agent-usable ones."""
|
||||
observation = self._cropper.crop(observation)
|
||||
obs_rows, obs_cols = observation.board.shape
|
||||
image = np.zeros([obs_rows * UPSAMPLE_SIZE, obs_cols * UPSAMPLE_SIZE, 3],
|
||||
dtype=np.float32)
|
||||
for i in range(obs_rows):
|
||||
for j in range(obs_cols):
|
||||
this_char = chr(observation.board[i, j])
|
||||
if this_char != ballet_core.FLOOR_CHAR:
|
||||
image[
|
||||
i * UPSAMPLE_SIZE:(i + 1) * UPSAMPLE_SIZE, j *
|
||||
UPSAMPLE_SIZE:(j + 1) * UPSAMPLE_SIZE] = self._char_to_template[
|
||||
this_char]
|
||||
image /= 255.
|
||||
language = np.array(self._current_game.the_plot["instruction_string"])
|
||||
full_observation = (image, language)
|
||||
return full_observation
|
||||
|
||||
def reset(self):
|
||||
"""Start a new episode."""
|
||||
# Build a new game and retrieve its first set of state/reward/discount.
|
||||
self._current_game = self._game_factory()
|
||||
# set up rendering, cropping, and state for current game
|
||||
self._char_to_template = {
|
||||
k: _generate_template(v) for k, v in self._current_game.the_plot[
|
||||
"char_to_color_shape"]}
|
||||
self._char_to_template.update(_CHAR_TO_TEMPLATE_BASE)
|
||||
self._cropper.set_engine(self._current_game)
|
||||
self._state = dm_env.StepType.FIRST
|
||||
# let's go!
|
||||
observation, _, _ = self._current_game.its_showtime()
|
||||
observation = self._render_observation(observation)
|
||||
return dm_env.TimeStep(
|
||||
step_type=self._state,
|
||||
reward=None,
|
||||
discount=None,
|
||||
observation=observation)
|
||||
|
||||
def step(self, action):
|
||||
"""Apply action, step the world forward, and return observations."""
|
||||
# If needed, reset and start new episode.
|
||||
if self._state == dm_env.StepType.LAST:
|
||||
self._clear_state()
|
||||
if self._current_game is None:
|
||||
return self.reset()
|
||||
|
||||
# Execute the action in pycolab.
|
||||
observation, reward, discount = self._current_game.play(action)
|
||||
|
||||
self._game_over = self._is_game_over()
|
||||
reward = reward if reward is not None else 0.
|
||||
observation = self._render_observation(observation)
|
||||
|
||||
# Check the current status of the game.
|
||||
if self._game_over:
|
||||
self._state = dm_env.StepType.LAST
|
||||
else:
|
||||
self._state = dm_env.StepType.MID
|
||||
|
||||
return dm_env.TimeStep(
|
||||
step_type=self._state,
|
||||
reward=reward,
|
||||
discount=discount,
|
||||
observation=observation)
|
||||
|
||||
@property
|
||||
def observation_spec(self):
|
||||
image_shape = (SCROLL_CROP_SIZE * UPSAMPLE_SIZE,
|
||||
SCROLL_CROP_SIZE * UPSAMPLE_SIZE,
|
||||
3)
|
||||
return (
|
||||
# vision
|
||||
dm_env.specs.Array(
|
||||
shape=image_shape, dtype=np.float32, name="image"),
|
||||
# language
|
||||
dm_env.specs.Array(
|
||||
shape=[], dtype=str, name="language"),
|
||||
)
|
||||
|
||||
@property
|
||||
def action_spec(self):
|
||||
return dm_env.specs.BoundedArray(
|
||||
shape=[], dtype="int32",
|
||||
minimum=0, maximum=7,
|
||||
name="grid_actions")
|
||||
|
||||
def _is_game_over(self):
|
||||
"""Returns whether it is game over, either from the engine or timeout."""
|
||||
return (self._current_game.game_over or
|
||||
(self._current_game.the_plot.frame >= self._max_steps))
|
||||
|
||||
def _clear_state(self):
|
||||
"""Clear all the internal information about the game."""
|
||||
self._state = None
|
||||
self._current_game = None
|
||||
self._char_to_template = None
|
||||
self._game_over = None
|
||||
|
||||
|
||||
def simple_builder(level_name):
|
||||
"""Simplifies building from fixed defs.
|
||||
|
||||
Args:
|
||||
level_name: '{num_dancers}_delay{delay_length}', where each variable is an
|
||||
integer. The levels used in the paper were:
|
||||
['2_delay16', '4_delay16', '8_delay16',
|
||||
'2_delay48', '4_delay48', '8_delay48']
|
||||
|
||||
Returns:
|
||||
A BalletEnvironment with the requested settings.
|
||||
"""
|
||||
num_dancers, dance_delay = level_name.split("_")
|
||||
num_dancers = int(num_dancers)
|
||||
dance_delay = int(dance_delay[5:])
|
||||
max_steps = 320 if dance_delay == 16 else 1024
|
||||
level_args = dict(
|
||||
num_dancers=num_dancers,
|
||||
dance_delay=dance_delay,
|
||||
max_steps=max_steps)
|
||||
return BalletEnvironment(**level_args)
|
||||
|
||||
|
||||
def main(argv):
|
||||
if len(argv) > 1:
|
||||
raise app.UsageError("Too many command-line arguments.")
|
||||
env = simple_builder("4_delay16")
|
||||
for _ in range(3):
|
||||
obs = env.reset().observation
|
||||
for _ in range(300):
|
||||
obs = env.step(0).observation
|
||||
print(obs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
@@ -0,0 +1,337 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The pycolab core of the environment for going to the ballet.
|
||||
|
||||
This builds the text-based (non-graphical) engine of the environment, and offers
|
||||
a UI which a human can play (for a fixed level). However, the logic of level
|
||||
creation, the graphics, and anything that is external to the pycolab engine
|
||||
itself is contained in ballet_environment.py.
|
||||
"""
|
||||
import curses
|
||||
import enum
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from pycolab import ascii_art
|
||||
from pycolab import human_ui
|
||||
from pycolab.prefab_parts import sprites as prefab_sprites
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
ROOM_SIZE = (11, 11) # one square around edge will be wall.
|
||||
DANCER_POSITIONS = [(2, 2), (2, 5), (2, 8),
|
||||
(5, 2), (5, 8), # space in center for agent
|
||||
(8, 2), (8, 5), (8, 8)]
|
||||
AGENT_START = (5, 5)
|
||||
AGENT_CHAR = "A"
|
||||
WALL_CHAR = "#"
|
||||
FLOOR_CHAR = " "
|
||||
RESERVED_CHARS = [AGENT_CHAR, WALL_CHAR, FLOOR_CHAR]
|
||||
POSSIBLE_DANCER_CHARS = [
|
||||
chr(i) for i in range(65, 91) if chr(i) not in RESERVED_CHARS
|
||||
]
|
||||
|
||||
DANCE_SEQUENCE_LENGTHS = 16
|
||||
|
||||
|
||||
# movement directions for dancers / actions for agent
|
||||
class DIRECTIONS(enum.IntEnum):
|
||||
N = 0
|
||||
NE = 1
|
||||
E = 2
|
||||
SE = 3
|
||||
S = 4
|
||||
SW = 5
|
||||
W = 6
|
||||
NW = 7
|
||||
|
||||
DANCE_SEQUENCES = {
|
||||
"circle_cw": [
|
||||
DIRECTIONS.N, DIRECTIONS.E, DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.W,
|
||||
DIRECTIONS.W, DIRECTIONS.N, DIRECTIONS.N, DIRECTIONS.E, DIRECTIONS.E,
|
||||
DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.N,
|
||||
DIRECTIONS.E
|
||||
],
|
||||
"circle_ccw": [
|
||||
DIRECTIONS.N, DIRECTIONS.W, DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.E,
|
||||
DIRECTIONS.E, DIRECTIONS.N, DIRECTIONS.N, DIRECTIONS.W, DIRECTIONS.W,
|
||||
DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.E, DIRECTIONS.E, DIRECTIONS.N,
|
||||
DIRECTIONS.W
|
||||
],
|
||||
"up_and_down": [
|
||||
DIRECTIONS.N, DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.N,
|
||||
DIRECTIONS.S, DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.N, DIRECTIONS.S,
|
||||
DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.N, DIRECTIONS.S, DIRECTIONS.S,
|
||||
DIRECTIONS.N
|
||||
],
|
||||
"left_and_right": [
|
||||
DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.E,
|
||||
DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.E, DIRECTIONS.W,
|
||||
DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.W,
|
||||
DIRECTIONS.E
|
||||
],
|
||||
"diagonal_uldr": [
|
||||
DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SE, DIRECTIONS.NW,
|
||||
DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SE, DIRECTIONS.NW,
|
||||
DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SE, DIRECTIONS.NW,
|
||||
DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SE, DIRECTIONS.NW
|
||||
],
|
||||
"diagonal_urdl": [
|
||||
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, DIRECTIONS.NE,
|
||||
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, DIRECTIONS.NE,
|
||||
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, DIRECTIONS.NE,
|
||||
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW, DIRECTIONS.NE
|
||||
],
|
||||
"plus_cw": [
|
||||
DIRECTIONS.N, DIRECTIONS.S, DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.S,
|
||||
DIRECTIONS.N, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.N, DIRECTIONS.S,
|
||||
DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.W,
|
||||
DIRECTIONS.E
|
||||
],
|
||||
"plus_ccw": [
|
||||
DIRECTIONS.N, DIRECTIONS.S, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.S,
|
||||
DIRECTIONS.N, DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.N, DIRECTIONS.S,
|
||||
DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.S, DIRECTIONS.N, DIRECTIONS.E,
|
||||
DIRECTIONS.W
|
||||
],
|
||||
"times_cw": [
|
||||
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SE, DIRECTIONS.NW,
|
||||
DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.NW, DIRECTIONS.SE,
|
||||
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SE, DIRECTIONS.NW,
|
||||
DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.NW, DIRECTIONS.SE
|
||||
],
|
||||
"times_ccw": [
|
||||
DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SW, DIRECTIONS.NE,
|
||||
DIRECTIONS.SE, DIRECTIONS.NW, DIRECTIONS.NE, DIRECTIONS.SW,
|
||||
DIRECTIONS.NW, DIRECTIONS.SE, DIRECTIONS.SW, DIRECTIONS.NE,
|
||||
DIRECTIONS.SE, DIRECTIONS.NW, DIRECTIONS.NE, DIRECTIONS.SW
|
||||
],
|
||||
"zee": [
|
||||
DIRECTIONS.NE, DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.E, DIRECTIONS.E,
|
||||
DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.SW,
|
||||
DIRECTIONS.E, DIRECTIONS.E, DIRECTIONS.W, DIRECTIONS.W, DIRECTIONS.NE,
|
||||
DIRECTIONS.SW, DIRECTIONS.NE
|
||||
],
|
||||
"chevron_down": [
|
||||
DIRECTIONS.NW, DIRECTIONS.S, DIRECTIONS.SE, DIRECTIONS.NE, DIRECTIONS.N,
|
||||
DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.NE,
|
||||
DIRECTIONS.S, DIRECTIONS.SW, DIRECTIONS.NW, DIRECTIONS.N, DIRECTIONS.SE,
|
||||
DIRECTIONS.NW, DIRECTIONS.SE
|
||||
],
|
||||
"chevron_up": [
|
||||
DIRECTIONS.SE, DIRECTIONS.N, DIRECTIONS.NW, DIRECTIONS.SW, DIRECTIONS.S,
|
||||
DIRECTIONS.NE, DIRECTIONS.SW, DIRECTIONS.NE, DIRECTIONS.SW,
|
||||
DIRECTIONS.N, DIRECTIONS.NE, DIRECTIONS.SE, DIRECTIONS.S, DIRECTIONS.NW,
|
||||
DIRECTIONS.SE, DIRECTIONS.NW
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class DancerSprite(prefab_sprites.MazeWalker):
|
||||
"""A `Sprite` for dancers."""
|
||||
|
||||
def __init__(self, corner, position, character, motion, color, shape,
|
||||
value=0.):
|
||||
super(DancerSprite, self).__init__(
|
||||
corner, position, character, impassable="#")
|
||||
self.motion = motion
|
||||
self.dance_sequence = DANCE_SEQUENCES[motion].copy()
|
||||
self.color = color
|
||||
self.shape = shape
|
||||
self.value = value
|
||||
self.is_dancing = False
|
||||
|
||||
def update(self, actions, board, layers, backdrop, things, the_plot):
|
||||
if the_plot["task_phase"] == "dance" and self.is_dancing:
|
||||
if not self.dance_sequence:
|
||||
raise ValueError(
|
||||
"Dance sequence is empty! Was this dancer repeated in the order?")
|
||||
dance_move = self.dance_sequence.pop(0)
|
||||
if dance_move == DIRECTIONS.N:
|
||||
self._north(board, the_plot)
|
||||
elif dance_move == DIRECTIONS.NE:
|
||||
self._northeast(board, the_plot)
|
||||
elif dance_move == DIRECTIONS.E:
|
||||
self._east(board, the_plot)
|
||||
elif dance_move == DIRECTIONS.SE:
|
||||
self._southeast(board, the_plot)
|
||||
elif dance_move == DIRECTIONS.S:
|
||||
self._south(board, the_plot)
|
||||
elif dance_move == DIRECTIONS.SW:
|
||||
self._southwest(board, the_plot)
|
||||
elif dance_move == DIRECTIONS.W:
|
||||
self._west(board, the_plot)
|
||||
elif dance_move == DIRECTIONS.NW:
|
||||
self._northwest(board, the_plot)
|
||||
|
||||
if not self.dance_sequence: # done!
|
||||
self.is_dancing = False
|
||||
the_plot["time_until_next_dance"] = the_plot["dance_delay"]
|
||||
else:
|
||||
if self.position == things[AGENT_CHAR].position:
|
||||
# Award the player the appropriate amount of reward, and end episode.
|
||||
the_plot.add_reward(self.value)
|
||||
the_plot.terminate_episode()
|
||||
|
||||
|
||||
class PlayerSprite(prefab_sprites.MazeWalker):
|
||||
"""The player / agent character.
|
||||
|
||||
MazeWalker class methods handle basic movement and collision detection.
|
||||
"""
|
||||
|
||||
def __init__(self, corner, position, character):
|
||||
super(PlayerSprite, self).__init__(
|
||||
corner, position, character, impassable="#")
|
||||
|
||||
def update(self, actions, board, layers, backdrop, things, the_plot):
|
||||
if the_plot["task_phase"] == "dance":
|
||||
# agent's actions are ignored, this logic updates the dance phases.
|
||||
if the_plot["time_until_next_dance"] > 0:
|
||||
the_plot["time_until_next_dance"] -= 1
|
||||
if the_plot["time_until_next_dance"] == 0: # next phase time!
|
||||
if the_plot["dance_order"]: # start the next dance!
|
||||
next_dancer = the_plot["dance_order"].pop(0)
|
||||
things[next_dancer].is_dancing = True
|
||||
else: # choice time!
|
||||
the_plot["task_phase"] = "choice"
|
||||
the_plot["instruction_string"] = the_plot[
|
||||
"choice_instruction_string"]
|
||||
elif the_plot["task_phase"] == "choice":
|
||||
# agent can now move and make its choice
|
||||
if actions == DIRECTIONS.N:
|
||||
self._north(board, the_plot)
|
||||
elif actions == DIRECTIONS.NE:
|
||||
self._northeast(board, the_plot)
|
||||
elif actions == DIRECTIONS.E:
|
||||
self._east(board, the_plot)
|
||||
elif actions == DIRECTIONS.SE:
|
||||
self._southeast(board, the_plot)
|
||||
elif actions == DIRECTIONS.S:
|
||||
self._south(board, the_plot)
|
||||
elif actions == DIRECTIONS.SW:
|
||||
self._southwest(board, the_plot)
|
||||
elif actions == DIRECTIONS.W:
|
||||
self._west(board, the_plot)
|
||||
elif actions == DIRECTIONS.NW:
|
||||
self._northwest(board, the_plot)
|
||||
|
||||
|
||||
def make_game(dancers_and_properties, dance_delay=16):
|
||||
"""Constructs an ascii map, then uses pycolab to make it a game.
|
||||
|
||||
Args:
|
||||
dancers_and_properties: list of (character, (row, column), motion, shape,
|
||||
color, value), for placing objects in the world.
|
||||
dance_delay: how long to wait between dances.
|
||||
|
||||
Returns:
|
||||
this_game: Pycolab engine running the specified game.
|
||||
"""
|
||||
num_rows, num_cols = ROOM_SIZE
|
||||
level_layout = []
|
||||
# upper wall
|
||||
level_layout.append("".join([WALL_CHAR] * num_cols))
|
||||
# room
|
||||
middle_string = "".join([WALL_CHAR] + [" "] * (num_cols - 2) + [WALL_CHAR])
|
||||
level_layout.extend([middle_string for _ in range(num_rows - 2)])
|
||||
# lower wall
|
||||
level_layout.append("".join([WALL_CHAR] * num_cols))
|
||||
|
||||
def _add_to_map(obj, loc):
|
||||
"""Adds an ascii character to the level at the requested position."""
|
||||
obj_row = level_layout[loc[0]]
|
||||
pre_string = obj_row[:loc[1]]
|
||||
post_string = obj_row[loc[1] + 1:]
|
||||
level_layout[loc[0]] = pre_string + obj + post_string
|
||||
|
||||
_add_to_map(AGENT_CHAR, AGENT_START)
|
||||
sprites = {AGENT_CHAR: PlayerSprite}
|
||||
dance_order = []
|
||||
char_to_color_shape = []
|
||||
# add dancers to level
|
||||
for obj, loc, motion, shape, color, value in dancers_and_properties:
|
||||
_add_to_map(obj, loc)
|
||||
sprites[obj] = ascii_art.Partial(
|
||||
DancerSprite, motion=motion, color=color, shape=shape, value=value)
|
||||
char_to_color_shape.append((obj, color + " " + shape))
|
||||
dance_order += obj
|
||||
if value > 0.:
|
||||
choice_instruction_string = motion
|
||||
|
||||
this_game = ascii_art.ascii_art_to_game(
|
||||
art=level_layout,
|
||||
what_lies_beneath=" ",
|
||||
sprites=sprites,
|
||||
update_schedule=[[AGENT_CHAR],
|
||||
dance_order])
|
||||
|
||||
this_game.the_plot["task_phase"] = "dance"
|
||||
this_game.the_plot["instruction_string"] = "watch"
|
||||
this_game.the_plot["choice_instruction_string"] = choice_instruction_string
|
||||
this_game.the_plot["dance_order"] = dance_order
|
||||
this_game.the_plot["dance_delay"] = dance_delay
|
||||
this_game.the_plot["time_until_next_dance"] = 1
|
||||
this_game.the_plot["char_to_color_shape"] = char_to_color_shape
|
||||
return this_game
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv # unused
|
||||
these_dancers_and_properties = [
|
||||
(POSSIBLE_DANCER_CHARS[1], (2, 2), "chevron_up", "triangle", "red", 1),
|
||||
(POSSIBLE_DANCER_CHARS[2], (2, 5), "circle_ccw", "triangle", "red", 0),
|
||||
(POSSIBLE_DANCER_CHARS[3], (2, 8), "plus_cw", "triangle", "red", 0),
|
||||
(POSSIBLE_DANCER_CHARS[4], (5, 2), "plus_ccw", "triangle", "red", 0),
|
||||
(POSSIBLE_DANCER_CHARS[5], (5, 8), "times_cw", "triangle", "red", 0),
|
||||
(POSSIBLE_DANCER_CHARS[6], (8, 2), "up_and_down", "plus", "blue", 0),
|
||||
(POSSIBLE_DANCER_CHARS[7], (8, 5), "left_and_right", "plus", "blue", 0),
|
||||
(POSSIBLE_DANCER_CHARS[8], (8, 8), "zee", "plus", "blue", 0),
|
||||
]
|
||||
|
||||
game = make_game(dancers_and_properties=these_dancers_and_properties)
|
||||
|
||||
# Note that these colors are only for human UI
|
||||
fg_colours = {
|
||||
AGENT_CHAR: (999, 999, 999), # Agent is white
|
||||
WALL_CHAR: (300, 300, 300), # Wall, dark grey
|
||||
FLOOR_CHAR: (0, 0, 0), # Floor
|
||||
}
|
||||
for (c, _, _, _, col, _) in these_dancers_and_properties:
|
||||
fg_colours[c] = (999, 0, 0) if col == "red" else (0, 0, 999)
|
||||
|
||||
bg_colours = {
|
||||
c: (0, 0, 0) for c in RESERVED_CHARS + POSSIBLE_DANCER_CHARS[1:8]
|
||||
}
|
||||
|
||||
ui = human_ui.CursesUi(
|
||||
keys_to_actions={
|
||||
# Basic movement.
|
||||
curses.KEY_UP: DIRECTIONS.N,
|
||||
curses.KEY_DOWN: DIRECTIONS.S,
|
||||
curses.KEY_LEFT: DIRECTIONS.W,
|
||||
curses.KEY_RIGHT: DIRECTIONS.E,
|
||||
-1: 8, # Do nothing.
|
||||
},
|
||||
delay=500,
|
||||
colour_fg=fg_colours,
|
||||
colour_bg=bg_colours)
|
||||
|
||||
ui.play(game)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for pycolab_ballet.ballet_environment_wrapper."""
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from hierarchical_transformer_memory.pycolab_ballet import ballet_environment
|
||||
from hierarchical_transformer_memory.pycolab_ballet import ballet_environment_core
|
||||
|
||||
|
||||
class BalletEnvironmentTest(parameterized.TestCase):
|
||||
|
||||
def test_full_wrapper(self):
|
||||
env = ballet_environment.BalletEnvironment(
|
||||
num_dancers=1, dance_delay=16, max_steps=200,
|
||||
rng=np.random.default_rng(seed=0))
|
||||
result = env.reset()
|
||||
self.assertIsNone(result.reward)
|
||||
level_size = ballet_environment_core.ROOM_SIZE
|
||||
upsample_size = ballet_environment.UPSAMPLE_SIZE
|
||||
# wait for dance to complete
|
||||
for i in range(30):
|
||||
result = env.step(0).observation
|
||||
self.assertEqual(result[0].shape,
|
||||
(level_size[0] * upsample_size,
|
||||
level_size[1] * upsample_size,
|
||||
3))
|
||||
self.assertEqual(str(result[1])[:5],
|
||||
np.array("watch"))
|
||||
for i in [1, 1, 1, 1]: # first gets eaten before agent can move
|
||||
result = env.step(i)
|
||||
self.assertEqual(result.observation[0].shape,
|
||||
(level_size[0] * upsample_size,
|
||||
level_size[1] * upsample_size,
|
||||
3))
|
||||
self.assertEqual(str(result.observation[1])[:11],
|
||||
np.array("up_and_down"))
|
||||
self.assertEqual(result.reward, 1.)
|
||||
# check egocentric scrolling is working, by checking object is in center
|
||||
np.testing.assert_array_almost_equal(
|
||||
result.observation[0][45:54, 45:54],
|
||||
ballet_environment._generate_template("orange plus") / 255.)
|
||||
|
||||
@parameterized.parameters(
|
||||
"2_delay16",
|
||||
"4_delay16",
|
||||
"8_delay48",
|
||||
)
|
||||
def test_simple_builder(self, level_name):
|
||||
dance_delay = int(level_name[-2:])
|
||||
np.random.seed(0)
|
||||
env = ballet_environment.simple_builder(level_name)
|
||||
# check max steps are set to match paper settings
|
||||
self.assertEqual(env._max_steps,
|
||||
320 if dance_delay == 16 else 1024)
|
||||
# test running a few steps of each
|
||||
env.reset()
|
||||
level_size = ballet_environment_core.ROOM_SIZE
|
||||
upsample_size = ballet_environment.UPSAMPLE_SIZE
|
||||
for i in range(8):
|
||||
result = env.step(i) # check all 8 movements work
|
||||
self.assertEqual(result.observation[0].shape,
|
||||
(level_size[0] * upsample_size,
|
||||
level_size[1] * upsample_size,
|
||||
3))
|
||||
self.assertEqual(str(result.observation[1])[:5],
|
||||
np.array("watch"))
|
||||
self.assertEqual(result.reward, 0.)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
@@ -0,0 +1,7 @@
|
||||
absl-py
|
||||
chex>=0.0.8
|
||||
dm-env
|
||||
dm-haiku>=0.0.5.dev
|
||||
jax>=0.2.17
|
||||
numpy
|
||||
pycolab
|
||||
Reference in New Issue
Block a user