mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
cfcf752cf9
PiperOrigin-RevId: 358396165
414 lines
14 KiB
Plaintext
414 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "KDiJzbb8KFvP"
|
|
},
|
|
"source": [
|
|
"Copyright 2021 DeepMind Technologies Limited.\n",
|
|
"\n",
|
|
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n",
|
|
"this file except in compliance with the License. You may obtain a copy of the\n",
|
|
"License at\n",
|
|
"\n",
|
|
"[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
|
|
"\n",
|
|
"Unless required by applicable law or agreed to in writing, software distributed\n",
|
|
"under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n",
|
|
"CONDITIONS OF ANY KIND, either express or implied. See the License for the\n",
|
|
"specific language governing permissions and limitations under the License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "ULdrhOaVbsdO"
|
|
},
|
|
"source": [
|
|
"# RL Unplugged: Offline DQN - Bsuite\n",
|
|
"## Guide to training an Acme DQN agent on Bsuite data.\n",
|
|
"# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/rl_unplugged/atari_dqn.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n",
|
|
"\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"collapsed": true,
|
|
"id": "KH3O0zcXUeun"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# @title Installation\n",
|
|
"!pip install dm-acme\n",
|
|
"!pip install dm-acme[reverb]\n",
|
|
"!pip install dm-acme[tf]\n",
|
|
"!pip install dm-sonnet\n",
|
|
"!pip install dopamine-rl==3.1.2\n",
|
|
"!pip install atari-py\n",
|
|
"!pip install dm_env\n",
|
|
"!git clone https://github.com/deepmind/deepmind-research.git\n",
|
|
"%cd deepmind-research\n",
|
|
"\n",
|
|
"!git clone https://github.com/deepmind/bsuite.git\n",
|
|
"!pip install -q bsuite/"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "HJ74Id-8MERq"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# @title Imports\n",
|
|
"import copy\n",
|
|
"import functools\n",
|
|
"from typing import Dict, Tuple\n",
|
|
"\n",
|
|
"\n",
|
|
"import acme\n",
|
|
"from acme.agents.tf import actors\n",
|
|
"from acme.agents.tf.dqn import learning as dqn\n",
|
|
"from acme.tf import utils as acme_utils\n",
|
|
"from acme.utils import loggers\n",
|
|
"\n",
|
|
"import sonnet as snt\n",
|
|
"import tensorflow as tf\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"import tree\n",
|
|
"import dm_env\n",
|
|
"import reverb\n",
|
|
"from acme.wrappers import base as wrapper_base\n",
|
|
"from acme.wrappers import single_precision\n",
|
|
"\n",
|
|
"import bsuite"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "j0Aahizs2ff-"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# @title Data Loading Utilities\n",
|
|
"def _parse_seq_tf_example(example, shapes, dtypes):\n",
|
|
" \"\"\"Parse tf.Example containing one or two episode steps.\"\"\"\n",
|
|
"\n",
|
|
" def to_feature(shape, dtype):\n",
|
|
" if np.issubdtype(dtype, np.floating):\n",
|
|
" return tf.io.FixedLenSequenceFeature(\n",
|
|
" shape=shape, dtype=tf.float32, allow_missing=True)\n",
|
|
" elif dtype == np.bool or np.issubdtype(dtype, np.integer):\n",
|
|
" return tf.io.FixedLenSequenceFeature(\n",
|
|
" shape=shape, dtype=tf.int64, allow_missing=True)\n",
|
|
" else:\n",
|
|
" raise ValueError(f'Unsupported type {dtype} to '\n",
|
|
" f'convert from TF Example.')\n",
|
|
"\n",
|
|
" feature_map = {}\n",
|
|
" for k, v in shapes.items():\n",
|
|
" feature_map[k] = to_feature(v, dtypes[k])\n",
|
|
"\n",
|
|
" parsed = tf.io.parse_single_example(example, features=feature_map)\n",
|
|
"\n",
|
|
" restructured = {}\n",
|
|
" for k, v in parsed.items():\n",
|
|
" dtype = tf.as_dtype(dtypes[k])\n",
|
|
" if v.dtype == dtype:\n",
|
|
" restructured[k] = parsed[k]\n",
|
|
" else:\n",
|
|
" restructured[k] = tf.cast(parsed[k], dtype)\n",
|
|
"\n",
|
|
" return restructured\n",
|
|
"\n",
|
|
"\n",
|
|
"def _build_sars_example(sequences):\n",
|
|
" \"\"\"Convert raw sequences into a Reverb SARS' sample.\"\"\"\n",
|
|
"\n",
|
|
" o_tm1 = tree.map_structure(lambda t: t[0], sequences['observation'])\n",
|
|
" o_t = tree.map_structure(lambda t: t[1], sequences['observation'])\n",
|
|
" a_tm1 = tree.map_structure(lambda t: t[0], sequences['action'])\n",
|
|
" r_t = tree.map_structure(lambda t: t[0], sequences['reward'])\n",
|
|
" p_t = tree.map_structure(\n",
|
|
" lambda d, st: d[0] * tf.cast(st[1] != dm_env.StepType.LAST, d.dtype),\n",
|
|
" sequences['discount'], sequences['step_type'])\n",
|
|
"\n",
|
|
" info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),\n",
|
|
" probability=tf.constant(1.0, tf.float64),\n",
|
|
" table_size=tf.constant(0, tf.int64),\n",
|
|
" priority=tf.constant(1.0, tf.float64))\n",
|
|
" return reverb.ReplaySample(info=info, data=(\n",
|
|
" o_tm1, a_tm1, r_t, p_t, o_t))\n",
|
|
"\n",
|
|
"\n",
|
|
"def bsuite_dataset_params(env):\n",
|
|
" \"\"\"Return shapes and dtypes parameters for bsuite offline dataset.\"\"\"\n",
|
|
" shapes = {\n",
|
|
" 'observation': env.observation_spec().shape,\n",
|
|
" 'action': env.action_spec().shape,\n",
|
|
" 'discount': env.discount_spec().shape,\n",
|
|
" 'reward': env.reward_spec().shape,\n",
|
|
" 'episodic_reward': env.reward_spec().shape,\n",
|
|
" 'step_type': (),\n",
|
|
" }\n",
|
|
"\n",
|
|
" dtypes = {\n",
|
|
" 'observation': env.observation_spec().dtype,\n",
|
|
" 'action': env.action_spec().dtype,\n",
|
|
" 'discount': env.discount_spec().dtype,\n",
|
|
" 'reward': env.reward_spec().dtype,\n",
|
|
" 'episodic_reward': env.reward_spec().dtype,\n",
|
|
" 'step_type': np.int64,\n",
|
|
" }\n",
|
|
"\n",
|
|
" return {'shapes': shapes, 'dtypes': dtypes}\n",
|
|
"\n",
|
|
"\n",
|
|
"def bsuite_dataset(path: str,\n",
|
|
" shapes: Dict[str, Tuple[int]],\n",
|
|
" dtypes: Dict[str, type], # pylint:disable=g-bare-generic\n",
|
|
" num_threads: int,\n",
|
|
" batch_size: int,\n",
|
|
" num_shards: int,\n",
|
|
" shuffle_buffer_size: int = 100000,\n",
|
|
" shuffle: bool = True) -\u003e tf.data.Dataset:\n",
|
|
" \"\"\"Create tf dataset for training.\"\"\"\n",
|
|
"\n",
|
|
" filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(\n",
|
|
" num_shards)]\n",
|
|
" file_ds = tf.data.Dataset.from_tensor_slices(filenames)\n",
|
|
" if shuffle:\n",
|
|
" file_ds = file_ds.repeat().shuffle(num_shards)\n",
|
|
"\n",
|
|
" example_ds = file_ds.interleave(\n",
|
|
" functools.partial(tf.data.TFRecordDataset, compression_type='GZIP'),\n",
|
|
" cycle_length=tf.data.experimental.AUTOTUNE,\n",
|
|
" block_length=5)\n",
|
|
" if shuffle:\n",
|
|
" example_ds = example_ds.shuffle(shuffle_buffer_size)\n",
|
|
"\n",
|
|
" def map_func(example):\n",
|
|
" example = _parse_seq_tf_example(example, shapes, dtypes)\n",
|
|
" return example\n",
|
|
"\n",
|
|
" example_ds = example_ds.map(map_func, num_parallel_calls=num_threads)\n",
|
|
" if shuffle:\n",
|
|
" example_ds = example_ds.repeat().shuffle(batch_size * 10)\n",
|
|
"\n",
|
|
" example_ds = example_ds.map(\n",
|
|
" _build_sars_example,\n",
|
|
" num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
|
|
" example_ds = example_ds.batch(batch_size, drop_remainder=True)\n",
|
|
"\n",
|
|
" example_ds = example_ds.prefetch(tf.data.experimental.AUTOTUNE)\n",
|
|
"\n",
|
|
" return example_ds\n",
|
|
"\n",
|
|
"\n",
|
|
"def load_offline_bsuite_dataset(\n",
|
|
" bsuite_id: str,\n",
|
|
" path: str,\n",
|
|
" batch_size: int,\n",
|
|
" num_shards: int = 1,\n",
|
|
" num_threads: int = 1,\n",
|
|
" single_precision_wrapper: bool = True,\n",
|
|
" shuffle: bool = True) -\u003e Tuple[tf.data.Dataset,\n",
|
|
" dm_env.Environment]:\n",
|
|
" \"\"\"Load bsuite offline dataset.\"\"\"\n",
|
|
" # Data file path format: {path}-?????-of-{num_shards:05d}\n",
|
|
" # The dataset is not deterministic and not repeated if shuffle = False.\n",
|
|
" environment = bsuite.load_from_id(bsuite_id)\n",
|
|
" if single_precision_wrapper:\n",
|
|
" environment = single_precision.SinglePrecisionWrapper(environment)\n",
|
|
" params = bsuite_dataset_params(environment)\n",
|
|
" dataset = bsuite_dataset(path=path,\n",
|
|
" num_threads=num_threads,\n",
|
|
" batch_size=batch_size,\n",
|
|
" num_shards=num_shards,\n",
|
|
" shuffle_buffer_size=2,\n",
|
|
" shuffle=shuffle,\n",
|
|
" **params)\n",
|
|
" return dataset, environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "a9vF7LtYvLzy"
|
|
},
|
|
"source": [
|
|
"## Dataset and environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "2I4lHR5jLeXm"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"tmp_path = 'gs://rl_unplugged/bsuite'\n",
|
|
"level = 'catch'\n",
|
|
"dir = '0_0.0'\n",
|
|
"filename = '0_full'\n",
|
|
"path = f'{tmp_path}/{level}/{dir}/{filename}'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true,
|
|
"id": "01AHHNd9cEX2"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_size = 2 #@param\n",
|
|
"bsuite_id = level + '/0'\n",
|
|
"dataset, environment = load_offline_bsuite_dataset(bsuite_id=bsuite_id,\n",
|
|
" path=path,\n",
|
|
" batch_size=batch_size)\n",
|
|
"dataset = dataset.prefetch(1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "BukOfOsmtSQn"
|
|
},
|
|
"source": [
|
|
"## DQN learner"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "3Jcjk1w6oHVX"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Get total number of actions.\n",
|
|
"num_actions = environment.action_spec().num_values\n",
|
|
"obs_spec = environment.observation_spec()\n",
|
|
"print(environment.observation_spec())\n",
|
|
"# Create the Q network.\n",
|
|
"network = snt.Sequential([\n",
|
|
" snt.flatten,\n",
|
|
" snt.nets.MLP([56, 56]),\n",
|
|
" snt.nets.MLP([num_actions])\n",
|
|
" ])\n",
|
|
"acme_utils.create_variables(network, [environment.observation_spec()])\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "9CD2sNK-oA9S"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a logger.\n",
|
|
"logger = loggers.TerminalLogger(label='learner', time_delta=1.)\n",
|
|
"\n",
|
|
"# Create the DQN learner.\n",
|
|
"learner = dqn.DQNLearner(\n",
|
|
" network=network,\n",
|
|
" target_network=copy.deepcopy(network),\n",
|
|
" discount=0.99,\n",
|
|
" learning_rate=3e-4,\n",
|
|
" importance_sampling_exponent=0.2,\n",
|
|
" target_update_period=2500,\n",
|
|
" dataset=dataset,\n",
|
|
" logger=logger)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "oKeGQxzitXYC"
|
|
},
|
|
"source": [
|
|
"## Training loop"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "VWZd5N-Qoz82"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for _ in range(10000):\n",
|
|
" learner.step()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "qFQDrp0CgIzU"
|
|
},
|
|
"source": [
|
|
"## Evaluation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "DWYHBalygIDF"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a logger.\n",
|
|
"logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)\n",
|
|
"\n",
|
|
"# Create an environment loop.\n",
|
|
"policy_network = snt.Sequential([\n",
|
|
" network,\n",
|
|
" lambda q: tf.argmax(q, axis=-1),\n",
|
|
"])\n",
|
|
"loop = acme.EnvironmentLoop(\n",
|
|
" environment=environment,\n",
|
|
" actor=actors.FeedForwardActor(policy_network=policy_network),\n",
|
|
" logger=logger)\n",
|
|
"\n",
|
|
"loop.run(400)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"name": "RL_Unplugged_Offline_DQN_Bsuite.ipynb",
|
|
"provenance": [
|
|
{
|
|
"file_id": "1rhoeTFuan8B_q8eBeCEtQPiIeMux1pWU",
|
|
"timestamp": 1613499802780
|
|
}
|
|
]
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|