Files
deepmind-research/rl_unplugged/bsuite.ipynb
T
Srivatsan Srinivasan cfcf752cf9 Add colab example of DQN agent training on BSUITE RL Unplugged dataset.
PiperOrigin-RevId: 358396165
2021-03-03 10:11:12 +00:00

414 lines
14 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "KDiJzbb8KFvP"
},
"source": [
"Copyright 2021 DeepMind Technologies Limited.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n",
"this file except in compliance with the License. You may obtain a copy of the\n",
"License at\n",
"\n",
"[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
"\n",
"Unless required by applicable law or agreed to in writing, software distributed\n",
"under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n",
"CONDITIONS OF ANY KIND, either express or implied. See the License for the\n",
"specific language governing permissions and limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ULdrhOaVbsdO"
},
"source": [
"# RL Unplugged: Offline DQN - Bsuite\n",
"## Guide to training an Acme DQN agent on Bsuite data.\n",
"# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/rl_unplugged/atari_dqn.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"collapsed": true,
"id": "KH3O0zcXUeun"
},
"outputs": [],
"source": [
"# @title Installation\n",
"!pip install dm-acme\n",
"!pip install dm-acme[reverb]\n",
"!pip install dm-acme[tf]\n",
"!pip install dm-sonnet\n",
"!pip install dopamine-rl==3.1.2\n",
"!pip install atari-py\n",
"!pip install dm_env\n",
"!git clone https://github.com/deepmind/deepmind-research.git\n",
"%cd deepmind-research\n",
"\n",
"!git clone https://github.com/deepmind/bsuite.git\n",
"!pip install -q bsuite/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "HJ74Id-8MERq"
},
"outputs": [],
"source": [
"# @title Imports\n",
"import copy\n",
"import functools\n",
"from typing import Dict, Tuple\n",
"\n",
"\n",
"import acme\n",
"from acme.agents.tf import actors\n",
"from acme.agents.tf.dqn import learning as dqn\n",
"from acme.tf import utils as acme_utils\n",
"from acme.utils import loggers\n",
"\n",
"import sonnet as snt\n",
"import tensorflow as tf\n",
"\n",
"import numpy as np\n",
"import tree\n",
"import dm_env\n",
"import reverb\n",
"from acme.wrappers import base as wrapper_base\n",
"from acme.wrappers import single_precision\n",
"\n",
"import bsuite"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "j0Aahizs2ff-"
},
"outputs": [],
"source": [
"# @title Data Loading Utilities\n",
"def _parse_seq_tf_example(example, shapes, dtypes):\n",
" \"\"\"Parse tf.Example containing one or two episode steps.\"\"\"\n",
"\n",
" def to_feature(shape, dtype):\n",
" if np.issubdtype(dtype, np.floating):\n",
" return tf.io.FixedLenSequenceFeature(\n",
" shape=shape, dtype=tf.float32, allow_missing=True)\n",
" elif dtype == np.bool or np.issubdtype(dtype, np.integer):\n",
" return tf.io.FixedLenSequenceFeature(\n",
" shape=shape, dtype=tf.int64, allow_missing=True)\n",
" else:\n",
" raise ValueError(f'Unsupported type {dtype} to '\n",
" f'convert from TF Example.')\n",
"\n",
" feature_map = {}\n",
" for k, v in shapes.items():\n",
" feature_map[k] = to_feature(v, dtypes[k])\n",
"\n",
" parsed = tf.io.parse_single_example(example, features=feature_map)\n",
"\n",
" restructured = {}\n",
" for k, v in parsed.items():\n",
" dtype = tf.as_dtype(dtypes[k])\n",
" if v.dtype == dtype:\n",
" restructured[k] = parsed[k]\n",
" else:\n",
" restructured[k] = tf.cast(parsed[k], dtype)\n",
"\n",
" return restructured\n",
"\n",
"\n",
"def _build_sars_example(sequences):\n",
" \"\"\"Convert raw sequences into a Reverb SARS' sample.\"\"\"\n",
"\n",
" o_tm1 = tree.map_structure(lambda t: t[0], sequences['observation'])\n",
" o_t = tree.map_structure(lambda t: t[1], sequences['observation'])\n",
" a_tm1 = tree.map_structure(lambda t: t[0], sequences['action'])\n",
" r_t = tree.map_structure(lambda t: t[0], sequences['reward'])\n",
" p_t = tree.map_structure(\n",
" lambda d, st: d[0] * tf.cast(st[1] != dm_env.StepType.LAST, d.dtype),\n",
" sequences['discount'], sequences['step_type'])\n",
"\n",
" info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),\n",
" probability=tf.constant(1.0, tf.float64),\n",
" table_size=tf.constant(0, tf.int64),\n",
" priority=tf.constant(1.0, tf.float64))\n",
" return reverb.ReplaySample(info=info, data=(\n",
" o_tm1, a_tm1, r_t, p_t, o_t))\n",
"\n",
"\n",
"def bsuite_dataset_params(env):\n",
" \"\"\"Return shapes and dtypes parameters for bsuite offline dataset.\"\"\"\n",
" shapes = {\n",
" 'observation': env.observation_spec().shape,\n",
" 'action': env.action_spec().shape,\n",
" 'discount': env.discount_spec().shape,\n",
" 'reward': env.reward_spec().shape,\n",
" 'episodic_reward': env.reward_spec().shape,\n",
" 'step_type': (),\n",
" }\n",
"\n",
" dtypes = {\n",
" 'observation': env.observation_spec().dtype,\n",
" 'action': env.action_spec().dtype,\n",
" 'discount': env.discount_spec().dtype,\n",
" 'reward': env.reward_spec().dtype,\n",
" 'episodic_reward': env.reward_spec().dtype,\n",
" 'step_type': np.int64,\n",
" }\n",
"\n",
" return {'shapes': shapes, 'dtypes': dtypes}\n",
"\n",
"\n",
"def bsuite_dataset(path: str,\n",
" shapes: Dict[str, Tuple[int]],\n",
" dtypes: Dict[str, type], # pylint:disable=g-bare-generic\n",
" num_threads: int,\n",
" batch_size: int,\n",
" num_shards: int,\n",
" shuffle_buffer_size: int = 100000,\n",
" shuffle: bool = True) -\u003e tf.data.Dataset:\n",
" \"\"\"Create tf dataset for training.\"\"\"\n",
"\n",
" filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(\n",
" num_shards)]\n",
" file_ds = tf.data.Dataset.from_tensor_slices(filenames)\n",
" if shuffle:\n",
" file_ds = file_ds.repeat().shuffle(num_shards)\n",
"\n",
" example_ds = file_ds.interleave(\n",
" functools.partial(tf.data.TFRecordDataset, compression_type='GZIP'),\n",
" cycle_length=tf.data.experimental.AUTOTUNE,\n",
" block_length=5)\n",
" if shuffle:\n",
" example_ds = example_ds.shuffle(shuffle_buffer_size)\n",
"\n",
" def map_func(example):\n",
" example = _parse_seq_tf_example(example, shapes, dtypes)\n",
" return example\n",
"\n",
" example_ds = example_ds.map(map_func, num_parallel_calls=num_threads)\n",
" if shuffle:\n",
" example_ds = example_ds.repeat().shuffle(batch_size * 10)\n",
"\n",
" example_ds = example_ds.map(\n",
" _build_sars_example,\n",
" num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
" example_ds = example_ds.batch(batch_size, drop_remainder=True)\n",
"\n",
" example_ds = example_ds.prefetch(tf.data.experimental.AUTOTUNE)\n",
"\n",
" return example_ds\n",
"\n",
"\n",
"def load_offline_bsuite_dataset(\n",
" bsuite_id: str,\n",
" path: str,\n",
" batch_size: int,\n",
" num_shards: int = 1,\n",
" num_threads: int = 1,\n",
" single_precision_wrapper: bool = True,\n",
" shuffle: bool = True) -\u003e Tuple[tf.data.Dataset,\n",
" dm_env.Environment]:\n",
" \"\"\"Load bsuite offline dataset.\"\"\"\n",
" # Data file path format: {path}-?????-of-{num_shards:05d}\n",
" # The dataset is not deterministic and not repeated if shuffle = False.\n",
" environment = bsuite.load_from_id(bsuite_id)\n",
" if single_precision_wrapper:\n",
" environment = single_precision.SinglePrecisionWrapper(environment)\n",
" params = bsuite_dataset_params(environment)\n",
" dataset = bsuite_dataset(path=path,\n",
" num_threads=num_threads,\n",
" batch_size=batch_size,\n",
" num_shards=num_shards,\n",
" shuffle_buffer_size=2,\n",
" shuffle=shuffle,\n",
" **params)\n",
" return dataset, environment"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a9vF7LtYvLzy"
},
"source": [
"## Dataset and environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2I4lHR5jLeXm"
},
"outputs": [],
"source": [
"tmp_path = 'gs://rl_unplugged/bsuite'\n",
"level = 'catch'\n",
"dir = '0_0.0'\n",
"filename = '0_full'\n",
"path = f'{tmp_path}/{level}/{dir}/{filename}'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "01AHHNd9cEX2"
},
"outputs": [],
"source": [
"batch_size = 2 #@param\n",
"bsuite_id = level + '/0'\n",
"dataset, environment = load_offline_bsuite_dataset(bsuite_id=bsuite_id,\n",
" path=path,\n",
" batch_size=batch_size)\n",
"dataset = dataset.prefetch(1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BukOfOsmtSQn"
},
"source": [
"## DQN learner"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Jcjk1w6oHVX"
},
"outputs": [],
"source": [
"# Get total number of actions.\n",
"num_actions = environment.action_spec().num_values\n",
"obs_spec = environment.observation_spec()\n",
"print(environment.observation_spec())\n",
"# Create the Q network.\n",
"network = snt.Sequential([\n",
" snt.flatten,\n",
" snt.nets.MLP([56, 56]),\n",
" snt.nets.MLP([num_actions])\n",
" ])\n",
"acme_utils.create_variables(network, [environment.observation_spec()])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9CD2sNK-oA9S"
},
"outputs": [],
"source": [
"# Create a logger.\n",
"logger = loggers.TerminalLogger(label='learner', time_delta=1.)\n",
"\n",
"# Create the DQN learner.\n",
"learner = dqn.DQNLearner(\n",
" network=network,\n",
" target_network=copy.deepcopy(network),\n",
" discount=0.99,\n",
" learning_rate=3e-4,\n",
" importance_sampling_exponent=0.2,\n",
" target_update_period=2500,\n",
" dataset=dataset,\n",
" logger=logger)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oKeGQxzitXYC"
},
"source": [
"## Training loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VWZd5N-Qoz82"
},
"outputs": [],
"source": [
"for _ in range(10000):\n",
" learner.step()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qFQDrp0CgIzU"
},
"source": [
"## Evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DWYHBalygIDF"
},
"outputs": [],
"source": [
"# Create a logger.\n",
"logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)\n",
"\n",
"# Create an environment loop.\n",
"policy_network = snt.Sequential([\n",
" network,\n",
" lambda q: tf.argmax(q, axis=-1),\n",
"])\n",
"loop = acme.EnvironmentLoop(\n",
" environment=environment,\n",
" actor=actors.FeedForwardActor(policy_network=policy_network),\n",
" logger=logger)\n",
"\n",
"loop.run(400)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "RL_Unplugged_Offline_DQN_Bsuite.ipynb",
"provenance": [
{
"file_id": "1rhoeTFuan8B_q8eBeCEtQPiIeMux1pWU",
"timestamp": 1613499802780
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}