Internal change.

PiperOrigin-RevId: 372158522
This commit is contained in:
Ravichandra Addanki
2021-05-05 17:42:43 +00:00
committed by Louise Deason
parent 2e866f1937
commit 59f5fb1268
6 changed files with 915 additions and 0 deletions

File diff suppressed because one or more lines are too long

209
cmtouch/README.md Normal file
View File

@@ -0,0 +1,209 @@
--------------------------------------------------------------------------------
# CMTouch Dataset
<!-- ![downloads](https://img.shields.io/github/downloads/atom/atom/total.svg)
![build](https://img.shields.io/appveyor/ci/:user/:repo.svg)
![chat](https://img.shields.io/discord/:serverId.svg)
-->
This repository contains datasets for cross-modal representation learning, used
in developing rich touch representations in "Learning rich touch representations
through cross-modal self-supervision" [1].
The datasets we provide are:
1. CMTouch-Props
2. CMTouch-YCB
The datasets consist of episodes collected by running a reinforcement learning
agent on a simulated Shadow Dexterous Hand [2] interacting with different
objects. From this interactions, observations from different sensory modalities
are collected at each time step, including vision, proprioception (joint
positions and velocities), touch, actions, object IDs. We used these data to
learn rich touch representations using cross-modal self-supervision.
## Bibtex
If you use one of these datasets in your work, please cite the reference paper
as follows:
```
@InProceedings{zambelli20learning,
author = "Zambelli, Martina and Aytar, Yusuf and Visin, Francesco and Zhou, Yuxiang and Hadsell, Raia",
title = "Learning rich touch representations through cross-modal self-supervision",
year = "2020",
}
```
<!--
@misc{cmtouchdatasets}, title={CMTouch Datasets}, author={Zambelli, Martina and
Aytar, Yusuf and Visin, Francesco and Zhou, Yuxiang and Hadsell, Raia},
howpublished={https://github.com/deepmind/deepmind-research/tree/master/cmtouch},
year={2020} }
-->
## Descriptions
### Experimental setup
We run experiments in simulation with MuJoCo [3] and we use the simulated Shadow
Dexterous Hand [2], with five fingers and 24 degrees of freedom, actuated by 20
motors. In simulation, each fingertip has a spatial touch sensor attached with a
spatial resolution of 4×4 and three channels: one for normal force and two for
tangential forces. We simplify this by summing across the spatial dimensions,
to obtain a single force vector for each fingertip representing one normal force
and two tangential forces. The state consists of proprioception (joint positions
and joint velocities) and touch.
Visual inputs are collected with a 64×64 resolution and are only used for
representation learning, but are not provided as observations to control the
robots actions. The action space is 20-dimensional. We use velocity control and
a control rate of 30 Hz. Each episode has 200 time steps, which correspond to
about 6 seconds. The environment consists of the Shadow Hand, facing down, and
interacting with different objects. These objects have different shapes, sizes
and physical properties (e.g. rigid or soft). We develop two versions of the
task, the first using simple props and the second using YCB objects. In both
cases, objects are fixed to their frame of reference, while their position and
orientation are randomized.
### CMTouch-Props
This is a dataset based on simple geometric 3D shapes (referred to as "props").
Props are simple 3D shaped objects that include cubes, spheres, cylinders and
ellipsoid of different sizes. We also generated the soft version of each prop,
which can deform under the pressure of the touching fingers.
Soft deformable objects are complex entities to simulate: they are defined
through a composition of multiple bodies (capsules) that are tied together to
form a shape, such as a cube or a sphere. The main characteristic of these
objects is their elastic behaviour, that is they change shape when touched. The
most difficult thing to simulate in this context is contacts, which grow
exponentially with the increased number of colliding bodies.
Forty-eight different objects are generated by sampling from 6 different sizes,
4 different shapes (i.e. sphere, cylinder, cube, ellipsoid), and they can either
be rigid or soft.
![](https://i.imgur.com/Hps38z5.jpg)
### CMTouch-YCB
This is a dataset based on YCB objects. The YCB objects dataset [4] consists of
everyday objects with different shapes, sizes, textures, weight and rigidity.
We chose a set of ten objects: cracker box, sugar box, mustard bottle, potted
meat can, banana, pitcher base, bleach cleanser, mug, power drill, scissors.
These are generated in simulation at their standard size, which is also
proportionate to the default dimension of the simulated Shadow Hand.
The pose of each object is randomly selected among a set of 60 different poses,
where we vary the orientation of the object. These variations make the
identification of each object more complex than the CMTouch-Props and require a
higher generalization capability from the learning method applied.
![](https://i.imgur.com/Mf3KYbn.jpg)
## Download
The datasets can be downloaded from
[Google Cloud Storage](https://console.cloud.google.com/storage/browser/dm_cmtouch).
Each dataset is a single
[TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file.
On Linux, to download a particular dataset, use the web interface, or run `wget`
with the appropriate filename as follows:
```
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_props_all_test.tfrecords
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_props_all_train.tfrecords
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_props_all_val.tfrecords
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_ycb_all_test.tfrecords
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_ycb_all_train.tfrecords
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_ycb_all_val.tfrecords
```
## Usage
After downloading the dataset files, you can read them as `tf.data.Dataset`
instances with the readers provided. The example below shows how to read the
cmtouch-props dataset:
```
record_file = 'test.tfrecords'
dataset = tf.data.TFRecordDataset(record_file)
parsed_dataset = dataset.map(_parse_tf_example)
```
(a complete example is provided in the Colab).
All dataset readers return the following set of observations:
'camera': tf.io.FixedLenFeature([], tf.string),
'camera/height': tf.io.FixedLenFeature([], tf.int64),
'camera/width': tf.io.FixedLenFeature([], tf.int64),
'camera/channel': tf.io.FixedLenFeature([], tf.int64),
'object_id': tf.io.FixedLenFeature([], tf.string), # for both
'object_id/dim': tf.io.FixedLenFeature([], tf.int64),
'orientation_id': tf.io.FixedLenFeature([], tf.string), # only for ycb
'orientation_id/dim': tf.io.FixedLenFeature([], tf.int64),
'shadowhand_motor/joints_vel': tf.io.FixedLenFeature([], tf.string),
'shadowhand_motor/joints_vel/dim': tf.io.FixedLenFeature([], tf.int64),
'shadowhand_motor/joints_pos': tf.io.FixedLenFeature([], tf.string),
'shadowhand_motor/joints_pos/dim': tf.io.FixedLenFeature([], tf.int64),
'shadowhand_motor/spatial_touch': tf.io.FixedLenFeature([], tf.string),
'shadowhand_motor/spatial_touch/dim': tf.io.FixedLenFeature([], tf.int64),
'actions'
* 'camera': `Tensor` of shape [sequence_length, height, width, channels] and type
uint8
* 'shadowhand_motor/spatial_touch': `Tensor` of shape [sequence_length, num_fingers x 3] and type float32
* 'shadowhand_motor/joints_pos': `Tensor` of shape [sequence_length, num_joint_positions] and type float32
* 'shadowhand_motor/joints_vel': `Tensor` of shape [sequence_length,
num_joint_velocities] and type float32
* 'actions': `Tensor` of shape [sequence_length, num_actuated_joints] and type
float32
* 'object_id': `Scalar` indicating an object identification number
* 'orientation_id': `Scalar` indicating a YCB object pose identification number
(CMTouch-YCB only)
Few-shot evaluations can be made by creating subsets of data to train and
evaluate the models.
<!--
```diff=
- TODO
```
-->
## References
[1] M. Zambelli, Y. Aytar, F. Visin, Y. Zhou, R. Hadsell. Learning rich touch
representations through cross-modal self-supervision. Conference on Robot
Learning (CoRL), 2020.
[2] ShadowRobot, Shadow Dexterous Hand.
https://www.shadowrobot.com/products/dexterous-hand/.
[3] E. Todorov, T. Erez, and Y. Tassa. MuJoCo: A physics engine for model-based
control. In Proceedings of the International Conference on Intelligent Robots
and Systems (IROS), 2012.
[4] B. Calli, A. Singh, J. Bruce, A. Walsman, K. Konolige, S. Srinivasa, P.
Abbeel, and A. M. Dollar. Yale-cmu-berkeley dataset for robotic manipulation
research. The International Journal of RoboticsResearch, 36(3):261268, 2017.
## Disclaimers
This is not an official Google product.
## Appendix and FAQ
**Find this document incomplete?** Leave a comment!

123
cmtouch/download_datasets.sh Executable file
View File

@@ -0,0 +1,123 @@
#!/bin/bash
# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train10.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train100.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train1000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train250.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train30.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train50.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train500.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_val.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train1500.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train180.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train300.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train3000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train60.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train600.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train6000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_val.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train1500.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train180.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train300.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train3000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train60.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train600.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train6000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_val.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train1500.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train180.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train300.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train3000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train60.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train600.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train6000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_val.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train1500.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train180.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train300.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train3000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train60.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train600.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train6000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_val.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train1500.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train180.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train300.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train3000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train60.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train600.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train6000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_val.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train1500.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train180.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train300.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train3000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train60.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train600.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train6000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_val.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train1500.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train180.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train300.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train3000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train60.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train600.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train6000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_val.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train1500.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train180.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train300.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train3000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train60.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train600.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train6000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_val.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train1500.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train180.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train300.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train3000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train60.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train600.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train6000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_val.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train1500.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train180.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train300.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train3000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train60.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train600.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train6000.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_val.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train1200.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train144.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train240.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train2400.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train48.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train480.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train4800.tfrecords
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_val.tfrecords

125
synthetic_returns/README.md Normal file
View File

@@ -0,0 +1,125 @@
# Code for Synthetic Returns
This repository contains code for the arXiv preprint
["Synthetic Returns for Long-Term Credit Assignment"](https://arxiv.org/abs/2102.12425)
by David Raposo, Sam Ritter, Adam Santoro, Greg Wayne, Theophane Weber, Matt
Botvinick, Hado van Hasselt, and Francis Song.
To cite this work:
```
@article{raposo2021synthetic,
title={Rapid Task-Solving in Novel Environments},
author={Raposo, David and Ritter, Sam and Santoro, Adam and Wayne, Greg and
Weber, Theophane and Botvinick, Matt and van Hasselt, Hado and Song, Francis},
journal={arXiv preprint arXiv:2102.12425},
year={2021}
}
```
### Agent core wrapper
We implemented the Synthetic Returns module as a wrapper to a recurrent neural
network (RNN), so it should be compatible with any Deep-RL agent with an
arbitrary RNN core, whose inputs consist of batches of vectors. This could be an
LSTM as in the example below, or a more sophisticated core as long as it
implements an `hk.RNNCore`.
```python
agent_core = hk.LSTM(128)
```
To build the SR wrapper, simply pass the existing agent core to the constructor,
along with the SR configuration:
```python
sr_config = {
"memory_size": 128,
"capacity": 300,
"hidden_layers": (128, 128),
"alpha": 0.3,
"beta": 1.0,
}
sr_agent_core = hk.ResetCore(
SyntheticReturnsCoreWrapper(core=agent_core, **sr_config))
```
Typically, the SR wrapper should itself be wrapped in a `hk.ResetCore` in order
to reset the core state in the beginning of a new episode. This will reset not
only the episodic memory but also the original agent core that was passed to the
SR wrapper constructor.
### Learner
Consider the distributed setting, wherein a learner receives mini-batches of
trajectories of length `T` produced by the actors.
`trajectory` is a nested structure of tensors of size `[T,B,...]` (where `B` is
the batch size) containing observations, agent states, rewards and step type
indicators.
We start by producing inputs to the SR core, which consist of tuples of current
state embeddings and return targets. The current state embeddings can be
produced by a ConvNet, for example. In our experiments we used the current step
reward as target. Note that the current step reward correspond to the rewards in
the trajectory shifted by one, relative to the observations:
```python
observations = jax.tree_map(lambda x: x[:-1], trajectory.observation)
vision_output = hk.BatchApply(vision_net)(observations)
return_targets = trajectory.reward[1:]
sr_core_inputs = (vision_output, return_targets)
```
For purposes of core resetting at the beginning of a new episode, we also need
to pass an indicator of which steps correspond to the first step of an episode.
```python
should_reset = jnp.equal(
trajectory.step_type[:-1], int(dm_env.StepType.FIRST))
core_inputs = (sr_core_inputs, should_reset)
```
We can now produce an unroll using `hk.dynamic_unroll` and passing it the SR
core, the core inputs we produced, and the initial state of the unroll, which
corresponds to the agent state in the first step of the trajectory:
```python
state = jax.tree_map(lambda t: t[0], trajectory.agent_state)
core_output, state = hk.dynamic_unroll(
sr_agent_core, core_inputs, state)
```
The SR wrapper produces 4 output tensors: the output of the agent core, the
synthetic returns, the SR-augmented return, and the SR loss.
The synthetic returns are taken into account when computing the augmented return
and the SR loss. Therefore they are not needed anymore and can be discarded or
used for logging purposes.
The agent core outputs should be used, as usual, for producing a policy. In an
actor-critic, policy gradient set-up, like IMPALA, we would produce policy
logits and values:
```python
policy_logits = hk.BatchApply(policy_net)(core_output.output)
value = hk.BatchApply(baseline_net)(core_output.output)
```
Similarly, in a Q-learning setting we would use the agent core outputs to
produce q-values.
The SR-augmented returns should be used in place of the environment rewards for
the policy updates (e.g. when computing the policy gradient and baseline
losses):
```python
rewards = core_output.augmented_return
```
Finally, the SR loss, summed over batch and time dimensions, should be added to
the total learner loss to be minimized:
```python
total_loss += jnp.sum(core_output.sr_loss)
```

View File

@@ -0,0 +1,2 @@
dm-haiku>=0.0.3
jax>=0.2.8

View File

@@ -0,0 +1,187 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Episodic Memory and Synthetic Returns Core Wrapper modules."""
import collections
import haiku as hk
import jax
import jax.numpy as jnp
SRCoreWrapperOutput = collections.namedtuple(
"SRCoreWrapperOutput", ["output", "synthetic_return", "augmented_return",
"sr_loss"])
class EpisodicMemory(hk.RNNCore):
"""Episodic Memory module."""
def __init__(self, memory_size, capacity, name="episodic_memory"):
"""Constructor.
Args:
memory_size: Integer. The size of the vectors to be stored.
capacity: Integer. The maximum number of memories to store before it
becomes necessary to overwrite old memories.
name: String. A name for this Haiku module instance.
"""
super().__init__(name=name)
self._memory_size = memory_size
self._capacity = capacity
def __call__(self, inputs, prev_state):
"""Writes a new memory into the episodic memory.
Args:
inputs: A Tensor of shape ``[batch_size, memory_size]``.
prev_state: The previous state of the episodic memory, which is a tuple
with a (i) counter of shape ``[batch_size, 1]`` indicating how many
memories have been written so far, and (ii) a tensor of shape
``[batch_size, capacity, memory_size]`` with the full content of the
episodic memory.
Returns:
A tuple with (i) a tensor of shape ``[batch_size, capacity, memory_size]``
with the full content of the episodic memory, including the newly
written memory, and (ii) the new state of the episodic memory.
"""
inputs = jax.lax.stop_gradient(inputs)
counter, memories = prev_state
counter_mod = jnp.mod(counter, self._capacity)
slot_selector = jnp.expand_dims(
jax.nn.one_hot(counter_mod, self._capacity), axis=2)
memories = memories * (1 - slot_selector) + (
slot_selector * jnp.expand_dims(inputs, 1))
counter = counter + 1
return memories, (counter, memories)
def initial_state(self, batch_size):
"""Creates the initial state of the episodic memory.
Args:
batch_size: Integer. The batch size of the episodic memory.
Returns:
A tuple with (i) a counter of shape ``[batch_size, 1]`` and (ii) a tensor
of shape ``[batch_size, capacity, memory_size]`` with the full content
of the episodic memory.
"""
if batch_size is None:
shape = []
else:
shape = [batch_size]
counter = jnp.zeros(shape)
memories = jnp.zeros(shape + [self._capacity, self._memory_size])
return (counter, memories)
class SyntheticReturnsCoreWrapper(hk.RNNCore):
"""Synthetic Returns core wrapper."""
def __init__(self, core, memory_size, capacity, hidden_layers, alpha, beta,
loss_func=(lambda x, y: 0.5 * jnp.square(x - y)),
apply_core_to_input=False, name="synthetic_returns_wrapper"):
"""Constructor.
Args:
core: hk.RNNCore. The recurrent core of the agent. E.g. an LSTM.
memory_size: Integer. The size of the vectors to be stored in the episodic
memory.
capacity: Integer. The maximum number of memories to store before it
becomes necessary to overwrite old memories.
hidden_layers: Tuple or list of integers, indicating the size of the
hidden layers of the MLPs used to produce synthetic returns, current
state bias, and gate.
alpha: The multiplier of the synthetic returns term in the augmented
return.
beta: The multiplier of the environment returns term in the augmented
return.
loss_func: A function of two arguments (predictions and targets) to
compute the SR loss.
apply_core_to_input: Boolean. Whether to apply the core on the inputs. If
true, the synthetic returns will be computed from the outputs of the
RNN core passed to the constructor. If false, the RNN core will be
applied only at the output of this wrapper, and the synthetic returns
will be computed from the inputs.
name: String. A name for this Haiku module instance.
"""
super().__init__(name=name)
self._em = EpisodicMemory(memory_size, capacity)
self._capacity = capacity
hidden_layers = list(hidden_layers)
self._synthetic_return = hk.nets.MLP(hidden_layers + [1])
self._bias = hk.nets.MLP(hidden_layers + [1])
self._gate = hk.Sequential([
hk.nets.MLP(hidden_layers + [1]),
jax.nn.sigmoid,
])
self._apply_core_to_input = apply_core_to_input
self._core = core
self._alpha = alpha
self._beta = beta
self._loss = loss_func
def initial_state(self, batch_size):
return (
self._em.initial_state(batch_size),
self._core.initial_state(batch_size)
)
def __call__(self, inputs, prev_state):
current_input, return_target = inputs
em_state, core_state = prev_state
(counter, memories) = em_state
if self._apply_core_to_input:
current_input, core_state = self._core(current_input, core_state)
# Synthetic return for the current state
synth_return = jnp.squeeze(self._synthetic_return(current_input), -1)
# Current state bias term
bias = self._bias(current_input)
# Gate computed from current state
gate = self._gate(current_input)
# When counter > capacity, mask will be all ones
mask = 1 - jnp.cumsum(jax.nn.one_hot(counter, self._capacity), axis=1)
mask = jnp.expand_dims(mask, axis=2)
# Synthetic returns for each state in memory
past_synth_returns = hk.BatchApply(self._synthetic_return)(memories)
# Sum of synthetic returns from previous states
sr_sum = jnp.sum(past_synth_returns * mask, axis=1)
prediction = jnp.squeeze(sr_sum * gate + bias, -1)
sr_loss = self._loss(prediction, return_target)
augmented_return = jax.lax.stop_gradient(
self._alpha * synth_return + self._beta * return_target)
# Write current state to memory
_, em_state = self._em(current_input, em_state)
if not self._apply_core_to_input:
output, core_state = self._core(current_input, core_state)
else:
output = current_input
output = SRCoreWrapperOutput(
output=output,
synthetic_return=synth_return,
augmented_return=augmented_return,
sr_loss=sr_loss,
)
return output, (em_state, core_state)