mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-24 00:05:19 +08:00
Update citation and add demo results for no-TVT with gamma<1.
PiperOrigin-RevId: 281522361
This commit is contained in:
committed by
Diego de Las Casas
parent
5c9f992652
commit
94505a89e6
+189
@@ -0,0 +1,189 @@
|
||||
# TVT: Temporal Value Transport
|
||||
|
||||
An open source implementation of agents, algorithm and environments related to
|
||||
the paper [Optimizing Agent Behavior over Long Time Scales by Transporting Value](https://arxiv.org/abs/1810.06721).
|
||||
|
||||
## Installation
|
||||
|
||||
TVT package installation and training can run using: `tvt/run.sh`. This will use
|
||||
all default flag values for the training script `tvt/main.py`. See the section
|
||||
on running experiments below for launching with non-default flags.
|
||||
|
||||
Note that the default installation uses tensorflow without gpu. Replace
|
||||
`tensorflow` by `tensorflow-gpu` in `tvt/requirements.txt` to use tensorflow
|
||||
with gpu.
|
||||
|
||||
## Differences between this implementation and the paper
|
||||
|
||||
In the paper agents were trained using a distributed A3C architecture with
|
||||
384 actors. This implementation runs a batched A2C agent on a single gpu machine
|
||||
with batch size 16.
|
||||
|
||||
## Tasks
|
||||
|
||||
### Pycolab tasks
|
||||
|
||||
In order for this to train in a reasonable time on a single machine, we
|
||||
provide 2D grid world versions of the paper tasks using Pycolab, to replace
|
||||
the original DeepMind Lab 3D tasks.
|
||||
|
||||
Further details of the tasks are given in the Pycolab directory README and users
|
||||
can also play the tasks themselves, from the command line.
|
||||
|
||||
Special thanks to Hamza Merzic for writing the two Pycolab task scripts.
|
||||
|
||||
### DeepMind Lab tasks
|
||||
|
||||
The DeepMind Lab tasks used in the paper are also provided as part of this
|
||||
release.
|
||||
|
||||
Further details of specific tasks are given in the DeepMind Lab directory
|
||||
README.
|
||||
|
||||
## Running experiments
|
||||
|
||||
### Launching
|
||||
|
||||
To start an experiment, run:
|
||||
|
||||
```
|
||||
source tvt_venv/bin/activate
|
||||
python3 -m tvt.main
|
||||
```
|
||||
|
||||
This will launch a default setup that uses the RMA agent on the 'Key To Door'
|
||||
Pycolab task.
|
||||
|
||||
### Important flags
|
||||
`tvt.main` accepts many flags.
|
||||
|
||||
Note that all the default hyperparameters are tuned for the TVT-RMA agent to
|
||||
solve both `key_to_door` and `active_visual_match` Pycolab tasks.
|
||||
|
||||
#### Information logging:
|
||||
`logging_frequency`: frequency of logging in console and tensorboard. <br>
|
||||
`logdir`: Directory for tensorboard logging. <br>
|
||||
|
||||
#### Agent configuration:
|
||||
`with_memory`: default True. Whether or not agent has external memory. If set to
|
||||
False, then agent has only LSTM memory.<br>
|
||||
`with_reconstruction`: default True. Whether or not agent reconstructs the
|
||||
observation as described in Reconstructive Memory Agent (RMA) architecture.<br>
|
||||
`gamma`: Agent discount factor.<br>
|
||||
`entropy_cost`: Weight of the entropy loss. <br>
|
||||
`image_cost_weight`: Weight of image reconstruction loss.<br>
|
||||
`read_strength_cost`: Weight of the memory read strength. Used to regularize the
|
||||
memory acess.<br>
|
||||
`read_strength_tolerance`: The tolerance of hinge loss for the read strengths.
|
||||
<br>
|
||||
`do_tvt`: default True. Whether or not to apply the Temporal Value Transport
|
||||
Algorithm (only works if the model has external memory).<br>
|
||||
|
||||
#### Optimization:
|
||||
`batch_size`: Batch size for the batched A2C algorithm.<br>
|
||||
`learning_rate`: Learning rate for Adam optimizer.<br>
|
||||
`beta1`: Adam optimizer beta1.<br>
|
||||
`beta2`: Adam optimizer beta2.<br>
|
||||
`epsilon` Adam optimizer epsilon.<br>
|
||||
`num_episodes` Number of episodes to train for. None means run forever.<br>
|
||||
|
||||
#### Pycolab-specific flags:
|
||||
`pycolab_game`: Which game to run. One of 'key_to_door' or
|
||||
'active_visual_match'. See pycolab/README for description.<br>
|
||||
|
||||
`pycolab_num_apples`: Number of apples to sample from.<br>
|
||||
`pycolab_apple_reward_min`: The minimum apple reward.<br>
|
||||
`pycolab_apple_reward_max`: The maximum apple reward.<br>
|
||||
`pycolab_fix_apple_reward_in_episode` default True. This fixes the sampled apple
|
||||
reward within an episode.<br>
|
||||
`pycolab_final_reward`: Reward obtained at the last phase.<br>
|
||||
`pycolab_crop`: default True. Whether to crop observations or not.<br>
|
||||
|
||||
|
||||
### Monitoring results
|
||||
|
||||
Key outputs are logged to the command line and to tensorboard logs.
|
||||
We can use [tensorboard](https://www.tensorflow.org/guide/summaries_and_tensorboard)
|
||||
to track the learning progress if FLAGS.logdir is set.<br>
|
||||
```
|
||||
tensorboard --logdir=<logdir>
|
||||
```
|
||||
<br>
|
||||
Key values logged:
|
||||
`reward`: The total rewards agent acquired in an episode. <br>
|
||||
`last phase reward`: The critical reward acquired in the exploit phase, which
|
||||
depends on the behavior in the exploring phase.<br>
|
||||
`tvt reward`: The total fictitious rewards generated by the Temporal Value
|
||||
Transport algorithm.<br>
|
||||
`total loss`: The sum of all losses, including policy gradient loss, value
|
||||
function loss, reconstruction loss, and memory read regularization loss. We also
|
||||
log these losses separatedly.
|
||||
|
||||
## Example results
|
||||
|
||||
Here we show the example results of running the TVT agent (with the default
|
||||
hyperparameters) and the best control RMA agent (with `do_tvt=False, gamma=1`).
|
||||
|
||||
Since TVT is designed to reduce the variance in signal for learning rewards that
|
||||
are temporally far from the actions or information that lead to those rewards,
|
||||
in the paper we focus on the reward in the last phase of each task, which is
|
||||
the only reward that depends on actions or information from much earlier in the
|
||||
task than the time at which the reward is given. In the experiments here, the
|
||||
best way to track if TVT is working is by monitoring the `last phase reward`
|
||||
as this is the critical performance we are interested in - the agent with TVT
|
||||
and the control agents are doing well in the apple collecting phase, which
|
||||
contributes most of the episodic rewards, but not in the last phase.
|
||||
|
||||
### Key-to-door
|
||||
Across 10 replicas, we found that the TVT agents get to a score of 10,
|
||||
meaning they reliably collected the key in the explore phase to open the door in
|
||||
the exploit phase.<br>
|
||||
# 
|
||||
For 10 replicas without TVT and with the same hyperparameters, we see consistent
|
||||
low performance.<br>
|
||||
# 
|
||||
For 5 replicas with gamma equal to 1, performance of the RMA agent without TVT
|
||||
is improved, but is unstable and never goes above 7.<br>
|
||||
# 
|
||||
|
||||
### Active-visual-match
|
||||
Across 10 replicas, we found that the TVT agents get to a score of 10,
|
||||
meaning they reliably searched for the pixel and remembered its color in the
|
||||
explore phase, and then touched the corresponding pixel in the exploit
|
||||
phase.<br>
|
||||
# 
|
||||
For 10 replicas without TVT and with the same hyperparamters, performance is
|
||||
better than chance level but not at the maximum level, indicating that it is not
|
||||
able to actively seek for information in the explore phase and instead must rely
|
||||
on randomly encountering the information.<br>
|
||||
# 
|
||||
For 5 replicas with gamma equal to 1, performance of the RMA agent without TVT
|
||||
is considerably worse, suggesting the behavior learnt from later phases does not
|
||||
result in undirected exploration in the first phase.
|
||||
# 
|
||||
|
||||
## Citing this work
|
||||
|
||||
If you use this code in your work, please cite the accompanying paper:
|
||||
|
||||
```
|
||||
@article{
|
||||
author = {Chia{-}Chun Hung and
|
||||
Timothy P. Lillicrap and
|
||||
Josh Abramson and
|
||||
Yan Wu and
|
||||
Mehdi Mirza and
|
||||
Federico Carnevale and
|
||||
Arun Ahuja and
|
||||
Greg Wayne},
|
||||
title = {Optimizing Agent Behavior over Long Time Scales by Transporting Value},
|
||||
journal = {Nat Commun},
|
||||
volume = {10},
|
||||
year = {2019},
|
||||
doi = {https://doi.org/10.1038/s41467-019-13073-w},
|
||||
}
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This is not an officially supported Google or DeepMind product.
|
||||
@@ -0,0 +1,110 @@
|
||||
# Lint as: python2, python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Threaded batch environment wrapper."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from concurrent import futures
|
||||
|
||||
from six.moves import range
|
||||
from six.moves import zip
|
||||
|
||||
from tvt import nest_utils
|
||||
|
||||
|
||||
class BatchEnv(object):
|
||||
"""Wrapper that steps multiple environments in separate threads.
|
||||
|
||||
The threads are stepped in lock step, so all threads progress by one step
|
||||
before any move to the next step.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size, env_builder, **env_kwargs):
|
||||
self.batch_size = batch_size
|
||||
self._envs = [env_builder(**env_kwargs) for _ in range(batch_size)]
|
||||
self._num_actions = self._envs[0].num_actions
|
||||
self._observation_shape = self._envs[0].observation_shape
|
||||
self._episode_length = self._envs[0].episode_length
|
||||
|
||||
self._executor = futures.ThreadPoolExecutor(max_workers=self.batch_size)
|
||||
|
||||
def reset(self):
|
||||
"""Reset the entire batch of environments."""
|
||||
|
||||
def reset_environment(env):
|
||||
return env.reset()
|
||||
|
||||
try:
|
||||
output_list = []
|
||||
for env in self._envs:
|
||||
output_list.append(self._executor.submit(reset_environment, env))
|
||||
output_list = [env_output.result() for env_output in output_list]
|
||||
except KeyboardInterrupt:
|
||||
self._executor.shutdown(wait=True)
|
||||
raise
|
||||
|
||||
observations, rewards = nest_utils.nest_stack(output_list)
|
||||
return observations, rewards
|
||||
|
||||
def step(self, action_list):
|
||||
"""Step batch of envs.
|
||||
|
||||
Args:
|
||||
action_list: A list of actions, one per environment in the batch. Each one
|
||||
should be a scalar int or a numpy scaler int.
|
||||
|
||||
Returns:
|
||||
A tuple (observations, rewards):
|
||||
observations: A nest of observations, each one a numpy array where the
|
||||
first dimension has size equal to the number of environments in the
|
||||
batch.
|
||||
rewards: An array of rewards with size equal to the number of
|
||||
environments in the batch.
|
||||
"""
|
||||
|
||||
def step_environment(env, action):
|
||||
return env.step(action)
|
||||
|
||||
try:
|
||||
output_list = []
|
||||
for env, action in zip(self._envs, action_list):
|
||||
output_list.append(self._executor.submit(step_environment, env, action))
|
||||
output_list = [env_output.result() for env_output in output_list]
|
||||
except KeyboardInterrupt:
|
||||
self._executor.shutdown(wait=True)
|
||||
raise
|
||||
|
||||
observations, rewards = nest_utils.nest_stack(output_list)
|
||||
return observations, rewards
|
||||
|
||||
@property
|
||||
def observation_shape(self):
|
||||
"""Observation shape per environment, i.e. with no batch dimension."""
|
||||
return self._observation_shape
|
||||
|
||||
@property
|
||||
def num_actions(self):
|
||||
return self._num_actions
|
||||
|
||||
@property
|
||||
def episode_length(self):
|
||||
return self._episode_length
|
||||
|
||||
def last_phase_rewards(self):
|
||||
return [env.last_phase_reward() for env in self._envs]
|
||||
@@ -0,0 +1,66 @@
|
||||
# DM Lab Tasks
|
||||
|
||||
## General Structure
|
||||
|
||||
There are 7 [DM Lab](https://github.com/deepmind/lab) tasks presented here.
|
||||
Each level is composed of 3 distinct phases (except `Key To Door To Match`
|
||||
which has 5 phases). The first phase is the 'explore' phase, where the agent
|
||||
should learn a piece of information or do something. For all tasks, the 2nd
|
||||
phase is the 'distractor' phase, where the agent collects apples for rewards.
|
||||
The 3rd phase is the 'exploit' phase, where the agent gets rewards based on the
|
||||
knowledge acquired or actions performed in phase 1.
|
||||
|
||||
## Specific Tasks
|
||||
|
||||
### Passive Visual Match
|
||||
|
||||
* Phase 1: A colour square right in front of the agent.
|
||||
* Phase 2: Apples collection.
|
||||
* Phase 3: Choose the colour square matched that in Phase 1 among 4 options.
|
||||
|
||||
### Active Visual Match
|
||||
|
||||
* Phase 1: A colour square randomly placed in a two-connected room.
|
||||
* Phase 2: Apples collection.
|
||||
* Phase 3: Choose the colour square matched that in Phase 1 among 4 options.
|
||||
|
||||
### Key To Door
|
||||
|
||||
* Phase 1: A key randomly placed in a two-connected room.
|
||||
* Phase 2: Apples collection.
|
||||
* Phase 3: A small room with a door. If agent has key, it can open the door to
|
||||
get to the goal behind the door to get reward.
|
||||
|
||||
### Key To Door Bluekey
|
||||
|
||||
All the same as key_to_door above but the key has a blue colour instead of
|
||||
black.
|
||||
|
||||
### Two Negative Keys
|
||||
|
||||
* Phase 1: A blue and a red key placed in a small room. The agent can only
|
||||
pick up one of the key.
|
||||
* Phase 2: Apples collection.
|
||||
* Phase 3: A small room with a door. If agent has either key, it can open the
|
||||
door to get reward. The reward depends on which key it got in Phase 1
|
||||
All the rewards are negative in this level.
|
||||
|
||||
### Latent Information Acquisition
|
||||
|
||||
* Phase 1: Thre randomly sampled objects are randomly placed in a small room.
|
||||
When the agent touch each object, a red or green cue will appear,
|
||||
indicating the reward it is associated in this episode. No rewards
|
||||
are given in this phase.
|
||||
* Phase 2: Apples collection.
|
||||
* Phase 3: The same three objects in Phase 1 randomly placed again in the room.
|
||||
The agent will get positive rewards if pick up the objects with green
|
||||
cues in Phase 1, and get negative rewards for objects with red cues.
|
||||
|
||||
### Key To Door To Match
|
||||
|
||||
* Phase 1: A key is randomly placed in a room. Agent could pick it up.
|
||||
* Phase 2: Apples collection.
|
||||
* Phase 3: A colour square behind a door. If agent has key from Phase 1, it can
|
||||
open the door to see the colour.
|
||||
* Phase 4: Apples collection.
|
||||
* Phase 5: Chose the colour square matched that in Phase 3 among 4 options.
|
||||
@@ -0,0 +1,24 @@
|
||||
-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
-- ============================================================================
|
||||
local factory = require 'visual_match_factory'
|
||||
|
||||
return factory.createLevelApi{
|
||||
exploreMapMode = 'TWO_ROOMS',
|
||||
episodeLengthSeconds = 40,
|
||||
exploreLengthSeconds = 5,
|
||||
distractorLengthSeconds = 30,
|
||||
|
||||
differentDistractRoomTexture = true,
|
||||
differentRewardRoomTexture = true,
|
||||
correctReward = 10,
|
||||
incorrectReward = 1,
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
-- ============================================================================
|
||||
local tensor = require 'dmlab.system.tensor'
|
||||
|
||||
local utils = {}
|
||||
utils.COLORS = {
|
||||
{0, 0, 0},
|
||||
{0, 0, 170},
|
||||
{0, 170, 0},
|
||||
{0, 170, 170},
|
||||
{170, 0, 0},
|
||||
{170, 0, 170},
|
||||
{170, 85, 0},
|
||||
{170, 170, 170},
|
||||
{85, 85, 85},
|
||||
{85, 85, 255},
|
||||
{85, 255, 85},
|
||||
{85, 255, 255},
|
||||
{255, 85, 85},
|
||||
{255, 85, 255},
|
||||
{255, 255, 85},
|
||||
{255, 255, 255},
|
||||
}
|
||||
|
||||
function utils:createByteImage(h, w, rgb)
|
||||
return tensor.ByteTensor(h, w, 4):fill{rgb[1], rgb[2], rgb[3], 255}
|
||||
end
|
||||
|
||||
function utils:createTransparentImage(h, w)
|
||||
return tensor.ByteTensor(h, w, 4):fill{127, 127, 127, 0}
|
||||
end
|
||||
|
||||
return utils
|
||||
@@ -0,0 +1,20 @@
|
||||
-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
-- ============================================================================
|
||||
local factory = require 'key_to_door_factory'
|
||||
|
||||
return factory.createLevelApi{
|
||||
episodeLengthSeconds = 37,
|
||||
exploreLengthSeconds = 5,
|
||||
distractorLengthSeconds = 30,
|
||||
differentDistractRoomTexture = true,
|
||||
differentRewardRoomTexture = true,
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
-- ============================================================================
|
||||
local factory = require 'key_to_door_factory'
|
||||
|
||||
return factory.createLevelApi{
|
||||
keyColor = {0, 0, 255},
|
||||
episodeLengthSeconds = 37,
|
||||
exploreLengthSeconds = 5,
|
||||
distractorLengthSeconds = 30,
|
||||
differentDistractRoomTexture = true,
|
||||
differentRewardRoomTexture = true,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,459 @@
|
||||
-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
-- ============================================================================
|
||||
local make_map = require 'common.make_map'
|
||||
local custom_observations = require 'decorators.custom_observations'
|
||||
local debug_observations = require 'decorators.debug_observations'
|
||||
local game = require 'dmlab.system.game'
|
||||
local map_maker = require 'dmlab.system.map_maker'
|
||||
local maze_generation = require 'dmlab.system.maze_generation'
|
||||
local pickup_decorator = require 'decorators.human_recognisable_pickups'
|
||||
local random = require 'common.random'
|
||||
local setting_overrides = require 'decorators.setting_overrides'
|
||||
local texture_sets = require 'themes.texture_sets'
|
||||
local themes = require 'themes.themes'
|
||||
local hrp = require 'common.human_recognisable_pickups'
|
||||
|
||||
local DEFAULTS = {
|
||||
EPISODE_LENGTH_SECONDS = 15,
|
||||
EXPLORE_LENGTH_SECONDS = 5,
|
||||
DISTRACTOR_LENGTH_SECONDS = 5,
|
||||
REWARD_LENGTH_SECONDS = nil,
|
||||
SHOW_KEY_COLOR_SQUARE_SECONDS = 1,
|
||||
PROB_APPLE_IN_DISTRACTOR_MAP = 0.3,
|
||||
APPLE_REWARD = 5,
|
||||
APPLE_REWARD_PROB = 1.0,
|
||||
APPLE_EXTRA_REWARD_RANGE = 0,
|
||||
GOAL_REWARD = 10,
|
||||
DISTRACTOR_ROOM_SIZE = {11, 11},
|
||||
DIFFERENT_DISTRACT_ROOM_TEXTURE = false,
|
||||
DIFFERENT_REWARD_ROOM_TEXTURE = false,
|
||||
KEY_COLOR = {0, 0, 0},
|
||||
}
|
||||
|
||||
local APPLE_ID = 998
|
||||
local GOAL_ID = 999
|
||||
local KEY_SPAWN_ID = 1000
|
||||
local DOOR_ID = 1001
|
||||
|
||||
local KEY_CUE_RECTANGLE_WIDTH = 600
|
||||
local KEY_CUE_RECTANGLE_HEIGHT = 200
|
||||
|
||||
-- Table that maps from full decal name to decal index number.
|
||||
local decalIndices = {}
|
||||
|
||||
local EXPLORE_MAP = "exploreMap"
|
||||
local DISTRACTOR_MAP = "distractorMap"
|
||||
local REWARD_MAP = "rewardMap"
|
||||
|
||||
-- Set texture set for all maps.
|
||||
local textureSet = texture_sets.PACMAN
|
||||
local secondTextureSet = texture_sets.TETRIS
|
||||
local thirdTextureSet = texture_sets.TRON
|
||||
|
||||
local REWARD_ROOM =[[
|
||||
***
|
||||
*P*
|
||||
*H*
|
||||
*G*
|
||||
***
|
||||
]]
|
||||
|
||||
local OPEN_TWO_ROOM = [[
|
||||
*********
|
||||
*********
|
||||
*PKK*KKK*
|
||||
*KKKKKKK*
|
||||
*KKK*KKK*
|
||||
*********
|
||||
]]
|
||||
local N_KEY_POS_IN_TWO_ROOM = 18 -- # of K in OPEN_TWO_ROOM
|
||||
|
||||
local function createDistractorMaze(opts)
|
||||
-- Example room with height = 2, width = 3
|
||||
-- A are possible apple locations (everywhere)
|
||||
-- *****
|
||||
-- *APA*
|
||||
-- *AAA*
|
||||
-- *****
|
||||
|
||||
local roomHeight = opts.roomSize[1]
|
||||
local roomWidth = opts.roomSize[2]
|
||||
centerWidth = 1 + math.ceil(roomWidth / 2)
|
||||
local maze = maze_generation:mazeGeneration{
|
||||
height = roomHeight + 2, -- +2 for the two side of walls
|
||||
width = roomWidth + 2
|
||||
}
|
||||
|
||||
-- Fill the room with 'A' for apples. updateSpawnVars decides where to put.
|
||||
for i = 2, roomHeight + 1 do
|
||||
for j = 2, roomWidth + 1 do
|
||||
maze:setEntityCell(i, j, 'A')
|
||||
end
|
||||
end
|
||||
-- Override one cell with 'P' for spawn point.
|
||||
maze:setEntityCell(2, centerWidth, 'P')
|
||||
return maze
|
||||
end
|
||||
|
||||
local function numPossibleAppleLocations(distractorRoomSize)
|
||||
return distractorRoomSize[1] * distractorRoomSize[2] - 1
|
||||
end
|
||||
|
||||
local factory = {}
|
||||
game:console('cg_drawScriptRectanglesAlways 1')
|
||||
|
||||
function factory.createLevelApi(kwargs)
|
||||
kwargs.episodeLengthSeconds = kwargs.episodeLengthSeconds or
|
||||
DEFAULTS.EPISODE_LENGTH_SECONDS
|
||||
kwargs.exploreLengthSeconds = kwargs.exploreLengthSeconds or
|
||||
DEFAULTS.EXPLORE_LENGTH_SECONDS
|
||||
kwargs.rewardLengthSeconds = kwargs.rewardLengthSeconds or
|
||||
DEFAULTS.REWARD_LENGTH_SECONDS
|
||||
kwargs.distractorLengthSeconds = kwargs.distractorLengthSeconds or
|
||||
DEFAULTS.DISTRACTOR_LENGTH_SECONDS
|
||||
kwargs.distractorRoomSize = kwargs.distractorRoomSize or
|
||||
DEFAULTS.DISTRACTOR_ROOM_SIZE
|
||||
|
||||
kwargs.appleReward = kwargs.appleReward or DEFAULTS.APPLE_REWARD
|
||||
kwargs.appleRewardProb = kwargs.appleRewardProb or DEFAULTS.APPLE_REWARD_PROB
|
||||
kwargs.probAppleInDistractorMap = kwargs.probAppleInDistractorMap or
|
||||
DEFAULTS.PROB_APPLE_IN_DISTRACTOR_MAP
|
||||
|
||||
kwargs.appleExtraRewardRange =
|
||||
kwargs.appleExtraRewardRange or DEFAULTS.APPLE_EXTRA_REWARD_RANGE
|
||||
|
||||
kwargs.differentDistractRoomTexture = kwargs.differentDistractRoomTexture or
|
||||
DEFAULTS.DIFFERENT_DISTRACT_ROOM_TEXTURE
|
||||
|
||||
kwargs.differentRewardRoomTexture = kwargs.differentRewardRoomTexture or
|
||||
DEFAULTS.DIFFERENT_REWARD_ROOM_TEXTURE
|
||||
|
||||
kwargs.showKeyColorSquareSeconds = kwargs.showKeyColorSquareSeconds or
|
||||
DEFAULTS.SHOW_KEY_COLOR_SQUARE_SECONDS
|
||||
kwargs.goalReward = kwargs.goalReward or DEFAULTS.GOAL_REWARD
|
||||
kwargs.keyColor = kwargs.keyColor or DEFAULTS.KEY_COLOR
|
||||
|
||||
local api = {}
|
||||
|
||||
function api:init(params)
|
||||
self:_createExploreMap()
|
||||
self:_createDistractorMap()
|
||||
self:_createRewardMap()
|
||||
|
||||
local keyInfo = {
|
||||
shape='key',
|
||||
pattern='solid',
|
||||
color1 = kwargs.keyColor,
|
||||
color2 = kwargs.keyColor
|
||||
}
|
||||
self._keyObject = hrp.create(keyInfo)
|
||||
self._keyCueRgba = {
|
||||
kwargs.keyColor[1]/255,
|
||||
kwargs.keyColor[2]/255,
|
||||
kwargs.keyColor[3]/255,
|
||||
1
|
||||
}
|
||||
end
|
||||
|
||||
function api:_createRewardMap()
|
||||
self._rewardMap = map_maker:mapFromTextLevel{
|
||||
mapName = REWARD_MAP,
|
||||
entityLayer = REWARD_ROOM,
|
||||
}
|
||||
|
||||
-- Create map theme and override default wall decal placement.
|
||||
local texture = textureSet
|
||||
if kwargs.differentRewardRoomTexture then
|
||||
texture = thirdTextureSet
|
||||
end
|
||||
local rewardMapTheme = themes.fromTextureSet{
|
||||
textureSet = texture,
|
||||
decalFrequency = 0.0,
|
||||
floorModelFrequency = 0.0,
|
||||
}
|
||||
|
||||
self._rewardMap = map_maker:mapFromTextLevel{
|
||||
mapName = REWARD_MAP,
|
||||
entityLayer = REWARD_ROOM,
|
||||
theme = rewardMapTheme,
|
||||
callback = function (i, j, c, maker)
|
||||
local pickup = self:_makePickup(c)
|
||||
if pickup then
|
||||
return maker:makeEntity{i = i, j = j, classname = pickup}
|
||||
end
|
||||
end
|
||||
}
|
||||
end
|
||||
|
||||
function api:_createExploreMap()
|
||||
exploreMapInfo = {map = OPEN_TWO_ROOM}
|
||||
|
||||
-- Create map theme and override default wall decal placement.
|
||||
local exploreMapTheme = themes.fromTextureSet{
|
||||
textureSet = textureSet,
|
||||
decalFrequency = 0.0,
|
||||
floorModelFrequency = 0.0,
|
||||
}
|
||||
|
||||
self._exploreMap = map_maker:mapFromTextLevel{
|
||||
mapName = EXPLORE_MAP,
|
||||
entityLayer = exploreMapInfo.map,
|
||||
theme = exploreMapTheme,
|
||||
callback = function (i, j, c, maker)
|
||||
local pickup = self:_makePickup(c)
|
||||
if pickup then
|
||||
return maker:makeEntity{i = i, j = j, classname = pickup}
|
||||
end
|
||||
end
|
||||
}
|
||||
end
|
||||
|
||||
function api:_createDistractorMap()
|
||||
-- Create maze to be converted into map.
|
||||
local maze = createDistractorMaze{roomSize = kwargs.distractorRoomSize}
|
||||
|
||||
-- Create map theme with no wall decals.
|
||||
local texture = textureSet
|
||||
if kwargs.differentDistractRoomTexture then
|
||||
texture = secondTextureSet
|
||||
end
|
||||
local distractorMapTheme = themes.fromTextureSet{
|
||||
textureSet = texture,
|
||||
decalFrequency = 0.0,
|
||||
floorModelFrequency = 0.0,
|
||||
}
|
||||
|
||||
self._distractorMap = map_maker:mapFromTextLevel{
|
||||
mapName = DISTRACTOR_MAP,
|
||||
entityLayer = maze:entityLayer(),
|
||||
theme = distractorMapTheme,
|
||||
callback = function (i, j, c, maker)
|
||||
local pickup = self:_makePickup(c)
|
||||
if pickup then
|
||||
return maker:makeEntity{i = i, j = j, classname = pickup}
|
||||
end
|
||||
end
|
||||
}
|
||||
end
|
||||
|
||||
function api:start(episode, seed)
|
||||
random:seed(seed)
|
||||
|
||||
self._map = nil
|
||||
self._time = 0
|
||||
self._holdingKey = false
|
||||
self._keyPosCount = 0
|
||||
self._collectedGoal = false
|
||||
|
||||
if kwargs.distractorLengthSecondsRange then
|
||||
self._distractorLen = random:uniformReal(
|
||||
kwargs.distractorLengthSecondsRange[1],
|
||||
kwargs.distractorLengthSecondsRange[2])
|
||||
else
|
||||
self._distractorLen = kwargs.distractorLengthSeconds
|
||||
end
|
||||
|
||||
-- Sample the key position in phase 1.
|
||||
self._keyPosition = random:uniformInt(1, N_KEY_POS_IN_TWO_ROOM)
|
||||
|
||||
-- Default instruction channel to 0 (indicating the rewards in final phase.)
|
||||
self.setInstruction(tostring(0))
|
||||
end
|
||||
|
||||
function api:filledRectangles(args)
|
||||
if self._showKeyCue then
|
||||
return {{
|
||||
x = 12,
|
||||
y = 12,
|
||||
width = KEY_CUE_RECTANGLE_WIDTH,
|
||||
height = KEY_CUE_RECTANGLE_HEIGHT,
|
||||
rgba = self._keyCueRgba
|
||||
}}
|
||||
end
|
||||
return {}
|
||||
end
|
||||
|
||||
function api:nextMap()
|
||||
-- 1. Decide what is the next map.
|
||||
if self._map == nil then
|
||||
self._map = EXPLORE_MAP
|
||||
elseif self._map == DISTRACTOR_MAP then
|
||||
self._map = REWARD_MAP
|
||||
elseif self._map == EXPLORE_MAP then
|
||||
if self._distractorLen > 0.0 then
|
||||
self._map = DISTRACTOR_MAP
|
||||
else
|
||||
self._map = REWARD_MAP
|
||||
end
|
||||
elseif self._map == REWARD_MAP then
|
||||
-- Stay in distractor map till end of episode.
|
||||
self._map = DISTRACTOR_MAP
|
||||
self._collectedGoal = true
|
||||
end
|
||||
|
||||
-- 2. Set up timeout for the up-coming map.
|
||||
if self._map == DISTRACTOR_MAP and self._collectedGoal then
|
||||
if not self._timeOut then -- don't override any existing timeout
|
||||
self._timeOut = self._time + 0.1
|
||||
end
|
||||
elseif self._map == EXPLORE_MAP then
|
||||
self._timeOut = self._time + kwargs.exploreLengthSeconds
|
||||
elseif self._map == DISTRACTOR_MAP then
|
||||
self._timeOut = self._time + self._distractorLen
|
||||
elseif self._map == REWARD_MAP then
|
||||
if kwargs.rewardLengthSeconds then
|
||||
self._timeOut = self._time + kwargs.rewardLengthSeconds
|
||||
else
|
||||
self._timeOut = nil
|
||||
end
|
||||
end
|
||||
|
||||
return self._map
|
||||
end
|
||||
|
||||
-- PICKUP functions ----------------------------------------------------------
|
||||
|
||||
function api:_makePickup(c)
|
||||
if c == 'K' then
|
||||
return 'key'
|
||||
end
|
||||
if c == 'G' then
|
||||
return 'goal'
|
||||
end
|
||||
if c == 'A' then
|
||||
return 'apple_reward'
|
||||
end
|
||||
end
|
||||
|
||||
function api:pickup(spawnId)
|
||||
if spawnId == GOAL_ID then
|
||||
local goalReward = kwargs.goalReward
|
||||
game:addScore(goalReward - 10) -- Offset the default +10 for goal.
|
||||
self.setInstruction(tostring(goalReward))
|
||||
game:finishMap()
|
||||
end
|
||||
if spawnId == KEY_SPAWN_ID then
|
||||
self._holdingKey = true
|
||||
self._holdingKeyTime = self._time -- When the avatar got the key.
|
||||
self._showKeyCue = true
|
||||
end
|
||||
|
||||
if spawnId == APPLE_ID then
|
||||
if kwargs.appleRewardProb >= 1 or
|
||||
random:uniformReal(0, 1) < kwargs.appleRewardProb then
|
||||
-- The -1 is to offset the default 1 point for apple in dmlab
|
||||
appleReward = kwargs.appleReward +
|
||||
random:uniformInt(0, kwargs.appleExtraRewardRange) - 1
|
||||
game:addScore(appleReward)
|
||||
else
|
||||
-- The -1 is to offset the default 1 point for apple in dmlab
|
||||
game:addScore(-1)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- TRIGGER functions ---------------------------------------------------------
|
||||
|
||||
function api:canTrigger(teleportId, targetName)
|
||||
if string.sub(targetName, 1, 4) == 'door' then
|
||||
if self._holdingKey then
|
||||
return true
|
||||
else
|
||||
return false
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
function api:trigger(teleportId, targetName)
|
||||
if string.sub(targetName, 1, 4) == 'door' then
|
||||
-- When door opend, stop showing key cue, and set holding key to false.
|
||||
self._showKeyCue = false
|
||||
self._holdingKey = false
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
function api:hasEpisodeFinished(timeSeconds)
|
||||
self._time = timeSeconds
|
||||
|
||||
if self._map == REWARD_MAP or self._collectedGoal then
|
||||
return self._timeOut and timeSeconds > self._timeOut
|
||||
end
|
||||
|
||||
-- Control the timing of showing key cue.
|
||||
if self._holdingKey then
|
||||
local showTime = self._time - self._holdingKeyTime
|
||||
if showTime > kwargs.showKeyColorSquareSeconds then
|
||||
self._showKeyCue = false
|
||||
end
|
||||
end
|
||||
|
||||
if self._map == EXPLORE_MAP or self._map == DISTRACTOR_MAP then
|
||||
if timeSeconds > self._timeOut then
|
||||
game:finishMap()
|
||||
end
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
-- END TRIGGER functions -----------------------------------------------------
|
||||
|
||||
function api:updateSpawnVars(spawnVars)
|
||||
local classname = spawnVars.classname
|
||||
if classname == "info_player_start" then
|
||||
-- Spawn facing South.
|
||||
spawnVars.angle = "-90"
|
||||
spawnVars.randomAngleRange = "0"
|
||||
elseif classname == "func_door" then
|
||||
spawnVars.id = tostring(DOOR_ID)
|
||||
spawnVars.wait = "1000000" -- Open the door for long time.
|
||||
elseif classname == "goal" then
|
||||
spawnVars.id = tostring(GOAL_ID)
|
||||
elseif classname == "apple_reward" then
|
||||
-- We respawn the avatar to distractor room after reaching goal
|
||||
-- there will be no more apples in this case.
|
||||
if self._collectedGoal == true then
|
||||
return nil
|
||||
end
|
||||
local useApple = false
|
||||
if kwargs.probAppleInDistractorMap > 0 then
|
||||
useApple = random:uniformReal(0, 1) < kwargs.probAppleInDistractorMap
|
||||
end
|
||||
if useApple then
|
||||
spawnVars.id = tostring(APPLE_ID)
|
||||
else
|
||||
return nil
|
||||
end
|
||||
elseif classname == "key" then
|
||||
self._keyPosCount = self._keyPosCount + 1
|
||||
if self._keyPosition == self._keyPosCount then
|
||||
spawnVars.id = tostring(KEY_SPAWN_ID)
|
||||
spawnVars.classname = self._keyObject
|
||||
else
|
||||
return nil
|
||||
end
|
||||
end
|
||||
return spawnVars
|
||||
end
|
||||
|
||||
custom_observations.decorate(api)
|
||||
pickup_decorator.decorate(api)
|
||||
setting_overrides.decorate{
|
||||
api = api,
|
||||
apiParams = kwargs,
|
||||
decorateWithTimeout = true
|
||||
}
|
||||
return api
|
||||
end
|
||||
|
||||
return factory
|
||||
@@ -0,0 +1,28 @@
|
||||
-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
-- ============================================================================
|
||||
local factory = require 'visual_match_factory'
|
||||
|
||||
return factory.createLevelApi{
|
||||
exploreMapMode = 'KEY_TO_COLOR',
|
||||
episodeLengthSeconds = 45,
|
||||
secondOrderExploreLengthSeconds = 5,
|
||||
preExploreDistractorLengthSeconds = 15,
|
||||
exploreLengthSeconds = 5,
|
||||
distractorLengthSeconds = 15,
|
||||
|
||||
differentDistractRoomTexture = true,
|
||||
differentRewardRoomTexture = true,
|
||||
differentSecondOrderRoomTexture = true,
|
||||
secondOrderExploreRoomSize = {4, 4},
|
||||
correctReward = 10,
|
||||
incorrectReward = 1,
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
-- ============================================================================
|
||||
local factory = require 'latent_information_acquisition_factory'
|
||||
|
||||
return factory.createLevelApi{
|
||||
episodeLengthSeconds = 40,
|
||||
exploreLengthSeconds = 5,
|
||||
distractorLengthSeconds = 30,
|
||||
numObjects = 3,
|
||||
probGoodObject = 0.5,
|
||||
correctReward = 20,
|
||||
incorrectReward = -10,
|
||||
differentDistractRoomTexture = true,
|
||||
}
|
||||
@@ -0,0 +1,418 @@
|
||||
-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
-- ============================================================================
|
||||
local make_map = require 'common.make_map'
|
||||
local custom_decals = require 'decorators.custom_decals_decoration'
|
||||
local custom_entities = require 'common.custom_entities'
|
||||
local custom_observations = require 'decorators.custom_observations'
|
||||
local datasets_selector = require 'datasets.selector'
|
||||
local game = require 'dmlab.system.game'
|
||||
local maze_generation = require 'dmlab.system.maze_generation'
|
||||
local pickup_decorator = require 'decorators.human_recognisable_pickups'
|
||||
local random = require 'common.random'
|
||||
local setting_overrides = require 'decorators.setting_overrides'
|
||||
local texture_sets = require 'themes.texture_sets'
|
||||
local themes = require 'themes.themes'
|
||||
local hrp = require 'common.human_recognisable_pickups'
|
||||
|
||||
local SHOW_COLOR_CUE_SECOND = 0.25
|
||||
local EPISODE_LENGTH_SECONDS = 30
|
||||
local EXPLORE_LENGTH_SECONDS = 10
|
||||
local DISTRACTOR_LENGTH_SECONDS = 10
|
||||
local NUM_OBJECTS = 3
|
||||
local PROB_GOOD_OBJECT = 0.5
|
||||
local GAURANTEE_GOOD_OBJECTS = 0
|
||||
local GAURANTEE_BAD_OBJECTS = 0
|
||||
|
||||
local PROB_APPLE_IN_DISTRACTOR_MAP = 0.3
|
||||
local APPLE_REWARD = 5
|
||||
local APPLE_EXTRA_REWARD_RANGE = 0
|
||||
local DISTRACTOR_ROOM_SIZE = {11, 11}
|
||||
local APPLE_ID = 1000
|
||||
local CORRECT_REWARD = 2
|
||||
local INCORRECT_REWARD = -1
|
||||
local ROOM_SIZE = {3, 5}
|
||||
local OBJECT_SCALE = 1.62
|
||||
|
||||
local EXPLORE_MAP = "exploreMap"
|
||||
local DISTRACTOR_MAP = "distractorMap"
|
||||
local EXPLOIT_MAP = "exploitMap"
|
||||
|
||||
|
||||
local DIFFERENT_DISTRACT_ROOM_TEXTURE = false
|
||||
|
||||
-- Set texture set for all maps.
|
||||
local textureSet = texture_sets.TRON
|
||||
local secondTextureSet = texture_sets.TETRIS
|
||||
|
||||
-- Takes goal/location:i -> i
|
||||
local function nameToLocationId(name)
|
||||
return tonumber(name:match('^.+:(%d+)$'))
|
||||
end
|
||||
|
||||
-- Takes goal/location:i -> goal/pickup
|
||||
local function nameToLocationClass(name)
|
||||
return name:match('^(.+):%d+$')
|
||||
end
|
||||
|
||||
local factory = {}
|
||||
game:console('cg_drawScriptRectanglesAlways 1')
|
||||
|
||||
function factory.createLevelApi(kwargs)
|
||||
kwargs.episodeLengthSeconds = kwargs.episodeLengthSeconds or
|
||||
EPISODE_LENGTH_SECONDS
|
||||
kwargs.exploreLengthSeconds = kwargs.exploreLengthSeconds or
|
||||
EXPLORE_LENGTH_SECONDS
|
||||
if kwargs.distractorLengthSeconds == 0 then
|
||||
kwargs.skipDistractor = true
|
||||
else
|
||||
kwargs.distractorLengthSeconds = kwargs.distractorLengthSeconds or
|
||||
DISTRACTOR_LENGTH_SECONDS
|
||||
end
|
||||
kwargs.numObjects = kwargs.numObjects or NUM_OBJECTS
|
||||
kwargs.probGoodObject = kwargs.probGoodObject or PROB_GOOD_OBJECT
|
||||
kwargs.guaranteeGoodObjects = kwargs.guaranteeGoodObjects or
|
||||
GAURANTEE_GOOD_OBJECTS
|
||||
kwargs.guaranteeBadObjects = kwargs.guaranteeBadObjects or
|
||||
GAURANTEE_BAD_OBJECTS
|
||||
kwargs.correctReward = kwargs.correctReward or CORRECT_REWARD
|
||||
kwargs.incorrectReward = kwargs.incorrectReward or INCORRECT_REWARD
|
||||
kwargs.roomSize = kwargs.roomSize or ROOM_SIZE
|
||||
kwargs.distractorRoomSize = kwargs.distractorRoomSize or DISTRACTOR_ROOM_SIZE
|
||||
kwargs.probAppleInDistractorMap = kwargs.probAppleInDistractorMap or
|
||||
PROB_APPLE_IN_DISTRACTOR_MAP
|
||||
kwargs.differentDistractRoomTexture = kwargs.differentDistractRoomTexture or
|
||||
DIFFERENT_DISTRACT_ROOM_TEXTURE
|
||||
kwargs.appleReward = kwargs.appleReward or APPLE_REWARD
|
||||
kwargs.appleExtraRewardRange = kwargs.appleExtraRewardRange or
|
||||
APPLE_EXTRA_REWARD_RANGE
|
||||
kwargs.objectScale = kwargs.objectScale or OBJECT_SCALE
|
||||
|
||||
local api = {}
|
||||
|
||||
function api:init(params)
|
||||
self:_createExploreMap()
|
||||
self:_createDistractorMap()
|
||||
self:_createExploitMap()
|
||||
end
|
||||
|
||||
function api:pickup(spawnId)
|
||||
if self._map == EXPLORE_MAP then
|
||||
-- Setup to show color cue.
|
||||
self._showObjectCue = true
|
||||
self._cueColor = self._objects[spawnId].cueColor
|
||||
self._cueStartTime = self._time
|
||||
elseif self._map == EXPLOIT_MAP then
|
||||
-- Give corresponding reward and termiante when all good objects collected
|
||||
game:addScore(self._objects[spawnId].reward)
|
||||
-- Update the instruction channel (to record final phase rewards.)
|
||||
self._finalRewardMainTask = (
|
||||
self._finalRewardMainTask + self._objects[spawnId].reward)
|
||||
self.setInstruction(tostring(self._finalRewardMainTask))
|
||||
end
|
||||
|
||||
if spawnId == APPLE_ID then
|
||||
-- note the -1 to offset default 1 point for apple in dmlab
|
||||
appleReward = kwargs.appleReward +
|
||||
random:uniformInt(0, kwargs.appleExtraRewardRange) - 1
|
||||
game:addScore(appleReward)
|
||||
end
|
||||
end
|
||||
|
||||
function api:_createRoomCommon()
|
||||
local roomHeight = kwargs.roomSize[1]
|
||||
local roomWidth = kwargs.roomSize[2]
|
||||
local maze = maze_generation:mazeGeneration{
|
||||
height = roomHeight + 2,
|
||||
width = roomWidth + 2
|
||||
}
|
||||
|
||||
-- Set (2,2) as 'P' for the avatar location.
|
||||
-- Set (i,j) as 'O' for possible object location if i%2 == 0 && j%2 == 0.
|
||||
-- Otherwise, fill with '.' for empty location.
|
||||
self._numLocations = 0
|
||||
for i = 2, roomHeight + 1 do
|
||||
for j = 2, roomWidth + 1 do
|
||||
if i == 2 and j == 2 then
|
||||
maze:setEntityCell(i, j, 'P')
|
||||
elseif i % 2 == 0 and j % 2 == 0 then
|
||||
maze:setEntityCell(i, j, 'O')
|
||||
self._numLocations = self._numLocations + 1
|
||||
else
|
||||
maze:setEntityCell(i, j, '.')
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return maze
|
||||
end
|
||||
|
||||
function api:_createExploreMap()
|
||||
maze = self:_createRoomCommon()
|
||||
print('Generated explore maze with entity layer:')
|
||||
print(maze:entityLayer())
|
||||
io.flush()
|
||||
|
||||
local mapTheme = themes.fromTextureSet{
|
||||
textureSet = textureSet,
|
||||
decalFrequency = 0.0,
|
||||
}
|
||||
|
||||
local counter = 1
|
||||
self._exploreMap = make_map.makeMap{
|
||||
mapName = EXPLORE_MAP,
|
||||
mapEntityLayer = maze:entityLayer(),
|
||||
theme = mapTheme,
|
||||
callback = function (i, j, c, maker)
|
||||
if c == 'O' then
|
||||
pickup = 'location:' .. counter
|
||||
counter = counter + 1
|
||||
return maker:makeEntity{i = i, j = j, classname = pickup}
|
||||
end
|
||||
end
|
||||
}
|
||||
end
|
||||
|
||||
function api:_createDistractorMap()
|
||||
-- Create map theme with no wall decals.
|
||||
local distractorMapTheme = themes.fromTextureSet{
|
||||
textureSet = textureSet,
|
||||
decalFrequency = 0.0,
|
||||
}
|
||||
|
||||
-- Example room with height = 2, width = 3
|
||||
-- *****
|
||||
-- *APA*
|
||||
-- *AAA*
|
||||
-- *****
|
||||
local roomHeight = kwargs.distractorRoomSize[1]
|
||||
local roomWidth = kwargs.distractorRoomSize[2]
|
||||
centerWidth = 1 + math.ceil(roomWidth / 2)
|
||||
local maze = maze_generation:mazeGeneration{
|
||||
height = roomHeight + 2,
|
||||
width = roomWidth + 2
|
||||
}
|
||||
|
||||
-- Fill the room with 'A' for apples. updateSpawnVars decides which to use.
|
||||
for i = 2, roomHeight + 1 do
|
||||
for j = 2, roomWidth + 1 do
|
||||
maze:setEntityCell(i, j, 'A')
|
||||
end
|
||||
end
|
||||
-- Override one cell with 'P' for spawn point.
|
||||
maze:setEntityCell(2, centerWidth, 'P')
|
||||
|
||||
print('Generated distractor maze with entity layer:')
|
||||
print(maze:entityLayer())
|
||||
io.flush()
|
||||
|
||||
local texture = textureSet
|
||||
if kwargs.differentDistractRoomTexture then
|
||||
texture = secondTextureSet
|
||||
end
|
||||
local mapTheme = themes.fromTextureSet{
|
||||
textureSet = texture,
|
||||
decalFrequency = 0.0,
|
||||
}
|
||||
self._distractMap = make_map.makeMap{
|
||||
mapName = DISTRACTOR_MAP,
|
||||
mapEntityLayer = maze:entityLayer(),
|
||||
theme = mapTheme,
|
||||
}
|
||||
end
|
||||
|
||||
function api:_createExploitMap()
|
||||
maze = self:_createRoomCommon()
|
||||
print('Generated exploit maze with entity layer:')
|
||||
print(maze:entityLayer())
|
||||
io.flush()
|
||||
|
||||
local mapTheme = themes.fromTextureSet{
|
||||
textureSet = textureSet,
|
||||
decalFrequency = 0.0,
|
||||
}
|
||||
|
||||
local counter = 1
|
||||
self.exploitMap = make_map.makeMap{
|
||||
mapName = EXPLOIT_MAP,
|
||||
mapEntityLayer = maze:entityLayer(),
|
||||
theme = mapTheme,
|
||||
useSkybox = false,
|
||||
callback = function (i, j, c, maker)
|
||||
if c == 'O' then
|
||||
pickup = 'location:' .. counter
|
||||
counter = counter + 1
|
||||
return maker:makeEntity{i = i, j = j, classname = pickup}
|
||||
end
|
||||
end
|
||||
}
|
||||
end
|
||||
|
||||
function api:_generateRandomObjects()
|
||||
-- 1. Generate a random list of positive/negative reward, `objectValence`
|
||||
-- as function(numObjects, guaranteeGood, guaranteeBad, probGoodObject)
|
||||
|
||||
local objectValence = {}
|
||||
for i = 1, kwargs.numObjects do
|
||||
if i <= kwargs.guaranteeGoodObjects then
|
||||
objectValence[i] = 1
|
||||
elseif i<= kwargs.guaranteeGoodObjects + kwargs.guaranteeBadObjects then
|
||||
objectValence[i] = -1
|
||||
else
|
||||
if random:uniformReal(0, 1) < kwargs.probGoodObject then
|
||||
objectValence[i] = 1
|
||||
else
|
||||
objectValence[i] = -1
|
||||
end
|
||||
end
|
||||
end
|
||||
random:shuffleInPlace(objectValence)
|
||||
|
||||
-- 2. Generate random objects and link to the object valence above.
|
||||
local objects = hrp.uniquelyShapedPickups(kwargs.numObjects)
|
||||
for i = 1, kwargs.numObjects do
|
||||
objects[i].scale= kwargs.objectScale
|
||||
end
|
||||
|
||||
self._objects = {}
|
||||
for i, object in ipairs(objects) do
|
||||
self._objects[i] = {}
|
||||
self._objects[i].data = hrp.create(object)
|
||||
if objectValence[i] == 1 then
|
||||
self._objects[i].isGoodObject = true
|
||||
self._objects[i].reward = kwargs.correctReward
|
||||
self._objects[i].cueColor = {0, 1, 0, 1} -- green means good
|
||||
else
|
||||
self._objects[i].isGoodObject = false
|
||||
self._objects[i].reward = kwargs.incorrectReward
|
||||
self._objects[i].cueColor = {1, 0, 0, 1} -- red means bad
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function api:start(episode, seed)
|
||||
random:seed(seed)
|
||||
|
||||
-- Setup a random mapping from locationId to pickupId
|
||||
-- There should be more locationId than pickupId
|
||||
-- The location set with pickupId == 0 will have no object presented there.
|
||||
self._mapLocationIdToPickupId = {}
|
||||
for i = 1, self._numLocations do
|
||||
if i <= kwargs.numObjects then
|
||||
self._mapLocationIdToPickupId[i] = i
|
||||
else
|
||||
self._mapLocationIdToPickupId[i] = 0
|
||||
end
|
||||
end
|
||||
random:shuffleInPlace(self._mapLocationIdToPickupId)
|
||||
|
||||
self:_generateRandomObjects()
|
||||
self._map = nil
|
||||
self._numTrials = 0
|
||||
self._timeOut = kwargs.exploreLengthSeconds
|
||||
|
||||
-- Set the instruction channel to record the rewards in the final phase.
|
||||
self._finalRewardMainTask = 0
|
||||
self.setInstruction("0")
|
||||
end
|
||||
|
||||
function api:nextMap()
|
||||
if self._map == nil then -- Start of episode.
|
||||
self._map = EXPLORE_MAP
|
||||
elseif not kwargs.skipDistractor and self._map == EXPLORE_MAP then
|
||||
-- Move from explore to distractor.
|
||||
self._map = DISTRACTOR_MAP
|
||||
self._timeOut = self._time + kwargs.distractorLengthSeconds
|
||||
elseif (kwargs.skipDistractor and self._map == EXPLORE_MAP)
|
||||
or self._map == DISTRACTOR_MAP then
|
||||
-- Move from distractor or explore map to exploit map.
|
||||
self._map = EXPLOIT_MAP
|
||||
random:shuffleInPlace(self._mapLocationIdToPickupId)
|
||||
self._timeOut = nil
|
||||
end
|
||||
|
||||
return self._map
|
||||
end
|
||||
|
||||
function api:hasEpisodeFinished(timeSeconds)
|
||||
self._time = timeSeconds
|
||||
if self._showObjectCue then
|
||||
if self._time - self._cueStartTime > SHOW_COLOR_CUE_SECOND then
|
||||
self._showObjectCue = false
|
||||
end
|
||||
end
|
||||
|
||||
if self._map == EXPLORE_MAP or self._map == DISTRACTOR_MAP then
|
||||
if timeSeconds > self._timeOut then
|
||||
game:finishMap()
|
||||
end
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
-- END TRIGGER functions -----------------------------------------------------
|
||||
function api:filledRectangles(args)
|
||||
if self._map == EXPLORE_MAP and self._showObjectCue then
|
||||
return {{
|
||||
x = 12,
|
||||
y = 12,
|
||||
width = 600,
|
||||
height = 300,
|
||||
rgba = self._cueColor,
|
||||
}}
|
||||
end
|
||||
return {}
|
||||
end
|
||||
|
||||
function api:updateSpawnVars(spawnVars)
|
||||
local classname = spawnVars.classname
|
||||
if classname == "info_player_start" then
|
||||
-- Spawn facing South.
|
||||
spawnVars.angle = "-90"
|
||||
spawnVars.randomAngleRange = "0"
|
||||
elseif classname == "apple_reward" then
|
||||
local useApple = false
|
||||
if kwargs.probAppleInDistractorMap > 0 then
|
||||
useApple = random:uniformReal(0, 1) < kwargs.probAppleInDistractorMap
|
||||
spawnVars.id = tostring(APPLE_ID)
|
||||
end
|
||||
if not useApple then
|
||||
return nil
|
||||
end
|
||||
else
|
||||
-- Allocate objects onto the map by mapLocationIdToPickupId.
|
||||
local locationClass = nameToLocationClass(classname)
|
||||
if locationClass then
|
||||
local locationId = nameToLocationId(classname)
|
||||
id = self._mapLocationIdToPickupId[locationId]
|
||||
if id == 0 then
|
||||
return nil
|
||||
else
|
||||
spawnVars.classname = self._objects[id].data
|
||||
spawnVars.id = tostring(id)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return spawnVars
|
||||
end
|
||||
|
||||
custom_observations.decorate(api)
|
||||
pickup_decorator.decorate(api)
|
||||
setting_overrides.decorate{
|
||||
api = api,
|
||||
apiParams = kwargs,
|
||||
decorateWithTimeout = true
|
||||
}
|
||||
return api
|
||||
end
|
||||
|
||||
return factory
|
||||
@@ -0,0 +1,24 @@
|
||||
-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
-- ============================================================================
|
||||
local factory = require 'visual_match_factory'
|
||||
|
||||
return factory.createLevelApi{
|
||||
exploreMapMode = 'PASSIVE',
|
||||
episodeLengthSeconds = 40,
|
||||
exploreLengthSeconds = 5,
|
||||
distractorLengthSeconds = 30,
|
||||
|
||||
differentDistractRoomTexture = true,
|
||||
differentRewardRoomTexture = true,
|
||||
correctReward = 10,
|
||||
incorrectReward = 1,
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,19 @@
|
||||
-- Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
-- ============================================================================
|
||||
local factory = require 'two_keys_to_choose_factory'
|
||||
|
||||
return factory.createLevelApi{
|
||||
episodeLengthSeconds = 37,
|
||||
exploreLengthSeconds = 5,
|
||||
distractorLengthSeconds = 30,
|
||||
differentDistractRoomTexture = true,
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
After Width: | Height: | Size: 55 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 42 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 39 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 31 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 34 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 31 KiB |
+157
@@ -0,0 +1,157 @@
|
||||
# Lint as: python2, python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Loss functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def sum_time_average_batch(tensor, name=None):
|
||||
"""Computes the mean over B assuming tensor is of shape [T, B]."""
|
||||
tensor.get_shape().assert_has_rank(2)
|
||||
return tf.reduce_mean(tf.reduce_sum(tensor, axis=0), axis=0, name=name)
|
||||
|
||||
|
||||
def combine_logged_values(*logged_values_dicts):
|
||||
"""Combine logged values dicts. Throws if there are any repeated keys."""
|
||||
combined_dict = dict()
|
||||
for logged_values in logged_values_dicts:
|
||||
for k, v in six.iteritems(logged_values):
|
||||
if k in combined_dict:
|
||||
raise ValueError('Key "%s" is repeated in loss logging.' % k)
|
||||
combined_dict[k] = v
|
||||
return combined_dict
|
||||
|
||||
|
||||
def reconstruction_losses(
|
||||
recons,
|
||||
targets,
|
||||
image_cost,
|
||||
action_cost,
|
||||
reward_cost):
|
||||
"""Reconstruction losses."""
|
||||
if image_cost > 0.0:
|
||||
# Neg log prob of obs image given Bernoulli(recon image) distribution.
|
||||
negative_image_log_prob = tf.nn.sigmoid_cross_entropy_with_logits(
|
||||
labels=targets.image, logits=recons.image)
|
||||
nll_per_time = tf.reduce_sum(negative_image_log_prob, [-3, -2, -1])
|
||||
image_loss = image_cost * nll_per_time
|
||||
image_loss = sum_time_average_batch(image_loss)
|
||||
else:
|
||||
image_loss = tf.constant(0.)
|
||||
|
||||
if action_cost > 0.0 and recons.last_action is not tuple():
|
||||
# Labels have shape (T, B), logits have shape (T, B, num_actions).
|
||||
action_loss = action_cost * tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=targets.last_action, logits=recons.last_action)
|
||||
action_loss = sum_time_average_batch(action_loss)
|
||||
else:
|
||||
action_loss = tf.constant(0.)
|
||||
|
||||
if reward_cost > 0.0 and recons.last_reward is not tuple():
|
||||
# MSE loss for reward.
|
||||
recon_last_reward = recons.last_reward
|
||||
recon_last_reward = tf.squeeze(recon_last_reward, -1)
|
||||
reward_loss = 0.5 * reward_cost * tf.square(
|
||||
recon_last_reward - targets.last_reward)
|
||||
reward_loss = sum_time_average_batch(reward_loss)
|
||||
else:
|
||||
reward_loss = tf.constant(0.)
|
||||
|
||||
total_loss = image_loss + action_loss + reward_loss
|
||||
|
||||
logged_values = dict(
|
||||
recon_loss_image=image_loss,
|
||||
recon_loss_action=action_loss,
|
||||
recon_loss_reward=reward_loss,
|
||||
total_reconstruction_loss=total_loss,)
|
||||
|
||||
return total_loss, logged_values
|
||||
|
||||
|
||||
def read_regularization_loss(
|
||||
read_info,
|
||||
strength_cost,
|
||||
strength_tolerance,
|
||||
strength_reg_mode,
|
||||
key_norm_cost,
|
||||
key_norm_tolerance):
|
||||
"""Computes the sum of read strength and read key regularization losses."""
|
||||
|
||||
if (strength_cost <= 0.) and (key_norm_cost <= 0.):
|
||||
read_reg_loss = tf.constant(0.)
|
||||
return read_reg_loss, dict(read_regularization_loss=read_reg_loss)
|
||||
|
||||
if hasattr(read_info, 'read_strengths'):
|
||||
read_strengths = read_info.read_strengths
|
||||
read_keys = read_info.read_keys
|
||||
else:
|
||||
read_strengths = read_info.strengths
|
||||
read_keys = read_info.keys
|
||||
|
||||
if read_info == tuple():
|
||||
raise ValueError('Make sure read regularization costs are zero when '
|
||||
'not outputting read info.')
|
||||
|
||||
read_reg_loss = tf.constant(0.)
|
||||
if strength_cost > 0.:
|
||||
strength_hinged = tf.maximum(strength_tolerance, read_strengths)
|
||||
if strength_reg_mode == 'L2':
|
||||
strength_loss = 0.5 * tf.square(strength_hinged)
|
||||
elif strength_reg_mode == 'L1':
|
||||
# Read strengths are always positive.
|
||||
strength_loss = strength_hinged
|
||||
else:
|
||||
raise ValueError(
|
||||
'Strength regularization mode "{}" is not supported.'.format(
|
||||
strength_reg_mode))
|
||||
|
||||
# Sum across read heads to reduce from [T, B, n_reads] to [T, B].
|
||||
strength_loss = strength_cost * tf.reduce_sum(strength_loss, axis=2)
|
||||
|
||||
if key_norm_cost > 0.:
|
||||
key_norm_norms = tf.norm(read_keys, axis=-1)
|
||||
key_norm_norms_hinged = tf.maximum(key_norm_tolerance, key_norm_norms)
|
||||
key_norm_loss = 0.5 * tf.square(key_norm_norms_hinged)
|
||||
|
||||
# Sum across read heads to reduce from [T, B, n_reads] to [T, B].
|
||||
key_norm_loss = key_norm_cost * tf.reduce_sum(key_norm_loss, axis=2)
|
||||
|
||||
read_reg_loss += key_norm_cost * key_norm_loss
|
||||
|
||||
if strength_cost > 0.:
|
||||
strength_loss = sum_time_average_batch(strength_loss)
|
||||
else:
|
||||
strength_loss = tf.constant(0.)
|
||||
|
||||
if key_norm_cost > 0.:
|
||||
key_norm_loss = sum_time_average_batch(key_norm_loss)
|
||||
else:
|
||||
key_norm_loss = tf.constant(0.)
|
||||
|
||||
read_reg_loss = strength_loss + key_norm_loss
|
||||
|
||||
logged_values = dict(
|
||||
read_reg_strength_loss=strength_loss,
|
||||
read_reg_key_norm_loss=key_norm_loss,
|
||||
total_read_reg_loss=read_reg_loss)
|
||||
|
||||
return read_reg_loss, logged_values
|
||||
+258
@@ -0,0 +1,258 @@
|
||||
# Lint as: python2, python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Batched synchronous actor/learner training."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
from six.moves import zip
|
||||
import tensorflow as tf
|
||||
|
||||
from tvt import batch_env
|
||||
from tvt import nest_utils
|
||||
from tvt import rma
|
||||
from tvt import tvt_rewards as tvt_module
|
||||
from tvt.pycolab import env as pycolab_env
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
|
||||
nest = contrib_framework.nest
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_integer('logging_frequency', 1,
|
||||
'Log training progress every logging_frequency episodes.')
|
||||
flags.DEFINE_string('logdir', None, 'Directory for tensorboard logging.')
|
||||
|
||||
flags.DEFINE_boolean('with_memory', True,
|
||||
'whether or not agent has external memory.')
|
||||
flags.DEFINE_boolean('with_reconstruction', True,
|
||||
'whether or not agent reconstruct the observation.')
|
||||
flags.DEFINE_float('gamma', 0.92, 'Agent discount factor')
|
||||
flags.DEFINE_float('entropy_cost', 0.05, 'weight of the entropy loss')
|
||||
flags.DEFINE_float('image_cost_weight', 50., 'image recon cost weight.')
|
||||
flags.DEFINE_float('read_strength_cost', 5e-5,
|
||||
'Cost weight of the memory read strength.')
|
||||
flags.DEFINE_float('read_strength_tolerance', 2.,
|
||||
'The tolerance of hinge loss of the read_strength_cost.')
|
||||
flags.DEFINE_boolean('do_tvt', True, 'whether or not do tvt')
|
||||
flags.DEFINE_enum('pycolab_game', 'key_to_door',
|
||||
['key_to_door', 'active_visual_match'],
|
||||
'The name of the game in pycolab environment')
|
||||
flags.DEFINE_integer('num_episodes', None,
|
||||
'Number of episodes to train for. None means run forever.')
|
||||
|
||||
flags.DEFINE_integer('batch_size', 16, 'Batch size')
|
||||
|
||||
flags.DEFINE_float('learning_rate', 2e-4, 'Adam optimizer learning rate')
|
||||
flags.DEFINE_float('beta1', 0., 'Adam optimizer beta1')
|
||||
flags.DEFINE_float('beta2', 0.95, 'Adam optimizer beta2')
|
||||
flags.DEFINE_float('epsilon', 1e-6, 'Adam optimizer epsilon')
|
||||
|
||||
# Pycolab-specific flags:
|
||||
flags.DEFINE_integer('pycolab_num_apples', 10,
|
||||
'Number of apples to sample from the distractor grid.')
|
||||
flags.DEFINE_float('pycolab_apple_reward_min', 1.,
|
||||
'A reward range [min, max) to uniformly sample from.')
|
||||
flags.DEFINE_float('pycolab_apple_reward_max', 10.,
|
||||
'A reward range [min, max) to uniformly sample from.')
|
||||
flags.DEFINE_boolean('pycolab_fix_apple_reward_in_episode', True,
|
||||
'Fix the sampled apple reward within an episode.')
|
||||
flags.DEFINE_float('pycolab_final_reward', 10.,
|
||||
'Reward obtained at the last phase.')
|
||||
flags.DEFINE_boolean('pycolab_crop', True,
|
||||
'Whether to crop observations or not.')
|
||||
|
||||
|
||||
def main(_):
|
||||
|
||||
batch_size = FLAGS.batch_size
|
||||
env_builder = pycolab_env.PycolabEnvironment
|
||||
env_kwargs = {
|
||||
'game': FLAGS.pycolab_game,
|
||||
'num_apples': FLAGS.pycolab_num_apples,
|
||||
'apple_reward': [FLAGS.pycolab_apple_reward_min,
|
||||
FLAGS.pycolab_apple_reward_max],
|
||||
'fix_apple_reward_in_episode': FLAGS.pycolab_fix_apple_reward_in_episode,
|
||||
'final_reward': FLAGS.pycolab_final_reward,
|
||||
'crop': FLAGS.pycolab_crop
|
||||
}
|
||||
env = batch_env.BatchEnv(batch_size, env_builder, **env_kwargs)
|
||||
ep_length = env.episode_length
|
||||
|
||||
agent = rma.Agent(batch_size=batch_size,
|
||||
num_actions=env.num_actions,
|
||||
observation_shape=env.observation_shape,
|
||||
with_reconstructions=FLAGS.with_reconstruction,
|
||||
gamma=FLAGS.gamma,
|
||||
read_strength_cost=FLAGS.read_strength_cost,
|
||||
read_strength_tolerance=FLAGS.read_strength_tolerance,
|
||||
entropy_cost=FLAGS.entropy_cost,
|
||||
with_memory=FLAGS.with_memory,
|
||||
image_cost_weight=FLAGS.image_cost_weight)
|
||||
|
||||
# Agent step placeholders and agent step.
|
||||
batch_shape = (batch_size,)
|
||||
observation_ph = tf.placeholder(
|
||||
dtype=tf.uint8, shape=batch_shape + env.observation_shape, name='obs')
|
||||
reward_ph = tf.placeholder(
|
||||
dtype=tf.float32, shape=batch_shape, name='reward')
|
||||
state_ph = nest.map_structure(
|
||||
lambda s: tf.placeholder(dtype=s.dtype, shape=s.shape, name='state'),
|
||||
agent.initial_state(batch_size=batch_size))
|
||||
step_outputs, state = agent.step(reward_ph, observation_ph, state_ph)
|
||||
|
||||
# Update op placeholders and update op.
|
||||
observations_ph = tf.placeholder(
|
||||
dtype=tf.uint8, shape=(ep_length + 1, batch_size) + env.observation_shape,
|
||||
name='observations')
|
||||
rewards_ph = tf.placeholder(
|
||||
dtype=tf.float32, shape=(ep_length + 1, batch_size), name='rewards')
|
||||
actions_ph = tf.placeholder(
|
||||
dtype=tf.int64, shape=(ep_length, batch_size), name='actions')
|
||||
tvt_rewards_ph = tf.placeholder(
|
||||
dtype=tf.float32, shape=(ep_length, batch_size), name='tvt_rewards')
|
||||
|
||||
loss, loss_logs = agent.loss(
|
||||
observations_ph, rewards_ph, actions_ph, tvt_rewards_ph)
|
||||
|
||||
optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=FLAGS.learning_rate,
|
||||
beta1=FLAGS.beta1,
|
||||
beta2=FLAGS.beta2,
|
||||
epsilon=FLAGS.epsilon)
|
||||
update_op = optimizer.minimize(loss)
|
||||
initial_state = agent.initial_state(batch_size)
|
||||
|
||||
if FLAGS.logdir:
|
||||
if not tf.io.gfile.exists(FLAGS.logdir):
|
||||
tf.io.gfile.makedirs(FLAGS.logdir)
|
||||
summary_writer = tf.summary.FileWriter(FLAGS.logdir)
|
||||
|
||||
# Do init
|
||||
init_ops = (tf.global_variables_initializer(),
|
||||
tf.local_variables_initializer())
|
||||
tf.get_default_graph().finalize()
|
||||
|
||||
sess = tf.Session()
|
||||
sess.run(init_ops)
|
||||
|
||||
run = True
|
||||
ep_num = 0
|
||||
prev_logging_time = time.time()
|
||||
while run:
|
||||
observation, reward = env.reset()
|
||||
agent_state = sess.run(initial_state)
|
||||
|
||||
# Initialise episode data stores.
|
||||
observations = [observation]
|
||||
rewards = [reward]
|
||||
actions = []
|
||||
baselines = []
|
||||
read_infos = []
|
||||
|
||||
for _ in range(ep_length):
|
||||
step_feed = {reward_ph: reward, observation_ph: observation}
|
||||
for ph, ar in zip(nest.flatten(state_ph), nest.flatten(agent_state)):
|
||||
step_feed[ph] = ar
|
||||
step_output, agent_state = sess.run(
|
||||
(step_outputs, state), feed_dict=step_feed)
|
||||
action = step_output.action
|
||||
baseline = step_output.baseline
|
||||
read_info = step_output.read_info
|
||||
|
||||
# Take step in environment, append results.
|
||||
observation, reward = env.step(action)
|
||||
|
||||
observations.append(observation)
|
||||
rewards.append(reward)
|
||||
actions.append(action)
|
||||
baselines.append(baseline)
|
||||
if read_info is not None:
|
||||
read_infos.append(read_info)
|
||||
|
||||
# Stack the lists of length ep_length so that each array (or each element
|
||||
# of nest stucture for read_infos) has shape (ep_length, batch_size, ...).
|
||||
observations = np.stack(observations)
|
||||
rewards = np.array(rewards)
|
||||
actions = np.array(actions)
|
||||
baselines = np.array(baselines)
|
||||
read_infos = nest_utils.nest_stack(read_infos)
|
||||
|
||||
# Compute TVT rewards.
|
||||
if FLAGS.do_tvt:
|
||||
tvt_rewards = tvt_module.compute_tvt_rewards(read_infos,
|
||||
baselines,
|
||||
gamma=FLAGS.gamma)
|
||||
else:
|
||||
tvt_rewards = np.squeeze(np.zeros_like(baselines))
|
||||
|
||||
# Run update op.
|
||||
loss_feed = {observations_ph: observations,
|
||||
rewards_ph: rewards,
|
||||
actions_ph: actions,
|
||||
tvt_rewards_ph: tvt_rewards}
|
||||
ep_loss, _, ep_loss_logs = sess.run([loss, update_op, loss_logs],
|
||||
feed_dict=loss_feed)
|
||||
|
||||
# Log episode results.
|
||||
if ep_num % FLAGS.logging_frequency == 0:
|
||||
steps_per_second = (
|
||||
FLAGS.logging_frequency * ep_length * batch_size / (
|
||||
time.time() - prev_logging_time))
|
||||
mean_reward = np.mean(np.sum(rewards, axis=0))
|
||||
mean_last_phase_reward = np.mean(env.last_phase_rewards())
|
||||
mean_tvt_reward = np.mean(np.sum(tvt_rewards, axis=0))
|
||||
|
||||
logging.info('Episode %d. SPS: %s', ep_num, steps_per_second)
|
||||
logging.info('Episode %d. Mean episode reward: %f', ep_num, mean_reward)
|
||||
logging.info('Episode %d. Last phase reward: %f', ep_num,
|
||||
mean_last_phase_reward)
|
||||
logging.info('Episode %d. Mean TVT episode reward: %f', ep_num,
|
||||
mean_tvt_reward)
|
||||
logging.info('Episode %d. Loss: %s', ep_num, ep_loss)
|
||||
logging.info('Episode %d. Loss logs: %s', ep_num, ep_loss_logs)
|
||||
|
||||
if FLAGS.logdir:
|
||||
summary = tf.Summary()
|
||||
summary.value.add(tag='reward', simple_value=mean_reward)
|
||||
summary.value.add(tag='last phase reward',
|
||||
simple_value=mean_last_phase_reward)
|
||||
summary.value.add(tag='tvt reward', simple_value=mean_tvt_reward)
|
||||
summary.value.add(tag='total loss', simple_value=ep_loss)
|
||||
for k, v in ep_loss_logs.items():
|
||||
summary.value.add(tag='loss - {}'.format(k), simple_value=v)
|
||||
# Tensorboard x-axis is total number of episodes run.
|
||||
summary_writer.add_summary(summary, ep_num * batch_size)
|
||||
summary_writer.flush()
|
||||
|
||||
prev_logging_time = time.time()
|
||||
|
||||
ep_num += 1
|
||||
if FLAGS.num_episodes and ep_num >= FLAGS.num_episodes:
|
||||
run = False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
+294
@@ -0,0 +1,294 @@
|
||||
# Lint as: python2, python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Memory Reader/Writer for RMA."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
|
||||
ReadInformation = collections.namedtuple(
|
||||
'ReadInformation', ('weights', 'indices', 'keys', 'strengths'))
|
||||
|
||||
|
||||
class MemoryWriter(snt.RNNCore):
|
||||
"""Memory Writer Module."""
|
||||
|
||||
def __init__(self, mem_shape, name='memory_writer'):
|
||||
"""Initializes the `MemoryWriter`.
|
||||
|
||||
Args:
|
||||
mem_shape: The shape of the memory `(num_rows, memory_width)`.
|
||||
name: The name to use for the Sonnet module.
|
||||
"""
|
||||
super(MemoryWriter, self).__init__(name=name)
|
||||
self._mem_shape = mem_shape
|
||||
|
||||
def _build(self, inputs, state):
|
||||
"""Inserts z into the argmin row of usage markers and updates all rows.
|
||||
|
||||
Returns an operation that, when executed, correctly updates the internal
|
||||
state and usage markers.
|
||||
|
||||
Args:
|
||||
inputs: A tuple consisting of:
|
||||
* z, the value to write at this timestep
|
||||
* mem_state, the state of the memory at this timestep before writing
|
||||
state: The state is just the write_counter.
|
||||
|
||||
Returns:
|
||||
A tuple of the new memory state and a tuple containing the next state.
|
||||
"""
|
||||
z, mem_state = inputs
|
||||
|
||||
# Stop gradient on writes to memory.
|
||||
z = tf.stop_gradient(z)
|
||||
|
||||
prev_write_counter = state
|
||||
new_row_value = z
|
||||
|
||||
# Find the index to insert the next row into.
|
||||
num_mem_rows = self._mem_shape[0]
|
||||
write_index = tf.cast(prev_write_counter, dtype=tf.int32) % num_mem_rows
|
||||
one_hot_row = tf.one_hot(write_index, num_mem_rows)
|
||||
write_counter = prev_write_counter + 1
|
||||
|
||||
# Insert state variable to new row.
|
||||
# First you need to size it up to the full size.
|
||||
insert_new_row = lambda mem, o_hot, z: mem - (o_hot * mem) + (o_hot * z)
|
||||
new_mem = insert_new_row(mem_state,
|
||||
tf.expand_dims(one_hot_row, axis=-1),
|
||||
tf.expand_dims(new_row_value, axis=-2))
|
||||
|
||||
new_state = write_counter
|
||||
|
||||
return new_mem, new_state
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
"""Returns a description of the state size, without batch dimension."""
|
||||
return tf.TensorShape([])
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
"""Returns a description of the output size, without batch dimension."""
|
||||
return self._mem_shape
|
||||
|
||||
|
||||
class MemoryReader(snt.AbstractModule):
|
||||
"""Memory Reader Module."""
|
||||
|
||||
def __init__(self,
|
||||
memory_word_size,
|
||||
num_read_heads,
|
||||
top_k=0,
|
||||
memory_size=None,
|
||||
name='memory_reader'):
|
||||
"""Initializes the `MemoryReader`.
|
||||
|
||||
Args:
|
||||
memory_word_size: The dimension of the 1-D read keys this memory reader
|
||||
should produce. Each row of the memory is of length `memory_word_size`.
|
||||
num_read_heads: The number of reads to perform.
|
||||
top_k: Softmax and summation when reading is only over top k most similar
|
||||
entries in memory. top_k=0 (default) means dense reads, i.e. no top_k.
|
||||
memory_size: Number of rows in memory.
|
||||
name: The name for this Sonnet module.
|
||||
"""
|
||||
super(MemoryReader, self).__init__(name=name)
|
||||
self._memory_word_size = memory_word_size
|
||||
self._num_read_heads = num_read_heads
|
||||
self._top_k = top_k
|
||||
|
||||
# This is not an RNNCore but it is useful to expose the output size.
|
||||
self._output_size = num_read_heads * memory_word_size
|
||||
|
||||
num_read_weights = top_k if top_k > 0 else memory_size
|
||||
self._read_info_size = ReadInformation(
|
||||
weights=tf.TensorShape([num_read_heads, num_read_weights]),
|
||||
indices=tf.TensorShape([num_read_heads, num_read_weights]),
|
||||
keys=tf.TensorShape([num_read_heads, memory_word_size]),
|
||||
strengths=tf.TensorShape([num_read_heads]),
|
||||
)
|
||||
|
||||
with self._enter_variable_scope():
|
||||
# Transforms to value-based read for each read head.
|
||||
output_dim = (memory_word_size + 1) * num_read_heads
|
||||
self._keys_and_read_strengths_generator = snt.Linear(output_dim)
|
||||
|
||||
def _build(self, inputs):
|
||||
"""Looks up rows in memory.
|
||||
|
||||
In the args list, we have the following conventions:
|
||||
B: batch size
|
||||
M: number of slots in a row of the memory matrix
|
||||
R: number of rows in the memory matrix
|
||||
H: number of read heads in the memory controller
|
||||
|
||||
Args:
|
||||
inputs: A tuple of
|
||||
* read_inputs, a tensor of shape [B, ...] that will be flattened and
|
||||
passed through a linear layer to get read keys/read_strengths for
|
||||
each head.
|
||||
* mem_state, the primary memory tensor. Of shape [B, R, M].
|
||||
|
||||
Returns:
|
||||
The read from the memory (concatenated across read heads) and read
|
||||
information.
|
||||
"""
|
||||
# Assert input shapes are compatible and separate inputs.
|
||||
_assert_compatible_memory_reader_input(inputs)
|
||||
read_inputs, mem_state = inputs
|
||||
|
||||
# Determine the read weightings for each key.
|
||||
flat_outputs = self._keys_and_read_strengths_generator(
|
||||
snt.BatchFlatten()(read_inputs))
|
||||
|
||||
# Separate the read_strengths from the rest of the weightings.
|
||||
h = self._num_read_heads
|
||||
flat_keys = flat_outputs[:, :-h]
|
||||
read_strengths = tf.nn.softplus(flat_outputs[:, -h:])
|
||||
|
||||
# Reshape the weights.
|
||||
read_shape = (self._num_read_heads, self._memory_word_size)
|
||||
read_keys = snt.BatchReshape(read_shape)(flat_keys)
|
||||
|
||||
# Read from memory.
|
||||
memory_reads, read_weights, read_indices, read_strengths = (
|
||||
read_from_memory(read_keys, read_strengths, mem_state, self._top_k))
|
||||
concatenated_reads = snt.BatchFlatten()(memory_reads)
|
||||
|
||||
return concatenated_reads, ReadInformation(
|
||||
weights=read_weights,
|
||||
indices=read_indices,
|
||||
keys=read_keys,
|
||||
strengths=read_strengths)
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
"""Returns a description of the output size, without batch dimension."""
|
||||
return self._output_size, self._read_info_size
|
||||
|
||||
|
||||
def read_from_memory(read_keys, read_strengths, mem_state, top_k):
|
||||
"""Function for cosine similarity content based reading from memory matrix.
|
||||
|
||||
In the args list, we have the following conventions:
|
||||
B: batch size
|
||||
M: number of slots in a row of the memory matrix
|
||||
R: number of rows in the memory matrix
|
||||
H: number of read heads (of the controller or the policy)
|
||||
K: top_k if top_k>0
|
||||
|
||||
Args:
|
||||
read_keys: the read keys of shape [B, H, M].
|
||||
read_strengths: the coefficients used to compute the normalised weighting
|
||||
vector of shape [B, H].
|
||||
mem_state: the primary memory tensor. Of shape [B, R, M].
|
||||
top_k: only use top k read matches, other reads do not go into softmax and
|
||||
are zeroed out in the output. top_k=0 (default) means use dense reads.
|
||||
|
||||
Returns:
|
||||
The memory reads [B, H, M], read weights [B, H, top k], read indices
|
||||
[B, H, top k], and read strengths [B, H, 1].
|
||||
"""
|
||||
_assert_compatible_read_from_memory_inputs(read_keys, read_strengths,
|
||||
mem_state)
|
||||
batch_size = read_keys.shape[0]
|
||||
num_read_heads = read_keys.shape[1]
|
||||
|
||||
with tf.name_scope('memory_reading'):
|
||||
# Scale such that all rows are L2-unit vectors, for memory and read query.
|
||||
scaled_read_keys = tf.math.l2_normalize(read_keys, axis=-1) # [B, H, M]
|
||||
scaled_mem = tf.math.l2_normalize(mem_state, axis=-1) # [B, R, M]
|
||||
|
||||
# The cosine distance is then their dot product.
|
||||
# Find the cosine distance between each read head and each row of memory.
|
||||
cosine_distances = tf.matmul(
|
||||
scaled_read_keys, scaled_mem, transpose_b=True) # [B, H, R]
|
||||
|
||||
# The rank must match cosine_distances for broadcasting to work.
|
||||
read_strengths = tf.expand_dims(read_strengths, axis=-1) # [B, H, 1]
|
||||
weighted_distances = read_strengths * cosine_distances # [B, H, R]
|
||||
|
||||
if top_k:
|
||||
# Get top k indices (row indices with top k largest weighted distances).
|
||||
top_k_output = tf.nn.top_k(weighted_distances, top_k, sorted=False)
|
||||
read_indices = top_k_output.indices # [B, H, K]
|
||||
|
||||
# Create a sub-memory for each read head with only the top k rows.
|
||||
# Each batch_gather is [B, K, M] and the list stacks to [B, H, K, M].
|
||||
topk_mem_per_head = [tf.batch_gather(mem_state, ri_this_head)
|
||||
for ri_this_head in tf.unstack(read_indices, axis=1)]
|
||||
topk_mem = tf.stack(topk_mem_per_head, axis=1) # [B, H, K, M]
|
||||
topk_scaled_mem = tf.math.l2_normalize(topk_mem, axis=-1) # [B, H, K, M]
|
||||
|
||||
# Calculate read weights for each head's top k sub-memory.
|
||||
expanded_scaled_read_keys = tf.expand_dims(
|
||||
scaled_read_keys, axis=2) # [B, H, 1, M]
|
||||
topk_cosine_distances = tf.reduce_sum(
|
||||
expanded_scaled_read_keys * topk_scaled_mem, axis=-1) # [B, H, K]
|
||||
topk_weighted_distances = (
|
||||
read_strengths * topk_cosine_distances) # [B, H, K]
|
||||
read_weights = tf.nn.softmax(
|
||||
topk_weighted_distances, axis=-1) # [B, H, K]
|
||||
|
||||
# For each head, read using the sub-memories and corresponding weights.
|
||||
expanded_weights = tf.expand_dims(read_weights, axis=-1) # [B, H, K, 1]
|
||||
memory_reads = tf.reduce_sum(
|
||||
expanded_weights * topk_mem, axis=2) # [B, H, M]
|
||||
else:
|
||||
read_weights = tf.nn.softmax(weighted_distances, axis=-1)
|
||||
|
||||
num_rows_memory = mem_state.shape[1]
|
||||
all_indices = tf.range(num_rows_memory, dtype=tf.int32)
|
||||
all_indices = tf.reshape(all_indices, [1, 1, num_rows_memory])
|
||||
read_indices = tf.tile(all_indices, [batch_size, num_read_heads, 1])
|
||||
|
||||
# This is the actual memory access.
|
||||
# Note that matmul automatically batch applies for us.
|
||||
memory_reads = tf.matmul(read_weights, mem_state)
|
||||
|
||||
read_keys.shape.assert_is_compatible_with(memory_reads.shape)
|
||||
|
||||
read_strengths = tf.squeeze(read_strengths, axis=-1) # [B, H, 1] -> [B, H]
|
||||
|
||||
return memory_reads, read_weights, read_indices, read_strengths
|
||||
|
||||
|
||||
def _assert_compatible_read_from_memory_inputs(read_keys, read_strengths,
|
||||
mem_state):
|
||||
read_keys.shape.assert_has_rank(3)
|
||||
b_shape, h_shape, m_shape = read_keys.shape
|
||||
mem_state.shape.assert_has_rank(3)
|
||||
r_shape = mem_state.shape[1]
|
||||
|
||||
read_strengths.shape.assert_is_compatible_with(
|
||||
tf.TensorShape([b_shape, h_shape]))
|
||||
mem_state.shape.assert_is_compatible_with(
|
||||
tf.TensorShape([b_shape, r_shape, m_shape]))
|
||||
|
||||
|
||||
def _assert_compatible_memory_reader_input(input_tensors):
|
||||
"""Asserts MemoryReader's _build has been given the correct shapes."""
|
||||
assert len(input_tensors) == 2
|
||||
_, mem_state = input_tensors
|
||||
mem_state.shape.assert_has_rank(3)
|
||||
@@ -0,0 +1,85 @@
|
||||
# Lint as: python2, python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""nest utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
|
||||
nest = contrib_framework.nest
|
||||
|
||||
|
||||
def _nest_apply_over_list(list_of_nests, fn):
|
||||
"""Equivalent to fn, but works on list-of-nests.
|
||||
|
||||
Transforms a list-of-nests to a nest-of-lists, then applies `fn`
|
||||
to each of the inner lists.
|
||||
|
||||
It is assumed that all nests have the same structure. Elements of the nest may
|
||||
be None, in which case they are ignored, i.e. they do not form part of the
|
||||
stack. This is useful when stacking agent states where parts of the state nest
|
||||
have been filtered.
|
||||
|
||||
Args:
|
||||
list_of_nests: A Python list of nests.
|
||||
fn: the function applied on the list of leaves.
|
||||
|
||||
Returns:
|
||||
A nest-of-arrays, where the arrays are formed by `fn`ing a list.
|
||||
"""
|
||||
list_of_flat_nests = [nest.flatten(n) for n in list_of_nests]
|
||||
flat_nest_of_stacks = []
|
||||
for position in range(len(list_of_flat_nests[0])):
|
||||
new_list = [flat_nest[position] for flat_nest in list_of_flat_nests]
|
||||
new_list = [x for x in new_list if x is not None]
|
||||
flat_nest_of_stacks.append(fn(new_list))
|
||||
return nest.pack_sequence_as(
|
||||
structure=list_of_nests[0], flat_sequence=flat_nest_of_stacks)
|
||||
|
||||
|
||||
def _take_indices(inputs, indices):
|
||||
return nest.map_structure(lambda t: np.take(t, indices, axis=0), inputs)
|
||||
|
||||
|
||||
def nest_stack(list_of_nests, axis=0):
|
||||
"""Equivalent to np.stack, but works on list-of-nests.
|
||||
|
||||
Transforms a list-of-nests to a nest-of-lists, then applies `np.stack`
|
||||
to each of the inner lists.
|
||||
|
||||
It is assumed that all nests have the same structure. Elements of the nest may
|
||||
be None, in which case they are ignored, i.e. they do not form part of the
|
||||
stack. This is useful when stacking agent states where parts of the state nest
|
||||
have been filtered.
|
||||
|
||||
Args:
|
||||
list_of_nests: A Python list of nests.
|
||||
axis: Optional, the `axis` argument for `np.stack`.
|
||||
|
||||
Returns:
|
||||
A nest-of-arrays, where the arrays are formed by `np.stack`ing a list.
|
||||
"""
|
||||
return _nest_apply_over_list(list_of_nests, lambda l: np.stack(l, axis=axis))
|
||||
|
||||
|
||||
def nest_unstack(batched_inputs, batch_size):
|
||||
"""Splits a sequence of numpy arrays along 0th dimension."""
|
||||
return [_take_indices(batched_inputs, idx) for idx in range(batch_size)]
|
||||
@@ -0,0 +1,31 @@
|
||||
# Pycolab Tasks
|
||||
|
||||
## Playing the Pycolab Tasks
|
||||
|
||||
We provide a script to allow human play of the Pycolab tasks. To play, run e.g.
|
||||
|
||||
`python3 pycolab/human_player.py -- --game=key_to_door`
|
||||
|
||||
## The Pycolab Tasks
|
||||
|
||||
There are 2 [Pycolab](https://github.com/deepmind/pycolab) tasks presented here.
|
||||
Each level is composed of 3 distinct phases. The first phase is the 'explore'
|
||||
phase, where the agent should learn a piece of information or do something. For
|
||||
both tasks, the 2nd phase is the 'distractor' phase, where the agent collects
|
||||
apples for rewards. The 3rd phase is the 'exploit' phase, where the agent gets
|
||||
rewards based on the knowledge acquired or actions performed in phase 1.
|
||||
|
||||
Special thanks to Hamza Merzic for writing these task scripts.
|
||||
|
||||
### Active Visual Match
|
||||
|
||||
* Phase 1: A colour square randomly placed in a two-connected room.
|
||||
* Phase 2: Apples collection.
|
||||
* Phase 3: Choose the colour square matched that in Phase 1 among 4 options.
|
||||
|
||||
### Key To Door
|
||||
|
||||
* Phase 1: A key randomly placed in a two-connected room.
|
||||
* Phase 2: Apples collection.
|
||||
* Phase 3: A small room with a door. If agent has key, it can open the door to
|
||||
get to the goal behind the door to get reward.
|
||||
@@ -0,0 +1,162 @@
|
||||
# Lint as: python2, python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Active visual match task.
|
||||
|
||||
The game is split up into three phases:
|
||||
1. (exploration phase) player is in one room and there's a colour in the other,
|
||||
2. (distractor phase) player is collecting apples,
|
||||
3. (reward phase) player sees three doors of different colours and has to select
|
||||
the one of the same color as the colour in the first phase.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from pycolab import ascii_art
|
||||
from pycolab import storytelling
|
||||
|
||||
from tvt.pycolab import common
|
||||
from tvt.pycolab import game
|
||||
from tvt.pycolab import objects
|
||||
|
||||
|
||||
SYMBOLS_TO_SHUFFLE = ['b', 'c', 'e']
|
||||
|
||||
EXPLORE_GRID = [
|
||||
' ppppppp ',
|
||||
' p p ',
|
||||
' p p ',
|
||||
' pp pp ',
|
||||
' p+++++p ',
|
||||
' p+++++p ',
|
||||
' ppppppp '
|
||||
]
|
||||
|
||||
REWARD_GRID = [
|
||||
'###########',
|
||||
'# b c e #',
|
||||
'# #',
|
||||
'# #',
|
||||
'#### ####',
|
||||
' # + # ',
|
||||
' ##### '
|
||||
]
|
||||
|
||||
|
||||
class Game(game.AbstractGame):
|
||||
"""Image Match Passive Game."""
|
||||
|
||||
def __init__(self,
|
||||
rng,
|
||||
num_apples=10,
|
||||
apple_reward=(1, 10),
|
||||
fix_apple_reward_in_episode=True,
|
||||
final_reward=10.,
|
||||
max_frames=common.DEFAULT_MAX_FRAMES_PER_PHASE):
|
||||
self._rng = rng
|
||||
self._num_apples = num_apples
|
||||
self._apple_reward = apple_reward
|
||||
self._fix_apple_reward_in_episode = fix_apple_reward_in_episode
|
||||
self._final_reward = final_reward
|
||||
self._max_frames = max_frames
|
||||
self._episode_length = sum(self._max_frames.values())
|
||||
self._num_actions = common.NUM_ACTIONS
|
||||
self._colours = common.FIXED_COLOURS.copy()
|
||||
self._colours.update(
|
||||
common.get_shuffled_symbol_colour_map(rng, SYMBOLS_TO_SHUFFLE))
|
||||
|
||||
self._extra_observation_fields = ['chapter_reward_as_string']
|
||||
|
||||
@property
|
||||
def extra_observation_fields(self):
|
||||
"""The field names of extra observations."""
|
||||
return self._extra_observation_fields
|
||||
|
||||
@property
|
||||
def num_actions(self):
|
||||
"""Number of possible actions in the game."""
|
||||
return self._num_actions
|
||||
|
||||
@property
|
||||
def episode_length(self):
|
||||
return self._episode_length
|
||||
|
||||
@property
|
||||
def colours(self):
|
||||
"""Symbol to colour map for key to door."""
|
||||
return self._colours
|
||||
|
||||
def _make_explore_phase(self, target_char):
|
||||
# Keep only one coloured position and one player position.
|
||||
grid = common.keep_n_characters_in_grid(EXPLORE_GRID, 'p', 1, common.BORDER)
|
||||
grid = common.keep_n_characters_in_grid(grid, 'p', 0, target_char)
|
||||
grid = common.keep_n_characters_in_grid(grid, common.PLAYER, 1)
|
||||
|
||||
return ascii_art.ascii_art_to_game(
|
||||
grid,
|
||||
what_lies_beneath=' ',
|
||||
sprites={
|
||||
common.PLAYER:
|
||||
ascii_art.Partial(
|
||||
common.PlayerSprite,
|
||||
impassable=common.BORDER + target_char),
|
||||
target_char:
|
||||
objects.ObjectSprite,
|
||||
common.TIMER:
|
||||
ascii_art.Partial(common.TimerSprite,
|
||||
self._max_frames['explore']),
|
||||
},
|
||||
update_schedule=[common.PLAYER, target_char, common.TIMER],
|
||||
z_order=[target_char, common.PLAYER, common.TIMER],
|
||||
)
|
||||
|
||||
def _make_distractor_phase(self):
|
||||
return common.distractor_phase(
|
||||
player_sprite=common.PlayerSprite,
|
||||
num_apples=self._num_apples,
|
||||
max_frames=self._max_frames['distractor'],
|
||||
apple_reward=self._apple_reward,
|
||||
fix_apple_reward_in_episode=self._fix_apple_reward_in_episode)
|
||||
|
||||
def _make_reward_phase(self, target_char):
|
||||
return ascii_art.ascii_art_to_game(
|
||||
REWARD_GRID,
|
||||
what_lies_beneath=' ',
|
||||
sprites={
|
||||
common.PLAYER: common.PlayerSprite,
|
||||
'b': objects.ObjectSprite,
|
||||
'c': objects.ObjectSprite,
|
||||
'e': objects.ObjectSprite,
|
||||
common.TIMER: ascii_art.Partial(common.TimerSprite,
|
||||
self._max_frames['reward'],
|
||||
track_chapter_reward=True),
|
||||
target_char: ascii_art.Partial(objects.ObjectSprite,
|
||||
reward=self._final_reward),
|
||||
},
|
||||
update_schedule=[common.PLAYER, 'b', 'c', 'e', common.TIMER],
|
||||
z_order=[common.PLAYER, 'b', 'c', 'e', common.TIMER],
|
||||
)
|
||||
|
||||
def make_episode(self):
|
||||
"""Factory method for generating new episodes of the game."""
|
||||
target_char = self._rng.choice(SYMBOLS_TO_SHUFFLE)
|
||||
return storytelling.Story([
|
||||
lambda: self._make_explore_phase(target_char),
|
||||
self._make_distractor_phase,
|
||||
lambda: self._make_reward_phase(target_char),
|
||||
], croppers=common.get_cropper())
|
||||
@@ -0,0 +1,325 @@
|
||||
# Lint as: python2, python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Common utilities for Pycolab games."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import colorsys
|
||||
import numpy as np
|
||||
from pycolab import ascii_art
|
||||
from pycolab import cropping
|
||||
from pycolab import things as plab_things
|
||||
from pycolab.prefab_parts import sprites as prefab_sprites
|
||||
from six.moves import zip
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
|
||||
nest = contrib_framework.nest
|
||||
|
||||
# Actions.
|
||||
# Those with a negative ID are not allowed for the agent.
|
||||
ACTION_QUIT = -2
|
||||
ACTION_DELAY = -1
|
||||
ACTION_NORTH = 0
|
||||
ACTION_SOUTH = 1
|
||||
ACTION_WEST = 2
|
||||
ACTION_EAST = 3
|
||||
|
||||
NUM_ACTIONS = 4
|
||||
DEFAULT_MAX_FRAMES_PER_PHASE = {
|
||||
'explore': 15,
|
||||
'distractor': 90,
|
||||
'reward': 15
|
||||
}
|
||||
|
||||
# Reserved symbols.
|
||||
PLAYER = '+'
|
||||
BORDER = '#'
|
||||
BACKGROUND = ' '
|
||||
KEY = 'k'
|
||||
DOOR = 'd'
|
||||
APPLE = 'a'
|
||||
TIMER = 't'
|
||||
INDICATOR = 'i'
|
||||
|
||||
FIXED_COLOURS = {
|
||||
PLAYER: (898, 584, 430),
|
||||
BORDER: (100, 100, 100),
|
||||
BACKGROUND: (800, 800, 800),
|
||||
KEY: (627, 321, 176),
|
||||
DOOR: (529, 808, 922),
|
||||
APPLE: (550, 700, 0),
|
||||
}
|
||||
|
||||
APPLE_DISTRACTOR_GRID = [
|
||||
'###########',
|
||||
'#a a a a a#',
|
||||
'# a a a a #',
|
||||
'#a a a a a#',
|
||||
'# a a a a #',
|
||||
'#a a + a a#',
|
||||
'###########'
|
||||
]
|
||||
DEFAULT_APPLE_RESPAWN_TIME = 20
|
||||
DEFAULT_APPLE_REWARD = 1.
|
||||
|
||||
|
||||
def get_shuffled_symbol_colour_map(rng_or_seed, symbols,
|
||||
num_potential_colours=None):
|
||||
"""Get a randomized mapping between symbols and colours.
|
||||
|
||||
Args:
|
||||
rng_or_seed: A random state or random seed.
|
||||
symbols: List of symbols.
|
||||
num_potential_colours: Number of equally spaced colours to choose from.
|
||||
Defaults to number of symbols. Colours are generated deterministically.
|
||||
|
||||
Returns:
|
||||
Randomized mapping between symbols and colours.
|
||||
"""
|
||||
num_symbols = len(symbols)
|
||||
num_potential_colours = num_potential_colours or num_symbols
|
||||
if isinstance(rng_or_seed, np.random.RandomState):
|
||||
rng = rng_or_seed
|
||||
else:
|
||||
rng = np.random.RandomState(rng_or_seed)
|
||||
|
||||
# Generate a range of colours.
|
||||
step = 1. / num_potential_colours
|
||||
hues = np.arange(0, num_potential_colours) * step
|
||||
potential_colours = [colorsys.hsv_to_rgb(h, 1.0, 1.0) for h in hues]
|
||||
|
||||
# Randomly draw num_symbols colours without replacement.
|
||||
rng.shuffle(potential_colours)
|
||||
colours = potential_colours[:num_symbols]
|
||||
|
||||
symbol_to_colour_map = dict(list(zip(symbols, colours)))
|
||||
|
||||
# Multiply each colour value by 1000.
|
||||
return nest.map_structure(lambda c: int(c * 1000), symbol_to_colour_map)
|
||||
|
||||
|
||||
def get_cropper():
|
||||
return cropping.ScrollingCropper(
|
||||
rows=5,
|
||||
cols=5,
|
||||
to_track=PLAYER,
|
||||
pad_char=BACKGROUND,
|
||||
scroll_margins=(2, 2))
|
||||
|
||||
|
||||
def distractor_phase(player_sprite, num_apples, max_frames,
|
||||
apple_reward=DEFAULT_APPLE_REWARD,
|
||||
fix_apple_reward_in_episode=False,
|
||||
respawn_every=DEFAULT_APPLE_RESPAWN_TIME):
|
||||
"""Distractor phase engine factory.
|
||||
|
||||
Args:
|
||||
player_sprite: Player sprite class.
|
||||
num_apples: Number of apples to sample from the apple distractor grid.
|
||||
max_frames: Maximum duration of the distractor phase in frames.
|
||||
apple_reward: Can either be a scalar specifying the reward or a reward range
|
||||
[min, max), given as a list or tuple, to uniformly sample from.
|
||||
fix_apple_reward_in_episode: The apple reward is constant throughout each
|
||||
episode.
|
||||
respawn_every: respawn frequency of apples.
|
||||
|
||||
Returns:
|
||||
Distractor phase engine.
|
||||
"""
|
||||
distractor_grid = keep_n_characters_in_grid(APPLE_DISTRACTOR_GRID, APPLE,
|
||||
num_apples)
|
||||
|
||||
engine = ascii_art.ascii_art_to_game(
|
||||
distractor_grid,
|
||||
what_lies_beneath=BACKGROUND,
|
||||
sprites={
|
||||
PLAYER: player_sprite,
|
||||
TIMER: ascii_art.Partial(TimerSprite, max_frames),
|
||||
},
|
||||
drapes={
|
||||
APPLE: ascii_art.Partial(
|
||||
AppleDrape,
|
||||
reward=apple_reward,
|
||||
fix_apple_reward_in_episode=fix_apple_reward_in_episode,
|
||||
respawn_every=respawn_every)
|
||||
},
|
||||
update_schedule=[PLAYER, APPLE, TIMER],
|
||||
z_order=[APPLE, PLAYER, TIMER],
|
||||
)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def replace_grid_symbols(grid, old_to_new_map):
|
||||
"""Replaces symbols in the grid.
|
||||
|
||||
If mapping is not defined the symbol is not updated.
|
||||
|
||||
Args:
|
||||
grid: Represented as a list of strings.
|
||||
old_to_new_map: Mapping between symbols.
|
||||
|
||||
Returns:
|
||||
Updated grid.
|
||||
"""
|
||||
def symbol_map(x):
|
||||
if x in old_to_new_map:
|
||||
return old_to_new_map[x]
|
||||
return x
|
||||
new_grid = []
|
||||
for row in grid:
|
||||
new_grid.append(''.join(symbol_map(i) for i in row))
|
||||
return new_grid
|
||||
|
||||
|
||||
def keep_n_characters_in_grid(grid, character, n, backdrop_char=BACKGROUND):
|
||||
"""Keeps only a sample of characters `character` in the grid."""
|
||||
np_grid = np.array([list(i) for i in grid])
|
||||
char_positions = np.argwhere(np_grid == character)
|
||||
|
||||
# Randomly select parts to remove.
|
||||
num_empty_positions = char_positions.shape[0] - n
|
||||
if num_empty_positions < 0:
|
||||
raise ValueError('Not enough characters `{}` in grid.'.format(character))
|
||||
empty_pos = np.random.permutation(char_positions)[:num_empty_positions]
|
||||
|
||||
# Remove characters.
|
||||
grid = [list(row) for row in grid]
|
||||
for (i, j) in empty_pos:
|
||||
grid[i][j] = backdrop_char
|
||||
|
||||
return [''.join(row) for row in grid]
|
||||
|
||||
|
||||
class PlayerSprite(prefab_sprites.MazeWalker):
|
||||
"""Sprite for the actor."""
|
||||
|
||||
def __init__(self, corner, position, character, impassable=BORDER):
|
||||
super(PlayerSprite, self).__init__(
|
||||
corner, position, character, impassable=impassable,
|
||||
confined_to_board=True)
|
||||
|
||||
def update(self, actions, board, layers, backdrop, things, the_plot):
|
||||
|
||||
the_plot.add_reward(0.)
|
||||
|
||||
if actions == ACTION_QUIT:
|
||||
the_plot.next_chapter = None
|
||||
the_plot.terminate_episode()
|
||||
|
||||
if actions == ACTION_WEST:
|
||||
self._west(board, the_plot)
|
||||
elif actions == ACTION_EAST:
|
||||
self._east(board, the_plot)
|
||||
elif actions == ACTION_NORTH:
|
||||
self._north(board, the_plot)
|
||||
elif actions == ACTION_SOUTH:
|
||||
self._south(board, the_plot)
|
||||
|
||||
|
||||
class AppleDrape(plab_things.Drape):
|
||||
"""Drape for the apples used in the distractor phase."""
|
||||
|
||||
def __init__(self,
|
||||
curtain,
|
||||
character,
|
||||
respawn_every,
|
||||
reward,
|
||||
fix_apple_reward_in_episode):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
curtain: Array specifying locations of apples. Obtained from ascii grid.
|
||||
character: Character representing the drape.
|
||||
respawn_every: respawn frequency of apples.
|
||||
reward: Can either be a scalar specifying the reward or a reward range
|
||||
[min, max), given as a list or tuple, to uniformly sample from.
|
||||
fix_apple_reward_in_episode: If set to True, then only sample the apple's
|
||||
reward once in the episode and then fix the value.
|
||||
"""
|
||||
super(AppleDrape, self).__init__(curtain, character)
|
||||
self._respawn_every = respawn_every
|
||||
if not isinstance(reward, (list, tuple)):
|
||||
# Assuming scalar.
|
||||
self._reward = [reward, reward]
|
||||
else:
|
||||
if len(reward) != 2:
|
||||
raise ValueError('Reward must be a scalar or a two element list/tuple.')
|
||||
self._reward = reward
|
||||
self._fix_apple_reward_in_episode = fix_apple_reward_in_episode
|
||||
|
||||
# Grid specifying for each apple the last frame it was picked up.
|
||||
# Initialized to inifinity for cells with apples and -1 for cells without.
|
||||
self._last_pickup = np.where(curtain,
|
||||
np.inf * np.ones_like(curtain),
|
||||
-1. * np.ones_like(curtain))
|
||||
|
||||
def update(self, actions, board, layers, backdrop, things, the_plot):
|
||||
player_position = things[PLAYER].position
|
||||
# decide the apple_reward
|
||||
if (self._fix_apple_reward_in_episode and
|
||||
not the_plot.get('sampled_apple_reward', None)):
|
||||
the_plot['sampled_apple_reward'] = np.random.choice((self._reward[0],
|
||||
self._reward[1]))
|
||||
|
||||
if self.curtain[player_position]:
|
||||
self._last_pickup[player_position] = the_plot.frame
|
||||
self.curtain[player_position] = False
|
||||
if not self._fix_apple_reward_in_episode:
|
||||
the_plot.add_reward(np.random.uniform(*self._reward))
|
||||
else:
|
||||
the_plot.add_reward(the_plot['sampled_apple_reward'])
|
||||
|
||||
if self._respawn_every:
|
||||
respawn_cond = the_plot.frame > self._last_pickup + self._respawn_every
|
||||
respawn_cond &= self._last_pickup >= 0
|
||||
self.curtain[respawn_cond] = True
|
||||
|
||||
|
||||
class TimerSprite(plab_things.Sprite):
|
||||
"""Sprite for the timer.
|
||||
|
||||
The timer is in charge of stopping the current chapter. Timer sprite should be
|
||||
placed last in the update order to make sure everything is updated before the
|
||||
chapter terminates.
|
||||
"""
|
||||
|
||||
def __init__(self, corner, position, character, max_frames,
|
||||
track_chapter_reward=False):
|
||||
super(TimerSprite, self).__init__(corner, position, character)
|
||||
if not isinstance(max_frames, int):
|
||||
raise ValueError('max_frames must be of type integer.')
|
||||
self._max_frames = max_frames
|
||||
self._visible = False
|
||||
self._track_chapter_reward = track_chapter_reward
|
||||
self._total_chapter_reward = 0.
|
||||
|
||||
def update(self, actions, board, layers, backdrop, things, the_plot):
|
||||
directives = the_plot._get_engine_directives() # pylint: disable=protected-access
|
||||
|
||||
if self._track_chapter_reward:
|
||||
self._total_chapter_reward += directives.summed_reward or 0.
|
||||
|
||||
# Every chapter starts at frame = 0.
|
||||
if the_plot.frame >= self._max_frames or directives.game_over:
|
||||
# Calculate the reward obtained in this phase and send it through the
|
||||
# extra observations channel.
|
||||
if self._track_chapter_reward:
|
||||
the_plot['chapter_reward'] = self._total_chapter_reward
|
||||
the_plot.terminate_episode()
|
||||
@@ -0,0 +1,105 @@
|
||||
# Lint as: python2, python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Pycolab env."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from pycolab import rendering
|
||||
|
||||
from tvt.pycolab import active_visual_match
|
||||
from tvt.pycolab import key_to_door
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
|
||||
nest = contrib_framework.nest
|
||||
|
||||
|
||||
class PycolabEnvironment(object):
|
||||
"""A simple environment adapter for pycolab games."""
|
||||
|
||||
def __init__(self, game,
|
||||
num_apples=10,
|
||||
apple_reward=1.,
|
||||
fix_apple_reward_in_episode=False,
|
||||
final_reward=10.,
|
||||
crop=True,
|
||||
default_reward=0):
|
||||
"""Construct a `environment.Base` adapter that wraps a pycolab game."""
|
||||
rng = np.random.RandomState()
|
||||
if game == 'key_to_door':
|
||||
self._game = key_to_door.Game(rng,
|
||||
num_apples,
|
||||
apple_reward,
|
||||
fix_apple_reward_in_episode,
|
||||
final_reward,
|
||||
crop)
|
||||
elif game == 'active_visual_match':
|
||||
self._game = active_visual_match.Game(rng,
|
||||
num_apples,
|
||||
apple_reward,
|
||||
fix_apple_reward_in_episode,
|
||||
final_reward)
|
||||
else:
|
||||
raise ValueError('Unsupported game "%s".' % game)
|
||||
self._default_reward = default_reward
|
||||
|
||||
self._num_actions = self._game.num_actions
|
||||
|
||||
# Agents expect HWC uint8 observations, Pycolab uses CHW float observations.
|
||||
colours = nest.map_structure(lambda c: float(c) * 255 / 1000,
|
||||
self._game.colours)
|
||||
self._rgb_converter = rendering.ObservationToArray(
|
||||
value_mapping=colours, permute=(1, 2, 0), dtype=np.uint8)
|
||||
|
||||
episode = self._game.make_episode()
|
||||
observation, _, _ = episode.its_showtime()
|
||||
self._image_shape = self._rgb_converter(observation).shape
|
||||
|
||||
def _process_outputs(self, observation, reward):
|
||||
if reward is None:
|
||||
reward = self._default_reward
|
||||
image = self._rgb_converter(observation)
|
||||
return image, reward
|
||||
|
||||
def reset(self):
|
||||
"""Start a new episode."""
|
||||
self._episode = self._game.make_episode()
|
||||
observation, reward, _ = self._episode.its_showtime()
|
||||
return self._process_outputs(observation, reward)
|
||||
|
||||
def step(self, action):
|
||||
"""Take step in episode."""
|
||||
observation, reward, _ = self._episode.play(action)
|
||||
return self._process_outputs(observation, reward)
|
||||
|
||||
@property
|
||||
def num_actions(self):
|
||||
return self._num_actions
|
||||
|
||||
@property
|
||||
def observation_shape(self):
|
||||
return self._image_shape
|
||||
|
||||
@property
|
||||
def episode_length(self):
|
||||
return self._game.episode_length
|
||||
|
||||
def last_phase_reward(self):
|
||||
# In Pycolab games here we only track chapter_reward for final chapter.
|
||||
return float(self._episode.the_plot['chapter_reward'])
|
||||
@@ -0,0 +1,44 @@
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Pycolab Game interface."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import six
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class AbstractGame(object):
|
||||
"""Abstract base class for Pycolab games."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, rng, **settings):
|
||||
"""Initialize the game."""
|
||||
|
||||
@abc.abstractproperty
|
||||
def num_actions(self):
|
||||
"""Number of possible actions in the game."""
|
||||
|
||||
@abc.abstractproperty
|
||||
def colours(self):
|
||||
"""Symbol to colour map for the game."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def make_episode(self):
|
||||
"""Factory method for generating new episodes of the game."""
|
||||
@@ -0,0 +1,67 @@
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Pycolab human player."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import curses
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
from pycolab import human_ui
|
||||
|
||||
from tvt.pycolab import active_visual_match
|
||||
from tvt.pycolab import common
|
||||
from tvt.pycolab import key_to_door
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_enum('game', 'key_to_door',
|
||||
['key_to_door', 'active_visual_match'],
|
||||
'The name of the game')
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
|
||||
rng = np.random.RandomState()
|
||||
|
||||
if FLAGS.game == 'key_to_door':
|
||||
game = key_to_door.Game(rng)
|
||||
elif FLAGS.game == 'active_visual_match':
|
||||
game = active_visual_match.Game(rng)
|
||||
else:
|
||||
raise ValueError('Unsupported game "%s".' % FLAGS.game)
|
||||
episode = game.make_episode()
|
||||
|
||||
ui = human_ui.CursesUi(
|
||||
keys_to_actions={
|
||||
curses.KEY_UP: common.ACTION_NORTH,
|
||||
curses.KEY_DOWN: common.ACTION_SOUTH,
|
||||
curses.KEY_LEFT: common.ACTION_WEST,
|
||||
curses.KEY_RIGHT: common.ACTION_EAST,
|
||||
-1: common.ACTION_DELAY,
|
||||
'q': common.ACTION_QUIT,
|
||||
'Q': common.ACTION_QUIT},
|
||||
delay=-1,
|
||||
colour_fg=game.colours
|
||||
)
|
||||
ui.play(episode)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,214 @@
|
||||
# Lint as: python2, python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Key to door task.
|
||||
|
||||
The game is split up into three phases:
|
||||
1. (exploration phase) player can collect a key,
|
||||
2. (distractor phase) player is collecting apples,
|
||||
3. (reward phase) player can open the door and get the reward if the key is
|
||||
previously collected.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from pycolab import ascii_art
|
||||
from pycolab import storytelling
|
||||
from pycolab import things as plab_things
|
||||
|
||||
from tvt.pycolab import common
|
||||
from tvt.pycolab import game
|
||||
from tvt.pycolab import objects
|
||||
|
||||
|
||||
COLOURS = {
|
||||
'i': (1000, 1000, 1000), # Indicator.
|
||||
}
|
||||
|
||||
EXPLORE_GRID = [
|
||||
' ####### ',
|
||||
' #kkkkk# ',
|
||||
' #kkkkk# ',
|
||||
' ## ## ',
|
||||
' #+++++# ',
|
||||
' #+++++# ',
|
||||
' ####### '
|
||||
]
|
||||
|
||||
REWARD_GRID = [
|
||||
' ',
|
||||
' ##d## ',
|
||||
' # # ',
|
||||
' # + # ',
|
||||
' # # ',
|
||||
' ##### ',
|
||||
' ',
|
||||
]
|
||||
|
||||
|
||||
class KeySprite(plab_things.Sprite):
|
||||
"""Sprite for the key."""
|
||||
|
||||
def update(self, actions, board, layers, backdrop, things, the_plot):
|
||||
player_position = things[common.PLAYER].position
|
||||
pick_up = self.position == player_position
|
||||
|
||||
if self.visible and pick_up:
|
||||
# Pass information to all phases.
|
||||
the_plot['has_key'] = True
|
||||
self._visible = False
|
||||
|
||||
|
||||
class DoorSprite(plab_things.Sprite):
|
||||
"""Sprite for the door."""
|
||||
|
||||
def __init__(self, corner, position, character, pickup_reward):
|
||||
super(DoorSprite, self).__init__(corner, position, character)
|
||||
self._pickup_reward = pickup_reward
|
||||
|
||||
def update(self, actions, board, layers, backdrop, things, the_plot):
|
||||
player_position = things[common.PLAYER].position
|
||||
pick_up = self.position == player_position
|
||||
|
||||
if pick_up and the_plot.get('has_key'):
|
||||
the_plot.add_reward(self._pickup_reward)
|
||||
# The key is lost after the first time opening the door
|
||||
# to ensure only one reward per episode.
|
||||
the_plot['has_key'] = False
|
||||
|
||||
|
||||
class PlayerSprite(common.PlayerSprite):
|
||||
"""Sprite for the actor."""
|
||||
|
||||
def __init__(self, corner, position, character):
|
||||
super(PlayerSprite, self).__init__(
|
||||
corner, position, character,
|
||||
impassable=common.BORDER + common.INDICATOR + common.DOOR)
|
||||
|
||||
def update(self, actions, board, layers, backdrop, things, the_plot):
|
||||
|
||||
# Allow moving through the door if key is previously collected.
|
||||
if common.DOOR in self.impassable and the_plot.get('has_key'):
|
||||
self._impassable.remove(common.DOOR)
|
||||
|
||||
super(PlayerSprite, self).update(actions, board, layers, backdrop, things,
|
||||
the_plot)
|
||||
|
||||
|
||||
class Game(game.AbstractGame):
|
||||
"""Key To Door Game."""
|
||||
|
||||
def __init__(self,
|
||||
rng,
|
||||
num_apples=10,
|
||||
apple_reward=(1, 10),
|
||||
fix_apple_reward_in_episode=True,
|
||||
final_reward=10.,
|
||||
crop=True,
|
||||
max_frames=common.DEFAULT_MAX_FRAMES_PER_PHASE):
|
||||
del rng # Each episode is identical and colours are not randomised.
|
||||
self._num_apples = num_apples
|
||||
self._apple_reward = apple_reward
|
||||
self._fix_apple_reward_in_episode = fix_apple_reward_in_episode
|
||||
self._final_reward = final_reward
|
||||
self._crop = crop
|
||||
self._max_frames = max_frames
|
||||
self._episode_length = sum(self._max_frames.values())
|
||||
self._num_actions = common.NUM_ACTIONS
|
||||
self._colours = common.FIXED_COLOURS.copy()
|
||||
self._colours.update(COLOURS)
|
||||
self._extra_observation_fields = ['chapter_reward_as_string']
|
||||
|
||||
@property
|
||||
def extra_observation_fields(self):
|
||||
"""The field names of extra observations."""
|
||||
return self._extra_observation_fields
|
||||
|
||||
@property
|
||||
def num_actions(self):
|
||||
"""Number of possible actions in the game."""
|
||||
return self._num_actions
|
||||
|
||||
@property
|
||||
def episode_length(self):
|
||||
return self._episode_length
|
||||
|
||||
@property
|
||||
def colours(self):
|
||||
"""Symbol to colour map for key to door."""
|
||||
return self._colours
|
||||
|
||||
def _make_explore_phase(self):
|
||||
# Keep only one key and one player position.
|
||||
explore_grid = common.keep_n_characters_in_grid(
|
||||
EXPLORE_GRID, common.KEY, 1)
|
||||
explore_grid = common.keep_n_characters_in_grid(
|
||||
explore_grid, common.PLAYER, 1)
|
||||
return ascii_art.ascii_art_to_game(
|
||||
art=explore_grid,
|
||||
what_lies_beneath=' ',
|
||||
sprites={
|
||||
common.PLAYER: PlayerSprite,
|
||||
common.KEY: KeySprite,
|
||||
common.INDICATOR: ascii_art.Partial(objects.IndicatorObjectSprite,
|
||||
char_to_track=common.KEY,
|
||||
override_position=(0, 5)),
|
||||
common.TIMER: ascii_art.Partial(common.TimerSprite,
|
||||
self._max_frames['explore']),
|
||||
},
|
||||
update_schedule=[
|
||||
common.PLAYER, common.KEY, common.INDICATOR, common.TIMER],
|
||||
z_order=[common.KEY, common.INDICATOR, common.PLAYER, common.TIMER],
|
||||
)
|
||||
|
||||
def _make_distractor_phase(self):
|
||||
return common.distractor_phase(
|
||||
player_sprite=PlayerSprite,
|
||||
num_apples=self._num_apples,
|
||||
max_frames=self._max_frames['distractor'],
|
||||
apple_reward=self._apple_reward,
|
||||
fix_apple_reward_in_episode=self._fix_apple_reward_in_episode)
|
||||
|
||||
def _make_reward_phase(self):
|
||||
return ascii_art.ascii_art_to_game(
|
||||
art=REWARD_GRID,
|
||||
what_lies_beneath=' ',
|
||||
sprites={
|
||||
common.PLAYER: PlayerSprite,
|
||||
common.DOOR: ascii_art.Partial(DoorSprite,
|
||||
pickup_reward=self._final_reward),
|
||||
common.TIMER: ascii_art.Partial(common.TimerSprite,
|
||||
self._max_frames['reward'],
|
||||
track_chapter_reward=True),
|
||||
},
|
||||
update_schedule=[common.PLAYER, common.DOOR, common.TIMER],
|
||||
z_order=[common.PLAYER, common.DOOR, common.TIMER],
|
||||
)
|
||||
|
||||
def make_episode(self):
|
||||
"""Factory method for generating new episodes of the game."""
|
||||
if self._crop:
|
||||
croppers = common.get_cropper()
|
||||
else:
|
||||
croppers = None
|
||||
|
||||
return storytelling.Story([
|
||||
self._make_explore_phase,
|
||||
self._make_distractor_phase,
|
||||
self._make_reward_phase,
|
||||
], croppers=croppers)
|
||||
@@ -0,0 +1,123 @@
|
||||
# Lint as: python2, python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Pycolab sprites."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from pycolab import things as plab_things
|
||||
from pycolab.prefab_parts import sprites as prefab_sprites
|
||||
import six
|
||||
from tvt.pycolab import common
|
||||
|
||||
|
||||
class PlayerSprite(prefab_sprites.MazeWalker):
|
||||
"""Sprite representing the agent."""
|
||||
|
||||
def __init__(self, corner, position, character,
|
||||
max_steps_per_act, moving_player):
|
||||
|
||||
"""Indicates to the superclass that we can't walk off the board."""
|
||||
super(PlayerSprite, self).__init__(
|
||||
corner, position, character, impassable=[common.BORDER],
|
||||
confined_to_board=True)
|
||||
|
||||
self._moving_player = moving_player
|
||||
self._max_steps_per_act = max_steps_per_act
|
||||
self._num_steps = 0
|
||||
|
||||
def update(self, actions, board, layers, backdrop, things, the_plot):
|
||||
del backdrop # Unused.
|
||||
|
||||
if actions is not None:
|
||||
assert actions in common.ACTIONS
|
||||
|
||||
the_plot.log("Step {} | Action {}".format(self._num_steps, actions))
|
||||
the_plot.add_reward(0.0)
|
||||
self._num_steps += 1
|
||||
|
||||
if actions == common.ACTION_QUIT:
|
||||
the_plot.terminate_episode()
|
||||
|
||||
if self._moving_player:
|
||||
if actions == common.ACTION_WEST:
|
||||
self._west(board, the_plot)
|
||||
elif actions == common.ACTION_EAST:
|
||||
self._east(board, the_plot)
|
||||
elif actions == common.ACTION_NORTH:
|
||||
self._north(board, the_plot)
|
||||
elif actions == common.ACTION_SOUTH:
|
||||
self._south(board, the_plot)
|
||||
|
||||
if self._max_steps_per_act == self._num_steps:
|
||||
the_plot.terminate_episode()
|
||||
|
||||
|
||||
class ObjectSprite(plab_things.Sprite):
|
||||
"""Sprite for a generic object which can be collectable."""
|
||||
|
||||
def __init__(self, corner, position, character, reward=0., collectable=True,
|
||||
terminate=True):
|
||||
super(ObjectSprite, self).__init__(corner, position, character)
|
||||
self._reward = reward # Reward on pickup.
|
||||
self._collectable = collectable
|
||||
|
||||
def set_visibility(self, visible):
|
||||
self._visible = visible
|
||||
|
||||
def update(self, actions, board, layers, backdrop, things, the_plot):
|
||||
player_position = things[common.PLAYER].position
|
||||
pick_up = self.position == player_position
|
||||
|
||||
if pick_up and self.visible:
|
||||
the_plot.add_reward(self._reward)
|
||||
if self._collectable:
|
||||
self.set_visibility(False)
|
||||
# set all other objects to be invisible
|
||||
for v in six.itervalues(things):
|
||||
if isinstance(v, ObjectSprite):
|
||||
v.set_visibility(False)
|
||||
|
||||
|
||||
class IndicatorObjectSprite(plab_things.Sprite):
|
||||
"""Sprite for the indicator object.
|
||||
|
||||
The indicator object is an object that spawns at a designated position once
|
||||
the player picks up an object defined by the `char_to_track` argument.
|
||||
The indicator object is spawned for just a single frame.
|
||||
"""
|
||||
|
||||
def __init__(self, corner, position, character, char_to_track,
|
||||
override_position=None):
|
||||
super(IndicatorObjectSprite, self).__init__(corner, position, character)
|
||||
if override_position is not None:
|
||||
self._position = override_position
|
||||
self._char_to_track = char_to_track
|
||||
self._visible = False
|
||||
self._pickup_frame = None
|
||||
|
||||
def update(self, actions, board, layers, backdrop, things, the_plot):
|
||||
player_position = things[common.PLAYER].position
|
||||
pick_up = things[self._char_to_track].position == player_position
|
||||
|
||||
if self._pickup_frame:
|
||||
self._visible = False
|
||||
|
||||
if pick_up and not self._pickup_frame:
|
||||
self._visible = True
|
||||
self._pickup_frame = the_plot.frame
|
||||
@@ -0,0 +1,8 @@
|
||||
absl-py
|
||||
dm-sonnet==1.34
|
||||
numpy
|
||||
pycolab
|
||||
six
|
||||
trfl
|
||||
tensorflow==1.13.2
|
||||
tensorflow-probability==0.6.0
|
||||
+584
File diff suppressed because it is too large
Load Diff
Executable
+20
@@ -0,0 +1,20 @@
|
||||
#!/bin/sh
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
python3 -m venv tvt_venv
|
||||
source tvt_venv/bin/activate
|
||||
pip install -r tvt/requirements.txt
|
||||
|
||||
python3 -m tvt.main
|
||||
@@ -0,0 +1,247 @@
|
||||
# Lint as: python2, python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Temporal Value Transport implementation."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from concurrent import futures
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
from six.moves import zip
|
||||
|
||||
|
||||
def _unstack(array, axis):
|
||||
"""Opposite of np.stack."""
|
||||
split_array = np.split(array, array.shape[axis], axis=axis)
|
||||
return [np.squeeze(a, axis=axis) for a in split_array]
|
||||
|
||||
|
||||
def _top_k_args(array, k):
|
||||
"""Return top k arguments or all arguments if array size is less than k."""
|
||||
if len(array) <= k:
|
||||
return np.arange(len(array))
|
||||
return np.argpartition(array, kth=-k)[-k:]
|
||||
|
||||
|
||||
def _threshold_read_event_times(read_strengths, threshold):
|
||||
"""Return the times of max read strengths within one threshold read event."""
|
||||
chosen_times = []
|
||||
over_threshold = False
|
||||
max_read_strength = 0.
|
||||
# Wait until the threshold is crossed then keep track of max read strength and
|
||||
# time of max read strength until the read strengths go back under the
|
||||
# threshold, then add that max read strength time to the chosen times. Wait
|
||||
# until threshold is crossed again and then repeat the process.
|
||||
for time, strength in enumerate(read_strengths):
|
||||
if strength > threshold:
|
||||
over_threshold = True
|
||||
if strength > max_read_strength:
|
||||
max_read_strength = strength
|
||||
max_read_strength_time = time
|
||||
else:
|
||||
# If coming back under threshold, add the time of the last max read.
|
||||
if over_threshold:
|
||||
chosen_times.append(max_read_strength_time)
|
||||
max_read_strength = 0.
|
||||
over_threshold = False
|
||||
# Add max read strength time if episode finishes before going under threshold.
|
||||
if over_threshold:
|
||||
chosen_times.append(max_read_strength_time)
|
||||
return np.array(chosen_times)
|
||||
|
||||
|
||||
def _tvt_rewards_single_head(read_weights, read_strengths, read_times,
|
||||
baselines, alpha, top_k_t1,
|
||||
read_strength_threshold, no_transport_period):
|
||||
"""Compute TVT rewards for a single read head, no batch dimension.
|
||||
|
||||
This performs the updates for one read head.
|
||||
`t1` and `t2` refer to times to where and from where the value is being
|
||||
transported, respectively. I.e. the rewards at `t1` times are being modified
|
||||
based on values at times `t2`.
|
||||
|
||||
Args:
|
||||
read_weights: shape (ep_length, top_k).
|
||||
read_strengths: shape (ep_length,).
|
||||
read_times: shape (ep_length, top_k).
|
||||
baselines: shape (ep_length,).
|
||||
alpha: The multiplier for the temporal value transport rewards.
|
||||
top_k_t1: For each read event time, this determines how many time points
|
||||
to send tvt reward to.
|
||||
read_strength_threshold: Read strengths below this value are ignored.
|
||||
no_transport_period: Length of no_transport_period.
|
||||
|
||||
Returns:
|
||||
An array of TVT rewards with shape (ep_length,).
|
||||
"""
|
||||
tvt_rewards = np.zeros_like(baselines)
|
||||
|
||||
# Mask read_weights for reads that read back to times within
|
||||
# no_transport_period of current time.
|
||||
ep_length = read_times.shape[0]
|
||||
times = np.arange(ep_length)
|
||||
# Expand dims for correct broadcasting when subtracting read_times.
|
||||
times = np.expand_dims(times, -1)
|
||||
read_past_no_transport_period = (times - read_times) > no_transport_period
|
||||
read_weights_masked = np.where(read_past_no_transport_period,
|
||||
read_weights,
|
||||
np.zeros_like(read_weights))
|
||||
|
||||
# Find t2 times with maximum read weights. Ignore t2 times whose maximum
|
||||
# read weights fall inside the no_transport_period.
|
||||
max_read_weight_args = np.argmax(read_weights, axis=1) # (ep_length,)
|
||||
times = np.arange(ep_length)
|
||||
max_read_weight_times = read_times[times,
|
||||
max_read_weight_args] # (ep_length,)
|
||||
read_strengths_cut = np.where(
|
||||
times - max_read_weight_times > no_transport_period,
|
||||
read_strengths,
|
||||
np.zeros_like(read_strengths))
|
||||
|
||||
# Filter t2 candidates to perform value transport on local maximums
|
||||
# above a threshold.
|
||||
t2_times_with_largest_reads = _threshold_read_event_times(
|
||||
read_strengths_cut, read_strength_threshold)
|
||||
|
||||
# Loop through all t2 candidates and transport value to top_k_t1 read times.
|
||||
for t2 in t2_times_with_largest_reads:
|
||||
try:
|
||||
baseline_value_when_reading = baselines[t2]
|
||||
except IndexError:
|
||||
raise RuntimeError("Attempting to access baselines array with length {}"
|
||||
" at index {}. Make sure output_baseline is set in"
|
||||
" the agent config.".format(len(baselines), t2))
|
||||
read_times_from_t2 = read_times[t2]
|
||||
read_weights_from_t2 = read_weights_masked[t2]
|
||||
|
||||
# Find the top_k_t1 read times for this t2 and their corresponding read
|
||||
# weights. The call to _top_k_args() here gives the array indices for the
|
||||
# times and weights of the top_k_t1 reads from t2.
|
||||
top_t1_indices = _top_k_args(read_weights_from_t2, top_k_t1)
|
||||
top_t1_read_times = np.take(read_times_from_t2, top_t1_indices)
|
||||
top_t1_read_weights = np.take(read_weights_from_t2, top_t1_indices)
|
||||
|
||||
# For each of the top_k_t1 read times t and corresponding read weight w,
|
||||
# find the trajectory that contains step_num (t + shift) and modify the
|
||||
# reward at step_num (t + shift) using w and the baseline value at t2.
|
||||
# We ignore any read times t >= t2. These can emerge because if nothing
|
||||
# in memory matches positively with the read query, the top reads may be
|
||||
# in the empty region of the memory.
|
||||
for step_num, read_weight in zip(top_t1_read_times, top_t1_read_weights):
|
||||
if step_num >= t2:
|
||||
# Skip this step_num as it is not really a memory time.
|
||||
continue
|
||||
|
||||
# Compute the tvt reward and add it on.
|
||||
tvt_reward = alpha * read_weight * baseline_value_when_reading
|
||||
tvt_rewards[step_num] += tvt_reward
|
||||
|
||||
return tvt_rewards
|
||||
|
||||
|
||||
def _compute_tvt_rewards_from_read_info(
|
||||
read_weights, read_strengths, read_times, baselines, gamma,
|
||||
alpha=0.9, top_k_t1=50,
|
||||
read_strength_threshold=2.,
|
||||
no_transport_period_when_gamma_1=25):
|
||||
"""Compute TVT rewards given supplied read information, no batch dimension.
|
||||
|
||||
Args:
|
||||
read_weights: shape (ep_length, num_read_heads, top_k).
|
||||
read_strengths: shape (ep_length, num_read_heads).
|
||||
read_times: shape (ep_length, num_read_heads, top_k).
|
||||
baselines: shape (ep_length,).
|
||||
gamma: Scalar discount factor used to calculate the no_transport_period.
|
||||
alpha: The multiplier for the temporal value transport rewards.
|
||||
top_k_t1: For each read event time, this determines how many time points
|
||||
to send tvt reward to.
|
||||
read_strength_threshold: Read strengths below this value are ignored.
|
||||
no_transport_period_when_gamma_1: no transport period when gamma == 1.
|
||||
|
||||
Returns:
|
||||
An array of TVT rewards with shape (ep_length,).
|
||||
"""
|
||||
|
||||
if gamma < 1:
|
||||
no_transport_period = int(1 / (1 - gamma))
|
||||
else:
|
||||
if no_transport_period_when_gamma_1 is None:
|
||||
raise ValueError("No transport period must be defined when gamma == 1.")
|
||||
no_transport_period = no_transport_period_when_gamma_1
|
||||
|
||||
# Split read infos by read head.
|
||||
num_read_heads = read_weights.shape[1]
|
||||
read_weights = _unstack(read_weights, axis=1)
|
||||
read_strengths = _unstack(read_strengths, axis=1)
|
||||
read_times = _unstack(read_times, axis=1)
|
||||
|
||||
# Calcuate TVT rewards for each read head separately and add to total.
|
||||
tvt_rewards = np.zeros_like(baselines)
|
||||
for i in range(num_read_heads):
|
||||
tvt_rewards += _tvt_rewards_single_head(
|
||||
read_weights[i], read_strengths[i], read_times[i],
|
||||
baselines, alpha, top_k_t1, read_strength_threshold,
|
||||
no_transport_period)
|
||||
|
||||
return tvt_rewards
|
||||
|
||||
|
||||
def compute_tvt_rewards(read_infos, baselines, gamma=.96):
|
||||
"""Compute TVT rewards from EpisodeOutputs.
|
||||
|
||||
Args:
|
||||
read_infos: A memory_reader.ReadInformation namedtuple, where each element
|
||||
has shape (ep_length, batch_size, num_read_heads, ...).
|
||||
baselines: A numpy float array with shape (ep_length, batch_size).
|
||||
gamma: Discount factor.
|
||||
|
||||
Returns:
|
||||
An array of TVT rewards with shape (ep_length,).
|
||||
"""
|
||||
if not read_infos:
|
||||
return np.zeros_like(baselines)
|
||||
|
||||
# TVT reward computation is without batch dimension. so we need to process
|
||||
# read_infos and baselines into batchwise components.
|
||||
batch_size = baselines.shape[1]
|
||||
|
||||
# Split each element of read info on batch dim.
|
||||
read_weights = _unstack(read_infos.weights, axis=1)
|
||||
read_strengths = _unstack(read_infos.strengths, axis=1)
|
||||
read_indices = _unstack(read_infos.indices, axis=1)
|
||||
# Split baselines on batch dim.
|
||||
baselines = _unstack(baselines, axis=1)
|
||||
|
||||
# Comute TVT rewards for each element in the batch (threading over batch).
|
||||
tvt_rewards = []
|
||||
with futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
|
||||
for i in range(batch_size):
|
||||
tvt_rewards.append(
|
||||
executor.submit(
|
||||
_compute_tvt_rewards_from_read_info,
|
||||
read_weights[i],
|
||||
read_strengths[i],
|
||||
read_indices[i],
|
||||
baselines[i],
|
||||
gamma)
|
||||
)
|
||||
tvt_rewards = [f.result() for f in tvt_rewards]
|
||||
|
||||
# Process TVT rewards back into an array of shape (ep_length, batch_size).
|
||||
return np.stack(tvt_rewards, axis=1)
|
||||
Reference in New Issue
Block a user