mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Release of sketchy
PiperOrigin-RevId: 314317695
This commit is contained in:
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
|
||||
|
||||
## Projects
|
||||
|
||||
* [Scaling data-driven robotics with reward sketching and batch reinforcement learning](sketchy), RSS 2020
|
||||
* [The Option Keyboard: Combining Skills in Reinforcement Learning](option_keyboard), NeurIPS 2019
|
||||
* [VISR - Fast Task Inference with Variational Intrinsic Successor Features](visr), ICLR 2020
|
||||
* [Unveiling the predictive power of static structure in glassy systems](glassy_dynamics), Nature Physics 2020
|
||||
|
||||
218
sketchy/README.md
Normal file
218
sketchy/README.md
Normal file
@@ -0,0 +1,218 @@
|
||||
# Sketchy data
|
||||
|
||||
This is a dataset accompanying the paper
|
||||
[Scaling data-driven robotics with reward sketching and batch reinforcement learning](https://arxiv.org/abs/1909.12200).
|
||||
If you use this dataset in your research please cite
|
||||
|
||||
```
|
||||
@article{cabi2019,
|
||||
title={Scaling data-driven robotics with reward sketching and batch reinforcement learning},
|
||||
author={Serkan Cabi and
|
||||
Sergio G{\'o}mez Colmenarejo and
|
||||
Alexander Novikov and
|
||||
Ksenia Konyushkova and
|
||||
Scott Reed and
|
||||
Rae Jeong and
|
||||
Konrad {\.Z}o\l{}na and
|
||||
Yusuf Aytar and
|
||||
David Budden and
|
||||
Mel Vecerik and
|
||||
Oleg Sushkov and
|
||||
David Barker and
|
||||
Jonathan Scholz and
|
||||
Misha Denil and
|
||||
Nando de Freitas and
|
||||
Ziyu Wang},
|
||||
journal={arXiv preprint arXiv:1909.12200},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
|
||||
## See example data
|
||||
|
||||
There is a small amount of example data included in this repository. To examine
|
||||
it, run the following commands from the repository root (i.e. one level up from
|
||||
this folder):
|
||||
|
||||
```
|
||||
python3 -m venv .sketchy_env
|
||||
source .sketchy_env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r sketchy/requirements.txt
|
||||
python -m sketchy.dataset_example --show_images
|
||||
```
|
||||
|
||||
For an example of loading rewards for episodes see `reward_example.py`.
|
||||
|
||||
## Download the full dataset
|
||||
|
||||
Run `./download.sh path/to/download/folder` to download the full dataset. The
|
||||
full dataset requires ~5.0TB of disk space to download, and extracts to approximately the same size.
|
||||
|
||||
You can edit `download.sh` to download subsets of the data.
|
||||
|
||||
Once the dataset has been downloaded it can be extracted wtih
|
||||
`./extract.sh path/to/download/folder`.
|
||||
|
||||
### Named subsets
|
||||
|
||||
We provide several named subsets of the full dataset, which can be easily
|
||||
downloaded on their own. See `download.sh` for a description of the subsets
|
||||
that are provided.
|
||||
|
||||
The episodes in each of these named subsets are identified by a tag in the
|
||||
metadata.
|
||||
If you would like to curate your own subset you can download the metadata
|
||||
file and inspect the `ArchiveFiles` table (see below) to figure out which
|
||||
archive files contain the episodes you want.
|
||||
|
||||
# Dataset Contents
|
||||
|
||||
The dataset is distribted as a *metadata file* (`metadata.sqlite`) and a
|
||||
collection of *archive files* (with names ending in `.tar.bz2`).
|
||||
|
||||
The metadata file contains information about the episodes, including annotated
|
||||
rewards for a subset of the episodes.
|
||||
|
||||
Each archive file contains several *episode files*, which have names like
|
||||
`10000313341320364033_b615a417-ce34-41a8-8411-2a1ce3f3bd07`.
|
||||
|
||||
Each episode file is a
|
||||
[tfrecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file,
|
||||
containing a sequence of *timesteps* for a single episode.
|
||||
|
||||
Each timestep is a `tf.train.Example` proto containing features corresponding to
|
||||
the observations and actions from a particular point in time.
|
||||
|
||||
## Metadata
|
||||
|
||||
The metadata file, `metadata.sqlite`, is a sqlite database containing metadata
|
||||
describing the contents of the files in the dataset.
|
||||
|
||||
The following sections describe the important metadata tables. You can find the
|
||||
full schema by running
|
||||
|
||||
```
|
||||
sqlite3 metadata.sqlite <<< .schema
|
||||
```
|
||||
|
||||
### Episodes
|
||||
|
||||
- `EpisodeId`: A string of digits that uniquely identifies the episode.
|
||||
- `TaskId``: A human readable name for the task corresponding to the behavior
|
||||
that generated the episode.
|
||||
- `DataPath`: The name of the episode file holding the data for this episode.
|
||||
- `EpisodeType`: A string describing the type of policy that generated the
|
||||
episode. Possible values are:
|
||||
- `EPISODE_ROBOT_AGENT`: The behavior policy is a learned or scripted
|
||||
controller.
|
||||
- `EPISODE_ROBOT_TELEOPERATION`: The behavior policy is a human teleoperating
|
||||
the robot.
|
||||
- `EPISODE_ROBOT_DAGGER`: The behavior policy is a mix of controller and human
|
||||
generated actions.
|
||||
- `Timestamp`: A unix timestamp recording when the episode was generated.
|
||||
|
||||
### EpisodeTags
|
||||
|
||||
- `EpisodeId`: Foreign key into the `Episodes` table.
|
||||
- `Tag`: A human readable identifier for some aspect of the episode (e.g. which
|
||||
object set is used).
|
||||
|
||||
### RewardSequences
|
||||
|
||||
- `EpisodeId`: Foreign key into the `Episodes` table.
|
||||
- `RewardSequenceId`: Distinguishes multiple rewards for the same episode.
|
||||
- `RewardTaskId`: A human readable name of the task for this reward signal.
|
||||
Typically the same as the corresponding `TaskId` in the `Episodes` table.
|
||||
- `Type`: A string describing the type of reward signal. Currently the only
|
||||
value is `REWARD_SKETCH`.
|
||||
- `Values`: A sequence of float32 values, packed as a binary blob. There is one
|
||||
float value for each frame of the episode, corresponding to the annotated
|
||||
reward.
|
||||
|
||||
### ArchiveFiles
|
||||
|
||||
- `EpisodeId`: Foreign key into the `Episodes` table.
|
||||
- `ArchiveFile`: Name of the archive file containing the corresponding episode.
|
||||
|
||||
## Episodes
|
||||
|
||||
Each episode file is a
|
||||
[tfrecords](https://www.tensorflow.org/tutorials/load_data/tfrecord) file
|
||||
containing a sequence of timesteps, encoded as
|
||||
[`tf.train.Example`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto)
|
||||
protos.
|
||||
|
||||
Each episode file contains a single episode, and each timestep within an episode
|
||||
contains all of the observations and actions associated with a that timestep as
|
||||
a single `tf.train.Example`. Within each episode file the timesteps are
|
||||
temporally ordered, so reading a file from beginning to end will visit all of
|
||||
the timesteps from the episode in the order they occurred.
|
||||
|
||||
Observations and actions occur at 10Hz.
|
||||
|
||||
## Timesteps
|
||||
|
||||
Each timestep is a collection of observations and actions. Actions stored with a
|
||||
timestep correspond to actions taken in response to the observations they are
|
||||
stored with.
|
||||
|
||||
For a description of the shapes and types of the timestep data, see the data
|
||||
loader in `sketchy.py`.
|
||||
|
||||
# Dataset Metadata
|
||||
|
||||
The following table is necessary for this dataset to be indexed by search
|
||||
engines such as <a href="https://g.co/datasetsearch">Google Dataset Search</a>.
|
||||
<div itemscope itemtype="http://schema.org/Dataset">
|
||||
<table>
|
||||
<tr>
|
||||
<th>property</th>
|
||||
<th>value</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>name</td>
|
||||
<td><code itemprop="name">Sketchy</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>url</td>
|
||||
<td><code itemprop="url">https://github.com/deepmind/deepmind_research/sketchy</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>sameAs</td>
|
||||
<td><code itemprop="sameAs">https://github.com/deepmind/deepmind_research/sketchy</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>description</td>
|
||||
<td><code itemprop="description">
|
||||
Data accompanying
|
||||
[Scaling data-driven robotics with reward sketching and batch reinforcement learning](https://arxiv.org/abs/1909.12200).
|
||||
</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>provider</td>
|
||||
<td>
|
||||
<div itemscope itemtype="http://schema.org/Organization" itemprop="provider">
|
||||
<table>
|
||||
<tr>
|
||||
<th>property</th>
|
||||
<th>value</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>name</td>
|
||||
<td><code itemprop="name">DeepMind</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>sameAs</td>
|
||||
<td><code itemprop="sameAs">https://en.wikipedia.org/wiki/DeepMind</code></td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>citation</td>
|
||||
<td><code itemprop="citation">https://identifiers.org/arxiv:1909.12200</code></td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
13
sketchy/__init__.py
Normal file
13
sketchy/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
50
sketchy/dataset_example.py
Normal file
50
sketchy/dataset_example.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Example of loading sketchy data in tensorflow."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import matplotlib.pyplot as plt
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
from sketchy import sketchy
|
||||
|
||||
flags.DEFINE_boolean('show_images', False, 'Enable to show example images.')
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def main(argv):
|
||||
if len(argv) > 1:
|
||||
raise app.UsageError('Too many command-line arguments.')
|
||||
|
||||
tf.enable_v2_behavior()
|
||||
|
||||
# The example file contains only a few timesteps from a single episode.
|
||||
dataset = sketchy.load_frames('sketchy/example_data.tfrecords')
|
||||
dataset = dataset.prefetch(5)
|
||||
|
||||
for example in dataset:
|
||||
print('---')
|
||||
for name, value in sorted(example.items()):
|
||||
print(name, value.dtype, value.shape)
|
||||
|
||||
if FLAGS.show_images:
|
||||
plt.imshow(example['pixels/basket_front_left'])
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
178
sketchy/download.sh
Executable file
178
sketchy/download.sh
Executable file
@@ -0,0 +1,178 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Use this script to download the sketchy dataset.
|
||||
#
|
||||
# You will need to extract the archive files before using them. Each archive
|
||||
# (except for the last in each group) contains 100 episodes.
|
||||
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $(basename "$0") download_folder" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DOWNLOAD_FOLDER="$1"
|
||||
NUM_PARALLEL_DOWNLOADS="4" # set the number of download workers
|
||||
DATA_URL="https://storage.cloud.google.com/sketchy-data"
|
||||
|
||||
function download_shards {
|
||||
# Usage: download_shards prefix num_shards
|
||||
local PREFIX="$1"
|
||||
local LIMIT="$(printf "%05d" "$2")"
|
||||
# Avoid leading zeros or this will be interpreted as an octal number.
|
||||
local MAX="$(("$2"-1))"
|
||||
|
||||
(
|
||||
for IDX in $(seq -f'%05.0f' 0 "$MAX"); do
|
||||
echo "${PREFIX}-${IDX}-of-${LIMIT}.tar.bz2"
|
||||
done
|
||||
) | xargs -I{} -n1 -P"${NUM_PARALLEL_DOWNLOADS}" \
|
||||
curl "${DATA_URL}/{}" --output "${DOWNLOAD_FOLDER}/{}"
|
||||
}
|
||||
|
||||
# This is the metadata. This file is small, you always want it.
|
||||
curl "${DATA_URL}/metadata.sqlite" --output "${DOWNLOAD_FOLDER}/metadata.sqlite"
|
||||
|
||||
# Download these files if you want all and only the episodes with an associated
|
||||
# reward sequence.
|
||||
#
|
||||
# sqlite3 metadata.sqlite <<EOF
|
||||
# SELECT DISTINCT(Episodes.DataPath)
|
||||
# FROM Episodes, RewardSequences
|
||||
# WHERE Episodes.EpisodeId = RewardSequences.EpisodeId
|
||||
# EOF
|
||||
#
|
||||
# If you are downloading the full dataset then you do not need these files. The
|
||||
# episodes they contain are included in the other subsets.
|
||||
#
|
||||
# download_shards episodes_with_rewards 58
|
||||
|
||||
# These files contain a curated set of high quality demonstrations for the
|
||||
# lift_green task.
|
||||
#
|
||||
# sqlite3 metadata.sqlite <<EOF
|
||||
# SELECT Episodes.DataPath
|
||||
# FROM Episodes, EpisodeTags
|
||||
# WHERE EpisodeTags.Tag='lift_green__demos'
|
||||
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
|
||||
# EOF
|
||||
#
|
||||
download_shards lift_green__demos 2
|
||||
|
||||
# These files contain a broader set of episodes for the lift_green task. If you
|
||||
# download these you should also download the lift_green__demos files.
|
||||
#
|
||||
# sqlite3 metadata.sqlite <<EOF
|
||||
# SELECT Episodes.DataPath
|
||||
# FROM Episodes, EpisodeTags
|
||||
# WHERE EpisodeTags.Tag='lift_green__episodes'
|
||||
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
|
||||
# AND EpisodeTags.EpisodeId NOT IN (
|
||||
# SELECT ET.EpisodeId
|
||||
# FROM EpisodeTags AS ET
|
||||
# WHERE ET.Tag IN ('lift_green__demos')
|
||||
# )
|
||||
# EOF
|
||||
#
|
||||
download_shards lift_green__episodes 70
|
||||
|
||||
# These files contain a curated set of high quality demonstrations for the
|
||||
# stack_green_on_red task.
|
||||
#
|
||||
# sqlite3 metadata.sqlite <<EOF
|
||||
# SELECT Episodes.DataPath
|
||||
# FROM Episodes, EpisodeTags
|
||||
# WHERE EpisodeTags.Tag='stack_green_on_red__demos'
|
||||
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
|
||||
# AND EpisodeTags.EpisodeId NOT IN (
|
||||
# SELECT ET.EpisodeId
|
||||
# FROM EpisodeTags AS ET
|
||||
# WHERE ET.Tag IN ('lift_green__demos', 'lift_green__episodes')
|
||||
# )
|
||||
# EOF
|
||||
#
|
||||
download_shards stack_green_on_red__demos 2
|
||||
|
||||
# These files contain a broader set of episodes for the stack_green_on_red task.
|
||||
# If you download these you should also download the stack_green_on_red__demos
|
||||
# files.
|
||||
#
|
||||
# sqlite3 metadata.sqlite <<EOF
|
||||
# SELECT Episodes.DataPath
|
||||
# FROM Episodes, EpisodeTags
|
||||
# WHERE EpisodeTags.Tag='stack_green_on_red__episodes'
|
||||
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
|
||||
# AND EpisodeTags.EpisodeId NOT IN (
|
||||
# SELECT ET.EpisodeId
|
||||
# FROM EpisodeTags AS ET
|
||||
# WHERE ET.Tag IN (
|
||||
# 'lift_green__demos',
|
||||
# 'lift_green__episodes',
|
||||
# 'stack_green_on_red__demos')
|
||||
# )
|
||||
# EOF
|
||||
#
|
||||
download_shards stack_green_on_red__episodes 101
|
||||
|
||||
# These files contain a large variety of episodes using the same object set as
|
||||
# the lift_green and stack_green_on_red tasks. There are many tasks represented
|
||||
# here, and the episode quality is highly variable.
|
||||
#
|
||||
# sqlite3 metadata.sqlite <<EOF
|
||||
# SELECT Episodes.DataPath
|
||||
# FROM Episodes, EpisodeTags
|
||||
# WHERE EpisodeTags.Tag='rgb30__all'
|
||||
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
|
||||
# AND EpisodeTags.EpisodeId NOT IN (
|
||||
# SELECT ET.EpisodeId
|
||||
# FROM EpisodeTags AS ET
|
||||
# WHERE ET.Tag IN (
|
||||
# 'lift_green__demos',
|
||||
# 'lift_green__episodes',
|
||||
# 'stack_green_on_red__demos',
|
||||
# 'stack_green_on_red__episodes')
|
||||
# )
|
||||
# EOF
|
||||
#
|
||||
download_shards rgb30__all 205
|
||||
|
||||
# These files contain a broad set of episodes for the pull_cloth_up task.
|
||||
#
|
||||
# sqlite3 metadata.sqlite <<EOF
|
||||
# SELECT Episodes.DataPath
|
||||
# FROM Episodes, EpisodeTags
|
||||
# WHERE EpisodeTags.Tag='pull_cloth_up__episodes'
|
||||
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
|
||||
# EOF
|
||||
#
|
||||
download_shards pull_cloth_up__episodes 133
|
||||
|
||||
# These files contain a large variety of episodes using the same object set as
|
||||
# the pull_cloth_up task.
|
||||
#
|
||||
# sqlite3 metadata.sqlite <<EOF
|
||||
# SELECT Episodes.DataPath
|
||||
# FROM Episodes, EpisodeTags
|
||||
# WHERE EpisodeTags.Tag='deform8__all'
|
||||
# AND Episodes.EpisodeId = EpisodeTags.EpisodeId
|
||||
# AND EpisodeTags.EpisodeId NOT IN (
|
||||
# SELECT ET.EpisodeId
|
||||
# FROM EpisodeTags AS ET
|
||||
# WHERE ET.Tag IN ('pull_cloth_up__episodes')
|
||||
# )
|
||||
# EOF
|
||||
#
|
||||
download_shards deform8__all 233
|
||||
BIN
sketchy/example_data.tfrecords
Normal file
BIN
sketchy/example_data.tfrecords
Normal file
Binary file not shown.
30
sketchy/extract.sh
Executable file
30
sketchy/extract.sh
Executable file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Use this script to extract the downloaded sketchy dataset.
|
||||
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $(basename "$0") download_folder" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DOWNLOAD_FOLDER="$1"
|
||||
NUM_PARALLEL_WORKERS="$(grep processor /proc/cpuinfo | wc -l)"
|
||||
|
||||
cd "$DOWNLOAD_FOLDER"
|
||||
|
||||
find . -name '*.tar.bz2' -print0 \
|
||||
| xargs -0 -n1 -P"$NUM_PARALLEL_WORKERS" tar xf
|
||||
138
sketchy/metadata_schema.py
Normal file
138
sketchy/metadata_schema.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Sqlalchemy schema for the metadata db."""
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from sqlalchemy.ext import declarative
|
||||
|
||||
|
||||
Column = sqlalchemy.Column
|
||||
Integer = sqlalchemy.Integer
|
||||
String = sqlalchemy.String
|
||||
LargeBinary = sqlalchemy.LargeBinary
|
||||
ForeignKey = sqlalchemy.ForeignKey
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
# https://docs.sqlalchemy.org/en/13/orm/tutorial.html
|
||||
|
||||
Base = declarative.declarative_base()
|
||||
|
||||
|
||||
EpisodeTag = sqlalchemy.Table(
|
||||
'EpisodeTags', Base.metadata,
|
||||
Column(
|
||||
'EpisodeId', String, ForeignKey('Episodes.EpisodeId'),
|
||||
primary_key=True),
|
||||
Column('Tag', String, ForeignKey('Tags.Name'), primary_key=True))
|
||||
"""Table relating episodes and tags.
|
||||
|
||||
Attributes:
|
||||
EpisodeId: A string of digits that uniquely identifies the episode.
|
||||
Tag: Human readable tag name.
|
||||
"""
|
||||
|
||||
|
||||
class Episode(Base):
|
||||
"""Table describing individual episodes.
|
||||
|
||||
Attributes:
|
||||
EpisodeId: A string of digits that uniquely identifies the episode.
|
||||
TaskId: A human readable name for the task corresponding to the behavior
|
||||
that generated the episode.
|
||||
DataPath: The name of the episode file holding the data for this episode.
|
||||
Timestamp: A unix timestamp recording when the episode was generated.
|
||||
EpisodeType: A string describing the type of policy that generated the
|
||||
episode. Possible values are:
|
||||
- `EPISODE_ROBOT_AGENT`: The behavior policy is a learned or scripted
|
||||
controller.
|
||||
- `EPISODE_ROBOT_TELEOPERATION`: The behavior policy is a human
|
||||
teleoperating the robot.
|
||||
- `EPISODE_ROBOT_DAGGER`: The behavior policy is a mix of controller
|
||||
and human generated actions.
|
||||
Tags: A list of tags attached to this episode.
|
||||
Rewards: A list of `RewardSequence`s containing sketched rewards for this
|
||||
episode.
|
||||
"""
|
||||
__tablename__ = 'Episodes'
|
||||
EpisodeId = Column(String, primary_key=True)
|
||||
TaskId = Column(String)
|
||||
DataPath = Column(String)
|
||||
Timestamp = Column(Integer)
|
||||
EpisodeType = Column(String)
|
||||
Tags = sqlalchemy.orm.relationship(
|
||||
'Tag', secondary=EpisodeTag, back_populates='Episodes')
|
||||
Rewards = sqlalchemy.orm.relationship(
|
||||
'RewardSequence', backref='Episode')
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
"""Table of tags that can be attached to episodes.
|
||||
|
||||
Attributes:
|
||||
Name: Human readable tag name.
|
||||
Episodes: The epsidoes that have been annotated with this tag.
|
||||
"""
|
||||
__tablename__ = 'Tags'
|
||||
Name = Column(String, primary_key=True)
|
||||
Episodes = sqlalchemy.orm.relationship(
|
||||
'Episode', secondary=EpisodeTag, back_populates='Tags')
|
||||
|
||||
|
||||
class RewardSequence(Base):
|
||||
"""Table describing reward sequences for episodes.
|
||||
|
||||
Attributes:
|
||||
EpisodeId: Foreign key into the `Episodes` table.
|
||||
RewardSequenceId: Distinguishes multiple rewards for the same episode.
|
||||
RewardTaskId: A human readable name of the task for this reward signal.
|
||||
Typically the same as the corresponding `TaskId` in the `Episodes`
|
||||
table.
|
||||
Type: A string describing the type of reward signal. Currently the only
|
||||
value is `REWARD_SKETCH`.
|
||||
User: The name of the user who produced this reward sequence.
|
||||
Values: A sequence of float32 values, packed as a binary blob. There is one
|
||||
float value for each frame of the episode, corresponding to the
|
||||
annotated reward.
|
||||
"""
|
||||
__tablename__ = 'RewardSequences'
|
||||
EpisodeId = Column(
|
||||
'EpisodeId', String, ForeignKey('Episodes.EpisodeId'), primary_key=True)
|
||||
RewardSequenceId = Column(String, primary_key=True)
|
||||
RewardTaskId = Column('RewardTaskId', String)
|
||||
Type = Column(String)
|
||||
User = Column(String)
|
||||
Values = Column(LargeBinary)
|
||||
|
||||
|
||||
class ArchiveFile(Base):
|
||||
"""Table describing where episodes are stored in archives.
|
||||
|
||||
This information is relevant if you want to download or extract a specific
|
||||
episode from the archives they are distributed in.
|
||||
|
||||
Attributes:
|
||||
EpisodeId: Foreign key into the `Episodes` table.
|
||||
ArchiveFile: Name of the archive file containing the corresponding episode.
|
||||
"""
|
||||
__tablename__ = 'ArchiveFiles'
|
||||
EpisodeId = Column(
|
||||
'EpisodeId', String, ForeignKey('Episodes.EpisodeId'), primary_key=True)
|
||||
ArchiveFile = Column(String)
|
||||
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
38
sketchy/requirements.txt
Normal file
38
sketchy/requirements.txt
Normal file
@@ -0,0 +1,38 @@
|
||||
absl-py==0.9.0
|
||||
astor==0.8.1
|
||||
cachetools==4.0.0
|
||||
certifi==2019.11.28
|
||||
chardet==3.0.4
|
||||
cycler==0.10.0
|
||||
gast==0.2.2
|
||||
google-auth==1.10.0
|
||||
google-auth-oauthlib==0.4.1
|
||||
google-pasta==0.1.8
|
||||
grpcio==1.26.0
|
||||
h5py==2.10.0
|
||||
idna==2.8
|
||||
Keras-Applications==1.0.8
|
||||
Keras-Preprocessing==1.1.0
|
||||
kiwisolver==1.1.0
|
||||
Markdown==3.1.1
|
||||
matplotlib==3.1.2
|
||||
numpy==1.18.1
|
||||
oauthlib==3.1.0
|
||||
opt-einsum==3.1.0
|
||||
protobuf==3.11.2
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.7
|
||||
pyparsing==2.4.6
|
||||
python-dateutil==2.8.1
|
||||
requests==2.22.0
|
||||
requests-oauthlib==1.3.0
|
||||
rsa==4.0
|
||||
six==1.13.0
|
||||
SQLAlchemy==1.3.12
|
||||
tensorboard==2.0.2
|
||||
tensorflow==2.0.0
|
||||
tensorflow-estimator==2.0.1
|
||||
termcolor==1.1.0
|
||||
urllib3==1.25.7
|
||||
Werkzeug==0.16.0
|
||||
wrapt==1.11.2
|
||||
51
sketchy/reward_example.py
Normal file
51
sketchy/reward_example.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Example of loading rewards from the metadata file."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
import sqlalchemy
|
||||
|
||||
from sketchy import metadata_schema
|
||||
|
||||
flags.DEFINE_string(
|
||||
'metadata', '/tmp/metadata.sqlite', 'Path to metadata file.')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def main(argv):
|
||||
if len(argv) > 1:
|
||||
raise app.UsageError('Too many command-line arguments.')
|
||||
|
||||
engine = sqlalchemy.create_engine('sqlite:///' + FLAGS.metadata)
|
||||
session = sqlalchemy.orm.sessionmaker(bind=engine)()
|
||||
|
||||
episodes = session.query(metadata_schema.Episode).join(
|
||||
metadata_schema.RewardSequence).limit(5)
|
||||
|
||||
for episode in episodes:
|
||||
rewards = np.frombuffer(episode.Rewards[0].Values, dtype=np.float32)
|
||||
print('---')
|
||||
print(f'Episode: {episode.EpisodeId}')
|
||||
print(f'Episode file: {episode.DataPath}')
|
||||
print(f'Reward type: {episode.Rewards[0].Type}')
|
||||
print(f'Reward values: {rewards}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
21
sketchy/run.sh
Executable file
21
sketchy/run.sh
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
python3 -m venv .sketchy_env
|
||||
source .sketchy_env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r sketchy/requirements.txt
|
||||
python -m sketchy.dataset_example --noshow_images
|
||||
86
sketchy/sketchy.py
Normal file
86
sketchy/sketchy.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Interface for loading sketchy data into tensorflow."""
|
||||
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
|
||||
def load_frames(filenames, num_parallel_reads=1, num_map_threads=None):
|
||||
if not num_map_threads:
|
||||
num_map_threads = num_parallel_reads
|
||||
dataset = tf.data.TFRecordDataset(
|
||||
filenames, num_parallel_reads=num_parallel_reads)
|
||||
return dataset.map(_parse_example, num_parallel_calls=num_map_threads)
|
||||
|
||||
|
||||
_FEATURES = {
|
||||
# Actions
|
||||
'actions':
|
||||
tf.io.FixedLenFeature(shape=7, dtype=tf.float32),
|
||||
|
||||
# Observations
|
||||
'gripper/joints/velocity':
|
||||
tf.io.FixedLenFeature(shape=1, dtype=tf.float32),
|
||||
'gripper/joints/torque':
|
||||
tf.io.FixedLenFeature(shape=1, dtype=tf.float32),
|
||||
'gripper/grasp':
|
||||
tf.io.FixedLenFeature(shape=1, dtype=tf.int64),
|
||||
'gripper/joints/angle':
|
||||
tf.io.FixedLenFeature(shape=1, dtype=tf.float32),
|
||||
'sawyer/joints/velocity':
|
||||
tf.io.FixedLenFeature(shape=7, dtype=tf.float32),
|
||||
'sawyer/pinch/pose':
|
||||
tf.io.FixedLenFeature(shape=7, dtype=tf.float32),
|
||||
'sawyer/tcp/pose':
|
||||
tf.io.FixedLenFeature(shape=7, dtype=tf.float32),
|
||||
'sawyer/tcp/effort':
|
||||
tf.io.FixedLenFeature(shape=6, dtype=tf.float32),
|
||||
'sawyer/joints/torque':
|
||||
tf.io.FixedLenFeature(shape=7, dtype=tf.float32),
|
||||
'sawyer/tcp/velocity':
|
||||
tf.io.FixedLenFeature(shape=6, dtype=tf.float32),
|
||||
'sawyer/joints/angle':
|
||||
tf.io.FixedLenFeature(shape=7, dtype=tf.float32),
|
||||
'wrist/torque':
|
||||
tf.io.FixedLenFeature(shape=3, dtype=tf.float32),
|
||||
'wrist/force':
|
||||
tf.io.FixedLenFeature(shape=3, dtype=tf.float32),
|
||||
'pixels/basket_front_left':
|
||||
tf.io.FixedLenFeature(shape=1, dtype=tf.string),
|
||||
'pixels/basket_back_left':
|
||||
tf.io.FixedLenFeature(shape=1, dtype=tf.string),
|
||||
'pixels/basket_front_right':
|
||||
tf.io.FixedLenFeature(shape=1, dtype=tf.string),
|
||||
'pixels/royale_camera_driver_depth':
|
||||
tf.io.FixedLenFeature(shape=(171, 224, 1), dtype=tf.float32),
|
||||
'pixels/royale_camera_driver_gray':
|
||||
tf.io.FixedLenFeature(shape=1, dtype=tf.string),
|
||||
'pixels/usbcam0':
|
||||
tf.io.FixedLenFeature(shape=1, dtype=tf.string),
|
||||
'pixels/usbcam1':
|
||||
tf.io.FixedLenFeature(shape=1, dtype=tf.string),
|
||||
}
|
||||
|
||||
|
||||
def _parse_example(example):
|
||||
return _decode_images(tf.io.parse_single_example(example, _FEATURES))
|
||||
|
||||
|
||||
def _decode_images(record):
|
||||
for name, value in list(record.items()):
|
||||
if value.dtype == tf.string:
|
||||
record[name] = tf.io.decode_jpeg(value[0])
|
||||
return record
|
||||
Reference in New Issue
Block a user