diff --git a/rl_unplugged/bsuite.ipynb b/rl_unplugged/bsuite.ipynb new file mode 100644 index 0000000..1b4961f --- /dev/null +++ b/rl_unplugged/bsuite.ipynb @@ -0,0 +1,413 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "KDiJzbb8KFvP" + }, + "source": [ + "Copyright 2021 DeepMind Technologies Limited.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n", + "this file except in compliance with the License. You may obtain a copy of the\n", + "License at\n", + "\n", + "[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n", + "\n", + "Unless required by applicable law or agreed to in writing, software distributed\n", + "under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n", + "CONDITIONS OF ANY KIND, either express or implied. See the License for the\n", + "specific language governing permissions and limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ULdrhOaVbsdO" + }, + "source": [ + "# RL Unplugged: Offline DQN - Bsuite\n", + "## Guide to training an Acme DQN agent on Bsuite data.\n", + "# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/rl_unplugged/atari_dqn.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "collapsed": true, + "id": "KH3O0zcXUeun" + }, + "outputs": [], + "source": [ + "# @title Installation\n", + "!pip install dm-acme\n", + "!pip install dm-acme[reverb]\n", + "!pip install dm-acme[tf]\n", + "!pip install dm-sonnet\n", + "!pip install dopamine-rl==3.1.2\n", + "!pip install atari-py\n", + "!pip install dm_env\n", + "!git clone https://github.com/deepmind/deepmind-research.git\n", + "%cd deepmind-research\n", + "\n", + "!git clone https://github.com/deepmind/bsuite.git\n", + "!pip install -q bsuite/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "HJ74Id-8MERq" + }, + "outputs": [], + "source": [ + "# @title Imports\n", + "import copy\n", + "import functools\n", + "from typing import Dict, Tuple\n", + "\n", + "\n", + "import acme\n", + "from acme.agents.tf import actors\n", + "from acme.agents.tf.dqn import learning as dqn\n", + "from acme.tf import utils as acme_utils\n", + "from acme.utils import loggers\n", + "\n", + "import sonnet as snt\n", + "import tensorflow as tf\n", + "\n", + "import numpy as np\n", + "import tree\n", + "import dm_env\n", + "import reverb\n", + "from acme.wrappers import base as wrapper_base\n", + "from acme.wrappers import single_precision\n", + "\n", + "import bsuite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "j0Aahizs2ff-" + }, + "outputs": [], + "source": [ + "# @title Data Loading Utilities\n", + "def _parse_seq_tf_example(example, shapes, dtypes):\n", + " \"\"\"Parse tf.Example containing one or two episode steps.\"\"\"\n", + "\n", + " def to_feature(shape, dtype):\n", + " if np.issubdtype(dtype, np.floating):\n", + " return tf.io.FixedLenSequenceFeature(\n", + " shape=shape, dtype=tf.float32, allow_missing=True)\n", + " elif dtype == np.bool or np.issubdtype(dtype, np.integer):\n", + " return tf.io.FixedLenSequenceFeature(\n", + " shape=shape, dtype=tf.int64, allow_missing=True)\n", + " else:\n", + " raise ValueError(f'Unsupported type {dtype} to '\n", + " f'convert from TF Example.')\n", + "\n", + " feature_map = {}\n", + " for k, v in shapes.items():\n", + " feature_map[k] = to_feature(v, dtypes[k])\n", + "\n", + " parsed = tf.io.parse_single_example(example, features=feature_map)\n", + "\n", + " restructured = {}\n", + " for k, v in parsed.items():\n", + " dtype = tf.as_dtype(dtypes[k])\n", + " if v.dtype == dtype:\n", + " restructured[k] = parsed[k]\n", + " else:\n", + " restructured[k] = tf.cast(parsed[k], dtype)\n", + "\n", + " return restructured\n", + "\n", + "\n", + "def _build_sars_example(sequences):\n", + " \"\"\"Convert raw sequences into a Reverb SARS' sample.\"\"\"\n", + "\n", + " o_tm1 = tree.map_structure(lambda t: t[0], sequences['observation'])\n", + " o_t = tree.map_structure(lambda t: t[1], sequences['observation'])\n", + " a_tm1 = tree.map_structure(lambda t: t[0], sequences['action'])\n", + " r_t = tree.map_structure(lambda t: t[0], sequences['reward'])\n", + " p_t = tree.map_structure(\n", + " lambda d, st: d[0] * tf.cast(st[1] != dm_env.StepType.LAST, d.dtype),\n", + " sequences['discount'], sequences['step_type'])\n", + "\n", + " info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),\n", + " probability=tf.constant(1.0, tf.float64),\n", + " table_size=tf.constant(0, tf.int64),\n", + " priority=tf.constant(1.0, tf.float64))\n", + " return reverb.ReplaySample(info=info, data=(\n", + " o_tm1, a_tm1, r_t, p_t, o_t))\n", + "\n", + "\n", + "def bsuite_dataset_params(env):\n", + " \"\"\"Return shapes and dtypes parameters for bsuite offline dataset.\"\"\"\n", + " shapes = {\n", + " 'observation': env.observation_spec().shape,\n", + " 'action': env.action_spec().shape,\n", + " 'discount': env.discount_spec().shape,\n", + " 'reward': env.reward_spec().shape,\n", + " 'episodic_reward': env.reward_spec().shape,\n", + " 'step_type': (),\n", + " }\n", + "\n", + " dtypes = {\n", + " 'observation': env.observation_spec().dtype,\n", + " 'action': env.action_spec().dtype,\n", + " 'discount': env.discount_spec().dtype,\n", + " 'reward': env.reward_spec().dtype,\n", + " 'episodic_reward': env.reward_spec().dtype,\n", + " 'step_type': np.int64,\n", + " }\n", + "\n", + " return {'shapes': shapes, 'dtypes': dtypes}\n", + "\n", + "\n", + "def bsuite_dataset(path: str,\n", + " shapes: Dict[str, Tuple[int]],\n", + " dtypes: Dict[str, type], # pylint:disable=g-bare-generic\n", + " num_threads: int,\n", + " batch_size: int,\n", + " num_shards: int,\n", + " shuffle_buffer_size: int = 100000,\n", + " shuffle: bool = True) -\u003e tf.data.Dataset:\n", + " \"\"\"Create tf dataset for training.\"\"\"\n", + "\n", + " filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(\n", + " num_shards)]\n", + " file_ds = tf.data.Dataset.from_tensor_slices(filenames)\n", + " if shuffle:\n", + " file_ds = file_ds.repeat().shuffle(num_shards)\n", + "\n", + " example_ds = file_ds.interleave(\n", + " functools.partial(tf.data.TFRecordDataset, compression_type='GZIP'),\n", + " cycle_length=tf.data.experimental.AUTOTUNE,\n", + " block_length=5)\n", + " if shuffle:\n", + " example_ds = example_ds.shuffle(shuffle_buffer_size)\n", + "\n", + " def map_func(example):\n", + " example = _parse_seq_tf_example(example, shapes, dtypes)\n", + " return example\n", + "\n", + " example_ds = example_ds.map(map_func, num_parallel_calls=num_threads)\n", + " if shuffle:\n", + " example_ds = example_ds.repeat().shuffle(batch_size * 10)\n", + "\n", + " example_ds = example_ds.map(\n", + " _build_sars_example,\n", + " num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", + " example_ds = example_ds.batch(batch_size, drop_remainder=True)\n", + "\n", + " example_ds = example_ds.prefetch(tf.data.experimental.AUTOTUNE)\n", + "\n", + " return example_ds\n", + "\n", + "\n", + "def load_offline_bsuite_dataset(\n", + " bsuite_id: str,\n", + " path: str,\n", + " batch_size: int,\n", + " num_shards: int = 1,\n", + " num_threads: int = 1,\n", + " single_precision_wrapper: bool = True,\n", + " shuffle: bool = True) -\u003e Tuple[tf.data.Dataset,\n", + " dm_env.Environment]:\n", + " \"\"\"Load bsuite offline dataset.\"\"\"\n", + " # Data file path format: {path}-?????-of-{num_shards:05d}\n", + " # The dataset is not deterministic and not repeated if shuffle = False.\n", + " environment = bsuite.load_from_id(bsuite_id)\n", + " if single_precision_wrapper:\n", + " environment = single_precision.SinglePrecisionWrapper(environment)\n", + " params = bsuite_dataset_params(environment)\n", + " dataset = bsuite_dataset(path=path,\n", + " num_threads=num_threads,\n", + " batch_size=batch_size,\n", + " num_shards=num_shards,\n", + " shuffle_buffer_size=2,\n", + " shuffle=shuffle,\n", + " **params)\n", + " return dataset, environment" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a9vF7LtYvLzy" + }, + "source": [ + "## Dataset and environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2I4lHR5jLeXm" + }, + "outputs": [], + "source": [ + "tmp_path = 'gs://rl_unplugged/bsuite'\n", + "level = 'catch'\n", + "dir = '0_0.0'\n", + "filename = '0_full'\n", + "path = f'{tmp_path}/{level}/{dir}/{filename}'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "01AHHNd9cEX2" + }, + "outputs": [], + "source": [ + "batch_size = 2 #@param\n", + "bsuite_id = level + '/0'\n", + "dataset, environment = load_offline_bsuite_dataset(bsuite_id=bsuite_id,\n", + " path=path,\n", + " batch_size=batch_size)\n", + "dataset = dataset.prefetch(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BukOfOsmtSQn" + }, + "source": [ + "## DQN learner" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3Jcjk1w6oHVX" + }, + "outputs": [], + "source": [ + "# Get total number of actions.\n", + "num_actions = environment.action_spec().num_values\n", + "obs_spec = environment.observation_spec()\n", + "print(environment.observation_spec())\n", + "# Create the Q network.\n", + "network = snt.Sequential([\n", + " snt.flatten,\n", + " snt.nets.MLP([56, 56]),\n", + " snt.nets.MLP([num_actions])\n", + " ])\n", + "acme_utils.create_variables(network, [environment.observation_spec()])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9CD2sNK-oA9S" + }, + "outputs": [], + "source": [ + "# Create a logger.\n", + "logger = loggers.TerminalLogger(label='learner', time_delta=1.)\n", + "\n", + "# Create the DQN learner.\n", + "learner = dqn.DQNLearner(\n", + " network=network,\n", + " target_network=copy.deepcopy(network),\n", + " discount=0.99,\n", + " learning_rate=3e-4,\n", + " importance_sampling_exponent=0.2,\n", + " target_update_period=2500,\n", + " dataset=dataset,\n", + " logger=logger)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oKeGQxzitXYC" + }, + "source": [ + "## Training loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VWZd5N-Qoz82" + }, + "outputs": [], + "source": [ + "for _ in range(10000):\n", + " learner.step()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qFQDrp0CgIzU" + }, + "source": [ + "## Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DWYHBalygIDF" + }, + "outputs": [], + "source": [ + "# Create a logger.\n", + "logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)\n", + "\n", + "# Create an environment loop.\n", + "policy_network = snt.Sequential([\n", + " network,\n", + " lambda q: tf.argmax(q, axis=-1),\n", + "])\n", + "loop = acme.EnvironmentLoop(\n", + " environment=environment,\n", + " actor=actors.FeedForwardActor(policy_network=policy_network),\n", + " logger=logger)\n", + "\n", + "loop.run(400)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "RL_Unplugged_Offline_DQN_Bsuite.ipynb", + "provenance": [ + { + "file_id": "1rhoeTFuan8B_q8eBeCEtQPiIeMux1pWU", + "timestamp": 1613499802780 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}