Release of sketchy

PiperOrigin-RevId: 314317695
This commit is contained in:
Louise Deason
2020-06-02 13:17:05 +00:00
parent b105a3646b
commit 38f4dabc6e
12 changed files with 824 additions and 0 deletions

View File

@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
## Projects
* [Scaling data-driven robotics with reward sketching and batch reinforcement learning](sketchy), RSS 2020
* [The Option Keyboard: Combining Skills in Reinforcement Learning](option_keyboard), NeurIPS 2019
* [VISR - Fast Task Inference with Variational Intrinsic Successor Features](visr), ICLR 2020
* [Unveiling the predictive power of static structure in glassy systems](glassy_dynamics), Nature Physics 2020

218
sketchy/README.md Normal file
View File

@@ -0,0 +1,218 @@
# Sketchy data
This is a dataset accompanying the paper
[Scaling data-driven robotics with reward sketching and batch reinforcement learning](https://arxiv.org/abs/1909.12200).
If you use this dataset in your research please cite
```
@article{cabi2019,
title={Scaling data-driven robotics with reward sketching and batch reinforcement learning},
author={Serkan Cabi and
Sergio G{\'o}mez Colmenarejo and
Alexander Novikov and
Ksenia Konyushkova and
Scott Reed and
Rae Jeong and
Konrad {\.Z}o\l{}na and
Yusuf Aytar and
David Budden and
Mel Vecerik and
Oleg Sushkov and
David Barker and
Jonathan Scholz and
Misha Denil and
Nando de Freitas and
Ziyu Wang},
journal={arXiv preprint arXiv:1909.12200},
year={2019}
}
```
## See example data
There is a small amount of example data included in this repository. To examine
it, run the following commands from the repository root (i.e. one level up from
this folder):
```
python3 -m venv .sketchy_env
source .sketchy_env/bin/activate
pip install --upgrade pip
pip install -r sketchy/requirements.txt
python -m sketchy.dataset_example --show_images
```
For an example of loading rewards for episodes see `reward_example.py`.
## Download the full dataset
Run `./download.sh path/to/download/folder` to download the full dataset. The
full dataset requires ~5.0TB of disk space to download, and extracts to approximately the same size.
You can edit `download.sh` to download subsets of the data.
Once the dataset has been downloaded it can be extracted wtih
`./extract.sh path/to/download/folder`.
### Named subsets
We provide several named subsets of the full dataset, which can be easily
downloaded on their own. See `download.sh` for a description of the subsets
that are provided.
The episodes in each of these named subsets are identified by a tag in the
metadata.
If you would like to curate your own subset you can download the metadata
file and inspect the `ArchiveFiles` table (see below) to figure out which
archive files contain the episodes you want.
# Dataset Contents
The dataset is distribted as a *metadata file* (`metadata.sqlite`) and a
collection of *archive files* (with names ending in `.tar.bz2`).
The metadata file contains information about the episodes, including annotated
rewards for a subset of the episodes.
Each archive file contains several *episode files*, which have names like
`10000313341320364033_b615a417-ce34-41a8-8411-2a1ce3f3bd07`.
Each episode file is a
[tfrecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file,
containing a sequence of *timesteps* for a single episode.
Each timestep is a `tf.train.Example` proto containing features corresponding to
the observations and actions from a particular point in time.
## Metadata
The metadata file, `metadata.sqlite`, is a sqlite database containing metadata
describing the contents of the files in the dataset.
The following sections describe the important metadata tables. You can find the
full schema by running
```
sqlite3 metadata.sqlite <<< .schema
```
### Episodes
- `EpisodeId`: A string of digits that uniquely identifies the episode.
- `TaskId``: A human readable name for the task corresponding to the behavior
that generated the episode.
- `DataPath`: The name of the episode file holding the data for this episode.
- `EpisodeType`: A string describing the type of policy that generated the
episode. Possible values are:
- `EPISODE_ROBOT_AGENT`: The behavior policy is a learned or scripted
controller.
- `EPISODE_ROBOT_TELEOPERATION`: The behavior policy is a human teleoperating
the robot.
- `EPISODE_ROBOT_DAGGER`: The behavior policy is a mix of controller and human
generated actions.
- `Timestamp`: A unix timestamp recording when the episode was generated.
### EpisodeTags
- `EpisodeId`: Foreign key into the `Episodes` table.
- `Tag`: A human readable identifier for some aspect of the episode (e.g. which
object set is used).
### RewardSequences
- `EpisodeId`: Foreign key into the `Episodes` table.
- `RewardSequenceId`: Distinguishes multiple rewards for the same episode.
- `RewardTaskId`: A human readable name of the task for this reward signal.
Typically the same as the corresponding `TaskId` in the `Episodes` table.
- `Type`: A string describing the type of reward signal. Currently the only
value is `REWARD_SKETCH`.
- `Values`: A sequence of float32 values, packed as a binary blob. There is one
float value for each frame of the episode, corresponding to the annotated
reward.
### ArchiveFiles
- `EpisodeId`: Foreign key into the `Episodes` table.
- `ArchiveFile`: Name of the archive file containing the corresponding episode.
## Episodes
Each episode file is a
[tfrecords](https://www.tensorflow.org/tutorials/load_data/tfrecord) file
containing a sequence of timesteps, encoded as
[`tf.train.Example`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto)
protos.
Each episode file contains a single episode, and each timestep within an episode
contains all of the observations and actions associated with a that timestep as
a single `tf.train.Example`. Within each episode file the timesteps are
temporally ordered, so reading a file from beginning to end will visit all of
the timesteps from the episode in the order they occurred.
Observations and actions occur at 10Hz.
## Timesteps
Each timestep is a collection of observations and actions. Actions stored with a
timestep correspond to actions taken in response to the observations they are
stored with.
For a description of the shapes and types of the timestep data, see the data
loader in `sketchy.py`.
# Dataset Metadata
The following table is necessary for this dataset to be indexed by search
engines such as <a href="https://g.co/datasetsearch">Google Dataset Search</a>.
<div itemscope itemtype="http://schema.org/Dataset">
<table>
<tr>
<th>property</th>
<th>value</th>
</tr>
<tr>
<td>name</td>
<td><code itemprop="name">Sketchy</code></td>
</tr>
<tr>
<td>url</td>
<td><code itemprop="url">https://github.com/deepmind/deepmind_research/sketchy</code></td>
</tr>
<tr>
<td>sameAs</td>
<td><code itemprop="sameAs">https://github.com/deepmind/deepmind_research/sketchy</code></td>
</tr>
<tr>
<td>description</td>
<td><code itemprop="description">
Data accompanying
[Scaling data-driven robotics with reward sketching and batch reinforcement learning](https://arxiv.org/abs/1909.12200).
</code></td>
</tr>
<tr>
<td>provider</td>
<td>
<div itemscope itemtype="http://schema.org/Organization" itemprop="provider">
<table>
<tr>
<th>property</th>
<th>value</th>
</tr>
<tr>
<td>name</td>
<td><code itemprop="name">DeepMind</code></td>
</tr>
<tr>
<td>sameAs</td>
<td><code itemprop="sameAs">https://en.wikipedia.org/wiki/DeepMind</code></td>
</tr>
</table>
</div>
</td>
</tr>
<tr>
<td>citation</td>
<td><code itemprop="citation">https://identifiers.org/arxiv:1909.12200</code></td>
</tr>
</table>
</div>

13
sketchy/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,50 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example of loading sketchy data in tensorflow."""
from absl import app
from absl import flags
import matplotlib.pyplot as plt
import tensorflow.compat.v2 as tf
from sketchy import sketchy
flags.DEFINE_boolean('show_images', False, 'Enable to show example images.')
FLAGS = flags.FLAGS
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
tf.enable_v2_behavior()
# The example file contains only a few timesteps from a single episode.
dataset = sketchy.load_frames('sketchy/example_data.tfrecords')
dataset = dataset.prefetch(5)
for example in dataset:
print('---')
for name, value in sorted(example.items()):
print(name, value.dtype, value.shape)
if FLAGS.show_images:
plt.imshow(example['pixels/basket_front_left'])
plt.show()
if __name__ == '__main__':
app.run(main)

178
sketchy/download.sh Executable file
View File

@@ -0,0 +1,178 @@
#!/bin/bash
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Use this script to download the sketchy dataset.
#
# You will need to extract the archive files before using them. Each archive
# (except for the last in each group) contains 100 episodes.
if [ -z "$1" ]; then
echo "Usage: $(basename "$0") download_folder" >&2
exit 1
fi
DOWNLOAD_FOLDER="$1"
NUM_PARALLEL_DOWNLOADS="4" # set the number of download workers
DATA_URL="https://storage.cloud.google.com/sketchy-data"
function download_shards {
# Usage: download_shards prefix num_shards
local PREFIX="$1"
local LIMIT="$(printf "%05d" "$2")"
# Avoid leading zeros or this will be interpreted as an octal number.
local MAX="$(("$2"-1))"
(
for IDX in $(seq -f'%05.0f' 0 "$MAX"); do
echo "${PREFIX}-${IDX}-of-${LIMIT}.tar.bz2"
done
) | xargs -I{} -n1 -P"${NUM_PARALLEL_DOWNLOADS}" \
curl "${DATA_URL}/{}" --output "${DOWNLOAD_FOLDER}/{}"
}
# This is the metadata. This file is small, you always want it.
curl "${DATA_URL}/metadata.sqlite" --output "${DOWNLOAD_FOLDER}/metadata.sqlite"
# Download these files if you want all and only the episodes with an associated
# reward sequence.
#
# sqlite3 metadata.sqlite <<EOF
# SELECT DISTINCT(Episodes.DataPath)
# FROM Episodes, RewardSequences
# WHERE Episodes.EpisodeId = RewardSequences.EpisodeId
# EOF
#
# If you are downloading the full dataset then you do not need these files. The
# episodes they contain are included in the other subsets.
#
# download_shards episodes_with_rewards 58
# These files contain a curated set of high quality demonstrations for the
# lift_green task.
#
# sqlite3 metadata.sqlite <<EOF
# SELECT Episodes.DataPath
# FROM Episodes, EpisodeTags
# WHERE EpisodeTags.Tag='lift_green__demos'
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
# EOF
#
download_shards lift_green__demos 2
# These files contain a broader set of episodes for the lift_green task. If you
# download these you should also download the lift_green__demos files.
#
# sqlite3 metadata.sqlite <<EOF
# SELECT Episodes.DataPath
# FROM Episodes, EpisodeTags
# WHERE EpisodeTags.Tag='lift_green__episodes'
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
# AND EpisodeTags.EpisodeId NOT IN (
# SELECT ET.EpisodeId
# FROM EpisodeTags AS ET
# WHERE ET.Tag IN ('lift_green__demos')
# )
# EOF
#
download_shards lift_green__episodes 70
# These files contain a curated set of high quality demonstrations for the
# stack_green_on_red task.
#
# sqlite3 metadata.sqlite <<EOF
# SELECT Episodes.DataPath
# FROM Episodes, EpisodeTags
# WHERE EpisodeTags.Tag='stack_green_on_red__demos'
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
# AND EpisodeTags.EpisodeId NOT IN (
# SELECT ET.EpisodeId
# FROM EpisodeTags AS ET
# WHERE ET.Tag IN ('lift_green__demos', 'lift_green__episodes')
# )
# EOF
#
download_shards stack_green_on_red__demos 2
# These files contain a broader set of episodes for the stack_green_on_red task.
# If you download these you should also download the stack_green_on_red__demos
# files.
#
# sqlite3 metadata.sqlite <<EOF
# SELECT Episodes.DataPath
# FROM Episodes, EpisodeTags
# WHERE EpisodeTags.Tag='stack_green_on_red__episodes'
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
# AND EpisodeTags.EpisodeId NOT IN (
# SELECT ET.EpisodeId
# FROM EpisodeTags AS ET
# WHERE ET.Tag IN (
# 'lift_green__demos',
# 'lift_green__episodes',
# 'stack_green_on_red__demos')
# )
# EOF
#
download_shards stack_green_on_red__episodes 101
# These files contain a large variety of episodes using the same object set as
# the lift_green and stack_green_on_red tasks. There are many tasks represented
# here, and the episode quality is highly variable.
#
# sqlite3 metadata.sqlite <<EOF
# SELECT Episodes.DataPath
# FROM Episodes, EpisodeTags
# WHERE EpisodeTags.Tag='rgb30__all'
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
# AND EpisodeTags.EpisodeId NOT IN (
# SELECT ET.EpisodeId
# FROM EpisodeTags AS ET
# WHERE ET.Tag IN (
# 'lift_green__demos',
# 'lift_green__episodes',
# 'stack_green_on_red__demos',
# 'stack_green_on_red__episodes')
# )
# EOF
#
download_shards rgb30__all 205
# These files contain a broad set of episodes for the pull_cloth_up task.
#
# sqlite3 metadata.sqlite <<EOF
# SELECT Episodes.DataPath
# FROM Episodes, EpisodeTags
# WHERE EpisodeTags.Tag='pull_cloth_up__episodes'
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
# EOF
#
download_shards pull_cloth_up__episodes 133
# These files contain a large variety of episodes using the same object set as
# the pull_cloth_up task.
#
# sqlite3 metadata.sqlite <<EOF
# SELECT Episodes.DataPath
# FROM Episodes, EpisodeTags
# WHERE EpisodeTags.Tag='deform8__all'
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
# AND EpisodeTags.EpisodeId NOT IN (
# SELECT ET.EpisodeId
# FROM EpisodeTags AS ET
# WHERE ET.Tag IN ('pull_cloth_up__episodes')
# )
# EOF
#
download_shards deform8__all 233

Binary file not shown.

30
sketchy/extract.sh Executable file
View File

@@ -0,0 +1,30 @@
#!/bin/bash
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Use this script to extract the downloaded sketchy dataset.
if [ -z "$1" ]; then
echo "Usage: $(basename "$0") download_folder" >&2
exit 1
fi
DOWNLOAD_FOLDER="$1"
NUM_PARALLEL_WORKERS="$(grep processor /proc/cpuinfo | wc -l)"
cd "$DOWNLOAD_FOLDER"
find . -name '*.tar.bz2' -print0 \
| xargs -0 -n1 -P"$NUM_PARALLEL_WORKERS" tar xf

138
sketchy/metadata_schema.py Normal file
View File

@@ -0,0 +1,138 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sqlalchemy schema for the metadata db."""
import sqlalchemy
from sqlalchemy.ext import declarative
Column = sqlalchemy.Column
Integer = sqlalchemy.Integer
String = sqlalchemy.String
LargeBinary = sqlalchemy.LargeBinary
ForeignKey = sqlalchemy.ForeignKey
# pylint: disable=invalid-name
# https://docs.sqlalchemy.org/en/13/orm/tutorial.html
Base = declarative.declarative_base()
EpisodeTag = sqlalchemy.Table(
'EpisodeTags', Base.metadata,
Column(
'EpisodeId', String, ForeignKey('Episodes.EpisodeId'),
primary_key=True),
Column('Tag', String, ForeignKey('Tags.Name'), primary_key=True))
"""Table relating episodes and tags.
Attributes:
EpisodeId: A string of digits that uniquely identifies the episode.
Tag: Human readable tag name.
"""
class Episode(Base):
"""Table describing individual episodes.
Attributes:
EpisodeId: A string of digits that uniquely identifies the episode.
TaskId: A human readable name for the task corresponding to the behavior
that generated the episode.
DataPath: The name of the episode file holding the data for this episode.
Timestamp: A unix timestamp recording when the episode was generated.
EpisodeType: A string describing the type of policy that generated the
episode. Possible values are:
- `EPISODE_ROBOT_AGENT`: The behavior policy is a learned or scripted
controller.
- `EPISODE_ROBOT_TELEOPERATION`: The behavior policy is a human
teleoperating the robot.
- `EPISODE_ROBOT_DAGGER`: The behavior policy is a mix of controller
and human generated actions.
Tags: A list of tags attached to this episode.
Rewards: A list of `RewardSequence`s containing sketched rewards for this
episode.
"""
__tablename__ = 'Episodes'
EpisodeId = Column(String, primary_key=True)
TaskId = Column(String)
DataPath = Column(String)
Timestamp = Column(Integer)
EpisodeType = Column(String)
Tags = sqlalchemy.orm.relationship(
'Tag', secondary=EpisodeTag, back_populates='Episodes')
Rewards = sqlalchemy.orm.relationship(
'RewardSequence', backref='Episode')
class Tag(Base):
"""Table of tags that can be attached to episodes.
Attributes:
Name: Human readable tag name.
Episodes: The epsidoes that have been annotated with this tag.
"""
__tablename__ = 'Tags'
Name = Column(String, primary_key=True)
Episodes = sqlalchemy.orm.relationship(
'Episode', secondary=EpisodeTag, back_populates='Tags')
class RewardSequence(Base):
"""Table describing reward sequences for episodes.
Attributes:
EpisodeId: Foreign key into the `Episodes` table.
RewardSequenceId: Distinguishes multiple rewards for the same episode.
RewardTaskId: A human readable name of the task for this reward signal.
Typically the same as the corresponding `TaskId` in the `Episodes`
table.
Type: A string describing the type of reward signal. Currently the only
value is `REWARD_SKETCH`.
User: The name of the user who produced this reward sequence.
Values: A sequence of float32 values, packed as a binary blob. There is one
float value for each frame of the episode, corresponding to the
annotated reward.
"""
__tablename__ = 'RewardSequences'
EpisodeId = Column(
'EpisodeId', String, ForeignKey('Episodes.EpisodeId'), primary_key=True)
RewardSequenceId = Column(String, primary_key=True)
RewardTaskId = Column('RewardTaskId', String)
Type = Column(String)
User = Column(String)
Values = Column(LargeBinary)
class ArchiveFile(Base):
"""Table describing where episodes are stored in archives.
This information is relevant if you want to download or extract a specific
episode from the archives they are distributed in.
Attributes:
EpisodeId: Foreign key into the `Episodes` table.
ArchiveFile: Name of the archive file containing the corresponding episode.
"""
__tablename__ = 'ArchiveFiles'
EpisodeId = Column(
'EpisodeId', String, ForeignKey('Episodes.EpisodeId'), primary_key=True)
ArchiveFile = Column(String)
# pylint: enable=invalid-name

38
sketchy/requirements.txt Normal file
View File

@@ -0,0 +1,38 @@
absl-py==0.9.0
astor==0.8.1
cachetools==4.0.0
certifi==2019.11.28
chardet==3.0.4
cycler==0.10.0
gast==0.2.2
google-auth==1.10.0
google-auth-oauthlib==0.4.1
google-pasta==0.1.8
grpcio==1.26.0
h5py==2.10.0
idna==2.8
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
kiwisolver==1.1.0
Markdown==3.1.1
matplotlib==3.1.2
numpy==1.18.1
oauthlib==3.1.0
opt-einsum==3.1.0
protobuf==3.11.2
pyasn1==0.4.8
pyasn1-modules==0.2.7
pyparsing==2.4.6
python-dateutil==2.8.1
requests==2.22.0
requests-oauthlib==1.3.0
rsa==4.0
six==1.13.0
SQLAlchemy==1.3.12
tensorboard==2.0.2
tensorflow==2.0.0
tensorflow-estimator==2.0.1
termcolor==1.1.0
urllib3==1.25.7
Werkzeug==0.16.0
wrapt==1.11.2

51
sketchy/reward_example.py Normal file
View File

@@ -0,0 +1,51 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example of loading rewards from the metadata file."""
from absl import app
from absl import flags
import numpy as np
import sqlalchemy
from sketchy import metadata_schema
flags.DEFINE_string(
'metadata', '/tmp/metadata.sqlite', 'Path to metadata file.')
FLAGS = flags.FLAGS
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
engine = sqlalchemy.create_engine('sqlite:///' + FLAGS.metadata)
session = sqlalchemy.orm.sessionmaker(bind=engine)()
episodes = session.query(metadata_schema.Episode).join(
metadata_schema.RewardSequence).limit(5)
for episode in episodes:
rewards = np.frombuffer(episode.Rewards[0].Values, dtype=np.float32)
print('---')
print(f'Episode: {episode.EpisodeId}')
print(f'Episode file: {episode.DataPath}')
print(f'Reward type: {episode.Rewards[0].Type}')
print(f'Reward values: {rewards}')
if __name__ == '__main__':
app.run(main)

21
sketchy/run.sh Executable file
View File

@@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python3 -m venv .sketchy_env
source .sketchy_env/bin/activate
pip install --upgrade pip
pip install -r sketchy/requirements.txt
python -m sketchy.dataset_example --noshow_images

86
sketchy/sketchy.py Normal file
View File

@@ -0,0 +1,86 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface for loading sketchy data into tensorflow."""
import tensorflow.compat.v2 as tf
def load_frames(filenames, num_parallel_reads=1, num_map_threads=None):
if not num_map_threads:
num_map_threads = num_parallel_reads
dataset = tf.data.TFRecordDataset(
filenames, num_parallel_reads=num_parallel_reads)
return dataset.map(_parse_example, num_parallel_calls=num_map_threads)
_FEATURES = {
# Actions
'actions':
tf.io.FixedLenFeature(shape=7, dtype=tf.float32),
# Observations
'gripper/joints/velocity':
tf.io.FixedLenFeature(shape=1, dtype=tf.float32),
'gripper/joints/torque':
tf.io.FixedLenFeature(shape=1, dtype=tf.float32),
'gripper/grasp':
tf.io.FixedLenFeature(shape=1, dtype=tf.int64),
'gripper/joints/angle':
tf.io.FixedLenFeature(shape=1, dtype=tf.float32),
'sawyer/joints/velocity':
tf.io.FixedLenFeature(shape=7, dtype=tf.float32),
'sawyer/pinch/pose':
tf.io.FixedLenFeature(shape=7, dtype=tf.float32),
'sawyer/tcp/pose':
tf.io.FixedLenFeature(shape=7, dtype=tf.float32),
'sawyer/tcp/effort':
tf.io.FixedLenFeature(shape=6, dtype=tf.float32),
'sawyer/joints/torque':
tf.io.FixedLenFeature(shape=7, dtype=tf.float32),
'sawyer/tcp/velocity':
tf.io.FixedLenFeature(shape=6, dtype=tf.float32),
'sawyer/joints/angle':
tf.io.FixedLenFeature(shape=7, dtype=tf.float32),
'wrist/torque':
tf.io.FixedLenFeature(shape=3, dtype=tf.float32),
'wrist/force':
tf.io.FixedLenFeature(shape=3, dtype=tf.float32),
'pixels/basket_front_left':
tf.io.FixedLenFeature(shape=1, dtype=tf.string),
'pixels/basket_back_left':
tf.io.FixedLenFeature(shape=1, dtype=tf.string),
'pixels/basket_front_right':
tf.io.FixedLenFeature(shape=1, dtype=tf.string),
'pixels/royale_camera_driver_depth':
tf.io.FixedLenFeature(shape=(171, 224, 1), dtype=tf.float32),
'pixels/royale_camera_driver_gray':
tf.io.FixedLenFeature(shape=1, dtype=tf.string),
'pixels/usbcam0':
tf.io.FixedLenFeature(shape=1, dtype=tf.string),
'pixels/usbcam1':
tf.io.FixedLenFeature(shape=1, dtype=tf.string),
}
def _parse_example(example):
return _decode_images(tf.io.parse_single_example(example, _FEATURES))
def _decode_images(record):
for name, value in list(record.items()):
if value.dtype == tf.string:
record[name] = tf.io.decode_jpeg(value[0])
return record