diff --git a/tvt/README.md b/tvt/README.md
new file mode 100644
index 0000000..d454faa
--- /dev/null
+++ b/tvt/README.md
@@ -0,0 +1,189 @@
+# TVT: Temporal Value Transport
+
+An open source implementation of agents, algorithm and environments related to
+the paper [Optimizing Agent Behavior over Long Time Scales by Transporting Value](https://arxiv.org/abs/1810.06721).
+
+## Installation
+
+TVT package installation and training can run using: `tvt/run.sh`. This will use
+all default flag values for the training script `tvt/main.py`. See the section
+on running experiments below for launching with non-default flags.
+
+Note that the default installation uses tensorflow without gpu. Replace
+`tensorflow` by `tensorflow-gpu` in `tvt/requirements.txt` to use tensorflow
+with gpu.
+
+## Differences between this implementation and the paper
+
+In the paper agents were trained using a distributed A3C architecture with
+384 actors. This implementation runs a batched A2C agent on a single gpu machine
+with batch size 16.
+
+## Tasks
+
+### Pycolab tasks
+
+In order for this to train in a reasonable time on a single machine, we
+provide 2D grid world versions of the paper tasks using Pycolab, to replace
+the original DeepMind Lab 3D tasks.
+
+Further details of the tasks are given in the Pycolab directory README and users
+can also play the tasks themselves, from the command line.
+
+Special thanks to Hamza Merzic for writing the two Pycolab task scripts.
+
+### DeepMind Lab tasks
+
+The DeepMind Lab tasks used in the paper are also provided as part of this
+release.
+
+Further details of specific tasks are given in the DeepMind Lab directory
+README.
+
+## Running experiments
+
+### Launching
+
+To start an experiment, run:
+
+```
+source tvt_venv/bin/activate
+python3 -m tvt.main
+```
+
+This will launch a default setup that uses the RMA agent on the 'Key To Door'
+Pycolab task.
+
+### Important flags
+`tvt.main` accepts many flags.
+
+Note that all the default hyperparameters are tuned for the TVT-RMA agent to
+solve both `key_to_door` and `active_visual_match` Pycolab tasks.
+
+#### Information logging:
+`logging_frequency`: frequency of logging in console and tensorboard.
+`logdir`: Directory for tensorboard logging.
+
+#### Agent configuration:
+`with_memory`: default True. Whether or not agent has external memory. If set to
+False, then agent has only LSTM memory.
+`with_reconstruction`: default True. Whether or not agent reconstructs the
+observation as described in Reconstructive Memory Agent (RMA) architecture.
+`gamma`: Agent discount factor.
+`entropy_cost`: Weight of the entropy loss.
+`image_cost_weight`: Weight of image reconstruction loss.
+`read_strength_cost`: Weight of the memory read strength. Used to regularize the
+memory acess.
+`read_strength_tolerance`: The tolerance of hinge loss for the read strengths.
+
+`do_tvt`: default True. Whether or not to apply the Temporal Value Transport
+Algorithm (only works if the model has external memory).
+
+#### Optimization:
+`batch_size`: Batch size for the batched A2C algorithm.
+`learning_rate`: Learning rate for Adam optimizer.
+`beta1`: Adam optimizer beta1.
+`beta2`: Adam optimizer beta2.
+`epsilon` Adam optimizer epsilon.
+`num_episodes` Number of episodes to train for. None means run forever.
+
+#### Pycolab-specific flags:
+`pycolab_game`: Which game to run. One of 'key_to_door' or
+'active_visual_match'. See pycolab/README for description.
+
+`pycolab_num_apples`: Number of apples to sample from.
+`pycolab_apple_reward_min`: The minimum apple reward.
+`pycolab_apple_reward_max`: The maximum apple reward.
+`pycolab_fix_apple_reward_in_episode` default True. This fixes the sampled apple
+reward within an episode.
+`pycolab_final_reward`: Reward obtained at the last phase.
+`pycolab_crop`: default True. Whether to crop observations or not.
+
+
+### Monitoring results
+
+Key outputs are logged to the command line and to tensorboard logs.
+We can use [tensorboard](https://www.tensorflow.org/guide/summaries_and_tensorboard)
+to track the learning progress if FLAGS.logdir is set.
+```
+tensorboard --logdir=
+```
+
+Key values logged:
+`reward`: The total rewards agent acquired in an episode.
+`last phase reward`: The critical reward acquired in the exploit phase, which
+depends on the behavior in the exploring phase.
+`tvt reward`: The total fictitious rewards generated by the Temporal Value
+Transport algorithm.
+`total loss`: The sum of all losses, including policy gradient loss, value
+function loss, reconstruction loss, and memory read regularization loss. We also
+log these losses separatedly.
+
+## Example results
+
+Here we show the example results of running the TVT agent (with the default
+hyperparameters) and the best control RMA agent (with `do_tvt=False, gamma=1`).
+
+Since TVT is designed to reduce the variance in signal for learning rewards that
+are temporally far from the actions or information that lead to those rewards,
+in the paper we focus on the reward in the last phase of each task, which is
+the only reward that depends on actions or information from much earlier in the
+task than the time at which the reward is given. In the experiments here, the
+best way to track if TVT is working is by monitoring the `last phase reward`
+as this is the critical performance we are interested in - the agent with TVT
+and the control agents are doing well in the apple collecting phase, which
+contributes most of the episodic rewards, but not in the last phase.
+
+### Key-to-door
+Across 10 replicas, we found that the TVT agents get to a score of 10,
+meaning they reliably collected the key in the explore phase to open the door in
+the exploit phase.
+# 
+For 10 replicas without TVT and with the same hyperparameters, we see consistent
+low performance.
+# 
+For 5 replicas with gamma equal to 1, performance of the RMA agent without TVT
+is improved, but is unstable and never goes above 7.
+# 
+
+### Active-visual-match
+Across 10 replicas, we found that the TVT agents get to a score of 10,
+meaning they reliably searched for the pixel and remembered its color in the
+explore phase, and then touched the corresponding pixel in the exploit
+phase.
+# 
+For 10 replicas without TVT and with the same hyperparamters, performance is
+better than chance level but not at the maximum level, indicating that it is not
+able to actively seek for information in the explore phase and instead must rely
+on randomly encountering the information.
+# 
+For 5 replicas with gamma equal to 1, performance of the RMA agent without TVT
+is considerably worse, suggesting the behavior learnt from later phases does not
+result in undirected exploration in the first phase.
+# 
+
+## Citing this work
+
+If you use this code in your work, please cite the accompanying paper:
+
+```
+@article{
+ author = {Chia{-}Chun Hung and
+ Timothy P. Lillicrap and
+ Josh Abramson and
+ Yan Wu and
+ Mehdi Mirza and
+ Federico Carnevale and
+ Arun Ahuja and
+ Greg Wayne},
+ title = {Optimizing Agent Behavior over Long Time Scales by Transporting Value},
+ journal = {Nat Commun},
+ volume = {10},
+ year = {2019},
+ doi = {https://doi.org/10.1038/s41467-019-13073-w},
+}
+```
+
+## Disclaimer
+
+This is not an officially supported Google or DeepMind product.
diff --git a/tvt/batch_env.py b/tvt/batch_env.py
new file mode 100644
index 0000000..bdc9c1f
--- /dev/null
+++ b/tvt/batch_env.py
@@ -0,0 +1,110 @@
+# Lint as: python2, python3
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Threaded batch environment wrapper."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from concurrent import futures
+
+from six.moves import range
+from six.moves import zip
+
+from tvt import nest_utils
+
+
+class BatchEnv(object):
+ """Wrapper that steps multiple environments in separate threads.
+
+ The threads are stepped in lock step, so all threads progress by one step
+ before any move to the next step.
+ """
+
+ def __init__(self, batch_size, env_builder, **env_kwargs):
+ self.batch_size = batch_size
+ self._envs = [env_builder(**env_kwargs) for _ in range(batch_size)]
+ self._num_actions = self._envs[0].num_actions
+ self._observation_shape = self._envs[0].observation_shape
+ self._episode_length = self._envs[0].episode_length
+
+ self._executor = futures.ThreadPoolExecutor(max_workers=self.batch_size)
+
+ def reset(self):
+ """Reset the entire batch of environments."""
+
+ def reset_environment(env):
+ return env.reset()
+
+ try:
+ output_list = []
+ for env in self._envs:
+ output_list.append(self._executor.submit(reset_environment, env))
+ output_list = [env_output.result() for env_output in output_list]
+ except KeyboardInterrupt:
+ self._executor.shutdown(wait=True)
+ raise
+
+ observations, rewards = nest_utils.nest_stack(output_list)
+ return observations, rewards
+
+ def step(self, action_list):
+ """Step batch of envs.
+
+ Args:
+ action_list: A list of actions, one per environment in the batch. Each one
+ should be a scalar int or a numpy scaler int.
+
+ Returns:
+ A tuple (observations, rewards):
+ observations: A nest of observations, each one a numpy array where the
+ first dimension has size equal to the number of environments in the
+ batch.
+ rewards: An array of rewards with size equal to the number of
+ environments in the batch.
+ """
+
+ def step_environment(env, action):
+ return env.step(action)
+
+ try:
+ output_list = []
+ for env, action in zip(self._envs, action_list):
+ output_list.append(self._executor.submit(step_environment, env, action))
+ output_list = [env_output.result() for env_output in output_list]
+ except KeyboardInterrupt:
+ self._executor.shutdown(wait=True)
+ raise
+
+ observations, rewards = nest_utils.nest_stack(output_list)
+ return observations, rewards
+
+ @property
+ def observation_shape(self):
+ """Observation shape per environment, i.e. with no batch dimension."""
+ return self._observation_shape
+
+ @property
+ def num_actions(self):
+ return self._num_actions
+
+ @property
+ def episode_length(self):
+ return self._episode_length
+
+ def last_phase_rewards(self):
+ return [env.last_phase_reward() for env in self._envs]
diff --git a/tvt/dmlab/README.md b/tvt/dmlab/README.md
new file mode 100644
index 0000000..2cb07d3
--- /dev/null
+++ b/tvt/dmlab/README.md
@@ -0,0 +1,66 @@
+# DM Lab Tasks
+
+## General Structure
+
+There are 7 [DM Lab](https://github.com/deepmind/lab) tasks presented here.
+Each level is composed of 3 distinct phases (except `Key To Door To Match`
+which has 5 phases). The first phase is the 'explore' phase, where the agent
+should learn a piece of information or do something. For all tasks, the 2nd
+phase is the 'distractor' phase, where the agent collects apples for rewards.
+The 3rd phase is the 'exploit' phase, where the agent gets rewards based on the
+knowledge acquired or actions performed in phase 1.
+
+## Specific Tasks
+
+### Passive Visual Match
+
+* Phase 1: A colour square right in front of the agent.
+* Phase 2: Apples collection.
+* Phase 3: Choose the colour square matched that in Phase 1 among 4 options.
+
+### Active Visual Match
+
+* Phase 1: A colour square randomly placed in a two-connected room.
+* Phase 2: Apples collection.
+* Phase 3: Choose the colour square matched that in Phase 1 among 4 options.
+
+### Key To Door
+
+* Phase 1: A key randomly placed in a two-connected room.
+* Phase 2: Apples collection.
+* Phase 3: A small room with a door. If agent has key, it can open the door to
+ get to the goal behind the door to get reward.
+
+### Key To Door Bluekey
+
+All the same as key_to_door above but the key has a blue colour instead of
+black.
+
+### Two Negative Keys
+
+* Phase 1: A blue and a red key placed in a small room. The agent can only
+ pick up one of the key.
+* Phase 2: Apples collection.
+* Phase 3: A small room with a door. If agent has either key, it can open the
+ door to get reward. The reward depends on which key it got in Phase 1
+ All the rewards are negative in this level.
+
+### Latent Information Acquisition
+
+* Phase 1: Thre randomly sampled objects are randomly placed in a small room.
+ When the agent touch each object, a red or green cue will appear,
+ indicating the reward it is associated in this episode. No rewards
+ are given in this phase.
+* Phase 2: Apples collection.
+* Phase 3: The same three objects in Phase 1 randomly placed again in the room.
+ The agent will get positive rewards if pick up the objects with green
+ cues in Phase 1, and get negative rewards for objects with red cues.
+
+### Key To Door To Match
+
+* Phase 1: A key is randomly placed in a room. Agent could pick it up.
+* Phase 2: Apples collection.
+* Phase 3: A colour square behind a door. If agent has key from Phase 1, it can
+ open the door to see the colour.
+* Phase 4: Apples collection.
+* Phase 5: Chose the colour square matched that in Phase 3 among 4 options.
diff --git a/tvt/dmlab/active_visual_match.lua b/tvt/dmlab/active_visual_match.lua
new file mode 100644
index 0000000..4c361e2
--- /dev/null
+++ b/tvt/dmlab/active_visual_match.lua
@@ -0,0 +1,24 @@
+-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+-- http://www.apache.org/licenses/LICENSE-2.0
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ============================================================================
+local factory = require 'visual_match_factory'
+
+return factory.createLevelApi{
+ exploreMapMode = 'TWO_ROOMS',
+ episodeLengthSeconds = 40,
+ exploreLengthSeconds = 5,
+ distractorLengthSeconds = 30,
+
+ differentDistractRoomTexture = true,
+ differentRewardRoomTexture = true,
+ correctReward = 10,
+ incorrectReward = 1,
+}
diff --git a/tvt/dmlab/image_utils.lua b/tvt/dmlab/image_utils.lua
new file mode 100644
index 0000000..c2211f0
--- /dev/null
+++ b/tvt/dmlab/image_utils.lua
@@ -0,0 +1,42 @@
+-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+-- http://www.apache.org/licenses/LICENSE-2.0
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ============================================================================
+local tensor = require 'dmlab.system.tensor'
+
+local utils = {}
+utils.COLORS = {
+ {0, 0, 0},
+ {0, 0, 170},
+ {0, 170, 0},
+ {0, 170, 170},
+ {170, 0, 0},
+ {170, 0, 170},
+ {170, 85, 0},
+ {170, 170, 170},
+ {85, 85, 85},
+ {85, 85, 255},
+ {85, 255, 85},
+ {85, 255, 255},
+ {255, 85, 85},
+ {255, 85, 255},
+ {255, 255, 85},
+ {255, 255, 255},
+}
+
+function utils:createByteImage(h, w, rgb)
+ return tensor.ByteTensor(h, w, 4):fill{rgb[1], rgb[2], rgb[3], 255}
+end
+
+function utils:createTransparentImage(h, w)
+ return tensor.ByteTensor(h, w, 4):fill{127, 127, 127, 0}
+end
+
+return utils
diff --git a/tvt/dmlab/key_to_door.lua b/tvt/dmlab/key_to_door.lua
new file mode 100644
index 0000000..af120cc
--- /dev/null
+++ b/tvt/dmlab/key_to_door.lua
@@ -0,0 +1,20 @@
+-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+-- http://www.apache.org/licenses/LICENSE-2.0
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ============================================================================
+local factory = require 'key_to_door_factory'
+
+return factory.createLevelApi{
+ episodeLengthSeconds = 37,
+ exploreLengthSeconds = 5,
+ distractorLengthSeconds = 30,
+ differentDistractRoomTexture = true,
+ differentRewardRoomTexture = true,
+}
diff --git a/tvt/dmlab/key_to_door_bluekey.lua b/tvt/dmlab/key_to_door_bluekey.lua
new file mode 100644
index 0000000..fe85183
--- /dev/null
+++ b/tvt/dmlab/key_to_door_bluekey.lua
@@ -0,0 +1,22 @@
+-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+-- http://www.apache.org/licenses/LICENSE-2.0
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ============================================================================
+local factory = require 'key_to_door_factory'
+
+return factory.createLevelApi{
+ keyColor = {0, 0, 255},
+ episodeLengthSeconds = 37,
+ exploreLengthSeconds = 5,
+ distractorLengthSeconds = 30,
+ differentDistractRoomTexture = true,
+ differentRewardRoomTexture = true,
+}
+
diff --git a/tvt/dmlab/key_to_door_factory.lua b/tvt/dmlab/key_to_door_factory.lua
new file mode 100644
index 0000000..2df8d32
--- /dev/null
+++ b/tvt/dmlab/key_to_door_factory.lua
@@ -0,0 +1,459 @@
+-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+-- http://www.apache.org/licenses/LICENSE-2.0
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ============================================================================
+local make_map = require 'common.make_map'
+local custom_observations = require 'decorators.custom_observations'
+local debug_observations = require 'decorators.debug_observations'
+local game = require 'dmlab.system.game'
+local map_maker = require 'dmlab.system.map_maker'
+local maze_generation = require 'dmlab.system.maze_generation'
+local pickup_decorator = require 'decorators.human_recognisable_pickups'
+local random = require 'common.random'
+local setting_overrides = require 'decorators.setting_overrides'
+local texture_sets = require 'themes.texture_sets'
+local themes = require 'themes.themes'
+local hrp = require 'common.human_recognisable_pickups'
+
+local DEFAULTS = {
+ EPISODE_LENGTH_SECONDS = 15,
+ EXPLORE_LENGTH_SECONDS = 5,
+ DISTRACTOR_LENGTH_SECONDS = 5,
+ REWARD_LENGTH_SECONDS = nil,
+ SHOW_KEY_COLOR_SQUARE_SECONDS = 1,
+ PROB_APPLE_IN_DISTRACTOR_MAP = 0.3,
+ APPLE_REWARD = 5,
+ APPLE_REWARD_PROB = 1.0,
+ APPLE_EXTRA_REWARD_RANGE = 0,
+ GOAL_REWARD = 10,
+ DISTRACTOR_ROOM_SIZE = {11, 11},
+ DIFFERENT_DISTRACT_ROOM_TEXTURE = false,
+ DIFFERENT_REWARD_ROOM_TEXTURE = false,
+ KEY_COLOR = {0, 0, 0},
+}
+
+local APPLE_ID = 998
+local GOAL_ID = 999
+local KEY_SPAWN_ID = 1000
+local DOOR_ID = 1001
+
+local KEY_CUE_RECTANGLE_WIDTH = 600
+local KEY_CUE_RECTANGLE_HEIGHT = 200
+
+-- Table that maps from full decal name to decal index number.
+local decalIndices = {}
+
+local EXPLORE_MAP = "exploreMap"
+local DISTRACTOR_MAP = "distractorMap"
+local REWARD_MAP = "rewardMap"
+
+-- Set texture set for all maps.
+local textureSet = texture_sets.PACMAN
+local secondTextureSet = texture_sets.TETRIS
+local thirdTextureSet = texture_sets.TRON
+
+local REWARD_ROOM =[[
+***
+*P*
+*H*
+*G*
+***
+]]
+
+local OPEN_TWO_ROOM = [[
+*********
+*********
+*PKK*KKK*
+*KKKKKKK*
+*KKK*KKK*
+*********
+]]
+local N_KEY_POS_IN_TWO_ROOM = 18 -- # of K in OPEN_TWO_ROOM
+
+local function createDistractorMaze(opts)
+ -- Example room with height = 2, width = 3
+ -- A are possible apple locations (everywhere)
+ -- *****
+ -- *APA*
+ -- *AAA*
+ -- *****
+
+ local roomHeight = opts.roomSize[1]
+ local roomWidth = opts.roomSize[2]
+ centerWidth = 1 + math.ceil(roomWidth / 2)
+ local maze = maze_generation:mazeGeneration{
+ height = roomHeight + 2, -- +2 for the two side of walls
+ width = roomWidth + 2
+ }
+
+ -- Fill the room with 'A' for apples. updateSpawnVars decides where to put.
+ for i = 2, roomHeight + 1 do
+ for j = 2, roomWidth + 1 do
+ maze:setEntityCell(i, j, 'A')
+ end
+ end
+ -- Override one cell with 'P' for spawn point.
+ maze:setEntityCell(2, centerWidth, 'P')
+ return maze
+end
+
+local function numPossibleAppleLocations(distractorRoomSize)
+ return distractorRoomSize[1] * distractorRoomSize[2] - 1
+end
+
+local factory = {}
+game:console('cg_drawScriptRectanglesAlways 1')
+
+function factory.createLevelApi(kwargs)
+ kwargs.episodeLengthSeconds = kwargs.episodeLengthSeconds or
+ DEFAULTS.EPISODE_LENGTH_SECONDS
+ kwargs.exploreLengthSeconds = kwargs.exploreLengthSeconds or
+ DEFAULTS.EXPLORE_LENGTH_SECONDS
+ kwargs.rewardLengthSeconds = kwargs.rewardLengthSeconds or
+ DEFAULTS.REWARD_LENGTH_SECONDS
+ kwargs.distractorLengthSeconds = kwargs.distractorLengthSeconds or
+ DEFAULTS.DISTRACTOR_LENGTH_SECONDS
+ kwargs.distractorRoomSize = kwargs.distractorRoomSize or
+ DEFAULTS.DISTRACTOR_ROOM_SIZE
+
+ kwargs.appleReward = kwargs.appleReward or DEFAULTS.APPLE_REWARD
+ kwargs.appleRewardProb = kwargs.appleRewardProb or DEFAULTS.APPLE_REWARD_PROB
+ kwargs.probAppleInDistractorMap = kwargs.probAppleInDistractorMap or
+ DEFAULTS.PROB_APPLE_IN_DISTRACTOR_MAP
+
+ kwargs.appleExtraRewardRange =
+ kwargs.appleExtraRewardRange or DEFAULTS.APPLE_EXTRA_REWARD_RANGE
+
+ kwargs.differentDistractRoomTexture = kwargs.differentDistractRoomTexture or
+ DEFAULTS.DIFFERENT_DISTRACT_ROOM_TEXTURE
+
+ kwargs.differentRewardRoomTexture = kwargs.differentRewardRoomTexture or
+ DEFAULTS.DIFFERENT_REWARD_ROOM_TEXTURE
+
+ kwargs.showKeyColorSquareSeconds = kwargs.showKeyColorSquareSeconds or
+ DEFAULTS.SHOW_KEY_COLOR_SQUARE_SECONDS
+ kwargs.goalReward = kwargs.goalReward or DEFAULTS.GOAL_REWARD
+ kwargs.keyColor = kwargs.keyColor or DEFAULTS.KEY_COLOR
+
+ local api = {}
+
+ function api:init(params)
+ self:_createExploreMap()
+ self:_createDistractorMap()
+ self:_createRewardMap()
+
+ local keyInfo = {
+ shape='key',
+ pattern='solid',
+ color1 = kwargs.keyColor,
+ color2 = kwargs.keyColor
+ }
+ self._keyObject = hrp.create(keyInfo)
+ self._keyCueRgba = {
+ kwargs.keyColor[1]/255,
+ kwargs.keyColor[2]/255,
+ kwargs.keyColor[3]/255,
+ 1
+ }
+ end
+
+ function api:_createRewardMap()
+ self._rewardMap = map_maker:mapFromTextLevel{
+ mapName = REWARD_MAP,
+ entityLayer = REWARD_ROOM,
+ }
+
+ -- Create map theme and override default wall decal placement.
+ local texture = textureSet
+ if kwargs.differentRewardRoomTexture then
+ texture = thirdTextureSet
+ end
+ local rewardMapTheme = themes.fromTextureSet{
+ textureSet = texture,
+ decalFrequency = 0.0,
+ floorModelFrequency = 0.0,
+ }
+
+ self._rewardMap = map_maker:mapFromTextLevel{
+ mapName = REWARD_MAP,
+ entityLayer = REWARD_ROOM,
+ theme = rewardMapTheme,
+ callback = function (i, j, c, maker)
+ local pickup = self:_makePickup(c)
+ if pickup then
+ return maker:makeEntity{i = i, j = j, classname = pickup}
+ end
+ end
+ }
+ end
+
+ function api:_createExploreMap()
+ exploreMapInfo = {map = OPEN_TWO_ROOM}
+
+ -- Create map theme and override default wall decal placement.
+ local exploreMapTheme = themes.fromTextureSet{
+ textureSet = textureSet,
+ decalFrequency = 0.0,
+ floorModelFrequency = 0.0,
+ }
+
+ self._exploreMap = map_maker:mapFromTextLevel{
+ mapName = EXPLORE_MAP,
+ entityLayer = exploreMapInfo.map,
+ theme = exploreMapTheme,
+ callback = function (i, j, c, maker)
+ local pickup = self:_makePickup(c)
+ if pickup then
+ return maker:makeEntity{i = i, j = j, classname = pickup}
+ end
+ end
+ }
+ end
+
+ function api:_createDistractorMap()
+ -- Create maze to be converted into map.
+ local maze = createDistractorMaze{roomSize = kwargs.distractorRoomSize}
+
+ -- Create map theme with no wall decals.
+ local texture = textureSet
+ if kwargs.differentDistractRoomTexture then
+ texture = secondTextureSet
+ end
+ local distractorMapTheme = themes.fromTextureSet{
+ textureSet = texture,
+ decalFrequency = 0.0,
+ floorModelFrequency = 0.0,
+ }
+
+ self._distractorMap = map_maker:mapFromTextLevel{
+ mapName = DISTRACTOR_MAP,
+ entityLayer = maze:entityLayer(),
+ theme = distractorMapTheme,
+ callback = function (i, j, c, maker)
+ local pickup = self:_makePickup(c)
+ if pickup then
+ return maker:makeEntity{i = i, j = j, classname = pickup}
+ end
+ end
+ }
+ end
+
+ function api:start(episode, seed)
+ random:seed(seed)
+
+ self._map = nil
+ self._time = 0
+ self._holdingKey = false
+ self._keyPosCount = 0
+ self._collectedGoal = false
+
+ if kwargs.distractorLengthSecondsRange then
+ self._distractorLen = random:uniformReal(
+ kwargs.distractorLengthSecondsRange[1],
+ kwargs.distractorLengthSecondsRange[2])
+ else
+ self._distractorLen = kwargs.distractorLengthSeconds
+ end
+
+ -- Sample the key position in phase 1.
+ self._keyPosition = random:uniformInt(1, N_KEY_POS_IN_TWO_ROOM)
+
+ -- Default instruction channel to 0 (indicating the rewards in final phase.)
+ self.setInstruction(tostring(0))
+ end
+
+ function api:filledRectangles(args)
+ if self._showKeyCue then
+ return {{
+ x = 12,
+ y = 12,
+ width = KEY_CUE_RECTANGLE_WIDTH,
+ height = KEY_CUE_RECTANGLE_HEIGHT,
+ rgba = self._keyCueRgba
+ }}
+ end
+ return {}
+ end
+
+ function api:nextMap()
+ -- 1. Decide what is the next map.
+ if self._map == nil then
+ self._map = EXPLORE_MAP
+ elseif self._map == DISTRACTOR_MAP then
+ self._map = REWARD_MAP
+ elseif self._map == EXPLORE_MAP then
+ if self._distractorLen > 0.0 then
+ self._map = DISTRACTOR_MAP
+ else
+ self._map = REWARD_MAP
+ end
+ elseif self._map == REWARD_MAP then
+ -- Stay in distractor map till end of episode.
+ self._map = DISTRACTOR_MAP
+ self._collectedGoal = true
+ end
+
+ -- 2. Set up timeout for the up-coming map.
+ if self._map == DISTRACTOR_MAP and self._collectedGoal then
+ if not self._timeOut then -- don't override any existing timeout
+ self._timeOut = self._time + 0.1
+ end
+ elseif self._map == EXPLORE_MAP then
+ self._timeOut = self._time + kwargs.exploreLengthSeconds
+ elseif self._map == DISTRACTOR_MAP then
+ self._timeOut = self._time + self._distractorLen
+ elseif self._map == REWARD_MAP then
+ if kwargs.rewardLengthSeconds then
+ self._timeOut = self._time + kwargs.rewardLengthSeconds
+ else
+ self._timeOut = nil
+ end
+ end
+
+ return self._map
+ end
+
+ -- PICKUP functions ----------------------------------------------------------
+
+ function api:_makePickup(c)
+ if c == 'K' then
+ return 'key'
+ end
+ if c == 'G' then
+ return 'goal'
+ end
+ if c == 'A' then
+ return 'apple_reward'
+ end
+ end
+
+ function api:pickup(spawnId)
+ if spawnId == GOAL_ID then
+ local goalReward = kwargs.goalReward
+ game:addScore(goalReward - 10) -- Offset the default +10 for goal.
+ self.setInstruction(tostring(goalReward))
+ game:finishMap()
+ end
+ if spawnId == KEY_SPAWN_ID then
+ self._holdingKey = true
+ self._holdingKeyTime = self._time -- When the avatar got the key.
+ self._showKeyCue = true
+ end
+
+ if spawnId == APPLE_ID then
+ if kwargs.appleRewardProb >= 1 or
+ random:uniformReal(0, 1) < kwargs.appleRewardProb then
+ -- The -1 is to offset the default 1 point for apple in dmlab
+ appleReward = kwargs.appleReward +
+ random:uniformInt(0, kwargs.appleExtraRewardRange) - 1
+ game:addScore(appleReward)
+ else
+ -- The -1 is to offset the default 1 point for apple in dmlab
+ game:addScore(-1)
+ end
+ end
+ end
+
+ -- TRIGGER functions ---------------------------------------------------------
+
+ function api:canTrigger(teleportId, targetName)
+ if string.sub(targetName, 1, 4) == 'door' then
+ if self._holdingKey then
+ return true
+ else
+ return false
+ end
+ end
+ return true
+ end
+
+ function api:trigger(teleportId, targetName)
+ if string.sub(targetName, 1, 4) == 'door' then
+ -- When door opend, stop showing key cue, and set holding key to false.
+ self._showKeyCue = false
+ self._holdingKey = false
+ return
+ end
+ end
+
+ function api:hasEpisodeFinished(timeSeconds)
+ self._time = timeSeconds
+
+ if self._map == REWARD_MAP or self._collectedGoal then
+ return self._timeOut and timeSeconds > self._timeOut
+ end
+
+ -- Control the timing of showing key cue.
+ if self._holdingKey then
+ local showTime = self._time - self._holdingKeyTime
+ if showTime > kwargs.showKeyColorSquareSeconds then
+ self._showKeyCue = false
+ end
+ end
+
+ if self._map == EXPLORE_MAP or self._map == DISTRACTOR_MAP then
+ if timeSeconds > self._timeOut then
+ game:finishMap()
+ end
+ return false
+ end
+ end
+
+ -- END TRIGGER functions -----------------------------------------------------
+
+ function api:updateSpawnVars(spawnVars)
+ local classname = spawnVars.classname
+ if classname == "info_player_start" then
+ -- Spawn facing South.
+ spawnVars.angle = "-90"
+ spawnVars.randomAngleRange = "0"
+ elseif classname == "func_door" then
+ spawnVars.id = tostring(DOOR_ID)
+ spawnVars.wait = "1000000" -- Open the door for long time.
+ elseif classname == "goal" then
+ spawnVars.id = tostring(GOAL_ID)
+ elseif classname == "apple_reward" then
+ -- We respawn the avatar to distractor room after reaching goal
+ -- there will be no more apples in this case.
+ if self._collectedGoal == true then
+ return nil
+ end
+ local useApple = false
+ if kwargs.probAppleInDistractorMap > 0 then
+ useApple = random:uniformReal(0, 1) < kwargs.probAppleInDistractorMap
+ end
+ if useApple then
+ spawnVars.id = tostring(APPLE_ID)
+ else
+ return nil
+ end
+ elseif classname == "key" then
+ self._keyPosCount = self._keyPosCount + 1
+ if self._keyPosition == self._keyPosCount then
+ spawnVars.id = tostring(KEY_SPAWN_ID)
+ spawnVars.classname = self._keyObject
+ else
+ return nil
+ end
+ end
+ return spawnVars
+ end
+
+ custom_observations.decorate(api)
+ pickup_decorator.decorate(api)
+ setting_overrides.decorate{
+ api = api,
+ apiParams = kwargs,
+ decorateWithTimeout = true
+ }
+ return api
+end
+
+return factory
diff --git a/tvt/dmlab/key_to_door_to_match.lua b/tvt/dmlab/key_to_door_to_match.lua
new file mode 100644
index 0000000..b8c0ffc
--- /dev/null
+++ b/tvt/dmlab/key_to_door_to_match.lua
@@ -0,0 +1,28 @@
+-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+-- http://www.apache.org/licenses/LICENSE-2.0
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ============================================================================
+local factory = require 'visual_match_factory'
+
+return factory.createLevelApi{
+ exploreMapMode = 'KEY_TO_COLOR',
+ episodeLengthSeconds = 45,
+ secondOrderExploreLengthSeconds = 5,
+ preExploreDistractorLengthSeconds = 15,
+ exploreLengthSeconds = 5,
+ distractorLengthSeconds = 15,
+
+ differentDistractRoomTexture = true,
+ differentRewardRoomTexture = true,
+ differentSecondOrderRoomTexture = true,
+ secondOrderExploreRoomSize = {4, 4},
+ correctReward = 10,
+ incorrectReward = 1,
+}
diff --git a/tvt/dmlab/latent_information_acquisition.lua b/tvt/dmlab/latent_information_acquisition.lua
new file mode 100644
index 0000000..b388c29
--- /dev/null
+++ b/tvt/dmlab/latent_information_acquisition.lua
@@ -0,0 +1,23 @@
+-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+-- http://www.apache.org/licenses/LICENSE-2.0
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ============================================================================
+local factory = require 'latent_information_acquisition_factory'
+
+return factory.createLevelApi{
+ episodeLengthSeconds = 40,
+ exploreLengthSeconds = 5,
+ distractorLengthSeconds = 30,
+ numObjects = 3,
+ probGoodObject = 0.5,
+ correctReward = 20,
+ incorrectReward = -10,
+ differentDistractRoomTexture = true,
+}
diff --git a/tvt/dmlab/latent_information_acquisition_factory.lua b/tvt/dmlab/latent_information_acquisition_factory.lua
new file mode 100644
index 0000000..e87d7c6
--- /dev/null
+++ b/tvt/dmlab/latent_information_acquisition_factory.lua
@@ -0,0 +1,418 @@
+-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+-- http://www.apache.org/licenses/LICENSE-2.0
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ============================================================================
+local make_map = require 'common.make_map'
+local custom_decals = require 'decorators.custom_decals_decoration'
+local custom_entities = require 'common.custom_entities'
+local custom_observations = require 'decorators.custom_observations'
+local datasets_selector = require 'datasets.selector'
+local game = require 'dmlab.system.game'
+local maze_generation = require 'dmlab.system.maze_generation'
+local pickup_decorator = require 'decorators.human_recognisable_pickups'
+local random = require 'common.random'
+local setting_overrides = require 'decorators.setting_overrides'
+local texture_sets = require 'themes.texture_sets'
+local themes = require 'themes.themes'
+local hrp = require 'common.human_recognisable_pickups'
+
+local SHOW_COLOR_CUE_SECOND = 0.25
+local EPISODE_LENGTH_SECONDS = 30
+local EXPLORE_LENGTH_SECONDS = 10
+local DISTRACTOR_LENGTH_SECONDS = 10
+local NUM_OBJECTS = 3
+local PROB_GOOD_OBJECT = 0.5
+local GAURANTEE_GOOD_OBJECTS = 0
+local GAURANTEE_BAD_OBJECTS = 0
+
+local PROB_APPLE_IN_DISTRACTOR_MAP = 0.3
+local APPLE_REWARD = 5
+local APPLE_EXTRA_REWARD_RANGE = 0
+local DISTRACTOR_ROOM_SIZE = {11, 11}
+local APPLE_ID = 1000
+local CORRECT_REWARD = 2
+local INCORRECT_REWARD = -1
+local ROOM_SIZE = {3, 5}
+local OBJECT_SCALE = 1.62
+
+local EXPLORE_MAP = "exploreMap"
+local DISTRACTOR_MAP = "distractorMap"
+local EXPLOIT_MAP = "exploitMap"
+
+
+local DIFFERENT_DISTRACT_ROOM_TEXTURE = false
+
+-- Set texture set for all maps.
+local textureSet = texture_sets.TRON
+local secondTextureSet = texture_sets.TETRIS
+
+-- Takes goal/location:i -> i
+local function nameToLocationId(name)
+ return tonumber(name:match('^.+:(%d+)$'))
+end
+
+-- Takes goal/location:i -> goal/pickup
+local function nameToLocationClass(name)
+ return name:match('^(.+):%d+$')
+end
+
+local factory = {}
+game:console('cg_drawScriptRectanglesAlways 1')
+
+function factory.createLevelApi(kwargs)
+ kwargs.episodeLengthSeconds = kwargs.episodeLengthSeconds or
+ EPISODE_LENGTH_SECONDS
+ kwargs.exploreLengthSeconds = kwargs.exploreLengthSeconds or
+ EXPLORE_LENGTH_SECONDS
+ if kwargs.distractorLengthSeconds == 0 then
+ kwargs.skipDistractor = true
+ else
+ kwargs.distractorLengthSeconds = kwargs.distractorLengthSeconds or
+ DISTRACTOR_LENGTH_SECONDS
+ end
+ kwargs.numObjects = kwargs.numObjects or NUM_OBJECTS
+ kwargs.probGoodObject = kwargs.probGoodObject or PROB_GOOD_OBJECT
+ kwargs.guaranteeGoodObjects = kwargs.guaranteeGoodObjects or
+ GAURANTEE_GOOD_OBJECTS
+ kwargs.guaranteeBadObjects = kwargs.guaranteeBadObjects or
+ GAURANTEE_BAD_OBJECTS
+ kwargs.correctReward = kwargs.correctReward or CORRECT_REWARD
+ kwargs.incorrectReward = kwargs.incorrectReward or INCORRECT_REWARD
+ kwargs.roomSize = kwargs.roomSize or ROOM_SIZE
+ kwargs.distractorRoomSize = kwargs.distractorRoomSize or DISTRACTOR_ROOM_SIZE
+ kwargs.probAppleInDistractorMap = kwargs.probAppleInDistractorMap or
+ PROB_APPLE_IN_DISTRACTOR_MAP
+ kwargs.differentDistractRoomTexture = kwargs.differentDistractRoomTexture or
+ DIFFERENT_DISTRACT_ROOM_TEXTURE
+ kwargs.appleReward = kwargs.appleReward or APPLE_REWARD
+ kwargs.appleExtraRewardRange = kwargs.appleExtraRewardRange or
+ APPLE_EXTRA_REWARD_RANGE
+ kwargs.objectScale = kwargs.objectScale or OBJECT_SCALE
+
+ local api = {}
+
+ function api:init(params)
+ self:_createExploreMap()
+ self:_createDistractorMap()
+ self:_createExploitMap()
+ end
+
+ function api:pickup(spawnId)
+ if self._map == EXPLORE_MAP then
+ -- Setup to show color cue.
+ self._showObjectCue = true
+ self._cueColor = self._objects[spawnId].cueColor
+ self._cueStartTime = self._time
+ elseif self._map == EXPLOIT_MAP then
+ -- Give corresponding reward and termiante when all good objects collected
+ game:addScore(self._objects[spawnId].reward)
+ -- Update the instruction channel (to record final phase rewards.)
+ self._finalRewardMainTask = (
+ self._finalRewardMainTask + self._objects[spawnId].reward)
+ self.setInstruction(tostring(self._finalRewardMainTask))
+ end
+
+ if spawnId == APPLE_ID then
+ -- note the -1 to offset default 1 point for apple in dmlab
+ appleReward = kwargs.appleReward +
+ random:uniformInt(0, kwargs.appleExtraRewardRange) - 1
+ game:addScore(appleReward)
+ end
+ end
+
+ function api:_createRoomCommon()
+ local roomHeight = kwargs.roomSize[1]
+ local roomWidth = kwargs.roomSize[2]
+ local maze = maze_generation:mazeGeneration{
+ height = roomHeight + 2,
+ width = roomWidth + 2
+ }
+
+ -- Set (2,2) as 'P' for the avatar location.
+ -- Set (i,j) as 'O' for possible object location if i%2 == 0 && j%2 == 0.
+ -- Otherwise, fill with '.' for empty location.
+ self._numLocations = 0
+ for i = 2, roomHeight + 1 do
+ for j = 2, roomWidth + 1 do
+ if i == 2 and j == 2 then
+ maze:setEntityCell(i, j, 'P')
+ elseif i % 2 == 0 and j % 2 == 0 then
+ maze:setEntityCell(i, j, 'O')
+ self._numLocations = self._numLocations + 1
+ else
+ maze:setEntityCell(i, j, '.')
+ end
+ end
+ end
+
+ return maze
+ end
+
+ function api:_createExploreMap()
+ maze = self:_createRoomCommon()
+ print('Generated explore maze with entity layer:')
+ print(maze:entityLayer())
+ io.flush()
+
+ local mapTheme = themes.fromTextureSet{
+ textureSet = textureSet,
+ decalFrequency = 0.0,
+ }
+
+ local counter = 1
+ self._exploreMap = make_map.makeMap{
+ mapName = EXPLORE_MAP,
+ mapEntityLayer = maze:entityLayer(),
+ theme = mapTheme,
+ callback = function (i, j, c, maker)
+ if c == 'O' then
+ pickup = 'location:' .. counter
+ counter = counter + 1
+ return maker:makeEntity{i = i, j = j, classname = pickup}
+ end
+ end
+ }
+ end
+
+ function api:_createDistractorMap()
+ -- Create map theme with no wall decals.
+ local distractorMapTheme = themes.fromTextureSet{
+ textureSet = textureSet,
+ decalFrequency = 0.0,
+ }
+
+ -- Example room with height = 2, width = 3
+ -- *****
+ -- *APA*
+ -- *AAA*
+ -- *****
+ local roomHeight = kwargs.distractorRoomSize[1]
+ local roomWidth = kwargs.distractorRoomSize[2]
+ centerWidth = 1 + math.ceil(roomWidth / 2)
+ local maze = maze_generation:mazeGeneration{
+ height = roomHeight + 2,
+ width = roomWidth + 2
+ }
+
+ -- Fill the room with 'A' for apples. updateSpawnVars decides which to use.
+ for i = 2, roomHeight + 1 do
+ for j = 2, roomWidth + 1 do
+ maze:setEntityCell(i, j, 'A')
+ end
+ end
+ -- Override one cell with 'P' for spawn point.
+ maze:setEntityCell(2, centerWidth, 'P')
+
+ print('Generated distractor maze with entity layer:')
+ print(maze:entityLayer())
+ io.flush()
+
+ local texture = textureSet
+ if kwargs.differentDistractRoomTexture then
+ texture = secondTextureSet
+ end
+ local mapTheme = themes.fromTextureSet{
+ textureSet = texture,
+ decalFrequency = 0.0,
+ }
+ self._distractMap = make_map.makeMap{
+ mapName = DISTRACTOR_MAP,
+ mapEntityLayer = maze:entityLayer(),
+ theme = mapTheme,
+ }
+ end
+
+ function api:_createExploitMap()
+ maze = self:_createRoomCommon()
+ print('Generated exploit maze with entity layer:')
+ print(maze:entityLayer())
+ io.flush()
+
+ local mapTheme = themes.fromTextureSet{
+ textureSet = textureSet,
+ decalFrequency = 0.0,
+ }
+
+ local counter = 1
+ self.exploitMap = make_map.makeMap{
+ mapName = EXPLOIT_MAP,
+ mapEntityLayer = maze:entityLayer(),
+ theme = mapTheme,
+ useSkybox = false,
+ callback = function (i, j, c, maker)
+ if c == 'O' then
+ pickup = 'location:' .. counter
+ counter = counter + 1
+ return maker:makeEntity{i = i, j = j, classname = pickup}
+ end
+ end
+ }
+ end
+
+ function api:_generateRandomObjects()
+ -- 1. Generate a random list of positive/negative reward, `objectValence`
+ -- as function(numObjects, guaranteeGood, guaranteeBad, probGoodObject)
+
+ local objectValence = {}
+ for i = 1, kwargs.numObjects do
+ if i <= kwargs.guaranteeGoodObjects then
+ objectValence[i] = 1
+ elseif i<= kwargs.guaranteeGoodObjects + kwargs.guaranteeBadObjects then
+ objectValence[i] = -1
+ else
+ if random:uniformReal(0, 1) < kwargs.probGoodObject then
+ objectValence[i] = 1
+ else
+ objectValence[i] = -1
+ end
+ end
+ end
+ random:shuffleInPlace(objectValence)
+
+ -- 2. Generate random objects and link to the object valence above.
+ local objects = hrp.uniquelyShapedPickups(kwargs.numObjects)
+ for i = 1, kwargs.numObjects do
+ objects[i].scale= kwargs.objectScale
+ end
+
+ self._objects = {}
+ for i, object in ipairs(objects) do
+ self._objects[i] = {}
+ self._objects[i].data = hrp.create(object)
+ if objectValence[i] == 1 then
+ self._objects[i].isGoodObject = true
+ self._objects[i].reward = kwargs.correctReward
+ self._objects[i].cueColor = {0, 1, 0, 1} -- green means good
+ else
+ self._objects[i].isGoodObject = false
+ self._objects[i].reward = kwargs.incorrectReward
+ self._objects[i].cueColor = {1, 0, 0, 1} -- red means bad
+ end
+ end
+ end
+
+ function api:start(episode, seed)
+ random:seed(seed)
+
+ -- Setup a random mapping from locationId to pickupId
+ -- There should be more locationId than pickupId
+ -- The location set with pickupId == 0 will have no object presented there.
+ self._mapLocationIdToPickupId = {}
+ for i = 1, self._numLocations do
+ if i <= kwargs.numObjects then
+ self._mapLocationIdToPickupId[i] = i
+ else
+ self._mapLocationIdToPickupId[i] = 0
+ end
+ end
+ random:shuffleInPlace(self._mapLocationIdToPickupId)
+
+ self:_generateRandomObjects()
+ self._map = nil
+ self._numTrials = 0
+ self._timeOut = kwargs.exploreLengthSeconds
+
+ -- Set the instruction channel to record the rewards in the final phase.
+ self._finalRewardMainTask = 0
+ self.setInstruction("0")
+ end
+
+ function api:nextMap()
+ if self._map == nil then -- Start of episode.
+ self._map = EXPLORE_MAP
+ elseif not kwargs.skipDistractor and self._map == EXPLORE_MAP then
+ -- Move from explore to distractor.
+ self._map = DISTRACTOR_MAP
+ self._timeOut = self._time + kwargs.distractorLengthSeconds
+ elseif (kwargs.skipDistractor and self._map == EXPLORE_MAP)
+ or self._map == DISTRACTOR_MAP then
+ -- Move from distractor or explore map to exploit map.
+ self._map = EXPLOIT_MAP
+ random:shuffleInPlace(self._mapLocationIdToPickupId)
+ self._timeOut = nil
+ end
+
+ return self._map
+ end
+
+ function api:hasEpisodeFinished(timeSeconds)
+ self._time = timeSeconds
+ if self._showObjectCue then
+ if self._time - self._cueStartTime > SHOW_COLOR_CUE_SECOND then
+ self._showObjectCue = false
+ end
+ end
+
+ if self._map == EXPLORE_MAP or self._map == DISTRACTOR_MAP then
+ if timeSeconds > self._timeOut then
+ game:finishMap()
+ end
+ return false
+ end
+ end
+
+ -- END TRIGGER functions -----------------------------------------------------
+ function api:filledRectangles(args)
+ if self._map == EXPLORE_MAP and self._showObjectCue then
+ return {{
+ x = 12,
+ y = 12,
+ width = 600,
+ height = 300,
+ rgba = self._cueColor,
+ }}
+ end
+ return {}
+ end
+
+ function api:updateSpawnVars(spawnVars)
+ local classname = spawnVars.classname
+ if classname == "info_player_start" then
+ -- Spawn facing South.
+ spawnVars.angle = "-90"
+ spawnVars.randomAngleRange = "0"
+ elseif classname == "apple_reward" then
+ local useApple = false
+ if kwargs.probAppleInDistractorMap > 0 then
+ useApple = random:uniformReal(0, 1) < kwargs.probAppleInDistractorMap
+ spawnVars.id = tostring(APPLE_ID)
+ end
+ if not useApple then
+ return nil
+ end
+ else
+ -- Allocate objects onto the map by mapLocationIdToPickupId.
+ local locationClass = nameToLocationClass(classname)
+ if locationClass then
+ local locationId = nameToLocationId(classname)
+ id = self._mapLocationIdToPickupId[locationId]
+ if id == 0 then
+ return nil
+ else
+ spawnVars.classname = self._objects[id].data
+ spawnVars.id = tostring(id)
+ end
+ end
+ end
+
+ return spawnVars
+ end
+
+ custom_observations.decorate(api)
+ pickup_decorator.decorate(api)
+ setting_overrides.decorate{
+ api = api,
+ apiParams = kwargs,
+ decorateWithTimeout = true
+ }
+ return api
+end
+
+return factory
diff --git a/tvt/dmlab/passive_visual_match.lua b/tvt/dmlab/passive_visual_match.lua
new file mode 100644
index 0000000..fcc7617
--- /dev/null
+++ b/tvt/dmlab/passive_visual_match.lua
@@ -0,0 +1,24 @@
+-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+-- http://www.apache.org/licenses/LICENSE-2.0
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ============================================================================
+local factory = require 'visual_match_factory'
+
+return factory.createLevelApi{
+ exploreMapMode = 'PASSIVE',
+ episodeLengthSeconds = 40,
+ exploreLengthSeconds = 5,
+ distractorLengthSeconds = 30,
+
+ differentDistractRoomTexture = true,
+ differentRewardRoomTexture = true,
+ correctReward = 10,
+ incorrectReward = 1,
+}
diff --git a/tvt/dmlab/two_keys_to_choose_factory.lua b/tvt/dmlab/two_keys_to_choose_factory.lua
new file mode 100644
index 0000000..4efcbed
--- /dev/null
+++ b/tvt/dmlab/two_keys_to_choose_factory.lua
@@ -0,0 +1,503 @@
+-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+-- http://www.apache.org/licenses/LICENSE-2.0
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ============================================================================
+local make_map = require 'common.make_map'
+local custom_observations = require 'decorators.custom_observations'
+local debug_observations = require 'decorators.debug_observations'
+local game = require 'dmlab.system.game'
+local image_utils = require 'image_utils'
+local map_maker = require 'dmlab.system.map_maker'
+local maze_generation = require 'dmlab.system.maze_generation'
+local pickup_decorator = require 'decorators.human_recognisable_pickups'
+local random = require 'common.random'
+local setting_overrides = require 'decorators.setting_overrides'
+local texture_sets = require 'themes.texture_sets'
+local themes = require 'themes.themes'
+local hrp = require 'common.human_recognisable_pickups'
+
+local EPISODE_LENGTH_SECONDS = 15
+local EXPLORE_LENGTH_SECONDS = 5
+local DISTRACTOR_LENGTH_SECONDS = 5
+local CUE_COLORS = {2, 5} -- Either red or blue cue.
+local APPLE_ID = 998
+local GOAL_ID = 999
+local KEY_SPAWN_ID = 1000
+local BAD_KEY_SPAWN_ID = 1001
+local DOOR_ID = 1002
+local KEY_CUE_RECTANGLE_WIDTH = 600
+local KEY_CUE_RECTANGLE_HEIGHT = 200
+local SHOW_COLOR_SQUARE_SECONDS = 1
+
+-- Table that maps from full decal name to decal index number.
+local decalIndices = {}
+
+local EXPLORE_MAP = "exploreMap"
+local DISTRACTOR_MAP = "distractorMap"
+local REWARD_MAP = "rewardMap"
+local COLORS = image_utils.COLORS
+
+local GOAL_WITH_GOOD_KEY_REWARD = -1
+local GOAL_WITH_BAD_KEY_REWARD = -10
+
+local DISTRACTOR_ROOM_SIZE = {11, 11}
+local EXPLORE_ROOM_SIZE = {4, 3}
+local APPLE_REWARD = 5
+local PROB_APPLE_IN_DISTRACTOR_MAP = 0.3
+local DEFAULT_FINAL_REWARD = -20
+local APPLE_EXTRA_REWARD_RANGE = 0
+local DIFFERENT_DISTRACT_ROOM_TEXTURE = false
+
+
+-- Set texture set for all maps.
+local textureSet = texture_sets.TRON
+local secondTextureSet = texture_sets.TETRIS
+
+local REWARD_ROOM =[[
+***
+*P*
+*H*
+*G*
+***
+]]
+
+local function createDistractorMaze(opts)
+ -- Example room with height = 2, width = 3
+ -- A are possible apple locations (everywhere)
+ -- *****
+ -- *APA*
+ -- *AAA*
+ -- *****
+
+ local roomHeight = opts.roomSize[1]
+ local roomWidth = opts.roomSize[2]
+ centerWidth = 1 + math.ceil(roomWidth / 2)
+ local maze = maze_generation:mazeGeneration{
+ height = roomHeight + 2, -- +2 for the two side of walls
+ width = roomWidth + 2
+ }
+
+ -- Fill the room with 'A' for apples. updateSpawnVars decides which to use.
+ for i = 2, roomHeight + 1 do
+ for j = 2, roomWidth + 1 do
+ maze:setEntityCell(i, j, 'A')
+ end
+ end
+ -- Override one cell with 'P' for spawn point.
+ maze:setEntityCell(2, centerWidth, 'P')
+
+ print('Generated distractor maze with entity layer:')
+ print(maze:entityLayer())
+ io.flush()
+ return maze
+end
+
+local function createExploreMaze(opts)
+ -- Procedurelly generate room like below:
+ -- xxxxxxx
+ -- x P x
+ -- x x
+ -- xK Kx
+ -- xxxxxxx
+
+ local roomHeight = opts.roomSize[1]
+ local roomWidth = opts.roomSize[2]
+ centerWidth = 1 + math.ceil(roomWidth / 2)
+ local maze = maze_generation:mazeGeneration{
+ height = roomHeight + 2,
+ width = roomWidth + 2
+ }
+
+ for i = 2, roomHeight + 1 do
+ for j = 2, roomWidth + 1 do
+ maze:setEntityCell(i, j, '.')
+ end
+ end
+
+ maze:setEntityCell(2, centerWidth, 'P')
+ maze:setEntityCell(roomHeight + 1, 2, 'K')
+ maze:setEntityCell(roomHeight + 1, roomWidth + 1, 'K')
+
+ print('Generated 2nd order explore maze with entity layer:')
+ print(maze:entityLayer())
+ io.flush()
+
+ return maze
+end
+
+local factory = {}
+game:console('cg_drawScriptRectanglesAlways 1')
+
+function factory.createLevelApi(kwargs)
+ kwargs.episodeLengthSeconds = kwargs.episodeLengthSeconds or
+ EPISODE_LENGTH_SECONDS
+ kwargs.exploreLengthSeconds = kwargs.exploreLengthSeconds or
+ EXPLORE_LENGTH_SECONDS
+ kwargs.distractorLengthSeconds = kwargs.distractorLengthSeconds or
+ DISTRACTOR_LENGTH_SECONDS
+ kwargs.distractorRoomSize = kwargs.distractorRoomSize or DISTRACTOR_ROOM_SIZE
+ kwargs.probAppleInDistractorMap = kwargs.probAppleInDistractorMap or
+ PROB_APPLE_IN_DISTRACTOR_MAP
+ kwargs.exploreRoomSize = kwargs.exploreRoomSize or EXPLORE_ROOM_SIZE
+ kwargs.appleExtraRewardRange =
+ kwargs.appleExtraRewardRange or APPLE_EXTRA_REWARD_RANGE
+ kwargs.differentDistractRoomTexture = kwargs.differentDistractRoomTexture or
+ DIFFERENT_DISTRACT_ROOM_TEXTURE
+ kwargs.defaultFinalReward = kwargs.defaultFinalReward or DEFAULT_FINAL_REWARD
+ kwargs.goalWithGoodKeyReward = kwargs.goalWithGoodKeyReward or
+ GOAL_WITH_GOOD_KEY_REWARD
+ kwargs.goalWithBadKeyReward = kwargs.goalWithBadKeyReward or
+ GOAL_WITH_BAD_KEY_REWARD
+ kwargs.appleReward = kwargs.appleReward or APPLE_REWARD
+
+ local api = {}
+
+ function api:init(params)
+ self:_createSquareExploreMap()
+ self:_createDistractorMap()
+ self:_createRewardMap()
+
+
+ -- key 1 is a red key, good, leads to less negative reward.
+ local keyInfo = {shape='key', pattern='solid',
+ color1 = {255, 0, 0}, color2={0, 0, 0}}
+ self._keyObject = hrp.create(keyInfo)
+ self._keyCueRgba = {1, 0, 0, 1}
+
+ -- key 2 is a blue key, bad, leads to more negative reward.
+ local keyInfo2 = {shape='key', pattern='solid',
+ color1 = {0, 0, 255}, color2={0, 0, 0}}
+ self._keyObject2 = hrp.create(keyInfo2)
+ self._keyCueRgba2 = {0, 0, 1, 1}
+
+ self._keyCueRgbaNoKey = {0, 0, 0, 1}
+ end
+
+ function api:_createRewardMap()
+
+ self._rewardMap = map_maker:mapFromTextLevel{
+ mapName = REWARD_MAP,
+ entityLayer = REWARD_ROOM,
+ }
+
+ -- Create map theme and override default wall decal placement.
+ local rewardMapTheme = themes.fromTextureSet{
+ textureSet = textureSet,
+ decalFrequency = 0.0,
+ }
+
+ self._rewardMap = map_maker:mapFromTextLevel{
+ mapName = REWARD_MAP,
+ entityLayer = REWARD_ROOM,
+ theme = rewardMapTheme,
+ callback = function (i, j, c, maker)
+ local pickup = self:_makePickup(c)
+ if pickup then
+ return maker:makeEntity{i = i, j = j, classname = pickup}
+ end
+ end
+ }
+ end
+
+ function api:_createSquareExploreMap()
+ -- Create a maze to be converted into map.
+ local maze = createExploreMaze{
+ roomSize = kwargs.exploreRoomSize
+ }
+
+ -- Create a map theme without wall decal placement.
+ local exploreMapTheme = themes.fromTextureSet{
+ textureSet = textureSet,
+ decalFrequency = 0.0,
+ }
+
+ self._exploreMap = map_maker:mapFromTextLevel{
+ mapName = EXPLORE_MAP,
+ entityLayer = maze:entityLayer(),
+ theme = exploreMapTheme,
+ callback = function (i, j, c, maker)
+ local pickup = self:_makePickup(c)
+ if pickup then
+ return maker:makeEntity{i = i, j = j, classname = pickup}
+ end
+ end
+ }
+ end
+
+ function api:_createDistractorMap()
+
+ -- Create maze to be converted into map.
+ local maze = createDistractorMaze{
+ roomSize = kwargs.distractorRoomSize,
+ }
+
+ -- Create map theme with no wall decals.
+ local texture = textureSet
+ if kwargs.differentDistractRoomTexture then
+ texture = secondTextureSet
+ end
+ local distractorMapTheme = themes.fromTextureSet{
+ textureSet = texture,
+ decalFrequency = 0.0,
+ }
+
+ self._exploreMap = map_maker:mapFromTextLevel{
+ mapName = DISTRACTOR_MAP,
+ entityLayer = maze:entityLayer(),
+ theme = distractorMapTheme,
+ callback = function (i, j, c, maker)
+ local pickup = self:_makePickup(c)
+ if pickup then
+ return maker:makeEntity{i = i, j = j, classname = pickup}
+ end
+ end
+ }
+ end
+
+ function api:start(episode, seed)
+ random:seed(seed)
+
+ self._map = nil
+ self._time = 0
+ self._holdingKey = false
+ self._holdingBadKey = false
+ self._keyPosCount = 0
+
+ self._collectedGoal = false
+ self._showKeyCue = false
+ self._showNoKeyCue = false
+ self._finalReward = kwargs.defaultFinalReward
+ self._finalRewardAdded = false
+
+ if kwargs.distractorLengthSecondsRange then
+ self._distractorLen = random:uniformReal(
+ kwargs.distractorLengthSecondsRange[1],
+ kwargs.distractorLengthSecondsRange[2])
+ else
+ self._distractorLen = kwargs.distractorLengthSeconds
+ end
+
+ if kwargs.exploreRoomSize then
+ local posIndex = {1, 2} -- only 2 possible key location
+ random:shuffleInPlace(posIndex)
+ self._keyPosition = posIndex[1]
+ self._keyPosition2 = posIndex[2]
+ end
+
+ -- Set instruction channel output to defaultFinalReward.
+ -- Later this will be set to be the goal reward if collected.
+ self.setInstruction(tostring(kwargs.defaultFinalReward))
+ end
+
+ function api:filledRectangles(args)
+ if self._showKeyCue or self._showNoKeyCue then
+ local cueColor
+ if self._holdingKey then
+ cueColor = self._keyCueRgba
+ elseif self._holdingBadKey then
+ cueColor = self._keyCueRgba2
+ elseif self._showNoKeyCue then
+ cueColor = self._keyCueRgbaNoKey
+ end
+ return {{
+ x = 12,
+ y = 12,
+ width = KEY_CUE_RECTANGLE_WIDTH,
+ height = KEY_CUE_RECTANGLE_HEIGHT,
+ rgba = cueColor
+ }}
+ end
+ return {}
+ end
+
+ function api:nextMap()
+ -- 1. Decide what is the next map.
+ if self._map == nil then
+ self._map = EXPLORE_MAP
+ elseif self._map == DISTRACTOR_MAP then
+ self._map = REWARD_MAP
+ elseif self._map == EXPLORE_MAP then
+ if self._distractorLen > 0.0 then
+ -- if not holding any key, show the no key cue
+ if not self._holdingKey and not self._holdingBadKey then
+ self._showNoKeyCue = true
+ self._NoKeyCueTime = self._time
+ end
+ self._map = DISTRACTOR_MAP
+ else
+ self._map = REWARD_MAP
+ end
+ elseif self._map == REWARD_MAP then
+ -- Stay in distractor map (no more apples) till the end of episode.
+ self._map = DISTRACTOR_MAP
+ self._collectedGoal = true
+ end
+
+ -- 2. Set up timeout for the up-coming map.
+ if self._map == EXPLORE_MAP then
+ self._timeOut = self._time + kwargs.exploreLengthSeconds
+ elseif self._map == DISTRACTOR_MAP and not self._collectedGoal then
+ self._timeOut = self._time + self._distractorLen
+ elseif self._map == REWARD_MAP then
+ self._timeOut = nil
+ end
+
+ return self._map
+ end
+
+ -- PICKUP functions ----------------------------------------------------------
+
+ function api:_makePickup(c)
+ if c == 'K' then
+ return 'key'
+ elseif c == 'G' then
+ return 'goal'
+ elseif c == 'A' then
+ return 'apple_reward'
+ end
+ end
+
+ function api:canPickup(spawnId)
+ -- Cannot pick up another key if avatar is already holding a key.
+ if spawnId == KEY_SPAWN_ID and self._holdingBadKey then
+ return false
+ end
+ if spawnId == BAD_KEY_SPAWN_ID and self._holdingKey then
+ return false
+ end
+
+ return true
+ end
+
+ function api:pickup(spawnId)
+ if spawnId == GOAL_ID then
+ local goalReward
+ if self._holdingKey then
+ goalReward = kwargs.goalWithGoodKeyReward
+ elseif self._holdingBadKey then
+ goalReward = kwargs.goalWithBadKeyReward
+ end
+ self.setInstruction(tostring(goalReward))
+ game:addScore(-10) -- offset the default +10 for pick up goal.
+ self._finalReward = goalReward
+ game:finishMap()
+ end
+ if spawnId == KEY_SPAWN_ID then
+ self._holdingKey = true
+ self._holdingKeyTime = self._time
+ self._showKeyCue = true
+ end
+ if spawnId == BAD_KEY_SPAWN_ID then
+ self._holdingBadKey = true
+ self._holdingKeyTime = self._time
+ self._showKeyCue = true
+ end
+
+ if spawnId == APPLE_ID then
+ -- note the -1 for the default 1 point for apple in dmlab
+ appleReward = kwargs.appleReward +
+ random:uniformInt(0, kwargs.appleExtraRewardRange) - 1
+ game:addScore(appleReward)
+ end
+ end
+
+ function api:hasEpisodeFinished(timeSeconds)
+ self._time = timeSeconds
+
+ -- Give the final reward near the end of the episode.
+ if not self._finalRewardAdded and
+ timeSeconds > kwargs.episodeLengthSeconds - 0.1 then
+ game:addScore(self._finalReward)
+ self._finalRewardAdded = true
+ end
+
+ if (self._holdingKey or self._holdingBadKey) and
+ self._time - self._holdingKeyTime > SHOW_COLOR_SQUARE_SECONDS then
+ self._showKeyCue = false
+ end
+
+ if self._showNoKeyCue and
+ self._time - self._NoKeyCueTime > SHOW_COLOR_SQUARE_SECONDS then
+ self._showNoKeyCue = false
+ end
+
+ if self._map == EXPLORE_MAP or self._map == DISTRACTOR_MAP or
+ self._map == SECOND_ORDER_EXPLORE_MAP then
+ if timeSeconds > self._timeOut then
+ game:finishMap()
+ end
+ return false
+ end
+ end
+
+ function api:canTrigger(teleportId, targetName)
+ if string.sub(targetName, 1, 4) == 'door' then
+ -- open the door no matter which key the avatar holds.
+ if self._holdingKey or self._holdingBadKey then
+ return true
+ else
+ return false
+ end
+ end
+ return false
+ end
+
+ function api:updateSpawnVars(spawnVars)
+ local classname = spawnVars.classname
+ if classname == "info_player_start" then
+ -- Spawn facing South.
+ spawnVars.angle = "-90"
+ spawnVars.randomAngleRange = "0"
+ elseif classname == "func_door" then
+ spawnVars.id = tostring(DOOR_ID)
+ spawnVars.wait = "1000000" -- Door open for a long time.
+ elseif classname == "goal" then
+ spawnVars.id = tostring(GOAL_ID)
+ elseif classname == "apple_reward" then
+ -- The avatar is spawned to distractor room after reaching goal
+ -- there should be no more apples in such case.
+ if self._collectedGoal then
+ return nil
+ end
+ local useApple = false
+ if kwargs.probAppleInDistractorMap > 0 then
+ useApple = random:uniformReal(0, 1) < kwargs.probAppleInDistractorMap
+ spawnVars.id = tostring(APPLE_ID)
+ end
+ if not useApple then
+ return nil
+ end
+ elseif classname == "key" then
+ self._keyPosCount = self._keyPosCount + 1
+ if self._keyPosition == self._keyPosCount then
+ spawnVars.id = tostring(KEY_SPAWN_ID)
+ spawnVars.classname = self._keyObject
+ elseif self._keyPosition2 == self._keyPosCount then
+ spawnVars.id = tostring(BAD_KEY_SPAWN_ID)
+ spawnVars.classname = self._keyObject2
+ else
+ return nil
+ end
+ end
+ return spawnVars
+ end
+
+ custom_observations.decorate(api)
+ pickup_decorator.decorate(api)
+ setting_overrides.decorate{
+ api = api,
+ apiParams = kwargs,
+ decorateWithTimeout = true
+ }
+ return api
+
+end
+
+return factory
diff --git a/tvt/dmlab/two_negative_keys.lua b/tvt/dmlab/two_negative_keys.lua
new file mode 100644
index 0000000..c220a28
--- /dev/null
+++ b/tvt/dmlab/two_negative_keys.lua
@@ -0,0 +1,19 @@
+-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+-- http://www.apache.org/licenses/LICENSE-2.0
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ============================================================================
+local factory = require 'two_keys_to_choose_factory'
+
+return factory.createLevelApi{
+ episodeLengthSeconds = 37,
+ exploreLengthSeconds = 5,
+ distractorLengthSeconds = 30,
+ differentDistractRoomTexture = true,
+}
diff --git a/tvt/dmlab/visual_match_factory.lua b/tvt/dmlab/visual_match_factory.lua
new file mode 100644
index 0000000..4faf9c2
--- /dev/null
+++ b/tvt/dmlab/visual_match_factory.lua
@@ -0,0 +1,776 @@
+-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+-- http://www.apache.org/licenses/LICENSE-2.0
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ============================================================================
+local make_map = require 'common.make_map'
+local custom_decals = require 'decorators.custom_decals_decoration'
+local custom_entities = require 'common.custom_entities'
+local custom_observations = require 'decorators.custom_observations'
+local datasets_selector = require 'datasets.selector'
+local debug_observations = require 'decorators.debug_observations'
+local game = require 'dmlab.system.game'
+local image_utils = require 'image_utils'
+local map_maker = require 'dmlab.system.map_maker'
+local maze_generation = require 'dmlab.system.maze_generation'
+local pickup_decorator = require 'decorators.human_recognisable_pickups'
+local random = require 'common.random'
+local setting_overrides = require 'decorators.setting_overrides'
+local texture_sets = require 'themes.texture_sets'
+local themes = require 'themes.themes'
+local hrp = require 'common.human_recognisable_pickups'
+
+local DEFAULTS = {
+ EXPLORE_MAP_MODE = 'PASSIVE',
+ EPISODE_LENGTH_SECONDS = 30,
+ SECOND_ORDER_EXPLORE_LENGTH_SECONDS = 4,
+ EXPLORE_LENGTH_SECONDS = 10,
+ DISTRACTOR_LENGTH_SECONDS = 10,
+ PRE_EXPLORE_DISTRACTOR_LENGTH_SECONDS = 0,
+ NUM_IMAGES = 4,
+ CORRECT_REWARD = 10,
+ INCORRECT_REWARD = 1,
+ IMAGE_SCALE = 3.0,
+ IMAGE_ROOM_HEIGHT = 4,
+ SHOW_KEY_COLOR_SQUARE_SECONDS = 1,
+ DISTRACTOR_ROOM_SIZE = {11, 11},
+ SECOND_ORDER_EXPLORE_ROOM_SIZE = {3, 3},
+ PROB_APPLE_IN_DISTRACTOR_MAP = 0.3,
+ APPLE_REWARD = 5,
+ APPLE_REWARD_PROB = 1.0,
+ APPLE_EXTRA_REWARD_RANGE = 0,
+ DIFFERENT_DISTRACT_ROOM_TEXTURE = false,
+ DIFFERENT_REWARD_ROOM_TEXTURE = false,
+ DIFFERENT_SECOND_ORDER_ROOM_TEXTURE = false,
+}
+
+local APPLE_ID = 999
+local KEY_OBJECT_SPAWN_ID = 1000
+local DOOR_ID = 1001
+
+-- Table that maps from full decal name to decal index number.
+local decalIndices = {}
+
+local SECOND_ORDER_EXPLORE_MAP = "secondOrderExploreMap"
+local EXPLORE_MAP = "exploreMap"
+local DISTRACTOR_MAP = "distractorMap"
+local IMAGE_MAP = "imageMap"
+local COLORS = image_utils.COLORS
+
+-- Set texture set for all maps.
+local textureSet = texture_sets.PACMAN
+local secondTextureSet = texture_sets.TETRIS
+local thirdTextureSet = texture_sets.TRON
+local fourthTextureSet = texture_sets.MINESWEEPER
+
+local SHORT_STRAIGHT_ROOM =[[
+***
+*P*
+* *
+* *
+***
+]]
+
+local SHORT_STRAIGHT_ROOM_WITH_DOOR =[[
+***
+*P*
+*H*
+* *
+***
+]]
+
+local TWO_ROOMS = [[
+*********
+*********
+* * *
+* P *
+* * *
+*********
+]]
+-- There are 24 walls for hanging the colour square.
+local TWO_ROOMS_VALID_PAINT_LOCATION = 24
+
+local EXPLORE_TEXT_MAP_DICT = {
+ PASSIVE = {
+ map = SHORT_STRAIGHT_ROOM,
+ targetPic = {row=4, col=2, dir='S'},
+ },
+ TWO_ROOMS = {
+ map = TWO_ROOMS,
+ targetPic = {row=0, col=0, dir='S'},
+ },
+ KEY_TO_COLOR = {
+ map = SHORT_STRAIGHT_ROOM_WITH_DOOR,
+ targetPic = {row=4, col=2, dir='S'},
+ }
+}
+
+--[[
+Setup image room maze.
+
+Example 1:
+numImages = 2
+imageRoomHeight = 3
+
+*****
+**P**
+* *
+* *
+*T T*
+*****
+ *t*
+ ***
+
+Example 2:
+numImages = 4
+imageRoomHeight = 4
+
+*********
+****P****
+* *
+* *
+* *
+*T T T T*
+*********
+ *t*
+ ***
+--]]
+local function createImageMaze(opts)
+ local numImages = opts.numImages
+ local imageRoomHeight = opts.imageRoomHeight or 3
+ local centerWidth = 1 + numImages
+
+ local width = 2 * numImages + 1
+ -- Set the height to imageRoomHeight + 3 for image room, 2 for finish area.
+ local height = (imageRoomHeight + 3) + 2
+
+ -- Initialize the maze. All cells start as '*' (wall).
+ local maze = maze_generation:mazeGeneration{
+ width = width,
+ height = height,
+ }
+ maze:setEntityCell(2, centerWidth, 'P') -- Avatar start location.
+
+ -- Fill image room with '.' (empty space).
+ local imageRoomHeightStart = 3
+ local imageRoomHeightEnd = imageRoomHeightStart + imageRoomHeight - 1
+ for i = imageRoomHeightStart, imageRoomHeightEnd do
+ for j = 2, width - 1 do
+ maze:setEntityCell(i, j, '.')
+ end
+ end
+
+ -- Teleports in final row of image room.
+ for n = 1, numImages do
+ maze:setEntityCell(imageRoomHeightEnd, 2 * n, 'T')
+ end
+ -- Teleport target in finish box after hallway.
+ maze:setEntityCell(imageRoomHeightEnd + 2, centerWidth, 't')
+
+ print('Generated image maze with entity layer:')
+ print(maze:entityLayer())
+ io.flush()
+
+ return maze
+end
+
+local function createSecondOrderExploreMaze(opts)
+ -- An open layout room of size = SECOND_ORDER_EXPLORE_ROOM_SIZE
+ -- the avatar is always at top-left corner, while the key is at other random
+ -- location. For example, a 3x3 room may be like this:
+ -- xxxxx
+ -- xPxxx
+ -- xKKKx
+ -- xKKKx
+ -- xxxxx
+
+ roomHeight = opts.roomSize[1]
+ roomWidth = opts.roomSize[2]
+ local maze = maze_generation:mazeGeneration{
+ height = roomHeight + 2, -- +2 for the two side of walls
+ width = roomWidth + 2
+ }
+
+ -- Fill image room with 'K' (possible key locations).
+ for i = 3, roomHeight + 1 do
+ for j = 2, roomWidth + 1 do
+ maze:setEntityCell(i, j, 'K')
+ end
+ end
+ maze:setEntityCell(2, 2, 'P') -- Avatar start at top-left corner.
+
+ print('Generated 2nd order explore maze with entity layer:')
+ print(maze:entityLayer())
+ io.flush()
+
+ return maze
+end
+
+local function createDistractorMaze(opts)
+ -- Example room with height = 2, width = 3
+ -- A are possible apple locations (everywhere)
+ -- *****
+ -- *APA*
+ -- *AAA*
+ -- *****
+
+ local roomHeight = opts.roomSize[1]
+ local roomWidth = opts.roomSize[2]
+ local centerWidth = 1 + math.ceil(roomWidth / 2)
+ local maze = maze_generation:mazeGeneration{
+ height = roomHeight + 2, -- +2 for the two side of walls
+ width = roomWidth + 2
+ }
+
+ -- Fill the room with 'A' for apples. updateSpawnVars decides which to use.
+ for i = 2, roomHeight + 1 do
+ for j = 2, roomWidth + 1 do
+ maze:setEntityCell(i, j, 'A')
+ end
+ end
+ -- Override one cell with 'P' for spawn point.
+ maze:setEntityCell(2, centerWidth, 'P')
+
+ print('Generated distractor maze with entity layer:')
+ print(maze:entityLayer())
+ io.flush()
+ return maze
+end
+
+local factory = {}
+game:console('cg_drawScriptRectanglesAlways 1')
+
+function factory.createLevelApi(kwargs)
+
+ kwargs.episodeLengthSeconds = kwargs.episodeLengthSeconds or
+ DEFAULTS.EPISODE_LENGTH_SECONDS
+ kwargs.secondOrderExploreLengthSeconds =
+ kwargs.secondOrderExploreLengthSeconds or
+ DEFAULTS.SECOND_ORDER_EXPLORE_LENGTH_SECONDS
+ kwargs.secondOrderExploreRoomSize = kwargs.secondOrderExploreRoomSize or
+ DEFAULTS.SECOND_ORDER_EXPLORE_ROOM_SIZE
+
+ kwargs.exploreLengthSeconds = kwargs.exploreLengthSeconds or
+ DEFAULTS.EXPLORE_LENGTH_SECONDS
+
+ kwargs.preExploreDistractorLengthSeconds =
+ kwargs.preExploreDistractorLengthSeconds or
+ DEFAULTS.PRE_EXPLORE_DISTRACTOR_LENGTH_SECONDS
+
+ kwargs.distractorLengthSeconds = kwargs.distractorLengthSeconds or
+ DEFAULTS.DISTRACTOR_LENGTH_SECONDS
+
+ kwargs.numImages = kwargs.numImages or DEFAULTS.NUM_IMAGES
+ kwargs.correctReward = kwargs.correctReward or DEFAULTS.CORRECT_REWARD
+ kwargs.incorrectReward = kwargs.incorrectReward or DEFAULTS.INCORRECT_REWARD
+
+ kwargs.appleReward = kwargs.appleReward or DEFAULTS.APPLE_REWARD
+ kwargs.appleRewardProb = kwargs.appleRewardProb or DEFAULTS.APPLE_REWARD_PROB
+ kwargs.appleExtraRewardRange =
+ kwargs.appleExtraRewardRange or DEFAULTS.APPLE_EXTRA_REWARD_RANGE
+
+ kwargs.imageScale = kwargs.imageScale or DEFAULTS.IMAGE_SCALE
+ kwargs.imageRoomHeight = kwargs.imageRoomHeight or DEFAULTS.IMAGE_ROOM_HEIGHT
+ kwargs.distractorRoomSize = kwargs.distractorRoomSize or
+ DEFAULTS.DISTRACTOR_ROOM_SIZE
+ kwargs.probAppleInDistractorMap = kwargs.probAppleInDistractorMap or
+ DEFAULTS.PROB_APPLE_IN_DISTRACTOR_MAP
+ kwargs.differentDistractRoomTexture = kwargs.differentDistractRoomTexture or
+ DEFAULTS.DIFFERENT_DISTRACT_ROOM_TEXTURE
+ kwargs.differentRewardRoomTexture = kwargs.differentRewardRoomTexture or
+ DEFAULTS.DIFFERENT_REWARD_ROOM_TEXTURE
+ kwargs.differentSecondOrderRoomTexture =
+ kwargs.differentSecondOrderRoomTexture or
+ DEFAULTS.DIFFERENT_SECOND_ORDER_ROOM_TEXTURE
+
+ kwargs.showKeyColorSquareSeconds = kwargs.showKeyColorSquareSeconds or
+ DEFAULTS.SHOW_KEY_COLOR_SQUARE_SECONDS
+
+ assert(kwargs.numImages % 2 == 0,
+ 'numImages must be an even number if there is space between images.')
+ assert(kwargs.numImages <= #COLORS,
+ 'numImages must be <=' .. #COLORS .. ' for simple color images.')
+
+ kwargs.exploreMapMode = kwargs.exploreMapMode or DEFAULTS.EXPLORE_MAP_MODE
+
+ local api = {}
+
+ function api:init(params)
+ self._isKeyToPaintingLevel = kwargs.exploreMapMode == 'KEY_TO_COLOR'
+
+ self:_createExploreMap()
+ self:_createDistractorMap()
+ self:_createImageMap()
+ if self._isKeyToPaintingLevel then
+ self:_createSecondOrderKeyExploreMap()
+ end
+
+ self._imageOrder = {}
+ for i = 1, kwargs.numImages do
+ self._imageOrder[i] = i
+ end
+ end
+
+ function api:_createSecondOrderKeyExploreMap()
+ -- Create maze to be converted into map.
+ local maze = createSecondOrderExploreMaze{
+ roomSize = kwargs.secondOrderExploreRoomSize
+ }
+
+ -- Create map theme with no wall decals.
+ local texture = textureSet
+ if kwargs.differentSecondOrderRoomTexture then
+ texture = fourthTextureSet
+ end
+
+ local keyExploreMapTheme = themes.fromTextureSet{
+ textureSet = texture,
+ decalFrequency = 0.0,
+ floorModelFrequency = 0.0,
+ }
+
+ self._secondOrderExploreMap = map_maker:mapFromTextLevel{
+ mapName = SECOND_ORDER_EXPLORE_MAP,
+ entityLayer = maze:entityLayer(),
+ theme = keyExploreMapTheme,
+ callback = function (i, j, c, maker)
+ local pickup = self:_makePickup(c)
+ if pickup then
+ return maker:makeEntity{i = i, j = j, classname = pickup}
+ end
+ end
+ }
+ end
+
+ function api:_createExploreMap()
+ -- Create map theme and override default wall decal placement.
+ local exploreMapTheme = themes.fromTextureSet{
+ textureSet = textureSet,
+ decalFrequency = 1.0,
+ floorModelFrequency = 0.0,
+ }
+
+ local exploreMapInfo = EXPLORE_TEXT_MAP_DICT[kwargs.exploreMapMode]
+ local targetPic = exploreMapInfo.targetPic
+ local exploreMapEntityLayer = exploreMapInfo.map
+
+ -- Note on decalIndex meaning:
+ -- decalIndex = 1 to numImages: the id for painting in the imageRoom
+ -- decalIndex = numImages + 1, the target image in exploreRoom
+
+ local function _matchTextureLocation(loc, target)
+ if loc.i == target.row and loc.j == target.col and
+ loc.direction == target.dir then
+ return true
+ else
+ return false
+ end
+ end
+
+ function exploreMapTheme:placeWallDecals(allWallLocations)
+ local wallDecals = {}
+ local numPossiblePaintLocation = 0
+ for _, loc in pairs(allWallLocations) do
+ local decalIndex = nil
+ if kwargs.exploreMapMode ~= 'TWO_ROOMS' then
+ if _matchTextureLocation(loc, targetPic) then
+ decalIndex = kwargs.numImages + 1
+ end
+ else
+ if loc.i > 2 then
+ numPossiblePaintLocation = numPossiblePaintLocation + 1
+ decalIndex = numPossiblePaintLocation
+ end
+ end
+
+ if decalIndex then
+ local decal = textureSet.wallDecals[decalIndex]
+ local actualDecal = {
+ tex = decal.tex .. '_alpha',
+ scale = kwargs.imageScale,
+ }
+ wallDecals[#wallDecals + 1] = {
+ index = loc.index,
+ decal = actualDecal,
+ }
+ local fullTextureName = "textures/" .. decal.tex
+ decalIndices[fullTextureName] = decalIndex
+ end
+ end
+ return wallDecals
+ end
+
+ self._exploreMap = map_maker:mapFromTextLevel{
+ mapName = EXPLORE_MAP,
+ entityLayer = exploreMapEntityLayer,
+ theme = exploreMapTheme,
+ }
+ end
+
+ function api:_createDistractorMap()
+
+ -- Create a maze to be converted into map.
+ local maze = createDistractorMaze{
+ roomSize = kwargs.distractorRoomSize,
+ }
+
+ -- Create a map theme with no wall decals.
+ local texture = textureSet
+ if kwargs.differentDistractRoomTexture then
+ texture = secondTextureSet
+ end
+ local mapTheme = themes.fromTextureSet{
+ textureSet = texture,
+ decalFrequency = 0.0,
+ floorModelFrequency = 0.0,
+ }
+
+ self._distractorMap = make_map.makeMap{
+ mapName = DISTRACTOR_MAP,
+ mapEntityLayer = maze:entityLayer(),
+ theme = mapTheme,
+ }
+ end
+
+ function api:_createImageMap()
+ -- Create a maze to be converted into map.
+ local imageMaze = createImageMaze{
+ numImages = kwargs.numImages,
+ imageRoomHeight = kwargs.imageRoomHeight,
+ }
+
+ local texture = textureSet
+ if kwargs.differentRewardRoomTexture then
+ texture = thirdTextureSet
+ end
+ -- Create map theme and override default wall decal placement.
+ local imageMapTheme = themes.fromTextureSet{
+ textureSet = texture,
+ decalFrequency = 1.0,
+ floorModelFrequency = 0.0,
+ }
+ local paintingsRow = kwargs.imageRoomHeight + 2
+ function imageMapTheme:placeWallDecals(allWallLocations)
+ local wallDecals = {}
+ local decalCount = 1
+ for _, loc in pairs(allWallLocations) do
+ if loc.direction == "S" then
+ local decalIndex = nil
+ if loc.i == paintingsRow then
+ -- Only use even columns for paintings.
+ if loc.j % 2 == 0 then
+ decalIndex = decalCount -- Will be between 1 and numImages.
+ decalCount = decalCount + 1
+ end
+ end
+ if decalIndex then
+ local decal = textureSet.wallDecals[decalIndex]
+ decal.scale = kwargs.imageScale
+ wallDecals[#wallDecals + 1] = {
+ index = loc.index,
+ decal = decal,
+ }
+ local fullTextureName = "textures/" .. decal.tex
+ decalIndices[fullTextureName] = decalIndex
+ end
+ end
+ end
+ return wallDecals
+ end
+
+ self._imageMap = map_maker:mapFromTextLevel{
+ mapName = IMAGE_MAP,
+ entityLayer = imageMaze:entityLayer(),
+ theme = imageMapTheme,
+ callback = function (i, j, c, maker)
+ if c == 'T' then
+ return custom_entities.makeTeleporter(
+ {imageMaze:toWorldPos(i + 1, j + 1)},
+ 'teleporter')
+ end
+ if c == 't' then
+ return custom_entities.makeTeleporterTarget(
+ {imageMaze:toWorldPos(i + 1, j + 1)},
+ 'teleporter')
+ end
+ end
+ }
+ end
+
+ function api:_prepareKey(keyColor)
+ self._holdingKey = false
+ local keyInfo = {shape='key', pattern='solid',
+ color1 = {0, 0, 0}, color2={0, 0, 0}}
+ self._keyCueColorAlpha = {0, 0, 0, 1}
+ self._keyObject = hrp.create(keyInfo)
+ end
+
+ function api:start(episode, seed)
+ random:seed(seed)
+ self._map = nil
+ self._time = 0
+ self._targetIndex = 1
+ self._images = {}
+ self._preExploreDistractorLen = kwargs.preExploreDistractorLengthSeconds
+ self._distractorLen = kwargs.distractorLengthSeconds
+
+ local colorIndices = {}
+ for i = 1, #COLORS do
+ colorIndices[i] = i
+ end
+ random:shuffleInPlace(colorIndices)
+ for i = 1, kwargs.numImages do
+ local rgb = COLORS[colorIndices[i]]
+ self._images[i] = image_utils:createByteImage(3, 3, rgb)
+ end
+
+ if kwargs.exploreMapMode == 'TWO_ROOMS' then
+ self._images[kwargs.numImages + 1] =
+ image_utils:createTransparentImage(3, 3)
+ local nPaintPos = TWO_ROOMS_VALID_PAINT_LOCATION
+ self._targetPaintLocation = random:uniformInt(1, nPaintPos)
+ end
+
+ if self._isKeyToPaintingLevel then
+ self:_prepareKey()
+ -- Randomly sample the key location in secondOrderExploreMaze
+ local nPossibleKeyLocation = (kwargs.secondOrderExploreRoomSize[1] - 1) *
+ kwargs.secondOrderExploreRoomSize[2]
+ self._keyPosition = random:uniformInt(1, nPossibleKeyLocation)
+ end
+
+ -- Set instruction channel output to 0. (to indicate final phase reward.)
+ self.setInstruction(tostring(0))
+ end
+
+ function api:filledRectangles(args)
+ if self._map == SECOND_ORDER_EXPLORE_MAP and self._showKeyCue then
+ return {{
+ x = 12,
+ y = 12,
+ width = 600,
+ height = 200,
+ rgba = self._keyCueColorAlpha
+ }}
+ end
+ return {}
+ end
+
+ function api:nextMap()
+ -- 1. Decide what is the next map.
+ if self._map == nil or self._map == IMAGE_MAP then
+ if self._isKeyToPaintingLevel then
+ self._map = SECOND_ORDER_EXPLORE_MAP
+ else
+ if self._preExploreDistractorLen > 0.0 then
+ self._notExploreYet = true
+ self._map = DISTRACTOR_MAP
+ else
+ self._map = EXPLORE_MAP
+ end
+ end
+ elseif self._map == SECOND_ORDER_EXPLORE_MAP then
+ self._notExploreYet = true
+ self._map = DISTRACTOR_MAP
+ elseif self._map == DISTRACTOR_MAP then
+ if self._notExploreYet then
+ self._notExploreYet = false
+ self._map = EXPLORE_MAP
+ else
+ self._map = IMAGE_MAP
+ end
+ elseif self._map == EXPLORE_MAP then
+ if self._distractorLen > 0.0 then
+ self._map = DISTRACTOR_MAP
+ else
+ self._map = IMAGE_MAP
+ end
+ end
+
+ -- 2. Set up properly for the up-coming map.
+ if self._map == DISTRACTOR_MAP and self._notExploreYet then
+ self._timeOut = self._time + self._preExploreDistractorLen
+ elseif self._map == SECOND_ORDER_EXPLORE_MAP then
+ self._holdingKey = false
+ self._timeOut = self._time + kwargs.secondOrderExploreLengthSeconds
+ self._possibleKeyPosCount = 0
+ elseif self._map == EXPLORE_MAP then
+ self._timeOut = self._time + kwargs.exploreLengthSeconds
+ elseif self._map == DISTRACTOR_MAP and not self._notExploreYet then
+ self._timeOut = self._time + self._distractorLen
+ elseif self._map == IMAGE_MAP then
+ self._timeOut = nil
+ self._teleportId = 0
+ random:shuffleInPlace(self._imageOrder)
+ for i, shuffled_i in ipairs(self._imageOrder) do
+ if self._targetIndex == shuffled_i then
+ self._shuffledTargetIndex = i
+ end
+ end
+ end
+
+ return self._map
+ end
+
+ function api:replaceShader(textureName)
+ local index = decalIndices[textureName]
+ if index then
+ textureName = textureName .. '_alpha'
+ end
+ return textureName
+ end
+
+ function api:loadTexture(textureName)
+ local fullTextureName = textureName .. "_nonsolid"
+ local index = decalIndices[fullTextureName]
+
+ if index then
+ if self._map == EXPLORE_MAP and
+ kwargs.exploreMapMode == 'TWO_ROOMS' then
+ if index == self._targetPaintLocation then
+ return self._images[self._targetIndex] -- Set to arget color.
+ else
+ return self._images[kwargs.numImages + 1] -- Set to transparent.
+ end
+ end
+
+ if index <= kwargs.numImages then
+ local shuffledIndex = self._imageOrder[index]
+ return self._images[shuffledIndex]
+ elseif index == kwargs.numImages + 1 then
+ return self._images[self._targetIndex]
+ end
+ end
+ end
+
+ -- PICKUP functions ----------------------------------------------------------
+
+ function api:_makePickup(c)
+ if c == 'K' then
+ return 'key'
+ end
+ end
+
+ function api:pickup(spawnId)
+ if spawnId == KEY_OBJECT_SPAWN_ID then
+ self._holdingKey = true
+ self._holdingKeyTime = self._time
+ self._showKeyCue = true
+ end
+
+ if spawnId == APPLE_ID then
+ if kwargs.appleRewardProb >= 1 or
+ random:uniformReal(0, 1) < kwargs.appleRewardProb then
+ -- the -1 is for the default 1 point for apple in dmlab
+ appleReward = kwargs.appleReward +
+ random:uniformInt(0, kwargs.appleExtraRewardRange) - 1
+ game:addScore(appleReward)
+ else
+ -- the -1 is to compensate the default 1 point for apple in dmlab
+ game:addScore(-1)
+ end
+ end
+ end
+
+ -- TRIGGER functions ---------------------------------------------------------
+
+ function api:canTrigger(teleportId, targetName)
+ if string.sub(targetName, 1, 4) == 'door' and not self._holdingKey then
+ return false
+ end
+ return true
+ end
+
+ function api:trigger(teleportId, targetName)
+ if string.sub(targetName, 1, 4) == 'door' then
+ return
+ end
+
+ -- Decide if the correct teleport is triggered.
+ local reward = 0
+ if teleportId == self._shuffledTargetIndex then
+ self.setInstruction(tostring(kwargs.correctReward))
+ reward = kwargs.correctReward
+ else
+ self.setInstruction(tostring(kwargs.incorrectReward))
+ reward = kwargs.incorrectReward
+ end
+
+ game:addScore(reward)
+ self._timeOut = self._time + 0.2
+ end
+
+ function api:hasEpisodeFinished(timeSeconds)
+ self._time = timeSeconds
+
+ -- Decide the timing of showing the key cue.
+ if self._isKeyToPaintingLevel and self._holdingKey then
+ showTime = self._time - self._holdingKeyTime
+ if showTime > kwargs.showKeyColorSquareSeconds then
+ self._showKeyCue = false
+ end
+ end
+
+ if self._map == EXPLORE_MAP or self._map == DISTRACTOR_MAP or
+ self._map == SECOND_ORDER_EXPLORE_MAP then
+ if timeSeconds > self._timeOut then
+ game:finishMap()
+ end
+ return false
+ else -- In the image room map, timeout only after been teleported.
+ return self._timeOut and timeSeconds > self._timeOut
+ end
+ end
+
+ -- END TRIGGER functions -----------------------------------------------------
+
+ function api:updateSpawnVars(spawnVars)
+ local classname = spawnVars.classname
+ if classname == "info_player_start" then
+ -- Spawn facing South.
+ spawnVars.angle = "-90"
+ spawnVars.randomAngleRange = "0"
+ elseif classname == "trigger_teleport" then
+ self._teleportId = self._teleportId + 1
+ spawnVars.id = tostring(self._teleportId)
+ elseif classname == "func_door" then
+ spawnVars.id = tostring(DOOR_ID)
+ spawnVars.wait = "1000000" -- Open the door for long time.
+ elseif classname == "apple_reward" then
+ local useApple = false
+ if kwargs.probAppleInDistractorMap > 0 then
+ useApple = random:uniformReal(0, 1) < kwargs.probAppleInDistractorMap
+ spawnVars.id = tostring(APPLE_ID)
+ end
+ if not useApple then
+ return nil
+ end
+ elseif classname == "key" then
+ self._possibleKeyPosCount = self._possibleKeyPosCount + 1
+ if self._keyPosition == self._possibleKeyPosCount then
+ spawnVars.id = tostring(KEY_OBJECT_SPAWN_ID)
+ spawnVars.classname = self._keyObject
+ else
+ return nil
+ end
+ end
+ return spawnVars
+ end
+
+ custom_observations.decorate(api)
+ pickup_decorator.decorate(api)
+ setting_overrides.decorate{
+ api = api,
+ apiParams = kwargs,
+ decorateWithTimeout = true
+ }
+ return api
+end
+
+return factory
diff --git a/tvt/images/RMA_gamma1_KtD.png b/tvt/images/RMA_gamma1_KtD.png
new file mode 100644
index 0000000..f3d7697
Binary files /dev/null and b/tvt/images/RMA_gamma1_KtD.png differ
diff --git a/tvt/images/RMA_gamma1_im2r.png b/tvt/images/RMA_gamma1_im2r.png
new file mode 100644
index 0000000..278a483
Binary files /dev/null and b/tvt/images/RMA_gamma1_im2r.png differ
diff --git a/tvt/images/avm_notvt.png b/tvt/images/avm_notvt.png
new file mode 100644
index 0000000..44b6b2a
Binary files /dev/null and b/tvt/images/avm_notvt.png differ
diff --git a/tvt/images/avm_tvt.png b/tvt/images/avm_tvt.png
new file mode 100644
index 0000000..94a1a49
Binary files /dev/null and b/tvt/images/avm_tvt.png differ
diff --git a/tvt/images/ktd_notvt.png b/tvt/images/ktd_notvt.png
new file mode 100644
index 0000000..0a5b678
Binary files /dev/null and b/tvt/images/ktd_notvt.png differ
diff --git a/tvt/images/ktd_tvt.png b/tvt/images/ktd_tvt.png
new file mode 100644
index 0000000..5d31f93
Binary files /dev/null and b/tvt/images/ktd_tvt.png differ
diff --git a/tvt/losses.py b/tvt/losses.py
new file mode 100644
index 0000000..d83e352
--- /dev/null
+++ b/tvt/losses.py
@@ -0,0 +1,157 @@
+# Lint as: python2, python3
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Loss functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+import tensorflow as tf
+
+
+def sum_time_average_batch(tensor, name=None):
+ """Computes the mean over B assuming tensor is of shape [T, B]."""
+ tensor.get_shape().assert_has_rank(2)
+ return tf.reduce_mean(tf.reduce_sum(tensor, axis=0), axis=0, name=name)
+
+
+def combine_logged_values(*logged_values_dicts):
+ """Combine logged values dicts. Throws if there are any repeated keys."""
+ combined_dict = dict()
+ for logged_values in logged_values_dicts:
+ for k, v in six.iteritems(logged_values):
+ if k in combined_dict:
+ raise ValueError('Key "%s" is repeated in loss logging.' % k)
+ combined_dict[k] = v
+ return combined_dict
+
+
+def reconstruction_losses(
+ recons,
+ targets,
+ image_cost,
+ action_cost,
+ reward_cost):
+ """Reconstruction losses."""
+ if image_cost > 0.0:
+ # Neg log prob of obs image given Bernoulli(recon image) distribution.
+ negative_image_log_prob = tf.nn.sigmoid_cross_entropy_with_logits(
+ labels=targets.image, logits=recons.image)
+ nll_per_time = tf.reduce_sum(negative_image_log_prob, [-3, -2, -1])
+ image_loss = image_cost * nll_per_time
+ image_loss = sum_time_average_batch(image_loss)
+ else:
+ image_loss = tf.constant(0.)
+
+ if action_cost > 0.0 and recons.last_action is not tuple():
+ # Labels have shape (T, B), logits have shape (T, B, num_actions).
+ action_loss = action_cost * tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=targets.last_action, logits=recons.last_action)
+ action_loss = sum_time_average_batch(action_loss)
+ else:
+ action_loss = tf.constant(0.)
+
+ if reward_cost > 0.0 and recons.last_reward is not tuple():
+ # MSE loss for reward.
+ recon_last_reward = recons.last_reward
+ recon_last_reward = tf.squeeze(recon_last_reward, -1)
+ reward_loss = 0.5 * reward_cost * tf.square(
+ recon_last_reward - targets.last_reward)
+ reward_loss = sum_time_average_batch(reward_loss)
+ else:
+ reward_loss = tf.constant(0.)
+
+ total_loss = image_loss + action_loss + reward_loss
+
+ logged_values = dict(
+ recon_loss_image=image_loss,
+ recon_loss_action=action_loss,
+ recon_loss_reward=reward_loss,
+ total_reconstruction_loss=total_loss,)
+
+ return total_loss, logged_values
+
+
+def read_regularization_loss(
+ read_info,
+ strength_cost,
+ strength_tolerance,
+ strength_reg_mode,
+ key_norm_cost,
+ key_norm_tolerance):
+ """Computes the sum of read strength and read key regularization losses."""
+
+ if (strength_cost <= 0.) and (key_norm_cost <= 0.):
+ read_reg_loss = tf.constant(0.)
+ return read_reg_loss, dict(read_regularization_loss=read_reg_loss)
+
+ if hasattr(read_info, 'read_strengths'):
+ read_strengths = read_info.read_strengths
+ read_keys = read_info.read_keys
+ else:
+ read_strengths = read_info.strengths
+ read_keys = read_info.keys
+
+ if read_info == tuple():
+ raise ValueError('Make sure read regularization costs are zero when '
+ 'not outputting read info.')
+
+ read_reg_loss = tf.constant(0.)
+ if strength_cost > 0.:
+ strength_hinged = tf.maximum(strength_tolerance, read_strengths)
+ if strength_reg_mode == 'L2':
+ strength_loss = 0.5 * tf.square(strength_hinged)
+ elif strength_reg_mode == 'L1':
+ # Read strengths are always positive.
+ strength_loss = strength_hinged
+ else:
+ raise ValueError(
+ 'Strength regularization mode "{}" is not supported.'.format(
+ strength_reg_mode))
+
+ # Sum across read heads to reduce from [T, B, n_reads] to [T, B].
+ strength_loss = strength_cost * tf.reduce_sum(strength_loss, axis=2)
+
+ if key_norm_cost > 0.:
+ key_norm_norms = tf.norm(read_keys, axis=-1)
+ key_norm_norms_hinged = tf.maximum(key_norm_tolerance, key_norm_norms)
+ key_norm_loss = 0.5 * tf.square(key_norm_norms_hinged)
+
+ # Sum across read heads to reduce from [T, B, n_reads] to [T, B].
+ key_norm_loss = key_norm_cost * tf.reduce_sum(key_norm_loss, axis=2)
+
+ read_reg_loss += key_norm_cost * key_norm_loss
+
+ if strength_cost > 0.:
+ strength_loss = sum_time_average_batch(strength_loss)
+ else:
+ strength_loss = tf.constant(0.)
+
+ if key_norm_cost > 0.:
+ key_norm_loss = sum_time_average_batch(key_norm_loss)
+ else:
+ key_norm_loss = tf.constant(0.)
+
+ read_reg_loss = strength_loss + key_norm_loss
+
+ logged_values = dict(
+ read_reg_strength_loss=strength_loss,
+ read_reg_key_norm_loss=key_norm_loss,
+ total_read_reg_loss=read_reg_loss)
+
+ return read_reg_loss, logged_values
diff --git a/tvt/main.py b/tvt/main.py
new file mode 100644
index 0000000..c81b123
--- /dev/null
+++ b/tvt/main.py
@@ -0,0 +1,258 @@
+# Lint as: python2, python3
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Batched synchronous actor/learner training."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl import app
+from absl import flags
+from absl import logging
+import numpy as np
+from six.moves import range
+from six.moves import zip
+import tensorflow as tf
+
+from tvt import batch_env
+from tvt import nest_utils
+from tvt import rma
+from tvt import tvt_rewards as tvt_module
+from tvt.pycolab import env as pycolab_env
+from tensorflow.contrib import framework as contrib_framework
+
+nest = contrib_framework.nest
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_integer('logging_frequency', 1,
+ 'Log training progress every logging_frequency episodes.')
+flags.DEFINE_string('logdir', None, 'Directory for tensorboard logging.')
+
+flags.DEFINE_boolean('with_memory', True,
+ 'whether or not agent has external memory.')
+flags.DEFINE_boolean('with_reconstruction', True,
+ 'whether or not agent reconstruct the observation.')
+flags.DEFINE_float('gamma', 0.92, 'Agent discount factor')
+flags.DEFINE_float('entropy_cost', 0.05, 'weight of the entropy loss')
+flags.DEFINE_float('image_cost_weight', 50., 'image recon cost weight.')
+flags.DEFINE_float('read_strength_cost', 5e-5,
+ 'Cost weight of the memory read strength.')
+flags.DEFINE_float('read_strength_tolerance', 2.,
+ 'The tolerance of hinge loss of the read_strength_cost.')
+flags.DEFINE_boolean('do_tvt', True, 'whether or not do tvt')
+flags.DEFINE_enum('pycolab_game', 'key_to_door',
+ ['key_to_door', 'active_visual_match'],
+ 'The name of the game in pycolab environment')
+flags.DEFINE_integer('num_episodes', None,
+ 'Number of episodes to train for. None means run forever.')
+
+flags.DEFINE_integer('batch_size', 16, 'Batch size')
+
+flags.DEFINE_float('learning_rate', 2e-4, 'Adam optimizer learning rate')
+flags.DEFINE_float('beta1', 0., 'Adam optimizer beta1')
+flags.DEFINE_float('beta2', 0.95, 'Adam optimizer beta2')
+flags.DEFINE_float('epsilon', 1e-6, 'Adam optimizer epsilon')
+
+# Pycolab-specific flags:
+flags.DEFINE_integer('pycolab_num_apples', 10,
+ 'Number of apples to sample from the distractor grid.')
+flags.DEFINE_float('pycolab_apple_reward_min', 1.,
+ 'A reward range [min, max) to uniformly sample from.')
+flags.DEFINE_float('pycolab_apple_reward_max', 10.,
+ 'A reward range [min, max) to uniformly sample from.')
+flags.DEFINE_boolean('pycolab_fix_apple_reward_in_episode', True,
+ 'Fix the sampled apple reward within an episode.')
+flags.DEFINE_float('pycolab_final_reward', 10.,
+ 'Reward obtained at the last phase.')
+flags.DEFINE_boolean('pycolab_crop', True,
+ 'Whether to crop observations or not.')
+
+
+def main(_):
+
+ batch_size = FLAGS.batch_size
+ env_builder = pycolab_env.PycolabEnvironment
+ env_kwargs = {
+ 'game': FLAGS.pycolab_game,
+ 'num_apples': FLAGS.pycolab_num_apples,
+ 'apple_reward': [FLAGS.pycolab_apple_reward_min,
+ FLAGS.pycolab_apple_reward_max],
+ 'fix_apple_reward_in_episode': FLAGS.pycolab_fix_apple_reward_in_episode,
+ 'final_reward': FLAGS.pycolab_final_reward,
+ 'crop': FLAGS.pycolab_crop
+ }
+ env = batch_env.BatchEnv(batch_size, env_builder, **env_kwargs)
+ ep_length = env.episode_length
+
+ agent = rma.Agent(batch_size=batch_size,
+ num_actions=env.num_actions,
+ observation_shape=env.observation_shape,
+ with_reconstructions=FLAGS.with_reconstruction,
+ gamma=FLAGS.gamma,
+ read_strength_cost=FLAGS.read_strength_cost,
+ read_strength_tolerance=FLAGS.read_strength_tolerance,
+ entropy_cost=FLAGS.entropy_cost,
+ with_memory=FLAGS.with_memory,
+ image_cost_weight=FLAGS.image_cost_weight)
+
+ # Agent step placeholders and agent step.
+ batch_shape = (batch_size,)
+ observation_ph = tf.placeholder(
+ dtype=tf.uint8, shape=batch_shape + env.observation_shape, name='obs')
+ reward_ph = tf.placeholder(
+ dtype=tf.float32, shape=batch_shape, name='reward')
+ state_ph = nest.map_structure(
+ lambda s: tf.placeholder(dtype=s.dtype, shape=s.shape, name='state'),
+ agent.initial_state(batch_size=batch_size))
+ step_outputs, state = agent.step(reward_ph, observation_ph, state_ph)
+
+ # Update op placeholders and update op.
+ observations_ph = tf.placeholder(
+ dtype=tf.uint8, shape=(ep_length + 1, batch_size) + env.observation_shape,
+ name='observations')
+ rewards_ph = tf.placeholder(
+ dtype=tf.float32, shape=(ep_length + 1, batch_size), name='rewards')
+ actions_ph = tf.placeholder(
+ dtype=tf.int64, shape=(ep_length, batch_size), name='actions')
+ tvt_rewards_ph = tf.placeholder(
+ dtype=tf.float32, shape=(ep_length, batch_size), name='tvt_rewards')
+
+ loss, loss_logs = agent.loss(
+ observations_ph, rewards_ph, actions_ph, tvt_rewards_ph)
+
+ optimizer = tf.train.AdamOptimizer(
+ learning_rate=FLAGS.learning_rate,
+ beta1=FLAGS.beta1,
+ beta2=FLAGS.beta2,
+ epsilon=FLAGS.epsilon)
+ update_op = optimizer.minimize(loss)
+ initial_state = agent.initial_state(batch_size)
+
+ if FLAGS.logdir:
+ if not tf.io.gfile.exists(FLAGS.logdir):
+ tf.io.gfile.makedirs(FLAGS.logdir)
+ summary_writer = tf.summary.FileWriter(FLAGS.logdir)
+
+ # Do init
+ init_ops = (tf.global_variables_initializer(),
+ tf.local_variables_initializer())
+ tf.get_default_graph().finalize()
+
+ sess = tf.Session()
+ sess.run(init_ops)
+
+ run = True
+ ep_num = 0
+ prev_logging_time = time.time()
+ while run:
+ observation, reward = env.reset()
+ agent_state = sess.run(initial_state)
+
+ # Initialise episode data stores.
+ observations = [observation]
+ rewards = [reward]
+ actions = []
+ baselines = []
+ read_infos = []
+
+ for _ in range(ep_length):
+ step_feed = {reward_ph: reward, observation_ph: observation}
+ for ph, ar in zip(nest.flatten(state_ph), nest.flatten(agent_state)):
+ step_feed[ph] = ar
+ step_output, agent_state = sess.run(
+ (step_outputs, state), feed_dict=step_feed)
+ action = step_output.action
+ baseline = step_output.baseline
+ read_info = step_output.read_info
+
+ # Take step in environment, append results.
+ observation, reward = env.step(action)
+
+ observations.append(observation)
+ rewards.append(reward)
+ actions.append(action)
+ baselines.append(baseline)
+ if read_info is not None:
+ read_infos.append(read_info)
+
+ # Stack the lists of length ep_length so that each array (or each element
+ # of nest stucture for read_infos) has shape (ep_length, batch_size, ...).
+ observations = np.stack(observations)
+ rewards = np.array(rewards)
+ actions = np.array(actions)
+ baselines = np.array(baselines)
+ read_infos = nest_utils.nest_stack(read_infos)
+
+ # Compute TVT rewards.
+ if FLAGS.do_tvt:
+ tvt_rewards = tvt_module.compute_tvt_rewards(read_infos,
+ baselines,
+ gamma=FLAGS.gamma)
+ else:
+ tvt_rewards = np.squeeze(np.zeros_like(baselines))
+
+ # Run update op.
+ loss_feed = {observations_ph: observations,
+ rewards_ph: rewards,
+ actions_ph: actions,
+ tvt_rewards_ph: tvt_rewards}
+ ep_loss, _, ep_loss_logs = sess.run([loss, update_op, loss_logs],
+ feed_dict=loss_feed)
+
+ # Log episode results.
+ if ep_num % FLAGS.logging_frequency == 0:
+ steps_per_second = (
+ FLAGS.logging_frequency * ep_length * batch_size / (
+ time.time() - prev_logging_time))
+ mean_reward = np.mean(np.sum(rewards, axis=0))
+ mean_last_phase_reward = np.mean(env.last_phase_rewards())
+ mean_tvt_reward = np.mean(np.sum(tvt_rewards, axis=0))
+
+ logging.info('Episode %d. SPS: %s', ep_num, steps_per_second)
+ logging.info('Episode %d. Mean episode reward: %f', ep_num, mean_reward)
+ logging.info('Episode %d. Last phase reward: %f', ep_num,
+ mean_last_phase_reward)
+ logging.info('Episode %d. Mean TVT episode reward: %f', ep_num,
+ mean_tvt_reward)
+ logging.info('Episode %d. Loss: %s', ep_num, ep_loss)
+ logging.info('Episode %d. Loss logs: %s', ep_num, ep_loss_logs)
+
+ if FLAGS.logdir:
+ summary = tf.Summary()
+ summary.value.add(tag='reward', simple_value=mean_reward)
+ summary.value.add(tag='last phase reward',
+ simple_value=mean_last_phase_reward)
+ summary.value.add(tag='tvt reward', simple_value=mean_tvt_reward)
+ summary.value.add(tag='total loss', simple_value=ep_loss)
+ for k, v in ep_loss_logs.items():
+ summary.value.add(tag='loss - {}'.format(k), simple_value=v)
+ # Tensorboard x-axis is total number of episodes run.
+ summary_writer.add_summary(summary, ep_num * batch_size)
+ summary_writer.flush()
+
+ prev_logging_time = time.time()
+
+ ep_num += 1
+ if FLAGS.num_episodes and ep_num >= FLAGS.num_episodes:
+ run = False
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/tvt/memory.py b/tvt/memory.py
new file mode 100644
index 0000000..56bc80a
--- /dev/null
+++ b/tvt/memory.py
@@ -0,0 +1,294 @@
+# Lint as: python2, python3
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Memory Reader/Writer for RMA."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import sonnet as snt
+import tensorflow as tf
+
+ReadInformation = collections.namedtuple(
+ 'ReadInformation', ('weights', 'indices', 'keys', 'strengths'))
+
+
+class MemoryWriter(snt.RNNCore):
+ """Memory Writer Module."""
+
+ def __init__(self, mem_shape, name='memory_writer'):
+ """Initializes the `MemoryWriter`.
+
+ Args:
+ mem_shape: The shape of the memory `(num_rows, memory_width)`.
+ name: The name to use for the Sonnet module.
+ """
+ super(MemoryWriter, self).__init__(name=name)
+ self._mem_shape = mem_shape
+
+ def _build(self, inputs, state):
+ """Inserts z into the argmin row of usage markers and updates all rows.
+
+ Returns an operation that, when executed, correctly updates the internal
+ state and usage markers.
+
+ Args:
+ inputs: A tuple consisting of:
+ * z, the value to write at this timestep
+ * mem_state, the state of the memory at this timestep before writing
+ state: The state is just the write_counter.
+
+ Returns:
+ A tuple of the new memory state and a tuple containing the next state.
+ """
+ z, mem_state = inputs
+
+ # Stop gradient on writes to memory.
+ z = tf.stop_gradient(z)
+
+ prev_write_counter = state
+ new_row_value = z
+
+ # Find the index to insert the next row into.
+ num_mem_rows = self._mem_shape[0]
+ write_index = tf.cast(prev_write_counter, dtype=tf.int32) % num_mem_rows
+ one_hot_row = tf.one_hot(write_index, num_mem_rows)
+ write_counter = prev_write_counter + 1
+
+ # Insert state variable to new row.
+ # First you need to size it up to the full size.
+ insert_new_row = lambda mem, o_hot, z: mem - (o_hot * mem) + (o_hot * z)
+ new_mem = insert_new_row(mem_state,
+ tf.expand_dims(one_hot_row, axis=-1),
+ tf.expand_dims(new_row_value, axis=-2))
+
+ new_state = write_counter
+
+ return new_mem, new_state
+
+ @property
+ def state_size(self):
+ """Returns a description of the state size, without batch dimension."""
+ return tf.TensorShape([])
+
+ @property
+ def output_size(self):
+ """Returns a description of the output size, without batch dimension."""
+ return self._mem_shape
+
+
+class MemoryReader(snt.AbstractModule):
+ """Memory Reader Module."""
+
+ def __init__(self,
+ memory_word_size,
+ num_read_heads,
+ top_k=0,
+ memory_size=None,
+ name='memory_reader'):
+ """Initializes the `MemoryReader`.
+
+ Args:
+ memory_word_size: The dimension of the 1-D read keys this memory reader
+ should produce. Each row of the memory is of length `memory_word_size`.
+ num_read_heads: The number of reads to perform.
+ top_k: Softmax and summation when reading is only over top k most similar
+ entries in memory. top_k=0 (default) means dense reads, i.e. no top_k.
+ memory_size: Number of rows in memory.
+ name: The name for this Sonnet module.
+ """
+ super(MemoryReader, self).__init__(name=name)
+ self._memory_word_size = memory_word_size
+ self._num_read_heads = num_read_heads
+ self._top_k = top_k
+
+ # This is not an RNNCore but it is useful to expose the output size.
+ self._output_size = num_read_heads * memory_word_size
+
+ num_read_weights = top_k if top_k > 0 else memory_size
+ self._read_info_size = ReadInformation(
+ weights=tf.TensorShape([num_read_heads, num_read_weights]),
+ indices=tf.TensorShape([num_read_heads, num_read_weights]),
+ keys=tf.TensorShape([num_read_heads, memory_word_size]),
+ strengths=tf.TensorShape([num_read_heads]),
+ )
+
+ with self._enter_variable_scope():
+ # Transforms to value-based read for each read head.
+ output_dim = (memory_word_size + 1) * num_read_heads
+ self._keys_and_read_strengths_generator = snt.Linear(output_dim)
+
+ def _build(self, inputs):
+ """Looks up rows in memory.
+
+ In the args list, we have the following conventions:
+ B: batch size
+ M: number of slots in a row of the memory matrix
+ R: number of rows in the memory matrix
+ H: number of read heads in the memory controller
+
+ Args:
+ inputs: A tuple of
+ * read_inputs, a tensor of shape [B, ...] that will be flattened and
+ passed through a linear layer to get read keys/read_strengths for
+ each head.
+ * mem_state, the primary memory tensor. Of shape [B, R, M].
+
+ Returns:
+ The read from the memory (concatenated across read heads) and read
+ information.
+ """
+ # Assert input shapes are compatible and separate inputs.
+ _assert_compatible_memory_reader_input(inputs)
+ read_inputs, mem_state = inputs
+
+ # Determine the read weightings for each key.
+ flat_outputs = self._keys_and_read_strengths_generator(
+ snt.BatchFlatten()(read_inputs))
+
+ # Separate the read_strengths from the rest of the weightings.
+ h = self._num_read_heads
+ flat_keys = flat_outputs[:, :-h]
+ read_strengths = tf.nn.softplus(flat_outputs[:, -h:])
+
+ # Reshape the weights.
+ read_shape = (self._num_read_heads, self._memory_word_size)
+ read_keys = snt.BatchReshape(read_shape)(flat_keys)
+
+ # Read from memory.
+ memory_reads, read_weights, read_indices, read_strengths = (
+ read_from_memory(read_keys, read_strengths, mem_state, self._top_k))
+ concatenated_reads = snt.BatchFlatten()(memory_reads)
+
+ return concatenated_reads, ReadInformation(
+ weights=read_weights,
+ indices=read_indices,
+ keys=read_keys,
+ strengths=read_strengths)
+
+ @property
+ def output_size(self):
+ """Returns a description of the output size, without batch dimension."""
+ return self._output_size, self._read_info_size
+
+
+def read_from_memory(read_keys, read_strengths, mem_state, top_k):
+ """Function for cosine similarity content based reading from memory matrix.
+
+ In the args list, we have the following conventions:
+ B: batch size
+ M: number of slots in a row of the memory matrix
+ R: number of rows in the memory matrix
+ H: number of read heads (of the controller or the policy)
+ K: top_k if top_k>0
+
+ Args:
+ read_keys: the read keys of shape [B, H, M].
+ read_strengths: the coefficients used to compute the normalised weighting
+ vector of shape [B, H].
+ mem_state: the primary memory tensor. Of shape [B, R, M].
+ top_k: only use top k read matches, other reads do not go into softmax and
+ are zeroed out in the output. top_k=0 (default) means use dense reads.
+
+ Returns:
+ The memory reads [B, H, M], read weights [B, H, top k], read indices
+ [B, H, top k], and read strengths [B, H, 1].
+ """
+ _assert_compatible_read_from_memory_inputs(read_keys, read_strengths,
+ mem_state)
+ batch_size = read_keys.shape[0]
+ num_read_heads = read_keys.shape[1]
+
+ with tf.name_scope('memory_reading'):
+ # Scale such that all rows are L2-unit vectors, for memory and read query.
+ scaled_read_keys = tf.math.l2_normalize(read_keys, axis=-1) # [B, H, M]
+ scaled_mem = tf.math.l2_normalize(mem_state, axis=-1) # [B, R, M]
+
+ # The cosine distance is then their dot product.
+ # Find the cosine distance between each read head and each row of memory.
+ cosine_distances = tf.matmul(
+ scaled_read_keys, scaled_mem, transpose_b=True) # [B, H, R]
+
+ # The rank must match cosine_distances for broadcasting to work.
+ read_strengths = tf.expand_dims(read_strengths, axis=-1) # [B, H, 1]
+ weighted_distances = read_strengths * cosine_distances # [B, H, R]
+
+ if top_k:
+ # Get top k indices (row indices with top k largest weighted distances).
+ top_k_output = tf.nn.top_k(weighted_distances, top_k, sorted=False)
+ read_indices = top_k_output.indices # [B, H, K]
+
+ # Create a sub-memory for each read head with only the top k rows.
+ # Each batch_gather is [B, K, M] and the list stacks to [B, H, K, M].
+ topk_mem_per_head = [tf.batch_gather(mem_state, ri_this_head)
+ for ri_this_head in tf.unstack(read_indices, axis=1)]
+ topk_mem = tf.stack(topk_mem_per_head, axis=1) # [B, H, K, M]
+ topk_scaled_mem = tf.math.l2_normalize(topk_mem, axis=-1) # [B, H, K, M]
+
+ # Calculate read weights for each head's top k sub-memory.
+ expanded_scaled_read_keys = tf.expand_dims(
+ scaled_read_keys, axis=2) # [B, H, 1, M]
+ topk_cosine_distances = tf.reduce_sum(
+ expanded_scaled_read_keys * topk_scaled_mem, axis=-1) # [B, H, K]
+ topk_weighted_distances = (
+ read_strengths * topk_cosine_distances) # [B, H, K]
+ read_weights = tf.nn.softmax(
+ topk_weighted_distances, axis=-1) # [B, H, K]
+
+ # For each head, read using the sub-memories and corresponding weights.
+ expanded_weights = tf.expand_dims(read_weights, axis=-1) # [B, H, K, 1]
+ memory_reads = tf.reduce_sum(
+ expanded_weights * topk_mem, axis=2) # [B, H, M]
+ else:
+ read_weights = tf.nn.softmax(weighted_distances, axis=-1)
+
+ num_rows_memory = mem_state.shape[1]
+ all_indices = tf.range(num_rows_memory, dtype=tf.int32)
+ all_indices = tf.reshape(all_indices, [1, 1, num_rows_memory])
+ read_indices = tf.tile(all_indices, [batch_size, num_read_heads, 1])
+
+ # This is the actual memory access.
+ # Note that matmul automatically batch applies for us.
+ memory_reads = tf.matmul(read_weights, mem_state)
+
+ read_keys.shape.assert_is_compatible_with(memory_reads.shape)
+
+ read_strengths = tf.squeeze(read_strengths, axis=-1) # [B, H, 1] -> [B, H]
+
+ return memory_reads, read_weights, read_indices, read_strengths
+
+
+def _assert_compatible_read_from_memory_inputs(read_keys, read_strengths,
+ mem_state):
+ read_keys.shape.assert_has_rank(3)
+ b_shape, h_shape, m_shape = read_keys.shape
+ mem_state.shape.assert_has_rank(3)
+ r_shape = mem_state.shape[1]
+
+ read_strengths.shape.assert_is_compatible_with(
+ tf.TensorShape([b_shape, h_shape]))
+ mem_state.shape.assert_is_compatible_with(
+ tf.TensorShape([b_shape, r_shape, m_shape]))
+
+
+def _assert_compatible_memory_reader_input(input_tensors):
+ """Asserts MemoryReader's _build has been given the correct shapes."""
+ assert len(input_tensors) == 2
+ _, mem_state = input_tensors
+ mem_state.shape.assert_has_rank(3)
diff --git a/tvt/nest_utils.py b/tvt/nest_utils.py
new file mode 100644
index 0000000..d95a116
--- /dev/null
+++ b/tvt/nest_utils.py
@@ -0,0 +1,85 @@
+# Lint as: python2, python3
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""nest utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import range
+from tensorflow.contrib import framework as contrib_framework
+
+nest = contrib_framework.nest
+
+
+def _nest_apply_over_list(list_of_nests, fn):
+ """Equivalent to fn, but works on list-of-nests.
+
+ Transforms a list-of-nests to a nest-of-lists, then applies `fn`
+ to each of the inner lists.
+
+ It is assumed that all nests have the same structure. Elements of the nest may
+ be None, in which case they are ignored, i.e. they do not form part of the
+ stack. This is useful when stacking agent states where parts of the state nest
+ have been filtered.
+
+ Args:
+ list_of_nests: A Python list of nests.
+ fn: the function applied on the list of leaves.
+
+ Returns:
+ A nest-of-arrays, where the arrays are formed by `fn`ing a list.
+ """
+ list_of_flat_nests = [nest.flatten(n) for n in list_of_nests]
+ flat_nest_of_stacks = []
+ for position in range(len(list_of_flat_nests[0])):
+ new_list = [flat_nest[position] for flat_nest in list_of_flat_nests]
+ new_list = [x for x in new_list if x is not None]
+ flat_nest_of_stacks.append(fn(new_list))
+ return nest.pack_sequence_as(
+ structure=list_of_nests[0], flat_sequence=flat_nest_of_stacks)
+
+
+def _take_indices(inputs, indices):
+ return nest.map_structure(lambda t: np.take(t, indices, axis=0), inputs)
+
+
+def nest_stack(list_of_nests, axis=0):
+ """Equivalent to np.stack, but works on list-of-nests.
+
+ Transforms a list-of-nests to a nest-of-lists, then applies `np.stack`
+ to each of the inner lists.
+
+ It is assumed that all nests have the same structure. Elements of the nest may
+ be None, in which case they are ignored, i.e. they do not form part of the
+ stack. This is useful when stacking agent states where parts of the state nest
+ have been filtered.
+
+ Args:
+ list_of_nests: A Python list of nests.
+ axis: Optional, the `axis` argument for `np.stack`.
+
+ Returns:
+ A nest-of-arrays, where the arrays are formed by `np.stack`ing a list.
+ """
+ return _nest_apply_over_list(list_of_nests, lambda l: np.stack(l, axis=axis))
+
+
+def nest_unstack(batched_inputs, batch_size):
+ """Splits a sequence of numpy arrays along 0th dimension."""
+ return [_take_indices(batched_inputs, idx) for idx in range(batch_size)]
diff --git a/tvt/pycolab/README.md b/tvt/pycolab/README.md
new file mode 100644
index 0000000..eb6bd31
--- /dev/null
+++ b/tvt/pycolab/README.md
@@ -0,0 +1,31 @@
+# Pycolab Tasks
+
+## Playing the Pycolab Tasks
+
+We provide a script to allow human play of the Pycolab tasks. To play, run e.g.
+
+`python3 pycolab/human_player.py -- --game=key_to_door`
+
+## The Pycolab Tasks
+
+There are 2 [Pycolab](https://github.com/deepmind/pycolab) tasks presented here.
+Each level is composed of 3 distinct phases. The first phase is the 'explore'
+phase, where the agent should learn a piece of information or do something. For
+both tasks, the 2nd phase is the 'distractor' phase, where the agent collects
+apples for rewards. The 3rd phase is the 'exploit' phase, where the agent gets
+rewards based on the knowledge acquired or actions performed in phase 1.
+
+Special thanks to Hamza Merzic for writing these task scripts.
+
+### Active Visual Match
+
+* Phase 1: A colour square randomly placed in a two-connected room.
+* Phase 2: Apples collection.
+* Phase 3: Choose the colour square matched that in Phase 1 among 4 options.
+
+### Key To Door
+
+* Phase 1: A key randomly placed in a two-connected room.
+* Phase 2: Apples collection.
+* Phase 3: A small room with a door. If agent has key, it can open the door to
+ get to the goal behind the door to get reward.
diff --git a/tvt/pycolab/active_visual_match.py b/tvt/pycolab/active_visual_match.py
new file mode 100644
index 0000000..bfe4d7f
--- /dev/null
+++ b/tvt/pycolab/active_visual_match.py
@@ -0,0 +1,162 @@
+# Lint as: python2, python3
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Active visual match task.
+
+The game is split up into three phases:
+1. (exploration phase) player is in one room and there's a colour in the other,
+2. (distractor phase) player is collecting apples,
+3. (reward phase) player sees three doors of different colours and has to select
+ the one of the same color as the colour in the first phase.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from pycolab import ascii_art
+from pycolab import storytelling
+
+from tvt.pycolab import common
+from tvt.pycolab import game
+from tvt.pycolab import objects
+
+
+SYMBOLS_TO_SHUFFLE = ['b', 'c', 'e']
+
+EXPLORE_GRID = [
+ ' ppppppp ',
+ ' p p ',
+ ' p p ',
+ ' pp pp ',
+ ' p+++++p ',
+ ' p+++++p ',
+ ' ppppppp '
+]
+
+REWARD_GRID = [
+ '###########',
+ '# b c e #',
+ '# #',
+ '# #',
+ '#### ####',
+ ' # + # ',
+ ' ##### '
+]
+
+
+class Game(game.AbstractGame):
+ """Image Match Passive Game."""
+
+ def __init__(self,
+ rng,
+ num_apples=10,
+ apple_reward=(1, 10),
+ fix_apple_reward_in_episode=True,
+ final_reward=10.,
+ max_frames=common.DEFAULT_MAX_FRAMES_PER_PHASE):
+ self._rng = rng
+ self._num_apples = num_apples
+ self._apple_reward = apple_reward
+ self._fix_apple_reward_in_episode = fix_apple_reward_in_episode
+ self._final_reward = final_reward
+ self._max_frames = max_frames
+ self._episode_length = sum(self._max_frames.values())
+ self._num_actions = common.NUM_ACTIONS
+ self._colours = common.FIXED_COLOURS.copy()
+ self._colours.update(
+ common.get_shuffled_symbol_colour_map(rng, SYMBOLS_TO_SHUFFLE))
+
+ self._extra_observation_fields = ['chapter_reward_as_string']
+
+ @property
+ def extra_observation_fields(self):
+ """The field names of extra observations."""
+ return self._extra_observation_fields
+
+ @property
+ def num_actions(self):
+ """Number of possible actions in the game."""
+ return self._num_actions
+
+ @property
+ def episode_length(self):
+ return self._episode_length
+
+ @property
+ def colours(self):
+ """Symbol to colour map for key to door."""
+ return self._colours
+
+ def _make_explore_phase(self, target_char):
+ # Keep only one coloured position and one player position.
+ grid = common.keep_n_characters_in_grid(EXPLORE_GRID, 'p', 1, common.BORDER)
+ grid = common.keep_n_characters_in_grid(grid, 'p', 0, target_char)
+ grid = common.keep_n_characters_in_grid(grid, common.PLAYER, 1)
+
+ return ascii_art.ascii_art_to_game(
+ grid,
+ what_lies_beneath=' ',
+ sprites={
+ common.PLAYER:
+ ascii_art.Partial(
+ common.PlayerSprite,
+ impassable=common.BORDER + target_char),
+ target_char:
+ objects.ObjectSprite,
+ common.TIMER:
+ ascii_art.Partial(common.TimerSprite,
+ self._max_frames['explore']),
+ },
+ update_schedule=[common.PLAYER, target_char, common.TIMER],
+ z_order=[target_char, common.PLAYER, common.TIMER],
+ )
+
+ def _make_distractor_phase(self):
+ return common.distractor_phase(
+ player_sprite=common.PlayerSprite,
+ num_apples=self._num_apples,
+ max_frames=self._max_frames['distractor'],
+ apple_reward=self._apple_reward,
+ fix_apple_reward_in_episode=self._fix_apple_reward_in_episode)
+
+ def _make_reward_phase(self, target_char):
+ return ascii_art.ascii_art_to_game(
+ REWARD_GRID,
+ what_lies_beneath=' ',
+ sprites={
+ common.PLAYER: common.PlayerSprite,
+ 'b': objects.ObjectSprite,
+ 'c': objects.ObjectSprite,
+ 'e': objects.ObjectSprite,
+ common.TIMER: ascii_art.Partial(common.TimerSprite,
+ self._max_frames['reward'],
+ track_chapter_reward=True),
+ target_char: ascii_art.Partial(objects.ObjectSprite,
+ reward=self._final_reward),
+ },
+ update_schedule=[common.PLAYER, 'b', 'c', 'e', common.TIMER],
+ z_order=[common.PLAYER, 'b', 'c', 'e', common.TIMER],
+ )
+
+ def make_episode(self):
+ """Factory method for generating new episodes of the game."""
+ target_char = self._rng.choice(SYMBOLS_TO_SHUFFLE)
+ return storytelling.Story([
+ lambda: self._make_explore_phase(target_char),
+ self._make_distractor_phase,
+ lambda: self._make_reward_phase(target_char),
+ ], croppers=common.get_cropper())
diff --git a/tvt/pycolab/common.py b/tvt/pycolab/common.py
new file mode 100644
index 0000000..1cca6a2
--- /dev/null
+++ b/tvt/pycolab/common.py
@@ -0,0 +1,325 @@
+# Lint as: python2, python3
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Common utilities for Pycolab games."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import colorsys
+import numpy as np
+from pycolab import ascii_art
+from pycolab import cropping
+from pycolab import things as plab_things
+from pycolab.prefab_parts import sprites as prefab_sprites
+from six.moves import zip
+from tensorflow.contrib import framework as contrib_framework
+
+nest = contrib_framework.nest
+
+# Actions.
+# Those with a negative ID are not allowed for the agent.
+ACTION_QUIT = -2
+ACTION_DELAY = -1
+ACTION_NORTH = 0
+ACTION_SOUTH = 1
+ACTION_WEST = 2
+ACTION_EAST = 3
+
+NUM_ACTIONS = 4
+DEFAULT_MAX_FRAMES_PER_PHASE = {
+ 'explore': 15,
+ 'distractor': 90,
+ 'reward': 15
+}
+
+# Reserved symbols.
+PLAYER = '+'
+BORDER = '#'
+BACKGROUND = ' '
+KEY = 'k'
+DOOR = 'd'
+APPLE = 'a'
+TIMER = 't'
+INDICATOR = 'i'
+
+FIXED_COLOURS = {
+ PLAYER: (898, 584, 430),
+ BORDER: (100, 100, 100),
+ BACKGROUND: (800, 800, 800),
+ KEY: (627, 321, 176),
+ DOOR: (529, 808, 922),
+ APPLE: (550, 700, 0),
+}
+
+APPLE_DISTRACTOR_GRID = [
+ '###########',
+ '#a a a a a#',
+ '# a a a a #',
+ '#a a a a a#',
+ '# a a a a #',
+ '#a a + a a#',
+ '###########'
+]
+DEFAULT_APPLE_RESPAWN_TIME = 20
+DEFAULT_APPLE_REWARD = 1.
+
+
+def get_shuffled_symbol_colour_map(rng_or_seed, symbols,
+ num_potential_colours=None):
+ """Get a randomized mapping between symbols and colours.
+
+ Args:
+ rng_or_seed: A random state or random seed.
+ symbols: List of symbols.
+ num_potential_colours: Number of equally spaced colours to choose from.
+ Defaults to number of symbols. Colours are generated deterministically.
+
+ Returns:
+ Randomized mapping between symbols and colours.
+ """
+ num_symbols = len(symbols)
+ num_potential_colours = num_potential_colours or num_symbols
+ if isinstance(rng_or_seed, np.random.RandomState):
+ rng = rng_or_seed
+ else:
+ rng = np.random.RandomState(rng_or_seed)
+
+ # Generate a range of colours.
+ step = 1. / num_potential_colours
+ hues = np.arange(0, num_potential_colours) * step
+ potential_colours = [colorsys.hsv_to_rgb(h, 1.0, 1.0) for h in hues]
+
+ # Randomly draw num_symbols colours without replacement.
+ rng.shuffle(potential_colours)
+ colours = potential_colours[:num_symbols]
+
+ symbol_to_colour_map = dict(list(zip(symbols, colours)))
+
+ # Multiply each colour value by 1000.
+ return nest.map_structure(lambda c: int(c * 1000), symbol_to_colour_map)
+
+
+def get_cropper():
+ return cropping.ScrollingCropper(
+ rows=5,
+ cols=5,
+ to_track=PLAYER,
+ pad_char=BACKGROUND,
+ scroll_margins=(2, 2))
+
+
+def distractor_phase(player_sprite, num_apples, max_frames,
+ apple_reward=DEFAULT_APPLE_REWARD,
+ fix_apple_reward_in_episode=False,
+ respawn_every=DEFAULT_APPLE_RESPAWN_TIME):
+ """Distractor phase engine factory.
+
+ Args:
+ player_sprite: Player sprite class.
+ num_apples: Number of apples to sample from the apple distractor grid.
+ max_frames: Maximum duration of the distractor phase in frames.
+ apple_reward: Can either be a scalar specifying the reward or a reward range
+ [min, max), given as a list or tuple, to uniformly sample from.
+ fix_apple_reward_in_episode: The apple reward is constant throughout each
+ episode.
+ respawn_every: respawn frequency of apples.
+
+ Returns:
+ Distractor phase engine.
+ """
+ distractor_grid = keep_n_characters_in_grid(APPLE_DISTRACTOR_GRID, APPLE,
+ num_apples)
+
+ engine = ascii_art.ascii_art_to_game(
+ distractor_grid,
+ what_lies_beneath=BACKGROUND,
+ sprites={
+ PLAYER: player_sprite,
+ TIMER: ascii_art.Partial(TimerSprite, max_frames),
+ },
+ drapes={
+ APPLE: ascii_art.Partial(
+ AppleDrape,
+ reward=apple_reward,
+ fix_apple_reward_in_episode=fix_apple_reward_in_episode,
+ respawn_every=respawn_every)
+ },
+ update_schedule=[PLAYER, APPLE, TIMER],
+ z_order=[APPLE, PLAYER, TIMER],
+ )
+
+ return engine
+
+
+def replace_grid_symbols(grid, old_to_new_map):
+ """Replaces symbols in the grid.
+
+ If mapping is not defined the symbol is not updated.
+
+ Args:
+ grid: Represented as a list of strings.
+ old_to_new_map: Mapping between symbols.
+
+ Returns:
+ Updated grid.
+ """
+ def symbol_map(x):
+ if x in old_to_new_map:
+ return old_to_new_map[x]
+ return x
+ new_grid = []
+ for row in grid:
+ new_grid.append(''.join(symbol_map(i) for i in row))
+ return new_grid
+
+
+def keep_n_characters_in_grid(grid, character, n, backdrop_char=BACKGROUND):
+ """Keeps only a sample of characters `character` in the grid."""
+ np_grid = np.array([list(i) for i in grid])
+ char_positions = np.argwhere(np_grid == character)
+
+ # Randomly select parts to remove.
+ num_empty_positions = char_positions.shape[0] - n
+ if num_empty_positions < 0:
+ raise ValueError('Not enough characters `{}` in grid.'.format(character))
+ empty_pos = np.random.permutation(char_positions)[:num_empty_positions]
+
+ # Remove characters.
+ grid = [list(row) for row in grid]
+ for (i, j) in empty_pos:
+ grid[i][j] = backdrop_char
+
+ return [''.join(row) for row in grid]
+
+
+class PlayerSprite(prefab_sprites.MazeWalker):
+ """Sprite for the actor."""
+
+ def __init__(self, corner, position, character, impassable=BORDER):
+ super(PlayerSprite, self).__init__(
+ corner, position, character, impassable=impassable,
+ confined_to_board=True)
+
+ def update(self, actions, board, layers, backdrop, things, the_plot):
+
+ the_plot.add_reward(0.)
+
+ if actions == ACTION_QUIT:
+ the_plot.next_chapter = None
+ the_plot.terminate_episode()
+
+ if actions == ACTION_WEST:
+ self._west(board, the_plot)
+ elif actions == ACTION_EAST:
+ self._east(board, the_plot)
+ elif actions == ACTION_NORTH:
+ self._north(board, the_plot)
+ elif actions == ACTION_SOUTH:
+ self._south(board, the_plot)
+
+
+class AppleDrape(plab_things.Drape):
+ """Drape for the apples used in the distractor phase."""
+
+ def __init__(self,
+ curtain,
+ character,
+ respawn_every,
+ reward,
+ fix_apple_reward_in_episode):
+ """Constructor.
+
+ Args:
+ curtain: Array specifying locations of apples. Obtained from ascii grid.
+ character: Character representing the drape.
+ respawn_every: respawn frequency of apples.
+ reward: Can either be a scalar specifying the reward or a reward range
+ [min, max), given as a list or tuple, to uniformly sample from.
+ fix_apple_reward_in_episode: If set to True, then only sample the apple's
+ reward once in the episode and then fix the value.
+ """
+ super(AppleDrape, self).__init__(curtain, character)
+ self._respawn_every = respawn_every
+ if not isinstance(reward, (list, tuple)):
+ # Assuming scalar.
+ self._reward = [reward, reward]
+ else:
+ if len(reward) != 2:
+ raise ValueError('Reward must be a scalar or a two element list/tuple.')
+ self._reward = reward
+ self._fix_apple_reward_in_episode = fix_apple_reward_in_episode
+
+ # Grid specifying for each apple the last frame it was picked up.
+ # Initialized to inifinity for cells with apples and -1 for cells without.
+ self._last_pickup = np.where(curtain,
+ np.inf * np.ones_like(curtain),
+ -1. * np.ones_like(curtain))
+
+ def update(self, actions, board, layers, backdrop, things, the_plot):
+ player_position = things[PLAYER].position
+ # decide the apple_reward
+ if (self._fix_apple_reward_in_episode and
+ not the_plot.get('sampled_apple_reward', None)):
+ the_plot['sampled_apple_reward'] = np.random.choice((self._reward[0],
+ self._reward[1]))
+
+ if self.curtain[player_position]:
+ self._last_pickup[player_position] = the_plot.frame
+ self.curtain[player_position] = False
+ if not self._fix_apple_reward_in_episode:
+ the_plot.add_reward(np.random.uniform(*self._reward))
+ else:
+ the_plot.add_reward(the_plot['sampled_apple_reward'])
+
+ if self._respawn_every:
+ respawn_cond = the_plot.frame > self._last_pickup + self._respawn_every
+ respawn_cond &= self._last_pickup >= 0
+ self.curtain[respawn_cond] = True
+
+
+class TimerSprite(plab_things.Sprite):
+ """Sprite for the timer.
+
+ The timer is in charge of stopping the current chapter. Timer sprite should be
+ placed last in the update order to make sure everything is updated before the
+ chapter terminates.
+ """
+
+ def __init__(self, corner, position, character, max_frames,
+ track_chapter_reward=False):
+ super(TimerSprite, self).__init__(corner, position, character)
+ if not isinstance(max_frames, int):
+ raise ValueError('max_frames must be of type integer.')
+ self._max_frames = max_frames
+ self._visible = False
+ self._track_chapter_reward = track_chapter_reward
+ self._total_chapter_reward = 0.
+
+ def update(self, actions, board, layers, backdrop, things, the_plot):
+ directives = the_plot._get_engine_directives() # pylint: disable=protected-access
+
+ if self._track_chapter_reward:
+ self._total_chapter_reward += directives.summed_reward or 0.
+
+ # Every chapter starts at frame = 0.
+ if the_plot.frame >= self._max_frames or directives.game_over:
+ # Calculate the reward obtained in this phase and send it through the
+ # extra observations channel.
+ if self._track_chapter_reward:
+ the_plot['chapter_reward'] = self._total_chapter_reward
+ the_plot.terminate_episode()
diff --git a/tvt/pycolab/env.py b/tvt/pycolab/env.py
new file mode 100644
index 0000000..9309b3b
--- /dev/null
+++ b/tvt/pycolab/env.py
@@ -0,0 +1,105 @@
+# Lint as: python2, python3
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Pycolab env."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from pycolab import rendering
+
+from tvt.pycolab import active_visual_match
+from tvt.pycolab import key_to_door
+from tensorflow.contrib import framework as contrib_framework
+
+nest = contrib_framework.nest
+
+
+class PycolabEnvironment(object):
+ """A simple environment adapter for pycolab games."""
+
+ def __init__(self, game,
+ num_apples=10,
+ apple_reward=1.,
+ fix_apple_reward_in_episode=False,
+ final_reward=10.,
+ crop=True,
+ default_reward=0):
+ """Construct a `environment.Base` adapter that wraps a pycolab game."""
+ rng = np.random.RandomState()
+ if game == 'key_to_door':
+ self._game = key_to_door.Game(rng,
+ num_apples,
+ apple_reward,
+ fix_apple_reward_in_episode,
+ final_reward,
+ crop)
+ elif game == 'active_visual_match':
+ self._game = active_visual_match.Game(rng,
+ num_apples,
+ apple_reward,
+ fix_apple_reward_in_episode,
+ final_reward)
+ else:
+ raise ValueError('Unsupported game "%s".' % game)
+ self._default_reward = default_reward
+
+ self._num_actions = self._game.num_actions
+
+ # Agents expect HWC uint8 observations, Pycolab uses CHW float observations.
+ colours = nest.map_structure(lambda c: float(c) * 255 / 1000,
+ self._game.colours)
+ self._rgb_converter = rendering.ObservationToArray(
+ value_mapping=colours, permute=(1, 2, 0), dtype=np.uint8)
+
+ episode = self._game.make_episode()
+ observation, _, _ = episode.its_showtime()
+ self._image_shape = self._rgb_converter(observation).shape
+
+ def _process_outputs(self, observation, reward):
+ if reward is None:
+ reward = self._default_reward
+ image = self._rgb_converter(observation)
+ return image, reward
+
+ def reset(self):
+ """Start a new episode."""
+ self._episode = self._game.make_episode()
+ observation, reward, _ = self._episode.its_showtime()
+ return self._process_outputs(observation, reward)
+
+ def step(self, action):
+ """Take step in episode."""
+ observation, reward, _ = self._episode.play(action)
+ return self._process_outputs(observation, reward)
+
+ @property
+ def num_actions(self):
+ return self._num_actions
+
+ @property
+ def observation_shape(self):
+ return self._image_shape
+
+ @property
+ def episode_length(self):
+ return self._game.episode_length
+
+ def last_phase_reward(self):
+ # In Pycolab games here we only track chapter_reward for final chapter.
+ return float(self._episode.the_plot['chapter_reward'])
diff --git a/tvt/pycolab/game.py b/tvt/pycolab/game.py
new file mode 100644
index 0000000..f1617ae
--- /dev/null
+++ b/tvt/pycolab/game.py
@@ -0,0 +1,44 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Pycolab Game interface."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import six
+
+
+@six.add_metaclass(abc.ABCMeta)
+class AbstractGame(object):
+ """Abstract base class for Pycolab games."""
+
+ @abc.abstractmethod
+ def __init__(self, rng, **settings):
+ """Initialize the game."""
+
+ @abc.abstractproperty
+ def num_actions(self):
+ """Number of possible actions in the game."""
+
+ @abc.abstractproperty
+ def colours(self):
+ """Symbol to colour map for the game."""
+
+ @abc.abstractmethod
+ def make_episode(self):
+ """Factory method for generating new episodes of the game."""
diff --git a/tvt/pycolab/human_player.py b/tvt/pycolab/human_player.py
new file mode 100644
index 0000000..a8a147e
--- /dev/null
+++ b/tvt/pycolab/human_player.py
@@ -0,0 +1,67 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Pycolab human player."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import curses
+
+from absl import app
+from absl import flags
+import numpy as np
+from pycolab import human_ui
+
+from tvt.pycolab import active_visual_match
+from tvt.pycolab import common
+from tvt.pycolab import key_to_door
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_enum('game', 'key_to_door',
+ ['key_to_door', 'active_visual_match'],
+ 'The name of the game')
+
+
+def main(unused_argv):
+
+ rng = np.random.RandomState()
+
+ if FLAGS.game == 'key_to_door':
+ game = key_to_door.Game(rng)
+ elif FLAGS.game == 'active_visual_match':
+ game = active_visual_match.Game(rng)
+ else:
+ raise ValueError('Unsupported game "%s".' % FLAGS.game)
+ episode = game.make_episode()
+
+ ui = human_ui.CursesUi(
+ keys_to_actions={
+ curses.KEY_UP: common.ACTION_NORTH,
+ curses.KEY_DOWN: common.ACTION_SOUTH,
+ curses.KEY_LEFT: common.ACTION_WEST,
+ curses.KEY_RIGHT: common.ACTION_EAST,
+ -1: common.ACTION_DELAY,
+ 'q': common.ACTION_QUIT,
+ 'Q': common.ACTION_QUIT},
+ delay=-1,
+ colour_fg=game.colours
+ )
+ ui.play(episode)
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/tvt/pycolab/key_to_door.py b/tvt/pycolab/key_to_door.py
new file mode 100644
index 0000000..dc5cb33
--- /dev/null
+++ b/tvt/pycolab/key_to_door.py
@@ -0,0 +1,214 @@
+# Lint as: python2, python3
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Key to door task.
+
+The game is split up into three phases:
+1. (exploration phase) player can collect a key,
+2. (distractor phase) player is collecting apples,
+3. (reward phase) player can open the door and get the reward if the key is
+ previously collected.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from pycolab import ascii_art
+from pycolab import storytelling
+from pycolab import things as plab_things
+
+from tvt.pycolab import common
+from tvt.pycolab import game
+from tvt.pycolab import objects
+
+
+COLOURS = {
+ 'i': (1000, 1000, 1000), # Indicator.
+}
+
+EXPLORE_GRID = [
+ ' ####### ',
+ ' #kkkkk# ',
+ ' #kkkkk# ',
+ ' ## ## ',
+ ' #+++++# ',
+ ' #+++++# ',
+ ' ####### '
+]
+
+REWARD_GRID = [
+ ' ',
+ ' ##d## ',
+ ' # # ',
+ ' # + # ',
+ ' # # ',
+ ' ##### ',
+ ' ',
+]
+
+
+class KeySprite(plab_things.Sprite):
+ """Sprite for the key."""
+
+ def update(self, actions, board, layers, backdrop, things, the_plot):
+ player_position = things[common.PLAYER].position
+ pick_up = self.position == player_position
+
+ if self.visible and pick_up:
+ # Pass information to all phases.
+ the_plot['has_key'] = True
+ self._visible = False
+
+
+class DoorSprite(plab_things.Sprite):
+ """Sprite for the door."""
+
+ def __init__(self, corner, position, character, pickup_reward):
+ super(DoorSprite, self).__init__(corner, position, character)
+ self._pickup_reward = pickup_reward
+
+ def update(self, actions, board, layers, backdrop, things, the_plot):
+ player_position = things[common.PLAYER].position
+ pick_up = self.position == player_position
+
+ if pick_up and the_plot.get('has_key'):
+ the_plot.add_reward(self._pickup_reward)
+ # The key is lost after the first time opening the door
+ # to ensure only one reward per episode.
+ the_plot['has_key'] = False
+
+
+class PlayerSprite(common.PlayerSprite):
+ """Sprite for the actor."""
+
+ def __init__(self, corner, position, character):
+ super(PlayerSprite, self).__init__(
+ corner, position, character,
+ impassable=common.BORDER + common.INDICATOR + common.DOOR)
+
+ def update(self, actions, board, layers, backdrop, things, the_plot):
+
+ # Allow moving through the door if key is previously collected.
+ if common.DOOR in self.impassable and the_plot.get('has_key'):
+ self._impassable.remove(common.DOOR)
+
+ super(PlayerSprite, self).update(actions, board, layers, backdrop, things,
+ the_plot)
+
+
+class Game(game.AbstractGame):
+ """Key To Door Game."""
+
+ def __init__(self,
+ rng,
+ num_apples=10,
+ apple_reward=(1, 10),
+ fix_apple_reward_in_episode=True,
+ final_reward=10.,
+ crop=True,
+ max_frames=common.DEFAULT_MAX_FRAMES_PER_PHASE):
+ del rng # Each episode is identical and colours are not randomised.
+ self._num_apples = num_apples
+ self._apple_reward = apple_reward
+ self._fix_apple_reward_in_episode = fix_apple_reward_in_episode
+ self._final_reward = final_reward
+ self._crop = crop
+ self._max_frames = max_frames
+ self._episode_length = sum(self._max_frames.values())
+ self._num_actions = common.NUM_ACTIONS
+ self._colours = common.FIXED_COLOURS.copy()
+ self._colours.update(COLOURS)
+ self._extra_observation_fields = ['chapter_reward_as_string']
+
+ @property
+ def extra_observation_fields(self):
+ """The field names of extra observations."""
+ return self._extra_observation_fields
+
+ @property
+ def num_actions(self):
+ """Number of possible actions in the game."""
+ return self._num_actions
+
+ @property
+ def episode_length(self):
+ return self._episode_length
+
+ @property
+ def colours(self):
+ """Symbol to colour map for key to door."""
+ return self._colours
+
+ def _make_explore_phase(self):
+ # Keep only one key and one player position.
+ explore_grid = common.keep_n_characters_in_grid(
+ EXPLORE_GRID, common.KEY, 1)
+ explore_grid = common.keep_n_characters_in_grid(
+ explore_grid, common.PLAYER, 1)
+ return ascii_art.ascii_art_to_game(
+ art=explore_grid,
+ what_lies_beneath=' ',
+ sprites={
+ common.PLAYER: PlayerSprite,
+ common.KEY: KeySprite,
+ common.INDICATOR: ascii_art.Partial(objects.IndicatorObjectSprite,
+ char_to_track=common.KEY,
+ override_position=(0, 5)),
+ common.TIMER: ascii_art.Partial(common.TimerSprite,
+ self._max_frames['explore']),
+ },
+ update_schedule=[
+ common.PLAYER, common.KEY, common.INDICATOR, common.TIMER],
+ z_order=[common.KEY, common.INDICATOR, common.PLAYER, common.TIMER],
+ )
+
+ def _make_distractor_phase(self):
+ return common.distractor_phase(
+ player_sprite=PlayerSprite,
+ num_apples=self._num_apples,
+ max_frames=self._max_frames['distractor'],
+ apple_reward=self._apple_reward,
+ fix_apple_reward_in_episode=self._fix_apple_reward_in_episode)
+
+ def _make_reward_phase(self):
+ return ascii_art.ascii_art_to_game(
+ art=REWARD_GRID,
+ what_lies_beneath=' ',
+ sprites={
+ common.PLAYER: PlayerSprite,
+ common.DOOR: ascii_art.Partial(DoorSprite,
+ pickup_reward=self._final_reward),
+ common.TIMER: ascii_art.Partial(common.TimerSprite,
+ self._max_frames['reward'],
+ track_chapter_reward=True),
+ },
+ update_schedule=[common.PLAYER, common.DOOR, common.TIMER],
+ z_order=[common.PLAYER, common.DOOR, common.TIMER],
+ )
+
+ def make_episode(self):
+ """Factory method for generating new episodes of the game."""
+ if self._crop:
+ croppers = common.get_cropper()
+ else:
+ croppers = None
+
+ return storytelling.Story([
+ self._make_explore_phase,
+ self._make_distractor_phase,
+ self._make_reward_phase,
+ ], croppers=croppers)
diff --git a/tvt/pycolab/objects.py b/tvt/pycolab/objects.py
new file mode 100644
index 0000000..0c261da
--- /dev/null
+++ b/tvt/pycolab/objects.py
@@ -0,0 +1,123 @@
+# Lint as: python2, python3
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Pycolab sprites."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from pycolab import things as plab_things
+from pycolab.prefab_parts import sprites as prefab_sprites
+import six
+from tvt.pycolab import common
+
+
+class PlayerSprite(prefab_sprites.MazeWalker):
+ """Sprite representing the agent."""
+
+ def __init__(self, corner, position, character,
+ max_steps_per_act, moving_player):
+
+ """Indicates to the superclass that we can't walk off the board."""
+ super(PlayerSprite, self).__init__(
+ corner, position, character, impassable=[common.BORDER],
+ confined_to_board=True)
+
+ self._moving_player = moving_player
+ self._max_steps_per_act = max_steps_per_act
+ self._num_steps = 0
+
+ def update(self, actions, board, layers, backdrop, things, the_plot):
+ del backdrop # Unused.
+
+ if actions is not None:
+ assert actions in common.ACTIONS
+
+ the_plot.log("Step {} | Action {}".format(self._num_steps, actions))
+ the_plot.add_reward(0.0)
+ self._num_steps += 1
+
+ if actions == common.ACTION_QUIT:
+ the_plot.terminate_episode()
+
+ if self._moving_player:
+ if actions == common.ACTION_WEST:
+ self._west(board, the_plot)
+ elif actions == common.ACTION_EAST:
+ self._east(board, the_plot)
+ elif actions == common.ACTION_NORTH:
+ self._north(board, the_plot)
+ elif actions == common.ACTION_SOUTH:
+ self._south(board, the_plot)
+
+ if self._max_steps_per_act == self._num_steps:
+ the_plot.terminate_episode()
+
+
+class ObjectSprite(plab_things.Sprite):
+ """Sprite for a generic object which can be collectable."""
+
+ def __init__(self, corner, position, character, reward=0., collectable=True,
+ terminate=True):
+ super(ObjectSprite, self).__init__(corner, position, character)
+ self._reward = reward # Reward on pickup.
+ self._collectable = collectable
+
+ def set_visibility(self, visible):
+ self._visible = visible
+
+ def update(self, actions, board, layers, backdrop, things, the_plot):
+ player_position = things[common.PLAYER].position
+ pick_up = self.position == player_position
+
+ if pick_up and self.visible:
+ the_plot.add_reward(self._reward)
+ if self._collectable:
+ self.set_visibility(False)
+ # set all other objects to be invisible
+ for v in six.itervalues(things):
+ if isinstance(v, ObjectSprite):
+ v.set_visibility(False)
+
+
+class IndicatorObjectSprite(plab_things.Sprite):
+ """Sprite for the indicator object.
+
+ The indicator object is an object that spawns at a designated position once
+ the player picks up an object defined by the `char_to_track` argument.
+ The indicator object is spawned for just a single frame.
+ """
+
+ def __init__(self, corner, position, character, char_to_track,
+ override_position=None):
+ super(IndicatorObjectSprite, self).__init__(corner, position, character)
+ if override_position is not None:
+ self._position = override_position
+ self._char_to_track = char_to_track
+ self._visible = False
+ self._pickup_frame = None
+
+ def update(self, actions, board, layers, backdrop, things, the_plot):
+ player_position = things[common.PLAYER].position
+ pick_up = things[self._char_to_track].position == player_position
+
+ if self._pickup_frame:
+ self._visible = False
+
+ if pick_up and not self._pickup_frame:
+ self._visible = True
+ self._pickup_frame = the_plot.frame
diff --git a/tvt/requirements.txt b/tvt/requirements.txt
new file mode 100644
index 0000000..62481c7
--- /dev/null
+++ b/tvt/requirements.txt
@@ -0,0 +1,8 @@
+absl-py
+dm-sonnet==1.34
+numpy
+pycolab
+six
+trfl
+tensorflow==1.13.2
+tensorflow-probability==0.6.0
diff --git a/tvt/rma.py b/tvt/rma.py
new file mode 100644
index 0000000..dccbd32
--- /dev/null
+++ b/tvt/rma.py
@@ -0,0 +1,584 @@
+# Lint as: python2, python3
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""RMA agent."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import numpy as np
+from six.moves import range
+from six.moves import zip
+import sonnet as snt
+import tensorflow as tf
+import trfl
+
+from tvt import losses
+from tvt import memory as memory_module
+from tensorflow.contrib import framework as contrib_framework
+
+nest = contrib_framework.nest
+
+PolicyOutputs = collections.namedtuple(
+ 'PolicyOutputs', ['policy', 'action', 'baseline'])
+
+StepOutput = collections.namedtuple(
+ 'StepOutput', ['action', 'baseline', 'read_info'])
+
+AgentState = collections.namedtuple(
+ 'AgentState', ['core_state', 'prev_action'])
+
+Observation = collections.namedtuple(
+ 'Observation', ['image', 'last_action', 'last_reward'])
+
+RNNStateNoMem = collections.namedtuple(
+ 'RNNStateNoMem', ['controller_outputs', 'h_controller'])
+
+RNNState = collections.namedtuple(
+ 'RNNState',
+ list(RNNStateNoMem._fields) + ['memory', 'mem_reads', 'h_mem_writer'])
+
+CoreOutputs = collections.namedtuple(
+ 'CoreOutputs', ['action', 'policy', 'baseline', 'z', 'read_info'])
+
+
+def rnn_inputs_to_static_rnn_inputs(inputs):
+ """Converts time major tensors to timestep lists."""
+ # Inputs to core build method are expected to be a tensor or tuple of tensors.
+ if isinstance(inputs, tuple):
+ num_timesteps = inputs[0].shape.as_list()[0]
+ converted_inputs = [tf.unstack(input_, num_timesteps) for input_ in inputs]
+ return list(zip(*converted_inputs))
+ else:
+ return tf.unstack(inputs)
+
+
+def static_rnn_outputs_to_core_outputs(outputs):
+ """Convert from length T list of nests to nest of tensors with first dim T."""
+ list_of_flats = [nest.flatten(n) for n in outputs]
+ new_outputs = list()
+ for i in range(len(list_of_flats[0])):
+ new_outputs.append(tf.stack([flat_nest[i] for flat_nest in list_of_flats]))
+ return nest.pack_sequence_as(structure=outputs[0], flat_sequence=new_outputs)
+
+
+def unroll(core, initial_state, inputs, dtype=tf.float32):
+ """Perform a static unroll of the core."""
+ static_rnn_inputs = rnn_inputs_to_static_rnn_inputs(inputs)
+ static_outputs, _ = tf.nn.static_rnn(
+ core,
+ inputs=static_rnn_inputs,
+ initial_state=initial_state,
+ dtype=dtype)
+ core_outputs = static_rnn_outputs_to_core_outputs(static_outputs)
+ return core_outputs
+
+
+class ImageEncoderDecoder(snt.AbstractModule):
+ """Image Encoder/Decoder module."""
+
+ def __init__(
+ self,
+ image_code_size,
+ name='image_encoder_decoder'):
+ """Initialize the image encoder/decoder."""
+ super(ImageEncoderDecoder, self).__init__(name=name)
+
+ # This is set by a call to `encode`. `decode` will fail before this is set.
+ self._convnet_output_shape = None
+
+ with self._enter_variable_scope():
+ self._convnet = snt.nets.ConvNet2D(
+ output_channels=(16, 32),
+ kernel_shapes=(3, 3),
+ strides=(1, 1),
+ paddings=('SAME',))
+ self._post_convnet_layer = snt.Linear(image_code_size, name='final_layer')
+
+ @snt.reuse_variables
+ def encode(self, image):
+ """Encode the image observation."""
+ convnet_output = self._convnet(image)
+
+ # Store unflattened convnet output shape for use in decoder.
+ self._convnet_output_shape = convnet_output.shape[1:]
+
+ # Flatten convnet outputs and pass through final layer to get image code.
+ return self._post_convnet_layer(snt.BatchFlatten()(convnet_output))
+
+ @snt.reuse_variables
+ def decode(self, code):
+ """Decode the image observation from a latent code."""
+ if self._convnet_output_shape is None:
+ raise ValueError('Must call `encode` before `decode`.')
+ transpose_convnet_in_flat = snt.Linear(
+ self._convnet_output_shape.num_elements(),
+ name='decode_initial_linear')(
+ code)
+ transpose_convnet_in_flat = tf.nn.relu(transpose_convnet_in_flat)
+ transpose_convnet_in = snt.BatchReshape(
+ self._convnet_output_shape.as_list())(transpose_convnet_in_flat)
+ return self._convnet.transpose(None)(transpose_convnet_in)
+
+ def _build(self, *args): # Unused. Use encode/decode instead.
+ raise NotImplementedError('Use encode/decode methods instead of __call__.')
+
+
+class Policy(snt.AbstractModule):
+ """A policy module possibly containing a read-only DNC."""
+
+ def __init__(self,
+ num_actions,
+ num_policy_hiddens=(),
+ num_baseline_hiddens=(),
+ activation=tf.nn.tanh,
+ policy_clip_abs_value=10.0,
+ name='Policy'):
+ """Construct a policy module possibly containing a read-only DNC.
+
+ Args:
+ num_actions: Number of discrete actions to choose from.
+ num_policy_hiddens: Tuple or List, sizes of policy MLP hidden layers.
+ num_baseline_hiddens: Tuple or List, sizes of baseline MLP hidden layers.
+ An empty tuple/list results in a linear layer instead of an MLP.
+ activation: Callable, e.g. tf.nn.tanh.
+ policy_clip_abs_value: float, Policy gradient clip value.
+ name: A string, the module's name
+ """
+ super(Policy, self).__init__(name=name)
+
+ self._num_actions = num_actions
+ self._policy_layers = tuple(num_policy_hiddens) + (num_actions,)
+ self._baseline_layers = tuple(num_baseline_hiddens) + (1,)
+ self._policy_clip_abs_value = policy_clip_abs_value
+ self._activation = activation
+
+ def _build(self, inputs):
+ (shared_inputs, extra_policy_inputs) = inputs
+ policy_in = tf.concat([shared_inputs, extra_policy_inputs], axis=1)
+
+ policy = snt.nets.MLP(
+ output_sizes=self._policy_layers,
+ activation=self._activation,
+ name='policy_mlp')(
+ policy_in)
+
+ # Sample an action from the policy logits.
+ action = tf.multinomial(policy, num_samples=1, output_dtype=tf.int32)
+ action = tf.squeeze(action, 1) # [B, 1] -> [B]
+
+ if self._policy_clip_abs_value > 0:
+ policy = snt.clip_gradient(
+ net=policy,
+ clip_value_min=-self._policy_clip_abs_value,
+ clip_value_max=self._policy_clip_abs_value)
+
+ baseline_in = tf.concat([shared_inputs, tf.stop_gradient(policy)], axis=1)
+ baseline = snt.nets.MLP(
+ self._baseline_layers,
+ activation=self._activation,
+ name='baseline_mlp')(
+ baseline_in)
+ baseline = tf.squeeze(baseline, axis=-1) # [B, 1] -> [B]
+
+ if self._policy_clip_abs_value > 0:
+ baseline = snt.clip_gradient(
+ net=baseline,
+ clip_value_min=-self._policy_clip_abs_value,
+ clip_value_max=self._policy_clip_abs_value)
+
+ outputs = PolicyOutputs(
+ policy=policy,
+ action=action,
+ baseline=baseline)
+
+ return outputs
+
+
+class _RMACore(snt.RNNCore):
+ """RMA RNN Core."""
+
+ def __init__(self,
+ num_actions,
+ with_memory=True,
+ name='rma_core'):
+ super(_RMACore, self).__init__(name=name)
+
+ # MLP activation as callable.
+ mlp_activation = tf.nn.tanh
+
+ # Size of latent code written to memory (if using it) and used to
+ # reconstruct from (if including reconstructions).
+ num_latents = 200
+
+ # Value function decode settings.
+ baseline_mlp_num_hiddens = (200,)
+
+ # Policy settings.
+ num_policy_hiddens = (200,) # Only used for non-recurrent core.
+
+ # Controller settings.
+ control_hidden_size = 256
+ control_num_layers = 2
+
+ # Memory settings (only used if with_memory=True).
+ memory_size = 1000
+ memory_num_reads = 3
+ memory_top_k = 50
+
+ self._with_memory = with_memory
+
+ with self._enter_variable_scope():
+ # Construct the features -> latent encoder.
+ self._z_encoder_mlp = snt.nets.MLP(
+ output_sizes=(2 * num_latents, num_latents),
+ activation=mlp_activation,
+ activate_final=False,
+ name='z_encoder_mlp')
+
+ # Construct controller.
+ rnn_cores = [snt.LSTM(control_hidden_size)
+ for _ in range(control_num_layers)]
+ self._controller = snt.DeepRNN(
+ rnn_cores, skip_connections=True, name='controller')
+
+ # Construct memory.
+ if self._with_memory:
+ memory_dim = num_latents # Each write to memory is of size memory_dim.
+ self._mem_shape = (memory_size, memory_dim)
+ self._memory_reader = memory_module.MemoryReader(
+ memory_word_size=memory_dim,
+ num_read_heads=memory_num_reads,
+ top_k=memory_top_k,
+ memory_size=memory_size)
+ self._memory_writer = memory_module.MemoryWriter(
+ mem_shape=self._mem_shape)
+
+ # Construct policy, starting with policy_core and policy_action_head.
+ # `extra_inputs` in this case will be mem_out from current time step (note
+ # that mem_out is just the controller output if with_memory=False).
+ self._policy = Policy(
+ num_policy_hiddens=num_policy_hiddens,
+ num_actions=num_actions,
+ num_baseline_hiddens=baseline_mlp_num_hiddens,
+ activation=mlp_activation,
+ policy_clip_abs_value=10.0,)
+
+ # Set state_size and output_size.
+ controller_out_size = self._controller.output_size
+ controller_state_size = self._controller.state_size
+ self._state_size = RNNStateNoMem(controller_outputs=controller_out_size,
+ h_controller=controller_state_size)
+ read_info_size = ()
+ if self._with_memory:
+ mem_reads_size, read_info_size = self._memory_reader.output_size
+ mem_writer_state_size = self._memory_writer.state_size
+ self._state_size = RNNState(memory=tf.TensorShape(self._mem_shape),
+ mem_reads=mem_reads_size,
+ h_mem_writer=mem_writer_state_size,
+ **self._state_size._asdict())
+
+ z_size = num_latents
+ self._output_size = CoreOutputs(
+ action=tf.TensorShape([]), # Scalar tensor shapes must be explicit.
+ policy=num_actions,
+ baseline=tf.TensorShape([]), # Scalar tensor shapes must be explicit.
+ z=z_size,
+ read_info=read_info_size)
+
+ def _build(self, inputs, h_prev):
+ features = inputs
+
+ z_net_inputs = [features, h_prev.controller_outputs]
+ if self._with_memory:
+ z_net_inputs.append(h_prev.mem_reads)
+ z_net_inputs_concat = tf.concat(z_net_inputs, axis=1)
+ z = self._z_encoder_mlp(z_net_inputs_concat)
+
+ controller_out, h_controller = self._controller(z, h_prev.h_controller)
+
+ read_info = ()
+ if self._with_memory:
+ # Perform a memory read/write step before generating the policy_modules.
+ mem_reads, read_info = self._memory_reader((controller_out,
+ h_prev.memory))
+ memory, h_mem_writer = self._memory_writer((z, h_prev.memory),
+ h_prev.h_mem_writer)
+ policy_extra_input = tf.concat([controller_out, mem_reads], axis=1)
+ else:
+ policy_extra_input = controller_out
+
+ # Get policy, action and (possible empty) baseline from policy module.
+ policy_inputs = (z, policy_extra_input)
+ policy_outputs = self._policy(policy_inputs)
+ core_outputs = CoreOutputs(
+ z=z,
+ read_info=read_info,
+ **policy_outputs._asdict())
+
+ h_next = RNNStateNoMem(controller_outputs=controller_out,
+ h_controller=h_controller)
+ if self._with_memory:
+ h_next = RNNState(memory=memory,
+ mem_reads=mem_reads,
+ h_mem_writer=h_mem_writer,
+ **h_next._asdict())
+
+ return core_outputs, h_next
+
+ def initial_state(self, batch_size):
+ """Use initial state for RNN modules, otherwise use zero state."""
+ zero_state = self.zero_state(batch_size, dtype=tf.float32)
+ controller_out = zero_state.controller_outputs
+ h_controller = self._controller.initial_state(batch_size)
+
+ state = RNNStateNoMem(controller_outputs=controller_out,
+ h_controller=h_controller)
+ if self._with_memory:
+ memory = zero_state.memory
+ mem_reads = zero_state.mem_reads
+ h_mem_writer = self._memory_writer.initial_state(batch_size)
+ state = RNNState(memory=memory,
+ mem_reads=mem_reads,
+ h_mem_writer=h_mem_writer,
+ **state._asdict())
+ return state
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+
+class Agent(snt.AbstractModule):
+ """Myriad RMA agent.
+
+ `latents` here refers to a purely deterministic encoding of the inputs, rather
+ than VAE-like latents in e.g. the MERLIN agent.
+ """
+
+ def __init__(self,
+ batch_size,
+ with_reconstructions=True,
+ with_memory=True,
+ image_code_size=500,
+ image_cost_weight=50.,
+ num_actions=None,
+ observation_shape=None,
+ entropy_cost=0.01,
+ return_cost_weight=0.4,
+ gamma=0.96,
+ read_strength_cost=5e-5,
+ read_strength_tolerance=2.,
+ name='rma_agent'):
+ super(Agent, self).__init__(name=name)
+
+ self._batch_size = batch_size
+ self._with_reconstructions = with_reconstructions
+ self._image_cost_weight = image_cost_weight
+ self._image_code_size = image_code_size
+ self._entropy_cost = entropy_cost
+ self._return_cost_weight = return_cost_weight
+ self._gamma = gamma
+ self._read_strength_cost = read_strength_cost
+ self._read_strength_tolerance = read_strength_tolerance
+ self._num_actions = num_actions
+ self._name = name
+ self._logged_values = {}
+
+ # Store total number of pixels across channels (for image loss scaling).
+ self._total_num_pixels = np.prod(observation_shape)
+
+ with self._enter_variable_scope():
+
+ # Construct image encoder/decoder.
+ self._image_encoder_decoder = ImageEncoderDecoder(
+ image_code_size=image_code_size)
+
+ self._core = _RMACore(
+ num_actions=self._num_actions,
+ with_memory=with_memory)
+
+ def initial_state(self, batch_size):
+ with tf.name_scope(self._name + '/initial_state'):
+ return AgentState(
+ core_state=self._core.initial_state(batch_size),
+ prev_action=tf.zeros(shape=(batch_size,), dtype=tf.int32))
+
+ def _prepare_observations(self, observation, last_reward, last_action):
+ image = observation
+
+ # Make sure the entries are in [0, 1) range.
+ if image.dtype.is_integer:
+ image = tf.cast(image, tf.float32) / 255.
+
+ if last_reward is None:
+ # For some envs, in the first timestep the last_reward can be None.
+ batch_size = observation.shape[0]
+ last_reward = tf.zeros((batch_size,), dtype=tf.float32)
+
+ return Observation(
+ image=image,
+ last_action=last_action,
+ last_reward=last_reward)
+
+ @snt.reuse_variables
+ def _encode(self, observation, last_reward, last_action):
+ inputs = self._prepare_observations(observation, last_reward, last_action)
+
+ # Encode image observation.
+ obs_code = self._image_encoder_decoder.encode(inputs.image)
+
+ # Encode last action.
+ action_code = tf.one_hot(inputs.last_action, self._num_actions)
+
+ # Encode last reward.
+ reward_code = tf.expand_dims(inputs.last_reward, -1)
+
+ features = tf.concat([obs_code, action_code, reward_code], axis=1)
+
+ return inputs, features
+
+ @snt.reuse_variables
+ def _decode(self, z):
+ # Decode image.
+ image_recon = self._image_encoder_decoder.decode(z)
+
+ # Decode action.
+ action_recon = snt.Linear(self._num_actions, name='action_recon_linear')(z)
+
+ # Decode reward.
+ reward_recon = snt.Linear(1, name='reward_recon_linear')(z)
+
+ # Full reconstructions.
+ recons = Observation(
+ image=image_recon,
+ last_reward=reward_recon,
+ last_action=action_recon)
+
+ return recons
+
+ def step(self, reward, observation, prev_state):
+ with tf.name_scope(self._name + '/step'):
+ _, features = self._encode(observation, reward, prev_state.prev_action)
+
+ core_outputs, next_core_state = self._core(
+ features, prev_state.core_state)
+
+ action = core_outputs.action
+
+ step_output = StepOutput(
+ action=action,
+ baseline=core_outputs.baseline,
+ read_info=core_outputs.read_info)
+ agent_state = AgentState(
+ core_state=next_core_state,
+ prev_action=action)
+ return step_output, agent_state
+
+ @snt.reuse_variables
+ def loss(self, observations, rewards, actions, additional_rewards=None):
+ """Compute the loss."""
+ dummy_zeroth_step_actions = tf.zeros_like(actions[:1])
+ all_actions = tf.concat([dummy_zeroth_step_actions, actions], axis=0)
+ inputs, features = snt.BatchApply(self._encode)(
+ observations, rewards, all_actions)
+
+ rewards = rewards[1:] # Zeroth step reward not correlated to actions.
+ if additional_rewards is not None:
+ # Additional rewards are not passed to the encoder (above) in order to be
+ # consistent with the step, nor to the recon loss so that recons are
+ # consistent with the observations. Thus, additional rewards only affect
+ # the returns used to learn the value function.
+ rewards += additional_rewards
+
+ initial_state = self._core.initial_state(self._batch_size)
+
+ rnn_inputs = features
+ core_outputs = unroll(self._core, initial_state, rnn_inputs)
+
+ # Remove final timestep of outputs.
+ core_outputs = nest.map_structure(lambda t: t[:-1], core_outputs)
+
+ if self._with_reconstructions:
+ recons = snt.BatchApply(self._decode)(core_outputs.z)
+ recon_targets = nest.map_structure(lambda t: t[:-1], inputs)
+ recon_loss, recon_logged_values = losses.reconstruction_losses(
+ recons=recons,
+ targets=recon_targets,
+ image_cost=self._image_cost_weight / self._total_num_pixels,
+ action_cost=1.,
+ reward_cost=1.)
+ else:
+ recon_loss = tf.constant(0.0)
+ recon_logged_values = dict()
+
+ if core_outputs.read_info is not tuple():
+ read_reg_loss, read_reg_logged_values = (
+ losses.read_regularization_loss(
+ read_info=core_outputs.read_info,
+ strength_cost=self._read_strength_cost,
+ strength_tolerance=self._read_strength_tolerance,
+ strength_reg_mode='L1',
+ key_norm_cost=0.,
+ key_norm_tolerance=1.))
+ else:
+ read_reg_loss = tf.constant(0.0)
+ read_reg_logged_values = dict()
+
+ # Bootstrap value is at end of episode so is zero.
+ bootstrap_value = tf.zeros(shape=(self._batch_size,), dtype=tf.float32)
+
+ discounts = self._gamma * tf.ones_like(rewards)
+
+ a2c_loss, a2c_loss_extra = trfl.sequence_advantage_actor_critic_loss(
+ policy_logits=core_outputs.policy,
+ baseline_values=core_outputs.baseline,
+ actions=actions,
+ rewards=rewards,
+ pcontinues=discounts,
+ bootstrap_value=bootstrap_value,
+ lambda_=self._gamma,
+ entropy_cost=self._entropy_cost,
+ baseline_cost=self._return_cost_weight,
+ name='SequenceA2CLoss')
+
+ a2c_loss = tf.reduce_mean(a2c_loss) # Average over batch.
+
+ total_loss = a2c_loss + recon_loss + read_reg_loss
+
+ a2c_loss_logged_values = dict(
+ pg_loss=tf.reduce_mean(a2c_loss_extra.policy_gradient_loss),
+ baseline_loss=tf.reduce_mean(a2c_loss_extra.baseline_loss),
+ entropy_loss=tf.reduce_mean(a2c_loss_extra.entropy_loss))
+ agent_loss_log = losses.combine_logged_values(
+ a2c_loss_logged_values,
+ recon_logged_values,
+ read_reg_logged_values)
+ agent_loss_log['total_loss'] = total_loss
+
+ return total_loss, agent_loss_log
+
+ def _build(self, *args): # Unused.
+ # pylint: disable=no-value-for-parameter
+ return self.step(*args)
+ # pylint: enable=no-value-for-parameter
diff --git a/tvt/run.sh b/tvt/run.sh
new file mode 100755
index 0000000..bcee93e
--- /dev/null
+++ b/tvt/run.sh
@@ -0,0 +1,20 @@
+#!/bin/sh
+# Copyright 2019 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+python3 -m venv tvt_venv
+source tvt_venv/bin/activate
+pip install -r tvt/requirements.txt
+
+python3 -m tvt.main
diff --git a/tvt/tvt_rewards.py b/tvt/tvt_rewards.py
new file mode 100644
index 0000000..93b1db7
--- /dev/null
+++ b/tvt/tvt_rewards.py
@@ -0,0 +1,247 @@
+# Lint as: python2, python3
+# pylint: disable=g-bad-file-header
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Temporal Value Transport implementation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from concurrent import futures
+import numpy as np
+from six.moves import range
+from six.moves import zip
+
+
+def _unstack(array, axis):
+ """Opposite of np.stack."""
+ split_array = np.split(array, array.shape[axis], axis=axis)
+ return [np.squeeze(a, axis=axis) for a in split_array]
+
+
+def _top_k_args(array, k):
+ """Return top k arguments or all arguments if array size is less than k."""
+ if len(array) <= k:
+ return np.arange(len(array))
+ return np.argpartition(array, kth=-k)[-k:]
+
+
+def _threshold_read_event_times(read_strengths, threshold):
+ """Return the times of max read strengths within one threshold read event."""
+ chosen_times = []
+ over_threshold = False
+ max_read_strength = 0.
+ # Wait until the threshold is crossed then keep track of max read strength and
+ # time of max read strength until the read strengths go back under the
+ # threshold, then add that max read strength time to the chosen times. Wait
+ # until threshold is crossed again and then repeat the process.
+ for time, strength in enumerate(read_strengths):
+ if strength > threshold:
+ over_threshold = True
+ if strength > max_read_strength:
+ max_read_strength = strength
+ max_read_strength_time = time
+ else:
+ # If coming back under threshold, add the time of the last max read.
+ if over_threshold:
+ chosen_times.append(max_read_strength_time)
+ max_read_strength = 0.
+ over_threshold = False
+ # Add max read strength time if episode finishes before going under threshold.
+ if over_threshold:
+ chosen_times.append(max_read_strength_time)
+ return np.array(chosen_times)
+
+
+def _tvt_rewards_single_head(read_weights, read_strengths, read_times,
+ baselines, alpha, top_k_t1,
+ read_strength_threshold, no_transport_period):
+ """Compute TVT rewards for a single read head, no batch dimension.
+
+ This performs the updates for one read head.
+ `t1` and `t2` refer to times to where and from where the value is being
+ transported, respectively. I.e. the rewards at `t1` times are being modified
+ based on values at times `t2`.
+
+ Args:
+ read_weights: shape (ep_length, top_k).
+ read_strengths: shape (ep_length,).
+ read_times: shape (ep_length, top_k).
+ baselines: shape (ep_length,).
+ alpha: The multiplier for the temporal value transport rewards.
+ top_k_t1: For each read event time, this determines how many time points
+ to send tvt reward to.
+ read_strength_threshold: Read strengths below this value are ignored.
+ no_transport_period: Length of no_transport_period.
+
+ Returns:
+ An array of TVT rewards with shape (ep_length,).
+ """
+ tvt_rewards = np.zeros_like(baselines)
+
+ # Mask read_weights for reads that read back to times within
+ # no_transport_period of current time.
+ ep_length = read_times.shape[0]
+ times = np.arange(ep_length)
+ # Expand dims for correct broadcasting when subtracting read_times.
+ times = np.expand_dims(times, -1)
+ read_past_no_transport_period = (times - read_times) > no_transport_period
+ read_weights_masked = np.where(read_past_no_transport_period,
+ read_weights,
+ np.zeros_like(read_weights))
+
+ # Find t2 times with maximum read weights. Ignore t2 times whose maximum
+ # read weights fall inside the no_transport_period.
+ max_read_weight_args = np.argmax(read_weights, axis=1) # (ep_length,)
+ times = np.arange(ep_length)
+ max_read_weight_times = read_times[times,
+ max_read_weight_args] # (ep_length,)
+ read_strengths_cut = np.where(
+ times - max_read_weight_times > no_transport_period,
+ read_strengths,
+ np.zeros_like(read_strengths))
+
+ # Filter t2 candidates to perform value transport on local maximums
+ # above a threshold.
+ t2_times_with_largest_reads = _threshold_read_event_times(
+ read_strengths_cut, read_strength_threshold)
+
+ # Loop through all t2 candidates and transport value to top_k_t1 read times.
+ for t2 in t2_times_with_largest_reads:
+ try:
+ baseline_value_when_reading = baselines[t2]
+ except IndexError:
+ raise RuntimeError("Attempting to access baselines array with length {}"
+ " at index {}. Make sure output_baseline is set in"
+ " the agent config.".format(len(baselines), t2))
+ read_times_from_t2 = read_times[t2]
+ read_weights_from_t2 = read_weights_masked[t2]
+
+ # Find the top_k_t1 read times for this t2 and their corresponding read
+ # weights. The call to _top_k_args() here gives the array indices for the
+ # times and weights of the top_k_t1 reads from t2.
+ top_t1_indices = _top_k_args(read_weights_from_t2, top_k_t1)
+ top_t1_read_times = np.take(read_times_from_t2, top_t1_indices)
+ top_t1_read_weights = np.take(read_weights_from_t2, top_t1_indices)
+
+ # For each of the top_k_t1 read times t and corresponding read weight w,
+ # find the trajectory that contains step_num (t + shift) and modify the
+ # reward at step_num (t + shift) using w and the baseline value at t2.
+ # We ignore any read times t >= t2. These can emerge because if nothing
+ # in memory matches positively with the read query, the top reads may be
+ # in the empty region of the memory.
+ for step_num, read_weight in zip(top_t1_read_times, top_t1_read_weights):
+ if step_num >= t2:
+ # Skip this step_num as it is not really a memory time.
+ continue
+
+ # Compute the tvt reward and add it on.
+ tvt_reward = alpha * read_weight * baseline_value_when_reading
+ tvt_rewards[step_num] += tvt_reward
+
+ return tvt_rewards
+
+
+def _compute_tvt_rewards_from_read_info(
+ read_weights, read_strengths, read_times, baselines, gamma,
+ alpha=0.9, top_k_t1=50,
+ read_strength_threshold=2.,
+ no_transport_period_when_gamma_1=25):
+ """Compute TVT rewards given supplied read information, no batch dimension.
+
+ Args:
+ read_weights: shape (ep_length, num_read_heads, top_k).
+ read_strengths: shape (ep_length, num_read_heads).
+ read_times: shape (ep_length, num_read_heads, top_k).
+ baselines: shape (ep_length,).
+ gamma: Scalar discount factor used to calculate the no_transport_period.
+ alpha: The multiplier for the temporal value transport rewards.
+ top_k_t1: For each read event time, this determines how many time points
+ to send tvt reward to.
+ read_strength_threshold: Read strengths below this value are ignored.
+ no_transport_period_when_gamma_1: no transport period when gamma == 1.
+
+ Returns:
+ An array of TVT rewards with shape (ep_length,).
+ """
+
+ if gamma < 1:
+ no_transport_period = int(1 / (1 - gamma))
+ else:
+ if no_transport_period_when_gamma_1 is None:
+ raise ValueError("No transport period must be defined when gamma == 1.")
+ no_transport_period = no_transport_period_when_gamma_1
+
+ # Split read infos by read head.
+ num_read_heads = read_weights.shape[1]
+ read_weights = _unstack(read_weights, axis=1)
+ read_strengths = _unstack(read_strengths, axis=1)
+ read_times = _unstack(read_times, axis=1)
+
+ # Calcuate TVT rewards for each read head separately and add to total.
+ tvt_rewards = np.zeros_like(baselines)
+ for i in range(num_read_heads):
+ tvt_rewards += _tvt_rewards_single_head(
+ read_weights[i], read_strengths[i], read_times[i],
+ baselines, alpha, top_k_t1, read_strength_threshold,
+ no_transport_period)
+
+ return tvt_rewards
+
+
+def compute_tvt_rewards(read_infos, baselines, gamma=.96):
+ """Compute TVT rewards from EpisodeOutputs.
+
+ Args:
+ read_infos: A memory_reader.ReadInformation namedtuple, where each element
+ has shape (ep_length, batch_size, num_read_heads, ...).
+ baselines: A numpy float array with shape (ep_length, batch_size).
+ gamma: Discount factor.
+
+ Returns:
+ An array of TVT rewards with shape (ep_length,).
+ """
+ if not read_infos:
+ return np.zeros_like(baselines)
+
+ # TVT reward computation is without batch dimension. so we need to process
+ # read_infos and baselines into batchwise components.
+ batch_size = baselines.shape[1]
+
+ # Split each element of read info on batch dim.
+ read_weights = _unstack(read_infos.weights, axis=1)
+ read_strengths = _unstack(read_infos.strengths, axis=1)
+ read_indices = _unstack(read_infos.indices, axis=1)
+ # Split baselines on batch dim.
+ baselines = _unstack(baselines, axis=1)
+
+ # Comute TVT rewards for each element in the batch (threading over batch).
+ tvt_rewards = []
+ with futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
+ for i in range(batch_size):
+ tvt_rewards.append(
+ executor.submit(
+ _compute_tvt_rewards_from_read_info,
+ read_weights[i],
+ read_strengths[i],
+ read_indices[i],
+ baselines[i],
+ gamma)
+ )
+ tvt_rewards = [f.result() for f in tvt_rewards]
+
+ # Process TVT rewards back into an array of shape (ep_length, batch_size).
+ return np.stack(tvt_rewards, axis=1)