mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-17 14:14:15 +08:00
Internal change.
PiperOrigin-RevId: 372158522
This commit is contained in:
committed by
Louise Deason
parent
2e866f1937
commit
59f5fb1268
269
cmtouch/CMTouch_Dataset_Visulization.ipynb
Normal file
269
cmtouch/CMTouch_Dataset_Visulization.ipynb
Normal file
File diff suppressed because one or more lines are too long
209
cmtouch/README.md
Normal file
209
cmtouch/README.md
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# CMTouch Dataset
|
||||||
|
|
||||||
|
<!-- 
|
||||||
|

|
||||||
|

|
||||||
|
-->
|
||||||
|
|
||||||
|
This repository contains datasets for cross-modal representation learning, used
|
||||||
|
in developing rich touch representations in "Learning rich touch representations
|
||||||
|
through cross-modal self-supervision" [1].
|
||||||
|
|
||||||
|
The datasets we provide are:
|
||||||
|
|
||||||
|
1. CMTouch-Props
|
||||||
|
2. CMTouch-YCB
|
||||||
|
|
||||||
|
The datasets consist of episodes collected by running a reinforcement learning
|
||||||
|
agent on a simulated Shadow Dexterous Hand [2] interacting with different
|
||||||
|
objects. From this interactions, observations from different sensory modalities
|
||||||
|
are collected at each time step, including vision, proprioception (joint
|
||||||
|
positions and velocities), touch, actions, object IDs. We used these data to
|
||||||
|
learn rich touch representations using cross-modal self-supervision.
|
||||||
|
|
||||||
|
## Bibtex
|
||||||
|
|
||||||
|
If you use one of these datasets in your work, please cite the reference paper
|
||||||
|
as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
@InProceedings{zambelli20learning,
|
||||||
|
author = "Zambelli, Martina and Aytar, Yusuf and Visin, Francesco and Zhou, Yuxiang and Hadsell, Raia",
|
||||||
|
title = "Learning rich touch representations through cross-modal self-supervision",
|
||||||
|
year = "2020",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
<!--
|
||||||
|
@misc{cmtouchdatasets}, title={CMTouch Datasets}, author={Zambelli, Martina and
|
||||||
|
Aytar, Yusuf and Visin, Francesco and Zhou, Yuxiang and Hadsell, Raia},
|
||||||
|
howpublished={https://github.com/deepmind/deepmind-research/tree/master/cmtouch},
|
||||||
|
year={2020} }
|
||||||
|
-->
|
||||||
|
|
||||||
|
## Descriptions
|
||||||
|
|
||||||
|
### Experimental setup
|
||||||
|
|
||||||
|
We run experiments in simulation with MuJoCo [3] and we use the simulated Shadow
|
||||||
|
Dexterous Hand [2], with five fingers and 24 degrees of freedom, actuated by 20
|
||||||
|
motors. In simulation, each fingertip has a spatial touch sensor attached with a
|
||||||
|
spatial resolution of 4×4 and three channels: one for normal force and two for
|
||||||
|
tangential forces. We simplify this by summing across the spatial dimensions,
|
||||||
|
to obtain a single force vector for each fingertip representing one normal force
|
||||||
|
and two tangential forces. The state consists of proprioception (joint positions
|
||||||
|
and joint velocities) and touch.
|
||||||
|
|
||||||
|
Visual inputs are collected with a 64×64 resolution and are only used for
|
||||||
|
representation learning, but are not provided as observations to control the
|
||||||
|
robot’s actions. The action space is 20-dimensional. We use velocity control and
|
||||||
|
a control rate of 30 Hz. Each episode has 200 time steps, which correspond to
|
||||||
|
about 6 seconds. The environment consists of the Shadow Hand, facing down, and
|
||||||
|
interacting with different objects. These objects have different shapes, sizes
|
||||||
|
and physical properties (e.g. rigid or soft). We develop two versions of the
|
||||||
|
task, the first using simple props and the second using YCB objects. In both
|
||||||
|
cases, objects are fixed to their frame of reference, while their position and
|
||||||
|
orientation are randomized.
|
||||||
|
|
||||||
|
### CMTouch-Props
|
||||||
|
|
||||||
|
This is a dataset based on simple geometric 3D shapes (referred to as "props").
|
||||||
|
Props are simple 3D shaped objects that include cubes, spheres, cylinders and
|
||||||
|
ellipsoid of different sizes. We also generated the soft version of each prop,
|
||||||
|
which can deform under the pressure of the touching fingers.
|
||||||
|
|
||||||
|
Soft deformable objects are complex entities to simulate: they are defined
|
||||||
|
through a composition of multiple bodies (capsules) that are tied together to
|
||||||
|
form a shape, such as a cube or a sphere. The main characteristic of these
|
||||||
|
objects is their elastic behaviour, that is they change shape when touched. The
|
||||||
|
most difficult thing to simulate in this context is contacts, which grow
|
||||||
|
exponentially with the increased number of colliding bodies.
|
||||||
|
|
||||||
|
Forty-eight different objects are generated by sampling from 6 different sizes,
|
||||||
|
4 different shapes (i.e. sphere, cylinder, cube, ellipsoid), and they can either
|
||||||
|
be rigid or soft.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### CMTouch-YCB
|
||||||
|
|
||||||
|
This is a dataset based on YCB objects. The YCB objects dataset [4] consists of
|
||||||
|
everyday objects with different shapes, sizes, textures, weight and rigidity.
|
||||||
|
|
||||||
|
We chose a set of ten objects: cracker box, sugar box, mustard bottle, potted
|
||||||
|
meat can, banana, pitcher base, bleach cleanser, mug, power drill, scissors.
|
||||||
|
These are generated in simulation at their standard size, which is also
|
||||||
|
proportionate to the default dimension of the simulated Shadow Hand.
|
||||||
|
|
||||||
|
The pose of each object is randomly selected among a set of 60 different poses,
|
||||||
|
where we vary the orientation of the object. These variations make the
|
||||||
|
identification of each object more complex than the CMTouch-Props and require a
|
||||||
|
higher generalization capability from the learning method applied.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Download
|
||||||
|
|
||||||
|
The datasets can be downloaded from
|
||||||
|
[Google Cloud Storage](https://console.cloud.google.com/storage/browser/dm_cmtouch).
|
||||||
|
Each dataset is a single
|
||||||
|
[TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file.
|
||||||
|
|
||||||
|
On Linux, to download a particular dataset, use the web interface, or run `wget`
|
||||||
|
with the appropriate filename as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_props_all_test.tfrecords
|
||||||
|
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_props_all_train.tfrecords
|
||||||
|
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_props_all_val.tfrecords
|
||||||
|
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_ycb_all_test.tfrecords
|
||||||
|
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_ycb_all_train.tfrecords
|
||||||
|
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_ycb_all_val.tfrecords
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
After downloading the dataset files, you can read them as `tf.data.Dataset`
|
||||||
|
instances with the readers provided. The example below shows how to read the
|
||||||
|
cmtouch-props dataset:
|
||||||
|
|
||||||
|
```
|
||||||
|
record_file = 'test.tfrecords'
|
||||||
|
dataset = tf.data.TFRecordDataset(record_file)
|
||||||
|
parsed_dataset = dataset.map(_parse_tf_example)
|
||||||
|
```
|
||||||
|
|
||||||
|
(a complete example is provided in the Colab).
|
||||||
|
|
||||||
|
All dataset readers return the following set of observations:
|
||||||
|
|
||||||
|
'camera': tf.io.FixedLenFeature([], tf.string),
|
||||||
|
'camera/height': tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
'camera/width': tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
'camera/channel': tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
'object_id': tf.io.FixedLenFeature([], tf.string), # for both
|
||||||
|
'object_id/dim': tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
'orientation_id': tf.io.FixedLenFeature([], tf.string), # only for ycb
|
||||||
|
'orientation_id/dim': tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
'shadowhand_motor/joints_vel': tf.io.FixedLenFeature([], tf.string),
|
||||||
|
'shadowhand_motor/joints_vel/dim': tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
'shadowhand_motor/joints_pos': tf.io.FixedLenFeature([], tf.string),
|
||||||
|
'shadowhand_motor/joints_pos/dim': tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
'shadowhand_motor/spatial_touch': tf.io.FixedLenFeature([], tf.string),
|
||||||
|
'shadowhand_motor/spatial_touch/dim': tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
'actions'
|
||||||
|
|
||||||
|
* 'camera': `Tensor` of shape [sequence_length, height, width, channels] and type
|
||||||
|
uint8
|
||||||
|
|
||||||
|
* 'shadowhand_motor/spatial_touch': `Tensor` of shape [sequence_length, num_fingers x 3] and type float32
|
||||||
|
|
||||||
|
* 'shadowhand_motor/joints_pos': `Tensor` of shape [sequence_length, num_joint_positions] and type float32
|
||||||
|
|
||||||
|
* 'shadowhand_motor/joints_vel': `Tensor` of shape [sequence_length,
|
||||||
|
num_joint_velocities] and type float32
|
||||||
|
|
||||||
|
* 'actions': `Tensor` of shape [sequence_length, num_actuated_joints] and type
|
||||||
|
float32
|
||||||
|
|
||||||
|
* 'object_id': `Scalar` indicating an object identification number
|
||||||
|
|
||||||
|
* 'orientation_id': `Scalar` indicating a YCB object pose identification number
|
||||||
|
(CMTouch-YCB only)
|
||||||
|
|
||||||
|
Few-shot evaluations can be made by creating subsets of data to train and
|
||||||
|
evaluate the models.
|
||||||
|
|
||||||
|
<!--
|
||||||
|
```diff=
|
||||||
|
- TODO
|
||||||
|
```
|
||||||
|
-->
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
[1] M. Zambelli, Y. Aytar, F. Visin, Y. Zhou, R. Hadsell. Learning rich touch
|
||||||
|
representations through cross-modal self-supervision. Conference on Robot
|
||||||
|
Learning (CoRL), 2020.
|
||||||
|
|
||||||
|
[2] ShadowRobot, Shadow Dexterous Hand.
|
||||||
|
https://www.shadowrobot.com/products/dexterous-hand/.
|
||||||
|
|
||||||
|
[3] E. Todorov, T. Erez, and Y. Tassa. MuJoCo: A physics engine for model-based
|
||||||
|
control. In Proceedings of the International Conference on Intelligent Robots
|
||||||
|
and Systems (IROS), 2012.
|
||||||
|
|
||||||
|
[4] B. Calli, A. Singh, J. Bruce, A. Walsman, K. Konolige, S. Srinivasa, P.
|
||||||
|
Abbeel, and A. M. Dollar. Yale-cmu-berkeley dataset for robotic manipulation
|
||||||
|
research. The International Journal of RoboticsResearch, 36(3):261–268, 2017.
|
||||||
|
|
||||||
|
## Disclaimers
|
||||||
|
|
||||||
|
This is not an official Google product.
|
||||||
|
|
||||||
|
## Appendix and FAQ
|
||||||
|
|
||||||
|
**Find this document incomplete?** Leave a comment!
|
||||||
123
cmtouch/download_datasets.sh
Executable file
123
cmtouch/download_datasets.sh
Executable file
@@ -0,0 +1,123 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train10.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train100.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train1000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train250.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train30.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train50.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train500.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_val.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train1500.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train180.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train300.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train3000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train60.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train600.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train6000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_val.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train1500.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train180.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train300.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train3000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train60.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train600.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train6000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_val.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train1500.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train180.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train300.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train3000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train60.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train600.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train6000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_val.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train1500.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train180.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train300.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train3000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train60.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train600.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train6000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_val.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train1500.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train180.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train300.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train3000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train60.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train600.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train6000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_val.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train1500.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train180.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train300.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train3000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train60.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train600.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train6000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_val.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train1500.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train180.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train300.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train3000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train60.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train600.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train6000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_val.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train1500.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train180.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train300.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train3000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train60.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train600.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train6000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_val.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train1500.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train180.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train300.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train3000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train60.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train600.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train6000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_val.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train1500.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train180.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train300.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train3000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train60.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train600.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train6000.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_val.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train1200.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train144.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train240.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train2400.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train48.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train480.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train4800.tfrecords
|
||||||
|
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_val.tfrecords
|
||||||
125
synthetic_returns/README.md
Normal file
125
synthetic_returns/README.md
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# Code for Synthetic Returns
|
||||||
|
|
||||||
|
This repository contains code for the arXiv preprint
|
||||||
|
["Synthetic Returns for Long-Term Credit Assignment"](https://arxiv.org/abs/2102.12425)
|
||||||
|
by David Raposo, Sam Ritter, Adam Santoro, Greg Wayne, Theophane Weber, Matt
|
||||||
|
Botvinick, Hado van Hasselt, and Francis Song.
|
||||||
|
|
||||||
|
To cite this work:
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{raposo2021synthetic,
|
||||||
|
title={Rapid Task-Solving in Novel Environments},
|
||||||
|
author={Raposo, David and Ritter, Sam and Santoro, Adam and Wayne, Greg and
|
||||||
|
Weber, Theophane and Botvinick, Matt and van Hasselt, Hado and Song, Francis},
|
||||||
|
journal={arXiv preprint arXiv:2102.12425},
|
||||||
|
year={2021}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Agent core wrapper
|
||||||
|
|
||||||
|
We implemented the Synthetic Returns module as a wrapper to a recurrent neural
|
||||||
|
network (RNN), so it should be compatible with any Deep-RL agent with an
|
||||||
|
arbitrary RNN core, whose inputs consist of batches of vectors. This could be an
|
||||||
|
LSTM as in the example below, or a more sophisticated core as long as it
|
||||||
|
implements an `hk.RNNCore`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
agent_core = hk.LSTM(128)
|
||||||
|
```
|
||||||
|
|
||||||
|
To build the SR wrapper, simply pass the existing agent core to the constructor,
|
||||||
|
along with the SR configuration:
|
||||||
|
|
||||||
|
```python
|
||||||
|
sr_config = {
|
||||||
|
"memory_size": 128,
|
||||||
|
"capacity": 300,
|
||||||
|
"hidden_layers": (128, 128),
|
||||||
|
"alpha": 0.3,
|
||||||
|
"beta": 1.0,
|
||||||
|
}
|
||||||
|
sr_agent_core = hk.ResetCore(
|
||||||
|
SyntheticReturnsCoreWrapper(core=agent_core, **sr_config))
|
||||||
|
```
|
||||||
|
|
||||||
|
Typically, the SR wrapper should itself be wrapped in a `hk.ResetCore` in order
|
||||||
|
to reset the core state in the beginning of a new episode. This will reset not
|
||||||
|
only the episodic memory but also the original agent core that was passed to the
|
||||||
|
SR wrapper constructor.
|
||||||
|
|
||||||
|
### Learner
|
||||||
|
|
||||||
|
Consider the distributed setting, wherein a learner receives mini-batches of
|
||||||
|
trajectories of length `T` produced by the actors.
|
||||||
|
|
||||||
|
`trajectory` is a nested structure of tensors of size `[T,B,...]` (where `B` is
|
||||||
|
the batch size) containing observations, agent states, rewards and step type
|
||||||
|
indicators.
|
||||||
|
|
||||||
|
We start by producing inputs to the SR core, which consist of tuples of current
|
||||||
|
state embeddings and return targets. The current state embeddings can be
|
||||||
|
produced by a ConvNet, for example. In our experiments we used the current step
|
||||||
|
reward as target. Note that the current step reward correspond to the rewards in
|
||||||
|
the trajectory shifted by one, relative to the observations:
|
||||||
|
|
||||||
|
```python
|
||||||
|
observations = jax.tree_map(lambda x: x[:-1], trajectory.observation)
|
||||||
|
vision_output = hk.BatchApply(vision_net)(observations)
|
||||||
|
return_targets = trajectory.reward[1:]
|
||||||
|
sr_core_inputs = (vision_output, return_targets)
|
||||||
|
```
|
||||||
|
|
||||||
|
For purposes of core resetting at the beginning of a new episode, we also need
|
||||||
|
to pass an indicator of which steps correspond to the first step of an episode.
|
||||||
|
|
||||||
|
```python
|
||||||
|
should_reset = jnp.equal(
|
||||||
|
trajectory.step_type[:-1], int(dm_env.StepType.FIRST))
|
||||||
|
core_inputs = (sr_core_inputs, should_reset)
|
||||||
|
```
|
||||||
|
|
||||||
|
We can now produce an unroll using `hk.dynamic_unroll` and passing it the SR
|
||||||
|
core, the core inputs we produced, and the initial state of the unroll, which
|
||||||
|
corresponds to the agent state in the first step of the trajectory:
|
||||||
|
|
||||||
|
```python
|
||||||
|
state = jax.tree_map(lambda t: t[0], trajectory.agent_state)
|
||||||
|
core_output, state = hk.dynamic_unroll(
|
||||||
|
sr_agent_core, core_inputs, state)
|
||||||
|
```
|
||||||
|
|
||||||
|
The SR wrapper produces 4 output tensors: the output of the agent core, the
|
||||||
|
synthetic returns, the SR-augmented return, and the SR loss.
|
||||||
|
|
||||||
|
The synthetic returns are taken into account when computing the augmented return
|
||||||
|
and the SR loss. Therefore they are not needed anymore and can be discarded or
|
||||||
|
used for logging purposes.
|
||||||
|
|
||||||
|
The agent core outputs should be used, as usual, for producing a policy. In an
|
||||||
|
actor-critic, policy gradient set-up, like IMPALA, we would produce policy
|
||||||
|
logits and values:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy_logits = hk.BatchApply(policy_net)(core_output.output)
|
||||||
|
value = hk.BatchApply(baseline_net)(core_output.output)
|
||||||
|
```
|
||||||
|
|
||||||
|
Similarly, in a Q-learning setting we would use the agent core outputs to
|
||||||
|
produce q-values.
|
||||||
|
|
||||||
|
The SR-augmented returns should be used in place of the environment rewards for
|
||||||
|
the policy updates (e.g. when computing the policy gradient and baseline
|
||||||
|
losses):
|
||||||
|
|
||||||
|
```python
|
||||||
|
rewards = core_output.augmented_return
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally, the SR loss, summed over batch and time dimensions, should be added to
|
||||||
|
the total learner loss to be minimized:
|
||||||
|
|
||||||
|
```python
|
||||||
|
total_loss += jnp.sum(core_output.sr_loss)
|
||||||
|
```
|
||||||
2
synthetic_returns/requirements.txt
Normal file
2
synthetic_returns/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
dm-haiku>=0.0.3
|
||||||
|
jax>=0.2.8
|
||||||
187
synthetic_returns/synthetic_returns.py
Normal file
187
synthetic_returns/synthetic_returns.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Episodic Memory and Synthetic Returns Core Wrapper modules."""
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import haiku as hk
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
SRCoreWrapperOutput = collections.namedtuple(
|
||||||
|
"SRCoreWrapperOutput", ["output", "synthetic_return", "augmented_return",
|
||||||
|
"sr_loss"])
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodicMemory(hk.RNNCore):
|
||||||
|
"""Episodic Memory module."""
|
||||||
|
|
||||||
|
def __init__(self, memory_size, capacity, name="episodic_memory"):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_size: Integer. The size of the vectors to be stored.
|
||||||
|
capacity: Integer. The maximum number of memories to store before it
|
||||||
|
becomes necessary to overwrite old memories.
|
||||||
|
name: String. A name for this Haiku module instance.
|
||||||
|
"""
|
||||||
|
super().__init__(name=name)
|
||||||
|
self._memory_size = memory_size
|
||||||
|
self._capacity = capacity
|
||||||
|
|
||||||
|
def __call__(self, inputs, prev_state):
|
||||||
|
"""Writes a new memory into the episodic memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A Tensor of shape ``[batch_size, memory_size]``.
|
||||||
|
prev_state: The previous state of the episodic memory, which is a tuple
|
||||||
|
with a (i) counter of shape ``[batch_size, 1]`` indicating how many
|
||||||
|
memories have been written so far, and (ii) a tensor of shape
|
||||||
|
``[batch_size, capacity, memory_size]`` with the full content of the
|
||||||
|
episodic memory.
|
||||||
|
Returns:
|
||||||
|
A tuple with (i) a tensor of shape ``[batch_size, capacity, memory_size]``
|
||||||
|
with the full content of the episodic memory, including the newly
|
||||||
|
written memory, and (ii) the new state of the episodic memory.
|
||||||
|
"""
|
||||||
|
inputs = jax.lax.stop_gradient(inputs)
|
||||||
|
counter, memories = prev_state
|
||||||
|
counter_mod = jnp.mod(counter, self._capacity)
|
||||||
|
slot_selector = jnp.expand_dims(
|
||||||
|
jax.nn.one_hot(counter_mod, self._capacity), axis=2)
|
||||||
|
memories = memories * (1 - slot_selector) + (
|
||||||
|
slot_selector * jnp.expand_dims(inputs, 1))
|
||||||
|
counter = counter + 1
|
||||||
|
return memories, (counter, memories)
|
||||||
|
|
||||||
|
def initial_state(self, batch_size):
|
||||||
|
"""Creates the initial state of the episodic memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Integer. The batch size of the episodic memory.
|
||||||
|
Returns:
|
||||||
|
A tuple with (i) a counter of shape ``[batch_size, 1]`` and (ii) a tensor
|
||||||
|
of shape ``[batch_size, capacity, memory_size]`` with the full content
|
||||||
|
of the episodic memory.
|
||||||
|
"""
|
||||||
|
if batch_size is None:
|
||||||
|
shape = []
|
||||||
|
else:
|
||||||
|
shape = [batch_size]
|
||||||
|
counter = jnp.zeros(shape)
|
||||||
|
memories = jnp.zeros(shape + [self._capacity, self._memory_size])
|
||||||
|
return (counter, memories)
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticReturnsCoreWrapper(hk.RNNCore):
|
||||||
|
"""Synthetic Returns core wrapper."""
|
||||||
|
|
||||||
|
def __init__(self, core, memory_size, capacity, hidden_layers, alpha, beta,
|
||||||
|
loss_func=(lambda x, y: 0.5 * jnp.square(x - y)),
|
||||||
|
apply_core_to_input=False, name="synthetic_returns_wrapper"):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
core: hk.RNNCore. The recurrent core of the agent. E.g. an LSTM.
|
||||||
|
memory_size: Integer. The size of the vectors to be stored in the episodic
|
||||||
|
memory.
|
||||||
|
capacity: Integer. The maximum number of memories to store before it
|
||||||
|
becomes necessary to overwrite old memories.
|
||||||
|
hidden_layers: Tuple or list of integers, indicating the size of the
|
||||||
|
hidden layers of the MLPs used to produce synthetic returns, current
|
||||||
|
state bias, and gate.
|
||||||
|
alpha: The multiplier of the synthetic returns term in the augmented
|
||||||
|
return.
|
||||||
|
beta: The multiplier of the environment returns term in the augmented
|
||||||
|
return.
|
||||||
|
loss_func: A function of two arguments (predictions and targets) to
|
||||||
|
compute the SR loss.
|
||||||
|
apply_core_to_input: Boolean. Whether to apply the core on the inputs. If
|
||||||
|
true, the synthetic returns will be computed from the outputs of the
|
||||||
|
RNN core passed to the constructor. If false, the RNN core will be
|
||||||
|
applied only at the output of this wrapper, and the synthetic returns
|
||||||
|
will be computed from the inputs.
|
||||||
|
name: String. A name for this Haiku module instance.
|
||||||
|
"""
|
||||||
|
super().__init__(name=name)
|
||||||
|
self._em = EpisodicMemory(memory_size, capacity)
|
||||||
|
self._capacity = capacity
|
||||||
|
hidden_layers = list(hidden_layers)
|
||||||
|
self._synthetic_return = hk.nets.MLP(hidden_layers + [1])
|
||||||
|
self._bias = hk.nets.MLP(hidden_layers + [1])
|
||||||
|
self._gate = hk.Sequential([
|
||||||
|
hk.nets.MLP(hidden_layers + [1]),
|
||||||
|
jax.nn.sigmoid,
|
||||||
|
])
|
||||||
|
self._apply_core_to_input = apply_core_to_input
|
||||||
|
self._core = core
|
||||||
|
self._alpha = alpha
|
||||||
|
self._beta = beta
|
||||||
|
self._loss = loss_func
|
||||||
|
|
||||||
|
def initial_state(self, batch_size):
|
||||||
|
return (
|
||||||
|
self._em.initial_state(batch_size),
|
||||||
|
self._core.initial_state(batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, inputs, prev_state):
|
||||||
|
current_input, return_target = inputs
|
||||||
|
|
||||||
|
em_state, core_state = prev_state
|
||||||
|
(counter, memories) = em_state
|
||||||
|
|
||||||
|
if self._apply_core_to_input:
|
||||||
|
current_input, core_state = self._core(current_input, core_state)
|
||||||
|
|
||||||
|
# Synthetic return for the current state
|
||||||
|
synth_return = jnp.squeeze(self._synthetic_return(current_input), -1)
|
||||||
|
|
||||||
|
# Current state bias term
|
||||||
|
bias = self._bias(current_input)
|
||||||
|
|
||||||
|
# Gate computed from current state
|
||||||
|
gate = self._gate(current_input)
|
||||||
|
|
||||||
|
# When counter > capacity, mask will be all ones
|
||||||
|
mask = 1 - jnp.cumsum(jax.nn.one_hot(counter, self._capacity), axis=1)
|
||||||
|
mask = jnp.expand_dims(mask, axis=2)
|
||||||
|
|
||||||
|
# Synthetic returns for each state in memory
|
||||||
|
past_synth_returns = hk.BatchApply(self._synthetic_return)(memories)
|
||||||
|
|
||||||
|
# Sum of synthetic returns from previous states
|
||||||
|
sr_sum = jnp.sum(past_synth_returns * mask, axis=1)
|
||||||
|
|
||||||
|
prediction = jnp.squeeze(sr_sum * gate + bias, -1)
|
||||||
|
sr_loss = self._loss(prediction, return_target)
|
||||||
|
|
||||||
|
augmented_return = jax.lax.stop_gradient(
|
||||||
|
self._alpha * synth_return + self._beta * return_target)
|
||||||
|
|
||||||
|
# Write current state to memory
|
||||||
|
_, em_state = self._em(current_input, em_state)
|
||||||
|
|
||||||
|
if not self._apply_core_to_input:
|
||||||
|
output, core_state = self._core(current_input, core_state)
|
||||||
|
else:
|
||||||
|
output = current_input
|
||||||
|
|
||||||
|
output = SRCoreWrapperOutput(
|
||||||
|
output=output,
|
||||||
|
synthetic_return=synth_return,
|
||||||
|
augmented_return=augmented_return,
|
||||||
|
sr_loss=sr_loss,
|
||||||
|
)
|
||||||
|
return output, (em_state, core_state)
|
||||||
Reference in New Issue
Block a user