mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-11 17:39:09 +08:00
Internal change.
PiperOrigin-RevId: 372158522
This commit is contained in:
committed by
Louise Deason
parent
2e866f1937
commit
59f5fb1268
269
cmtouch/CMTouch_Dataset_Visulization.ipynb
Normal file
269
cmtouch/CMTouch_Dataset_Visulization.ipynb
Normal file
File diff suppressed because one or more lines are too long
209
cmtouch/README.md
Normal file
209
cmtouch/README.md
Normal file
@@ -0,0 +1,209 @@
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
# CMTouch Dataset
|
||||
|
||||
<!-- 
|
||||

|
||||

|
||||
-->
|
||||
|
||||
This repository contains datasets for cross-modal representation learning, used
|
||||
in developing rich touch representations in "Learning rich touch representations
|
||||
through cross-modal self-supervision" [1].
|
||||
|
||||
The datasets we provide are:
|
||||
|
||||
1. CMTouch-Props
|
||||
2. CMTouch-YCB
|
||||
|
||||
The datasets consist of episodes collected by running a reinforcement learning
|
||||
agent on a simulated Shadow Dexterous Hand [2] interacting with different
|
||||
objects. From this interactions, observations from different sensory modalities
|
||||
are collected at each time step, including vision, proprioception (joint
|
||||
positions and velocities), touch, actions, object IDs. We used these data to
|
||||
learn rich touch representations using cross-modal self-supervision.
|
||||
|
||||
## Bibtex
|
||||
|
||||
If you use one of these datasets in your work, please cite the reference paper
|
||||
as follows:
|
||||
|
||||
```
|
||||
@InProceedings{zambelli20learning,
|
||||
author = "Zambelli, Martina and Aytar, Yusuf and Visin, Francesco and Zhou, Yuxiang and Hadsell, Raia",
|
||||
title = "Learning rich touch representations through cross-modal self-supervision",
|
||||
year = "2020",
|
||||
}
|
||||
```
|
||||
|
||||
<!--
|
||||
@misc{cmtouchdatasets}, title={CMTouch Datasets}, author={Zambelli, Martina and
|
||||
Aytar, Yusuf and Visin, Francesco and Zhou, Yuxiang and Hadsell, Raia},
|
||||
howpublished={https://github.com/deepmind/deepmind-research/tree/master/cmtouch},
|
||||
year={2020} }
|
||||
-->
|
||||
|
||||
## Descriptions
|
||||
|
||||
### Experimental setup
|
||||
|
||||
We run experiments in simulation with MuJoCo [3] and we use the simulated Shadow
|
||||
Dexterous Hand [2], with five fingers and 24 degrees of freedom, actuated by 20
|
||||
motors. In simulation, each fingertip has a spatial touch sensor attached with a
|
||||
spatial resolution of 4×4 and three channels: one for normal force and two for
|
||||
tangential forces. We simplify this by summing across the spatial dimensions,
|
||||
to obtain a single force vector for each fingertip representing one normal force
|
||||
and two tangential forces. The state consists of proprioception (joint positions
|
||||
and joint velocities) and touch.
|
||||
|
||||
Visual inputs are collected with a 64×64 resolution and are only used for
|
||||
representation learning, but are not provided as observations to control the
|
||||
robot’s actions. The action space is 20-dimensional. We use velocity control and
|
||||
a control rate of 30 Hz. Each episode has 200 time steps, which correspond to
|
||||
about 6 seconds. The environment consists of the Shadow Hand, facing down, and
|
||||
interacting with different objects. These objects have different shapes, sizes
|
||||
and physical properties (e.g. rigid or soft). We develop two versions of the
|
||||
task, the first using simple props and the second using YCB objects. In both
|
||||
cases, objects are fixed to their frame of reference, while their position and
|
||||
orientation are randomized.
|
||||
|
||||
### CMTouch-Props
|
||||
|
||||
This is a dataset based on simple geometric 3D shapes (referred to as "props").
|
||||
Props are simple 3D shaped objects that include cubes, spheres, cylinders and
|
||||
ellipsoid of different sizes. We also generated the soft version of each prop,
|
||||
which can deform under the pressure of the touching fingers.
|
||||
|
||||
Soft deformable objects are complex entities to simulate: they are defined
|
||||
through a composition of multiple bodies (capsules) that are tied together to
|
||||
form a shape, such as a cube or a sphere. The main characteristic of these
|
||||
objects is their elastic behaviour, that is they change shape when touched. The
|
||||
most difficult thing to simulate in this context is contacts, which grow
|
||||
exponentially with the increased number of colliding bodies.
|
||||
|
||||
Forty-eight different objects are generated by sampling from 6 different sizes,
|
||||
4 different shapes (i.e. sphere, cylinder, cube, ellipsoid), and they can either
|
||||
be rigid or soft.
|
||||
|
||||

