mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-20 03:11:31 +08:00
Add RL Unplugged data loading code and examples
PiperOrigin-RevId: 321746296
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
bd29e1b710
commit
1ea4cc033c
@@ -24,7 +24,7 @@ https://deepmind.com/research/publications/
|
||||
|
||||
## Projects
|
||||
|
||||
* [RL Unplugged: Benchmarks for Offline Reinforcement Learning] (rl_unplugged)
|
||||
* [RL Unplugged: Benchmarks for Offline Reinforcement Learning](rl_unplugged)
|
||||
* [Disentangling by Subspace Diffusion (GEOMANCER)](geomancer)
|
||||
* [What can I do here? A theory of affordances in reinforcmenet learning](affordances_theory), ICML 2020
|
||||
* [Scaling data-driven robotics with reward sketching and batch reinforcement learning](sketchy), RSS 2020
|
||||
|
||||
+25
-4
@@ -26,8 +26,6 @@ In this suite of benchmarks, we try to focus on the following problems:
|
||||
The data is available under
|
||||
[RL Unplugged GCP bucket](https://console.cloud.google.com/storage/browser/rl_unplugged).
|
||||
|
||||
Data loading code and examples will be available soon.
|
||||
|
||||
## Atari Dataset
|
||||
|
||||
We are releasing a large and diverse dataset of gameplay following the protocol
|
||||
@@ -40,7 +38,7 @@ transition include stacks of four frames to be able to do frame-stacking with
|
||||
our baselines. We release datasets for 46 Atari games. For details on how the
|
||||
dataset was generated, please refer to the paper.
|
||||
|
||||
## Deepmind Locomotion Dataset
|
||||
## DeepMind Locomotion Dataset
|
||||
|
||||
These tasks are made up of the corridor locomotion tasks involving the CMU
|
||||
Humanoid, for which prior efforts have either used motion capture data [Merel et
|
||||
@@ -51,7 +49,7 @@ Locomotion tasks feature the combination of challenging high-DoF continuous
|
||||
control along with perception from rich egocentric observations. For details on
|
||||
how the dataset was generated, please refer to the paper.
|
||||
|
||||
## Deepmind Control Suite Dataset
|
||||
## DeepMind Control Suite Dataset
|
||||
|
||||
DeepMind Control Suite [Tassa et al., 2018] is a set of control tasks
|
||||
implemented in MuJoCo [Todorov et al., 2012]. We consider a subset of the tasks
|
||||
@@ -73,6 +71,29 @@ We release 8 datasets in total -- with no combined challenge and easy combined
|
||||
challenge on the cartpole, walker, quadruped, and humanoid tasks. For details on
|
||||
how the dataset was generated, please refer to the paper.
|
||||
|
||||
## Running the code
|
||||
|
||||
### Installation
|
||||
|
||||
* Install dependencies: `pip install requirements.txt`
|
||||
* (Optional) Setup MuJoCo license key for DM Control environments
|
||||
([instructions](https://github.com/deepmind/dm_control#requirements-and-installation)).
|
||||
* (Optional) Install
|
||||
[realworldrl_suite](https://github.com/google-research/realworldrl_suite#installation).
|
||||
|
||||
### Atari example
|
||||
|
||||
```
|
||||
mkdir -p /tmp/dataset/Asterix
|
||||
gsutil cp gs://rl_unplugged/atari/Asterix/run_1-00000-of-00100 \
|
||||
/tmp/dataset/Asterix/run_1-00000-of-00001
|
||||
python atari_example.py --path=/tmp/dataset --game=Asterix
|
||||
```
|
||||
|
||||
This copies a single shard from one of the Asterix datasets from GCP to a local
|
||||
folder, and then runs a script that loads a single example and runs a step on
|
||||
the Atari environment.
|
||||
|
||||
## Citation
|
||||
|
||||
Please use the following bibtex for citations:
|
||||
|
||||
@@ -0,0 +1,252 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Atari RL Unplugged datasets.
|
||||
|
||||
Examples in the dataset represent SARSA transitions stored during a
|
||||
DQN training run as described in https://arxiv.org/pdf/1907.04543.
|
||||
|
||||
For every training run we have recorded all 50 million transitions corresponding
|
||||
to 200 million environment steps (4x factor because of frame skipping). There
|
||||
are 5 separate datasets for each of the 45 games.
|
||||
|
||||
Every transition in the dataset is a tuple containing the following features:
|
||||
|
||||
* o_t: Observation at time t. Observations have been processed using the
|
||||
canonical Atari frame processing, including 4x frame stacking. The shape
|
||||
of a single observation is [84, 84, 4].
|
||||
* a_t: Action taken at time t.
|
||||
* r_t: Reward after a_t.
|
||||
* d_t: Discount after a_t.
|
||||
* o_tp1: Observation at time t+1.
|
||||
* a_tp1: Action at time t+1.
|
||||
* extras:
|
||||
* episode_id: Episode identifier.
|
||||
* episode_return: Total episode return computed using per-step [-1, 1]
|
||||
clipping.
|
||||
"""
|
||||
import functools
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from acme import wrappers
|
||||
import dm_env
|
||||
from dm_env import specs
|
||||
from dopamine.discrete_domains import atari_lib
|
||||
import reverb
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
# 9 tuning games.
|
||||
TUNING_SUITE = [
|
||||
'BeamRider',
|
||||
'DemonAttack',
|
||||
'DoubleDunk',
|
||||
'IceHockey',
|
||||
'MsPacman',
|
||||
'Pooyan',
|
||||
'RoadRunner',
|
||||
'Robotank',
|
||||
'Zaxxon',
|
||||
]
|
||||
|
||||
# 36 testing games.
|
||||
TESTING_SUITE = [
|
||||
'Alien',
|
||||
'Amidar',
|
||||
'Assault',
|
||||
'Asterix',
|
||||
'Atlantis',
|
||||
'BankHeist',
|
||||
'BattleZone',
|
||||
'Boxing',
|
||||
'Breakout',
|
||||
'Carnival',
|
||||
'Centipede',
|
||||
'ChopperCommand',
|
||||
'CrazyClimber',
|
||||
'Enduro',
|
||||
'FishingDerby',
|
||||
'Freeway',
|
||||
'Frostbite',
|
||||
'Gopher',
|
||||
'Gravitar',
|
||||
'Hero',
|
||||
'Jamesbond',
|
||||
'Kangaroo',
|
||||
'Krull',
|
||||
'KungFuMaster',
|
||||
'NameThisGame',
|
||||
'Phoenix',
|
||||
'Pong',
|
||||
'Qbert',
|
||||
'Riverraid',
|
||||
'Seaquest',
|
||||
'SpaceInvaders',
|
||||
'StarGunner',
|
||||
'TimePilot',
|
||||
'UpNDown',
|
||||
'VideoPinball',
|
||||
'WizardOfWor',
|
||||
'YarsRevenge',
|
||||
]
|
||||
|
||||
# Total of 45 games.
|
||||
ALL = TUNING_SUITE + TESTING_SUITE
|
||||
|
||||
|
||||
def _decode_frames(pngs):
|
||||
"""Decode PNGs.
|
||||
|
||||
Args:
|
||||
pngs: String Tensor of size (4,) containing PNG encoded images.
|
||||
|
||||
Returns:
|
||||
4 84x84 grayscale images packed in a (84, 84, 4) uint8 Tensor.
|
||||
"""
|
||||
# Statically unroll png decoding
|
||||
frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
|
||||
frames = tf.concat(frames, axis=2)
|
||||
frames.set_shape((84, 84, 4))
|
||||
return frames
|
||||
|
||||
|
||||
def _make_reverb_sample(o_t,
|
||||
a_t,
|
||||
r_t,
|
||||
d_t,
|
||||
o_tp1,
|
||||
a_tp1,
|
||||
extras):
|
||||
"""Create Reverb sample with offline data.
|
||||
|
||||
Args:
|
||||
o_t: Observation at time t.
|
||||
a_t: Action at time t.
|
||||
r_t: Reward at time t.
|
||||
d_t: Discount at time t.
|
||||
o_tp1: Observation at time t+1.
|
||||
a_tp1: Action at time t+1.
|
||||
extras: Dictionary with extra features.
|
||||
|
||||
Returns:
|
||||
Replay sample with fake info: key=0, probability=1, table_size=0.
|
||||
"""
|
||||
info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),
|
||||
probability=tf.constant(1.0, tf.float64),
|
||||
table_size=tf.constant(0, tf.int64),
|
||||
priority=tf.constant(1.0, tf.float64))
|
||||
data = (o_t, a_t, r_t, d_t, o_tp1, a_tp1, extras)
|
||||
return reverb.ReplaySample(info=info, data=data)
|
||||
|
||||
|
||||
def _tf_example_to_reverb_sample(tf_example
|
||||
):
|
||||
"""Create a Reverb replay sample from a TF example."""
|
||||
|
||||
# Parse tf.Example.
|
||||
feature_description = {
|
||||
'o_t': tf.io.FixedLenFeature([4], tf.string),
|
||||
'o_tp1': tf.io.FixedLenFeature([4], tf.string),
|
||||
'a_t': tf.io.FixedLenFeature([], tf.int64),
|
||||
'a_tp1': tf.io.FixedLenFeature([], tf.int64),
|
||||
'r_t': tf.io.FixedLenFeature([], tf.float32),
|
||||
'd_t': tf.io.FixedLenFeature([], tf.float32),
|
||||
'episode_id': tf.io.FixedLenFeature([], tf.int64),
|
||||
'episode_return': tf.io.FixedLenFeature([], tf.float32),
|
||||
}
|
||||
data = tf.io.parse_single_example(tf_example, feature_description)
|
||||
|
||||
# Process data.
|
||||
o_t = _decode_frames(data['o_t'])
|
||||
o_tp1 = _decode_frames(data['o_tp1'])
|
||||
a_t = tf.cast(data['a_t'], tf.int32)
|
||||
a_tp1 = tf.cast(data['a_tp1'], tf.int32)
|
||||
episode_id = tf.bitcast(data['episode_id'], tf.uint64)
|
||||
|
||||
# Build Reverb replay sample.
|
||||
extras = {
|
||||
'episode_id': episode_id,
|
||||
'return': data['episode_return']
|
||||
}
|
||||
return _make_reverb_sample(o_t, a_t, data['r_t'], data['d_t'], o_tp1, a_tp1,
|
||||
extras)
|
||||
|
||||
|
||||
def dataset(path,
|
||||
game,
|
||||
run,
|
||||
num_shards = 100,
|
||||
shuffle_buffer_size = 100000):
|
||||
"""TF dataset of Atari SARSA tuples."""
|
||||
path = os.path.join(path, f'{game}/run_{run}')
|
||||
filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(num_shards)]
|
||||
file_ds = tf.data.Dataset.from_tensor_slices(filenames)
|
||||
file_ds = file_ds.repeat().shuffle(num_shards)
|
||||
example_ds = file_ds.interleave(
|
||||
functools.partial(tf.data.TFRecordDataset, compression_type='GZIP'),
|
||||
cycle_length=tf.data.experimental.AUTOTUNE,
|
||||
block_length=5)
|
||||
example_ds = example_ds.shuffle(shuffle_buffer_size)
|
||||
return example_ds.map(_tf_example_to_reverb_sample,
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
|
||||
|
||||
class AtariDopamineWrapper(dm_env.Environment):
|
||||
"""Wrapper for Atari Dopamine environmnet."""
|
||||
|
||||
def __init__(self, env, max_episode_steps=108000):
|
||||
self._env = env
|
||||
self._max_episode_steps = max_episode_steps
|
||||
self._episode_steps = 0
|
||||
self._reset_next_episode = True
|
||||
|
||||
def reset(self):
|
||||
self._episode_steps = 0
|
||||
self._reset_next_step = False
|
||||
observation = self._env.reset()
|
||||
return dm_env.restart(observation.squeeze(-1))
|
||||
|
||||
def step(self, action):
|
||||
if self._reset_next_step:
|
||||
return self.reset()
|
||||
|
||||
observation, reward, terminal, _ = self._env.step(action.item())
|
||||
observation = observation.squeeze(-1)
|
||||
discount = 1 - float(terminal)
|
||||
self._episode_steps += 1
|
||||
if terminal:
|
||||
self._reset_next_episode = True
|
||||
return dm_env.termination(reward, observation)
|
||||
elif self._episode_steps == self._max_episode_steps:
|
||||
self._reset_next_episode = True
|
||||
return dm_env.truncation(reward, observation, discount)
|
||||
else:
|
||||
return dm_env.transition(reward, observation, discount)
|
||||
|
||||
def observation_spec(self):
|
||||
space = self._env.observation_space
|
||||
return specs.Array(space.shape[:-1], space.dtype)
|
||||
|
||||
def action_spec(self):
|
||||
return specs.DiscreteArray(self._env.action_space.n)
|
||||
|
||||
|
||||
def environment(game):
|
||||
"""Atari environment."""
|
||||
env = atari_lib.create_atari_environment(game_name=game,
|
||||
sticky_actions=True)
|
||||
env = AtariDopamineWrapper(env)
|
||||
env = wrappers.FrameStackingWrapper(env, num_frames=4)
|
||||
return wrappers.SinglePrecisionWrapper(env)
|
||||
@@ -0,0 +1,407 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "KDiJzbb8KFvP"
|
||||
},
|
||||
"source": [
|
||||
"Copyright 2020 DeepMind Technologies Limited.\n",
|
||||
"\n",
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n",
|
||||
"this file except in compliance with the License. You may obtain a copy of the\n",
|
||||
"License at\n",
|
||||
"\n",
|
||||
"[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
|
||||
"\n",
|
||||
"Unless required by applicable law or agreed to in writing, software distributed\n",
|
||||
"under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n",
|
||||
"CONDITIONS OF ANY KIND, either express or implied. See the License for the\n",
|
||||
"specific language governing permissions and limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "ULdrhOaVbsdO"
|
||||
},
|
||||
"source": [
|
||||
"# RL Unplugged: Offline DQN - Atari\n",
|
||||
"## Guide to training an Acme DQN agent on Atari data.\n",
|
||||
"# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/rl_unplugged/atari_dqn.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "xaJxoatMhJ71"
|
||||
},
|
||||
"source": [
|
||||
"## Installation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "both",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "KH3O0zcXUeun"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install dm-acme\n",
|
||||
"!pip install dm-acme[reverb]\n",
|
||||
"!pip install dm-acme[tf]\n",
|
||||
"!pip install dm-sonnet\n",
|
||||
"!pip install dopamine-rl==3.0.1\n",
|
||||
"!pip install atari-py\n",
|
||||
"!git clone https://github.com/deepmind/deepmind-research.git\n",
|
||||
"%cd deepmind-research"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "c-H2d6UZi7Sf"
|
||||
},
|
||||
"source": [
|
||||
"## Imports"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "both",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "HJ74Id-8MERq"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import copy\n",
|
||||
"\n",
|
||||
"import acme\n",
|
||||
"from acme.agents.tf import actors\n",
|
||||
"from acme.agents.tf.dqn import learning as dqn\n",
|
||||
"from acme.tf import utils as acme_utils\n",
|
||||
"from acme.utils import loggers\n",
|
||||
"from rl_unplugged import atari\n",
|
||||
"import sonnet as snt\n",
|
||||
"import tensorflow as tf"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "JrOSnoWiY4Xl"
|
||||
},
|
||||
"source": [
|
||||
"## Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Vi3_H_h1zy_0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"game = 'Pong' #@param\n",
|
||||
"run = 1 #@param\n",
|
||||
"\n",
|
||||
"tmp_path = '/tmp/atari'\n",
|
||||
"gs_path = 'gs://rl_unplugged/atari'\n",
|
||||
"\n",
|
||||
"!mkdir -p {tmp_path}/{game}\n",
|
||||
"!gsutil cp {gs_path}/{game}/run_{run}-00000-of-00001 {tmp_path}/{game}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "a9vF7LtYvLzy"
|
||||
},
|
||||
"source": [
|
||||
"## Dataset and environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "01AHHNd9cEX2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch_size = 10 #@param\n",
|
||||
"\n",
|
||||
"def discard_extras(sample):\n",
|
||||
" return sample._replace(data=sample.data[:5])\n",
|
||||
"\n",
|
||||
"dataset = atari.dataset(path=tmp_path, game='Pong', run=1, num_shards=1)\n",
|
||||
"# Small batch size, experiments in the paper were run with batch size 256.\n",
|
||||
"dataset = dataset.map(discard_extras).batch(batch_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "KoYBhjPtI_N6"
|
||||
},
|
||||
"source": [
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "4b4_rHwCmQg-"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"environment = atari.environment(game='Pong')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "BukOfOsmtSQn"
|
||||
},
|
||||
"source": [
|
||||
"## DQN learner"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"height": 34
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 83,
|
||||
"status": "ok",
|
||||
"timestamp": 1593614657342,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -60
|
||||
},
|
||||
"id": "3Jcjk1w6oHVX",
|
||||
"outputId": "1746b0bb-5a5c-45dd-b5a1-c77852545e12"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TensorSpec(shape=(6,), dtype=tf.float32, name=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Get total number of actions.\n",
|
||||
"num_actions = environment.action_spec().num_values\n",
|
||||
"\n",
|
||||
"# Create the Q network.\n",
|
||||
"network = snt.Sequential([\n",
|
||||
" lambda x: tf.image.convert_image_dtype(x, tf.float32),\n",
|
||||
" snt.Conv2D(32, [8, 8], [4, 4]),\n",
|
||||
" tf.nn.relu,\n",
|
||||
" snt.Conv2D(64, [4, 4], [2, 2]),\n",
|
||||
" tf.nn.relu,\n",
|
||||
" snt.Conv2D(64, [3, 3], [1, 1]),\n",
|
||||
" tf.nn.relu,\n",
|
||||
" snt.Flatten(),\n",
|
||||
" snt.nets.MLP([512, num_actions])\n",
|
||||
"])\n",
|
||||
"acme_utils.create_variables(network, [environment.observation_spec()])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "9CD2sNK-oA9S"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create a logger.\n",
|
||||
"logger = loggers.TerminalLogger(label='learner', time_delta=1.)\n",
|
||||
"\n",
|
||||
"# Create the DQN learner.\n",
|
||||
"learner = dqn.DQNLearner(\n",
|
||||
" network=network,\n",
|
||||
" target_network=copy.deepcopy(network),\n",
|
||||
" discount=0.99,\n",
|
||||
" learning_rate=3e-4,\n",
|
||||
" importance_sampling_exponent=0.2,\n",
|
||||
" target_update_period=2500,\n",
|
||||
" dataset=dataset,\n",
|
||||
" logger=logger)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "oKeGQxzitXYC"
|
||||
},
|
||||
"source": [
|
||||
"## Training loop"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"height": 51
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 4694,
|
||||
"status": "ok",
|
||||
"timestamp": 1593614662237,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -60
|
||||
},
|
||||
"id": "VWZd5N-Qoz82",
|
||||
"outputId": "5ee2ce7c-b3fe-483b-8893-5a6e13519f48"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Learner] Loss = 0.003 | Steps = 1 | Walltime = 0\n",
|
||||
"[Learner] Loss = 0.004 | Steps = 54 | Walltime = 1.126\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for _ in range(100):\n",
|
||||
" learner.step()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "qFQDrp0CgIzU"
|
||||
},
|
||||
"source": [
|
||||
"## Evaluation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"height": 102
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 15099,
|
||||
"status": "ok",
|
||||
"timestamp": 1593614677360,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -60
|
||||
},
|
||||
"id": "DWYHBalygIDF",
|
||||
"outputId": "4ec412c3-810a-4208-b521-919a8ece40df"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Evaluation] Episode Length = 842 | Episode Return = -20.000 | Episodes = 1 | Steps = 842 | Steps Per Second = 265.850\n",
|
||||
"[Evaluation] Episode Length = 792 | Episode Return = -21.000 | Episodes = 2 | Steps = 1634 | Steps Per Second = 270.043\n",
|
||||
"[Evaluation] Episode Length = 812 | Episode Return = -21.000 | Episodes = 3 | Steps = 2446 | Steps Per Second = 274.792\n",
|
||||
"[Evaluation] Episode Length = 812 | Episode Return = -21.000 | Episodes = 4 | Steps = 3258 | Steps Per Second = 270.967\n",
|
||||
"[Evaluation] Episode Length = 812 | Episode Return = -21.000 | Episodes = 5 | Steps = 4070 | Steps Per Second = 274.253\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create a logger.\n",
|
||||
"logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)\n",
|
||||
"\n",
|
||||
"# Create an environment loop.\n",
|
||||
"policy_network = snt.Sequential([\n",
|
||||
" network,\n",
|
||||
" lambda q: tf.argmax(q, axis=-1),\n",
|
||||
"])\n",
|
||||
"loop = acme.EnvironmentLoop(\n",
|
||||
" environment=environment,\n",
|
||||
" actor=actors.FeedForwardActor(policy_network=policy_network),\n",
|
||||
" logger=logger)\n",
|
||||
"\n",
|
||||
"loop.run(5)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"last_runtime": {
|
||||
"build_target": "",
|
||||
"kind": "local"
|
||||
},
|
||||
"name": "RL Unplugged: Offline DQN - Atari",
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1g9yTbTuk9aeERxWflOWqUGpx2M3osx0l",
|
||||
"timestamp": 1593685504110
|
||||
}
|
||||
]
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
r"""Atari dataset example.
|
||||
|
||||
Instructions:
|
||||
> mkdir -p /tmp/dataset/Asterix
|
||||
> gsutil cp gs://rl_unplugged/atari/Asterix/run_1-00000-of-00100 \
|
||||
/tmp/dataset/Asterix/run_1-00000-of-00001
|
||||
> python atari_example.py --path=/tmp/dataset --game=Asterix
|
||||
"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from acme import specs
|
||||
import tree
|
||||
|
||||
from rl_unplugged import atari
|
||||
|
||||
flags.DEFINE_string('path', '/tmp/dataset', 'Path to dataset.')
|
||||
flags.DEFINE_string('game', 'Asterix', 'Game.')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def main(_):
|
||||
ds = atari.dataset(FLAGS.path, FLAGS.game, 1,
|
||||
num_shards=1,
|
||||
shuffle_buffer_size=1)
|
||||
for sample in ds.take(1):
|
||||
print('Data spec')
|
||||
print(tree.map_structure(lambda x: (x.dtype, x.shape), sample.data))
|
||||
|
||||
env = atari.environment(FLAGS.game)
|
||||
print('Environment spec')
|
||||
print(specs.make_environment_spec(env))
|
||||
print('Environment observation')
|
||||
timestep = env.reset()
|
||||
print(tree.map_structure(lambda x: (x.dtype, x.shape), timestep.observation))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,475 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "KDiJzbb8KFvP"
|
||||
},
|
||||
"source": [
|
||||
"Copyright 2020 DeepMind Technologies Limited.\n",
|
||||
"\n",
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n",
|
||||
"this file except in compliance with the License. You may obtain a copy of the\n",
|
||||
"License at\n",
|
||||
"\n",
|
||||
"[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
|
||||
"\n",
|
||||
"Unless required by applicable law or agreed to in writing, software distributed\n",
|
||||
"under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n",
|
||||
"CONDITIONS OF ANY KIND, either express or implied. See the License for the\n",
|
||||
"specific language governing permissions and limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "zzJlIvx4tnrM"
|
||||
},
|
||||
"source": [
|
||||
"# RL Unplugged: Offline D4PG - DM control\n",
|
||||
"\n",
|
||||
"## Guide to training an Acme D4PG agent on DM control data.\n",
|
||||
"# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/rl_unplugged/dm_control_d4gp.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "o1eig5zGEL4y"
|
||||
},
|
||||
"source": [
|
||||
"## Installation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "WbpMoLbgEL41"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install dm-acme\n",
|
||||
"!pip install dm-acme[reverb]\n",
|
||||
"!pip install dm-acme[tf]\n",
|
||||
"!pip install dm-sonnet\n",
|
||||
"!git clone https://github.com/deepmind/deepmind-research.git\n",
|
||||
"%cd deepmind-research"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "04bMANoeEPM3"
|
||||
},
|
||||
"source": [
|
||||
"### dm_control\n",
|
||||
"\n",
|
||||
"More detailed instructions in [this tutorial](https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb#scrollTo=YvyGCsgSCxHQ)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "VEEj3Qw60y73"
|
||||
},
|
||||
"source": [
|
||||
"#### Institutional MuJoCo license."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "both",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "IbZxYDxzoz5R"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Edit and run\n",
|
||||
"mjkey = \"\"\"\n",
|
||||
"\n",
|
||||
"REPLACE THIS LINE WITH YOUR MUJOCO LICENSE KEY\n",
|
||||
"\n",
|
||||
"\"\"\".strip()\n",
|
||||
"\n",
|
||||
"mujoco_dir = \"$HOME/.mujoco\"\n",
|
||||
"\n",
|
||||
"# Install OpenGL deps\n",
|
||||
"!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
|
||||
" libgl1-mesa-glx libosmesa6 libglew2.0\n",
|
||||
"\n",
|
||||
"# Fetch MuJoCo binaries from Roboti\n",
|
||||
"!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
|
||||
"!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
|
||||
"\n",
|
||||
"# Copy over MuJoCo license\n",
|
||||
"!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Configure dm_control to use the OSMesa rendering backend\n",
|
||||
"%env MUJOCO_GL=osmesa\n",
|
||||
"\n",
|
||||
"# Install dm_control, including extra dependencies needed for the locomotion\n",
|
||||
"# mazes.\n",
|
||||
"!pip install dm_control[locomotion_mazes]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "-_7tVg-zzjzW"
|
||||
},
|
||||
"source": [
|
||||
"#### Machine-locked MuJoCo license."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "OvMLEDE-D9oF"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Add your MuJoCo License and run\n",
|
||||
"mjkey = \"\"\"\n",
|
||||
"\"\"\".strip()\n",
|
||||
"\n",
|
||||
"mujoco_dir = \"$HOME/.mujoco\"\n",
|
||||
"\n",
|
||||
"# Install OpenGL dependencies\n",
|
||||
"!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
|
||||
" libgl1-mesa-glx libosmesa6 libglew2.0\n",
|
||||
"\n",
|
||||
"# Get MuJoCo binaries\n",
|
||||
"!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
|
||||
"!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
|
||||
"\n",
|
||||
"# Copy over MuJoCo license\n",
|
||||
"!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
|
||||
"\n",
|
||||
"# Install dm_control\n",
|
||||
"!pip install dm_control[locomotion_mazes]\n",
|
||||
"\n",
|
||||
"# Configure dm_control to use the OSMesa rendering backend\n",
|
||||
"%env MUJOCO_GL=osmesa"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "IE2nV9Hivnv5"
|
||||
},
|
||||
"source": [
|
||||
"## Imports"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "RI7NgnJIvs4s"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import collections\n",
|
||||
"import copy\n",
|
||||
"from typing import Mapping, Sequence\n",
|
||||
"\n",
|
||||
"import acme\n",
|
||||
"from acme import specs\n",
|
||||
"from acme.agents.tf import actors\n",
|
||||
"from acme.agents.tf import d4pg\n",
|
||||
"from acme.tf import networks\n",
|
||||
"from acme.tf import utils as tf2_utils\n",
|
||||
"from acme.utils import loggers\n",
|
||||
"from acme.wrappers import single_precision\n",
|
||||
"from acme.tf import utils as tf2_utils\n",
|
||||
"import numpy as np\n",
|
||||
"from rl_unplugged import dm_control_suite\n",
|
||||
"import sonnet as snt\n",
|
||||
"import tensorflow as tf"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "a2PCwF3bwBII"
|
||||
},
|
||||
"source": [
|
||||
"## Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "both",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "VaEJbXjampPy"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"task_name = 'cartpole_swingup' #@param\n",
|
||||
"tmp_path = '/tmp/dm_control_suite'\n",
|
||||
"gs_path = 'gs://rl_unplugged/dm_control_suite'\n",
|
||||
"\n",
|
||||
"!mkdir -p {tmp_path}/{task_name}\n",
|
||||
"!gsutil cp {gs_path}/{task_name}/* {tmp_path}/{task_name}\n",
|
||||
"\n",
|
||||
"num_shards_str, = !ls {tmp_path}/{task_name}/* | wc -l\n",
|
||||
"num_shards = int(num_shards_str)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "mQ1as51Mww7X"
|
||||
},
|
||||
"source": [
|
||||
"## Dataset and environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "5kHzJpfcw306"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch_size = 10 #@param\n",
|
||||
"\n",
|
||||
"task = dm_control_suite.ControlSuite(task_name)\n",
|
||||
"\n",
|
||||
"environment = task.environment\n",
|
||||
"environment_spec = specs.make_environment_spec(environment)\n",
|
||||
"\n",
|
||||
"dataset = dm_control_suite.dataset(\n",
|
||||
" '/tmp',\n",
|
||||
" data_path=task.data_path,\n",
|
||||
" shapes=task.shapes,\n",
|
||||
" uint8_features=task.uint8_features,\n",
|
||||
" num_threads=1,\n",
|
||||
" batch_size=batch_size,\n",
|
||||
" num_shards=num_shards)\n",
|
||||
"\n",
|
||||
"def discard_extras(sample):\n",
|
||||
" return sample._replace(data=sample.data[:5])\n",
|
||||
"\n",
|
||||
"dataset = dataset.map(discard_extras).batch(batch_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "adb0cyE5qu9G"
|
||||
},
|
||||
"source": [
|
||||
"## D4PG learner"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "83naOY7a_A4I"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create the networks to optimize.\n",
|
||||
"action_spec = environment_spec.actions\n",
|
||||
"action_size = np.prod(action_spec.shape, dtype=int)\n",
|
||||
"\n",
|
||||
"policy_network = snt.Sequential([\n",
|
||||
" tf2_utils.batch_concat,\n",
|
||||
" networks.LayerNormMLP(layer_sizes=(300, 200, action_size)),\n",
|
||||
" networks.TanhToSpec(spec=environment_spec.actions)])\n",
|
||||
"\n",
|
||||
"critic_network = snt.Sequential([\n",
|
||||
" networks.CriticMultiplexer(\n",
|
||||
" observation_network=tf2_utils.batch_concat,\n",
|
||||
" action_network=tf.identity,\n",
|
||||
" critic_network=networks.LayerNormMLP(\n",
|
||||
" layer_sizes=(400, 300),\n",
|
||||
" activate_final=True)),\n",
|
||||
" # Value-head gives a 51-atomed delta distribution over state-action values.\n",
|
||||
" networks.DiscreteValuedHead(vmin=-150., vmax=150., num_atoms=51)])\n",
|
||||
"\n",
|
||||
"# Create the target networks\n",
|
||||
"target_policy_network = copy.deepcopy(policy_network)\n",
|
||||
"target_critic_network = copy.deepcopy(critic_network)\n",
|
||||
"\n",
|
||||
"# Create variables.\n",
|
||||
"tf2_utils.create_variables(network=policy_network,\n",
|
||||
" input_spec=[environment_spec.observations])\n",
|
||||
"tf2_utils.create_variables(network=critic_network,\n",
|
||||
" input_spec=[environment_spec.observations,\n",
|
||||
" environment_spec.actions])\n",
|
||||
"tf2_utils.create_variables(network=target_policy_network,\n",
|
||||
" input_spec=[environment_spec.observations])\n",
|
||||
"tf2_utils.create_variables(network=target_critic_network,\n",
|
||||
" input_spec=[environment_spec.observations,\n",
|
||||
" environment_spec.actions])\n",
|
||||
"\n",
|
||||
"# The learner updates the parameters (and initializes them).\n",
|
||||
"learner = d4pg.D4PGLearner(\n",
|
||||
" policy_network=policy_network,\n",
|
||||
" critic_network=critic_network,\n",
|
||||
" target_policy_network=target_policy_network,\n",
|
||||
" target_critic_network=target_critic_network,\n",
|
||||
" dataset=dataset,\n",
|
||||
" discount=0.99,\n",
|
||||
" target_update_period=100)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "PYkjKaduy_xj"
|
||||
},
|
||||
"source": [
|
||||
"## Training loop"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"height": 34
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 3493,
|
||||
"status": "ok",
|
||||
"timestamp": 1593622068277,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -60
|
||||
},
|
||||
"id": "HbQOyCG4zCwa",
|
||||
"outputId": "cfb99d00-da2d-4ce8-e010-034a26e2ada0"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Learner] Critic Loss = 3.919 | Policy Loss = 0.326 | Steps = 1 | Walltime = 0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for _ in range(100):\n",
|
||||
" learner.step()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "LJ_XsuQSzFSV"
|
||||
},
|
||||
"source": [
|
||||
"## Evaluation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"height": 51
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 4197,
|
||||
"status": "ok",
|
||||
"timestamp": 1593620604870,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -60
|
||||
},
|
||||
"id": "blvNCANKb22J",
|
||||
"outputId": "af5ae073-9847-45cc-e51e-a803fc2148b0"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Evaluation] Episode Length = 1000 | Episode Return = 129.717 | Episodes = 2 | Steps = 2000 | Steps Per Second = 1480.399\n",
|
||||
"[Evaluation] Episode Length = 1000 | Episode Return = 34.790 | Episodes = 4 | Steps = 4000 | Steps Per Second = 1449.009\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create a logger.\n",
|
||||
"logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)\n",
|
||||
"\n",
|
||||
"# Create an environment loop.\n",
|
||||
"loop = acme.EnvironmentLoop(\n",
|
||||
" environment=environment,\n",
|
||||
" actor=actors.FeedForwardActor(policy_network),\n",
|
||||
" logger=logger)\n",
|
||||
"\n",
|
||||
"loop.run(5)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"last_runtime": {
|
||||
"build_target": "",
|
||||
"kind": "local"
|
||||
},
|
||||
"name": "RL Unplugged: Offline D4PG - DM control",
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1OerSIsTjv4d3rQCjAsi0ljPaLan87juJ",
|
||||
"timestamp": 1593080049369
|
||||
}
|
||||
]
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
r"""DM control suite and locomotion dataset examples.
|
||||
|
||||
Example:
|
||||
Instructions:
|
||||
> export TMP_PATH=/tmp/dataset
|
||||
> export TASK_NAME=humanoid_run
|
||||
> mkdir -p $TMP_PATH/$TASK_NAME
|
||||
> gsutil cp gs://rl_unplugged/dm_control_suite/$TASK_NAME/train-00000-of-00100 \
|
||||
$TMP_PATH/dm_control_suite/$TASK_NAME/train-00000-of-00001
|
||||
> python dm_control_suite_example.py --path=$TMP_PATH \
|
||||
--task_class=control_suite --task_name=$TASK_NAME
|
||||
"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import tree
|
||||
|
||||
from rl_unplugged import dm_control_suite
|
||||
|
||||
flags.DEFINE_string('path', '/tmp/dataset', 'Path to dataset.')
|
||||
flags.DEFINE_string('task_name', 'humanoid_run', 'Game.')
|
||||
flags.DEFINE_enum('task_class', 'control_suite',
|
||||
['humanoid', 'rodent', 'control_suite'],
|
||||
'Task classes.')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def main(_):
|
||||
|
||||
if FLAGS.task_class == 'control_suite':
|
||||
task = dm_control_suite.ControlSuite(task_name=FLAGS.task_name)
|
||||
elif FLAGS.task_class == 'humanoid':
|
||||
task = dm_control_suite.CmuThirdParty(task_name=FLAGS.task_name)
|
||||
elif FLAGS.task_class == 'rodent':
|
||||
task = dm_control_suite.Rodent(task_name=FLAGS.task_name)
|
||||
|
||||
ds = dm_control_suite.dataset(root_path=FLAGS.path,
|
||||
data_path=task.data_path,
|
||||
shapes=task.shapes,
|
||||
num_threads=1,
|
||||
batch_size=2,
|
||||
uint8_features=task.uint8_features,
|
||||
num_shards=1,
|
||||
shuffle_buffer_size=10)
|
||||
|
||||
for sample in ds.take(1):
|
||||
print('Data spec')
|
||||
print(tree.map_structure(lambda x: (x.dtype, x.shape), sample.data))
|
||||
|
||||
environment = task.environment
|
||||
timestep = environment.reset()
|
||||
print(tree.map_structure(lambda x: (x.dtype, x.shape), timestep.observation))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,58 @@
|
||||
absl-py==0.9.0
|
||||
astunparse==1.6.3
|
||||
atari-py==0.2.6
|
||||
cachetools==4.1.1
|
||||
certifi==2020.6.20
|
||||
chardet==3.0.4
|
||||
cloudpickle==1.3.0
|
||||
decorator==4.4.2
|
||||
dm-acme==0.1.7
|
||||
dm-control==0.0.319497192
|
||||
dm-env==1.2
|
||||
dm-reverb-nightly==0.1.0.dev20200616
|
||||
dm-sonnet==2.0.0
|
||||
dm-tree==0.1.5
|
||||
dopamine-rl==3.0.1
|
||||
future==0.18.2
|
||||
gast==0.3.3
|
||||
gin-config==0.3.0
|
||||
glfw==1.11.2
|
||||
google-auth==1.18.0
|
||||
google-auth-oauthlib==0.4.1
|
||||
google-pasta==0.2.0
|
||||
grpcio==1.30.0
|
||||
gym==0.17.2
|
||||
h5py==2.10.0
|
||||
idna==2.10
|
||||
Keras-Preprocessing==1.1.2
|
||||
lxml==4.5.1
|
||||
Markdown==3.2.2
|
||||
numpy==1.19.0
|
||||
oauthlib==3.1.0
|
||||
opencv-python==4.3.0.36
|
||||
opt-einsum==3.2.1
|
||||
Pillow==7.2.0
|
||||
pkg-resources==0.0.0
|
||||
portpicker==1.3.1
|
||||
protobuf==3.12.2
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.8
|
||||
pyglet==1.5.0
|
||||
PyOpenGL==3.1.5
|
||||
pyparsing==2.4.7
|
||||
requests==2.24.0
|
||||
requests-oauthlib==1.3.0
|
||||
rsa==4.6
|
||||
scipy==1.4.1
|
||||
six==1.15.0
|
||||
tabulate==0.8.7
|
||||
tb-nightly==2.3.0a20200706
|
||||
tensorboard-plugin-wit==1.7.0
|
||||
termcolor==1.1.0
|
||||
tf-estimator-nightly==2.4.0.dev2020070701
|
||||
tf-nightly==2.3.0.dev20200616
|
||||
tfp-nightly==0.11.0.dev20200707
|
||||
trfl==1.1.0
|
||||
urllib3==1.25.9
|
||||
Werkzeug==1.0.1
|
||||
wrapt==1.12.1
|
||||
@@ -0,0 +1,193 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Real World RL for RL Unplugged datasets.
|
||||
|
||||
Examples in the dataset represent SARS transitions stored when running a
|
||||
partially online trained agent as described in https://arxiv.org/abs/1904.12901.
|
||||
|
||||
We release 8 datasets in total -- with no combined challenge and easy combined
|
||||
challenge on the cartpole, walker, quadruped, and humanoid tasks. For details
|
||||
on how the dataset was generated, please refer to the paper.
|
||||
|
||||
Every transition in the dataset is a tuple containing the following features:
|
||||
|
||||
* o_t: Observation at time t. Observations have been processed using the
|
||||
canonical
|
||||
* a_t: Action taken at time t.
|
||||
* r_t: Reward at time t.
|
||||
* d_t: Discount at time t.
|
||||
* o_tp1: Observation at time t+1.
|
||||
* a_tp1: Action taken at time t+1. This is set to equal to the last action
|
||||
for the last timestep.
|
||||
|
||||
Note that this serves as an example. For optimal data loading speed, consider
|
||||
separating out data preprocessing from the data loading loop during training,
|
||||
e.g. saving the preprocessed data.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
from acme import wrappers
|
||||
import dm_env
|
||||
import realworldrl_suite.environments as rwrl_envs
|
||||
import reverb
|
||||
import tensorflow as tf
|
||||
import tree
|
||||
|
||||
DELIMITER = ':'
|
||||
# Control suite tasks have 1000 timesteps per episode. One additional timestep
|
||||
# accounts for the very first observation where no action has been taken yet.
|
||||
DEFAULT_NUM_TIMESTEPS = 1001
|
||||
|
||||
|
||||
def _decombine_key(k, delimiter = DELIMITER):
|
||||
return k.split(delimiter)
|
||||
|
||||
|
||||
def tf_example_to_feature_description(example,
|
||||
num_timesteps=DEFAULT_NUM_TIMESTEPS):
|
||||
"""Takes a string tensor encoding an tf example and returns its features."""
|
||||
if not tf.executing_eagerly():
|
||||
raise AssertionError(
|
||||
'tf_example_to_reverb_sample() only works under eager mode.')
|
||||
example = tf.train.Example.FromString(example.numpy())
|
||||
|
||||
ret = {}
|
||||
for k, v in example.features.feature.items():
|
||||
l = len(v.float_list.value)
|
||||
if l % num_timesteps:
|
||||
raise ValueError('Unexpected feature length %d. It should be divisible '
|
||||
'by num_timesteps: %d' % (l, num_timesteps))
|
||||
size = l // num_timesteps
|
||||
ret[k] = tf.io.FixedLenFeature([num_timesteps, size], tf.float32)
|
||||
return ret
|
||||
|
||||
|
||||
def tree_deflatten_with_delimiter(
|
||||
flat_dict, delimiter = DELIMITER):
|
||||
"""De-flattens a dict to its originally nested structure.
|
||||
|
||||
Does the opposite of {combine_nested_keys(k) :v
|
||||
for k, v in tree.flatten_with_path(nested_dicts)}
|
||||
Example: {'a:b': 1} -> {'a': {'b': 1}}
|
||||
Args:
|
||||
flat_dict: the keys of which equals the `path` separated by `delimiter`.
|
||||
delimiter: the delimiter that separates the keys of the nested dict.
|
||||
|
||||
Returns:
|
||||
An un-flattened dict.
|
||||
"""
|
||||
root = collections.defaultdict(dict)
|
||||
for delimited_key, v in flat_dict.items():
|
||||
keys = _decombine_key(delimited_key, delimiter=delimiter)
|
||||
node = root
|
||||
for k in keys[:-1]:
|
||||
node = node[k]
|
||||
node[keys[-1]] = v
|
||||
return dict(root)
|
||||
|
||||
|
||||
def get_slice_of_nested(nested, start,
|
||||
end):
|
||||
return tree.map_structure(lambda item: item[start:end], nested)
|
||||
|
||||
|
||||
def repeat_last_and_append_to_nested(nested):
|
||||
return tree.map_structure(
|
||||
lambda item: tf.concat((item, item[-1:]), axis=0), nested)
|
||||
|
||||
|
||||
def tf_example_to_reverb_sample(example,
|
||||
feature_description,
|
||||
num_timesteps=DEFAULT_NUM_TIMESTEPS):
|
||||
"""Converts the episode encoded as a tf example into SARSA reverb samples."""
|
||||
example = tf.io.parse_single_example(example, feature_description)
|
||||
kv = tree_deflatten_with_delimiter(example)
|
||||
output = (
|
||||
get_slice_of_nested(kv['observation'], 0, num_timesteps - 1),
|
||||
get_slice_of_nested(kv['action'], 1, num_timesteps),
|
||||
kv['reward'][1:num_timesteps],
|
||||
# The two fields below aren't needed for learning,
|
||||
# but are kept here to be compatible with acme learner format.
|
||||
kv['discount'][1:num_timesteps],
|
||||
get_slice_of_nested(kv['observation'], 1, num_timesteps),
|
||||
repeat_last_and_append_to_nested(
|
||||
get_slice_of_nested(kv['action'], 2, num_timesteps)))
|
||||
ret = tf.data.Dataset.from_tensor_slices(output)
|
||||
ret = ret.map(lambda *x: reverb.ReplaySample(info=b'None', data=x)) # pytype: disable=wrong-arg-types
|
||||
return ret
|
||||
|
||||
|
||||
def dataset(path,
|
||||
combined_challenge,
|
||||
domain,
|
||||
task,
|
||||
difficulty,
|
||||
num_shards = 100,
|
||||
shuffle_buffer_size = 100000):
|
||||
"""TF dataset of RWRL SARSA tuples."""
|
||||
path = os.path.join(
|
||||
path,
|
||||
f'combined_challenge_{combined_challenge}/{domain}/{task}/'
|
||||
f'offline_rl_challenge_{difficulty}'
|
||||
)
|
||||
filenames = [
|
||||
f'{path}/episodes.tfrecord-{i:05d}-of-{num_shards:05d}'
|
||||
for i in range(num_shards)
|
||||
]
|
||||
file_ds = tf.data.Dataset.from_tensor_slices(filenames)
|
||||
file_ds = file_ds.repeat().shuffle(num_shards)
|
||||
tf_example_ds = file_ds.interleave(
|
||||
tf.data.TFRecordDataset,
|
||||
cycle_length=tf.data.experimental.AUTOTUNE,
|
||||
block_length=5)
|
||||
|
||||
# Take one item to get the output types and shapes.
|
||||
example_item = None
|
||||
for example_item in tf.data.TFRecordDataset(filenames[:1]).take(1):
|
||||
break
|
||||
if example_item is None:
|
||||
raise ValueError('Empty dataset')
|
||||
|
||||
feature_description = tf_example_to_feature_description(example_item)
|
||||
|
||||
reverb_ds = tf_example_ds.interleave(
|
||||
functools.partial(
|
||||
tf_example_to_reverb_sample, feature_description=feature_description),
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE,
|
||||
deterministic=False)
|
||||
reverb_ds = reverb_ds.prefetch(100)
|
||||
reverb_ds = reverb_ds.shuffle(shuffle_buffer_size)
|
||||
|
||||
return reverb_ds
|
||||
|
||||
|
||||
def environment(
|
||||
combined_challenge,
|
||||
domain,
|
||||
task,
|
||||
log_output = None,
|
||||
environment_kwargs = None):
|
||||
"""RWRL environment."""
|
||||
env = rwrl_envs.load(
|
||||
domain_name=domain,
|
||||
task_name=task,
|
||||
log_output=log_output,
|
||||
environment_kwargs=environment_kwargs,
|
||||
combined_challenge=combined_challenge)
|
||||
return wrappers.SinglePrecisionWrapper(env)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,51 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=line-too-long
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
r"""RWRL dataset example.
|
||||
|
||||
Instructions:
|
||||
> export TMP_PATH=/tmp/dataset/rwrl
|
||||
> export DATA_PATH=combined_challenge_easy/quadruped/walk/offline_rl_challenge_easy
|
||||
> mkdir -p $TMP_PATH/$DATA_PATH
|
||||
> gsutil cp gs://rl_unplugged/rwrl/$DATA_PATH/episodes.tfrecord-00001-of-00015 \
|
||||
$TMP_PATH/$DATA_PATH/episodes.tfrecord-00000-of-00001
|
||||
> python rwrl_example.py --path=$TMP_PATH
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import tree
|
||||
|
||||
from rl_unplugged import rwrl
|
||||
|
||||
flags.DEFINE_string('path', '/tmp/dataset', 'Path to dataset.')
|
||||
|
||||
|
||||
def main(_):
|
||||
ds = rwrl.dataset(
|
||||
flags.FLAGS.path,
|
||||
combined_challenge='easy',
|
||||
domain='quadruped',
|
||||
task='walk',
|
||||
difficulty='easy',
|
||||
num_shards=1,
|
||||
shuffle_buffer_size=1)
|
||||
for replay_sample in ds.take(1):
|
||||
print(tree.map_structure(lambda x: (x.dtype, x.shape), replay_sample.data))
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
Reference in New Issue
Block a user