|
||||
|
||||
### CMTouch-YCB
|
||||
|
||||
This is a dataset based on YCB objects. The YCB objects dataset [4] consists of
|
||||
everyday objects with different shapes, sizes, textures, weight and rigidity.
|
||||
|
||||
We chose a set of ten objects: cracker box, sugar box, mustard bottle, potted
|
||||
meat can, banana, pitcher base, bleach cleanser, mug, power drill, scissors.
|
||||
These are generated in simulation at their standard size, which is also
|
||||
proportionate to the default dimension of the simulated Shadow Hand.
|
||||
|
||||
The pose of each object is randomly selected among a set of 60 different poses,
|
||||
where we vary the orientation of the object. These variations make the
|
||||
identification of each object more complex than the CMTouch-Props and require a
|
||||
higher generalization capability from the learning method applied.
|
||||
|
||||

|
||||
|
||||
## Download
|
||||
|
||||
The datasets can be downloaded from
|
||||
[Google Cloud Storage](https://console.cloud.google.com/storage/browser/dm_cmtouch).
|
||||
Each dataset is a single
|
||||
[TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file.
|
||||
|
||||
On Linux, to download a particular dataset, use the web interface, or run `wget`
|
||||
with the appropriate filename as follows:
|
||||
|
||||
```
|
||||
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_props_all_test.tfrecords
|
||||
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_props_all_train.tfrecords
|
||||
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_props_all_val.tfrecords
|
||||
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_ycb_all_test.tfrecords
|
||||
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_ycb_all_train.tfrecords
|
||||
wget https://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_ycb_all_val.tfrecords
|
||||
```
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
After downloading the dataset files, you can read them as `tf.data.Dataset`
|
||||
instances with the readers provided. The example below shows how to read the
|
||||
cmtouch-props dataset:
|
||||
|
||||
```
|
||||
record_file = 'test.tfrecords'
|
||||
dataset = tf.data.TFRecordDataset(record_file)
|
||||
parsed_dataset = dataset.map(_parse_tf_example)
|
||||
```
|
||||
|
||||
(a complete example is provided in the Colab).
|
||||
|
||||
All dataset readers return the following set of observations:
|
||||
|
||||
'camera': tf.io.FixedLenFeature([], tf.string),
|
||||
'camera/height': tf.io.FixedLenFeature([], tf.int64),
|
||||
'camera/width': tf.io.FixedLenFeature([], tf.int64),
|
||||
'camera/channel': tf.io.FixedLenFeature([], tf.int64),
|
||||
'object_id': tf.io.FixedLenFeature([], tf.string), # for both
|
||||
'object_id/dim': tf.io.FixedLenFeature([], tf.int64),
|
||||
'orientation_id': tf.io.FixedLenFeature([], tf.string), # only for ycb
|
||||
'orientation_id/dim': tf.io.FixedLenFeature([], tf.int64),
|
||||
'shadowhand_motor/joints_vel': tf.io.FixedLenFeature([], tf.string),
|
||||
'shadowhand_motor/joints_vel/dim': tf.io.FixedLenFeature([], tf.int64),
|
||||
'shadowhand_motor/joints_pos': tf.io.FixedLenFeature([], tf.string),
|
||||
'shadowhand_motor/joints_pos/dim': tf.io.FixedLenFeature([], tf.int64),
|
||||
'shadowhand_motor/spatial_touch': tf.io.FixedLenFeature([], tf.string),
|
||||
'shadowhand_motor/spatial_touch/dim': tf.io.FixedLenFeature([], tf.int64),
|
||||
'actions'
|
||||
|
||||
* 'camera': `Tensor` of shape [sequence_length, height, width, channels] and type
|
||||
uint8
|
||||
|
||||
* 'shadowhand_motor/spatial_touch': `Tensor` of shape [sequence_length, num_fingers x 3] and type float32
|
||||
|
||||
* 'shadowhand_motor/joints_pos': `Tensor` of shape [sequence_length, num_joint_positions] and type float32
|
||||
|
||||
* 'shadowhand_motor/joints_vel': `Tensor` of shape [sequence_length,
|
||||
num_joint_velocities] and type float32
|
||||
|
||||
* 'actions': `Tensor` of shape [sequence_length, num_actuated_joints] and type
|
||||
float32
|
||||
|
||||
* 'object_id': `Scalar` indicating an object identification number
|
||||
|
||||
* 'orientation_id': `Scalar` indicating a YCB object pose identification number
|
||||
(CMTouch-YCB only)
|
||||
|
||||
Few-shot evaluations can be made by creating subsets of data to train and
|
||||
evaluate the models.
|
||||
|
||||
<!--
|
||||
```diff=
|
||||
- TODO
|
||||
```
|
||||
-->
|
||||
|
||||
## References
|
||||
|
||||
[1] M. Zambelli, Y. Aytar, F. Visin, Y. Zhou, R. Hadsell. Learning rich touch
|
||||
representations through cross-modal self-supervision. Conference on Robot
|
||||
Learning (CoRL), 2020.
|
||||
|
||||
[2] ShadowRobot, Shadow Dexterous Hand.
|
||||
https://www.shadowrobot.com/products/dexterous-hand/.
|
||||
|
||||
[3] E. Todorov, T. Erez, and Y. Tassa. MuJoCo: A physics engine for model-based
|
||||
control. In Proceedings of the International Conference on Intelligent Robots
|
||||
and Systems (IROS), 2012.
|
||||
|
||||
[4] B. Calli, A. Singh, J. Bruce, A. Walsman, K. Konolige, S. Srinivasa, P.
|
||||
Abbeel, and A. M. Dollar. Yale-cmu-berkeley dataset for robotic manipulation
|
||||
research. The International Journal of RoboticsResearch, 36(3):261–268, 2017.
|
||||
|
||||
## Disclaimers
|
||||
|
||||
This is not an official Google product.
|
||||
|
||||
## Appendix and FAQ
|
||||
|
||||
**Find this document incomplete?** Leave a comment!
|
||||
123
cmtouch/download_datasets.sh
Executable file
123
cmtouch/download_datasets.sh
Executable file
@@ -0,0 +1,123 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train10.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train100.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train1000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train250.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train30.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train50.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_train500.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_all_im64_val.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train1500.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train180.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train300.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train3000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train60.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train600.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_train6000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj0_im64_val.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train1500.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train180.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train300.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train3000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train60.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train600.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_train6000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj1_im64_val.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train1500.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train180.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train300.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train3000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train60.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train600.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_train6000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj2_im64_val.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train1500.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train180.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train300.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train3000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train60.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train600.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_train6000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj3_im64_val.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train1500.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train180.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train300.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train3000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train60.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train600.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_train6000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj4_im64_val.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train1500.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train180.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train300.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train3000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train60.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train600.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_train6000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj5_im64_val.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train1500.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train180.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train300.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train3000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train60.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train600.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_train6000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj6_im64_val.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train1500.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train180.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train300.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train3000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train60.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train600.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_train6000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj7_im64_val.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train1500.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train180.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train300.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train3000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train60.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train600.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_train6000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj8_im64_val.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train1500.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train180.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train300.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train3000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train60.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train600.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_train6000.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_objects_ycb_obj9_im64_val.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train1200.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train144.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train240.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train2400.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train48.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train480.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_train4800.tfrecords
|
||||
wget http://storage.googleapis.com/dm_cmtouch/datasets/cmtouch_touch_props_all_im64_val.tfrecords
|
||||
125
synthetic_returns/README.md
Normal file
125
synthetic_returns/README.md
Normal file
@@ -0,0 +1,125 @@
|
||||
# Code for Synthetic Returns
|
||||
|
||||
This repository contains code for the arXiv preprint
|
||||
["Synthetic Returns for Long-Term Credit Assignment"](https://arxiv.org/abs/2102.12425)
|
||||
by David Raposo, Sam Ritter, Adam Santoro, Greg Wayne, Theophane Weber, Matt
|
||||
Botvinick, Hado van Hasselt, and Francis Song.
|
||||
|
||||
To cite this work:
|
||||
|
||||
```
|
||||
@article{raposo2021synthetic,
|
||||
title={Rapid Task-Solving in Novel Environments},
|
||||
author={Raposo, David and Ritter, Sam and Santoro, Adam and Wayne, Greg and
|
||||
Weber, Theophane and Botvinick, Matt and van Hasselt, Hado and Song, Francis},
|
||||
journal={arXiv preprint arXiv:2102.12425},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
### Agent core wrapper
|
||||
|
||||
We implemented the Synthetic Returns module as a wrapper to a recurrent neural
|
||||
network (RNN), so it should be compatible with any Deep-RL agent with an
|
||||
arbitrary RNN core, whose inputs consist of batches of vectors. This could be an
|
||||
LSTM as in the example below, or a more sophisticated core as long as it
|
||||
implements an `hk.RNNCore`.
|
||||
|
||||
```python
|
||||
agent_core = hk.LSTM(128)
|
||||
```
|
||||
|
||||
To build the SR wrapper, simply pass the existing agent core to the constructor,
|
||||
along with the SR configuration:
|
||||
|
||||
```python
|
||||
sr_config = {
|
||||
"memory_size": 128,
|
||||
"capacity": 300,
|
||||
"hidden_layers": (128, 128),
|
||||
"alpha": 0.3,
|
||||
"beta": 1.0,
|
||||
}
|
||||
sr_agent_core = hk.ResetCore(
|
||||
SyntheticReturnsCoreWrapper(core=agent_core, **sr_config))
|
||||
```
|
||||
|
||||
Typically, the SR wrapper should itself be wrapped in a `hk.ResetCore` in order
|
||||
to reset the core state in the beginning of a new episode. This will reset not
|
||||
only the episodic memory but also the original agent core that was passed to the
|
||||
SR wrapper constructor.
|
||||
|
||||
### Learner
|
||||
|
||||
Consider the distributed setting, wherein a learner receives mini-batches of
|
||||
trajectories of length `T` produced by the actors.
|
||||
|
||||
`trajectory` is a nested structure of tensors of size `[T,B,...]` (where `B` is
|
||||
the batch size) containing observations, agent states, rewards and step type
|
||||
indicators.
|
||||
|
||||
We start by producing inputs to the SR core, which consist of tuples of current
|
||||
state embeddings and return targets. The current state embeddings can be
|
||||
produced by a ConvNet, for example. In our experiments we used the current step
|
||||
reward as target. Note that the current step reward correspond to the rewards in
|
||||
the trajectory shifted by one, relative to the observations:
|
||||
|
||||
```python
|
||||
observations = jax.tree_map(lambda x: x[:-1], trajectory.observation)
|
||||
vision_output = hk.BatchApply(vision_net)(observations)
|
||||
return_targets = trajectory.reward[1:]
|
||||
sr_core_inputs = (vision_output, return_targets)
|
||||
```
|
||||
|
||||
For purposes of core resetting at the beginning of a new episode, we also need
|
||||
to pass an indicator of which steps correspond to the first step of an episode.
|
||||
|
||||
```python
|
||||
should_reset = jnp.equal(
|
||||
trajectory.step_type[:-1], int(dm_env.StepType.FIRST))
|
||||
core_inputs = (sr_core_inputs, should_reset)
|
||||
```
|
||||
|
||||
We can now produce an unroll using `hk.dynamic_unroll` and passing it the SR
|
||||
core, the core inputs we produced, and the initial state of the unroll, which
|
||||
corresponds to the agent state in the first step of the trajectory:
|
||||
|
||||
```python
|
||||
state = jax.tree_map(lambda t: t[0], trajectory.agent_state)
|
||||
core_output, state = hk.dynamic_unroll(
|
||||
sr_agent_core, core_inputs, state)
|
||||
```
|
||||
|
||||
The SR wrapper produces 4 output tensors: the output of the agent core, the
|
||||
synthetic returns, the SR-augmented return, and the SR loss.
|
||||
|
||||
The synthetic returns are taken into account when computing the augmented return
|
||||
and the SR loss. Therefore they are not needed anymore and can be discarded or
|
||||
used for logging purposes.
|
||||
|
||||
The agent core outputs should be used, as usual, for producing a policy. In an
|
||||
actor-critic, policy gradient set-up, like IMPALA, we would produce policy
|
||||
logits and values:
|
||||
|
||||
```python
|
||||
policy_logits = hk.BatchApply(policy_net)(core_output.output)
|
||||
value = hk.BatchApply(baseline_net)(core_output.output)
|
||||
```
|
||||
|
||||
Similarly, in a Q-learning setting we would use the agent core outputs to
|
||||
produce q-values.
|
||||
|
||||
The SR-augmented returns should be used in place of the environment rewards for
|
||||
the policy updates (e.g. when computing the policy gradient and baseline
|
||||
losses):
|
||||
|
||||
```python
|
||||
rewards = core_output.augmented_return
|
||||
```
|
||||
|
||||
Finally, the SR loss, summed over batch and time dimensions, should be added to
|
||||
the total learner loss to be minimized:
|
||||
|
||||
```python
|
||||
total_loss += jnp.sum(core_output.sr_loss)
|
||||
```
|
||||
2
synthetic_returns/requirements.txt
Normal file
2
synthetic_returns/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
dm-haiku>=0.0.3
|
||||
jax>=0.2.8
|
||||
187
synthetic_returns/synthetic_returns.py
Normal file
187
synthetic_returns/synthetic_returns.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Episodic Memory and Synthetic Returns Core Wrapper modules."""
|
||||
import collections
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
SRCoreWrapperOutput = collections.namedtuple(
|
||||
"SRCoreWrapperOutput", ["output", "synthetic_return", "augmented_return",
|
||||
"sr_loss"])
|
||||
|
||||
|
||||
class EpisodicMemory(hk.RNNCore):
|
||||
"""Episodic Memory module."""
|
||||
|
||||
def __init__(self, memory_size, capacity, name="episodic_memory"):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
memory_size: Integer. The size of the vectors to be stored.
|
||||
capacity: Integer. The maximum number of memories to store before it
|
||||
becomes necessary to overwrite old memories.
|
||||
name: String. A name for this Haiku module instance.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
self._memory_size = memory_size
|
||||
self._capacity = capacity
|
||||
|
||||
def __call__(self, inputs, prev_state):
|
||||
"""Writes a new memory into the episodic memory.
|
||||
|
||||
Args:
|
||||
inputs: A Tensor of shape ``[batch_size, memory_size]``.
|
||||
prev_state: The previous state of the episodic memory, which is a tuple
|
||||
with a (i) counter of shape ``[batch_size, 1]`` indicating how many
|
||||
memories have been written so far, and (ii) a tensor of shape
|
||||
``[batch_size, capacity, memory_size]`` with the full content of the
|
||||
episodic memory.
|
||||
Returns:
|
||||
A tuple with (i) a tensor of shape ``[batch_size, capacity, memory_size]``
|
||||
with the full content of the episodic memory, including the newly
|
||||
written memory, and (ii) the new state of the episodic memory.
|
||||
"""
|
||||
inputs = jax.lax.stop_gradient(inputs)
|
||||
counter, memories = prev_state
|
||||
counter_mod = jnp.mod(counter, self._capacity)
|
||||
slot_selector = jnp.expand_dims(
|
||||
jax.nn.one_hot(counter_mod, self._capacity), axis=2)
|
||||
memories = memories * (1 - slot_selector) + (
|
||||
slot_selector * jnp.expand_dims(inputs, 1))
|
||||
counter = counter + 1
|
||||
return memories, (counter, memories)
|
||||
|
||||
def initial_state(self, batch_size):
|
||||
"""Creates the initial state of the episodic memory.
|
||||
|
||||
Args:
|
||||
batch_size: Integer. The batch size of the episodic memory.
|
||||
Returns:
|
||||
A tuple with (i) a counter of shape ``[batch_size, 1]`` and (ii) a tensor
|
||||
of shape ``[batch_size, capacity, memory_size]`` with the full content
|
||||
of the episodic memory.
|
||||
"""
|
||||
if batch_size is None:
|
||||
shape = []
|
||||
else:
|
||||
shape = [batch_size]
|
||||
counter = jnp.zeros(shape)
|
||||
memories = jnp.zeros(shape + [self._capacity, self._memory_size])
|
||||
return (counter, memories)
|
||||
|
||||
|
||||
class SyntheticReturnsCoreWrapper(hk.RNNCore):
|
||||
"""Synthetic Returns core wrapper."""
|
||||
|
||||
def __init__(self, core, memory_size, capacity, hidden_layers, alpha, beta,
|
||||
loss_func=(lambda x, y: 0.5 * jnp.square(x - y)),
|
||||
apply_core_to_input=False, name="synthetic_returns_wrapper"):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
core: hk.RNNCore. The recurrent core of the agent. E.g. an LSTM.
|
||||
memory_size: Integer. The size of the vectors to be stored in the episodic
|
||||
memory.
|
||||
capacity: Integer. The maximum number of memories to store before it
|
||||
becomes necessary to overwrite old memories.
|
||||
hidden_layers: Tuple or list of integers, indicating the size of the
|
||||
hidden layers of the MLPs used to produce synthetic returns, current
|
||||
state bias, and gate.
|
||||
alpha: The multiplier of the synthetic returns term in the augmented
|
||||
return.
|
||||
beta: The multiplier of the environment returns term in the augmented
|
||||
return.
|
||||
loss_func: A function of two arguments (predictions and targets) to
|
||||
compute the SR loss.
|
||||
apply_core_to_input: Boolean. Whether to apply the core on the inputs. If
|
||||
true, the synthetic returns will be computed from the outputs of the
|
||||
RNN core passed to the constructor. If false, the RNN core will be
|
||||
applied only at the output of this wrapper, and the synthetic returns
|
||||
will be computed from the inputs.
|
||||
name: String. A name for this Haiku module instance.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
self._em = EpisodicMemory(memory_size, capacity)
|
||||
self._capacity = capacity
|
||||
hidden_layers = list(hidden_layers)
|
||||
self._synthetic_return = hk.nets.MLP(hidden_layers + [1])
|
||||
self._bias = hk.nets.MLP(hidden_layers + [1])
|
||||
self._gate = hk.Sequential([
|
||||
hk.nets.MLP(hidden_layers + [1]),
|
||||
jax.nn.sigmoid,
|
||||
])
|
||||
self._apply_core_to_input = apply_core_to_input
|
||||
self._core = core
|
||||
self._alpha = alpha
|
||||
self._beta = beta
|
||||
self._loss = loss_func
|
||||
|
||||
def initial_state(self, batch_size):
|
||||
return (
|
||||
self._em.initial_state(batch_size),
|
||||
self._core.initial_state(batch_size)
|
||||
)
|
||||
|
||||
def __call__(self, inputs, prev_state):
|
||||
current_input, return_target = inputs
|
||||
|
||||
em_state, core_state = prev_state
|
||||
(counter, memories) = em_state
|
||||
|
||||
if self._apply_core_to_input:
|
||||
current_input, core_state = self._core(current_input, core_state)
|
||||
|
||||
# Synthetic return for the current state
|
||||
synth_return = jnp.squeeze(self._synthetic_return(current_input), -1)
|
||||
|
||||
# Current state bias term
|
||||
bias = self._bias(current_input)
|
||||
|
||||
# Gate computed from current state
|
||||
gate = self._gate(current_input)
|
||||
|
||||
# When counter > capacity, mask will be all ones
|
||||
mask = 1 - jnp.cumsum(jax.nn.one_hot(counter, self._capacity), axis=1)
|
||||
mask = jnp.expand_dims(mask, axis=2)
|
||||
|
||||
# Synthetic returns for each state in memory
|
||||
past_synth_returns = hk.BatchApply(self._synthetic_return)(memories)
|
||||
|
||||
# Sum of synthetic returns from previous states
|
||||
sr_sum = jnp.sum(past_synth_returns * mask, axis=1)
|
||||
|
||||
prediction = jnp.squeeze(sr_sum * gate + bias, -1)
|
||||
sr_loss = self._loss(prediction, return_target)
|
||||
|
||||
augmented_return = jax.lax.stop_gradient(
|
||||
self._alpha * synth_return + self._beta * return_target)
|
||||
|
||||
# Write current state to memory
|
||||
_, em_state = self._em(current_input, em_state)
|
||||
|
||||
if not self._apply_core_to_input:
|
||||
output, core_state = self._core(current_input, core_state)
|
||||
else:
|
||||
output = current_input
|
||||
|
||||
output = SRCoreWrapperOutput(
|
||||
output=output,
|
||||
synthetic_return=synth_return,
|
||||
augmented_return=augmented_return,
|
||||
sr_loss=sr_loss,
|
||||
)
|
||||
return output, (em_state, core_state)
|
||||
Reference in New Issue
Block a user