mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
99aaa6930a
PiperOrigin-RevId: 324071731
411 lines
11 KiB
Plaintext
411 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "KDiJzbb8KFvP"
|
|
},
|
|
"source": [
|
|
"Copyright 2020 DeepMind Technologies Limited.\n",
|
|
"\n",
|
|
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n",
|
|
"this file except in compliance with the License. You may obtain a copy of the\n",
|
|
"License at\n",
|
|
"\n",
|
|
"[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
|
|
"\n",
|
|
"Unless required by applicable law or agreed to in writing, software distributed\n",
|
|
"under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n",
|
|
"CONDITIONS OF ANY KIND, either express or implied. See the License for the\n",
|
|
"specific language governing permissions and limitations under the License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "ULdrhOaVbsdO"
|
|
},
|
|
"source": [
|
|
"# RL Unplugged: Offline DQN - Atari\n",
|
|
"## Guide to training an Acme DQN agent on Atari data.\n",
|
|
"# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/rl_unplugged/atari_dqn.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "xaJxoatMhJ71"
|
|
},
|
|
"source": [
|
|
"## Installation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "both",
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "KH3O0zcXUeun"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install dm-acme\n",
|
|
"!pip install dm-acme[reverb]\n",
|
|
"!pip install dm-acme[tf]\n",
|
|
"!pip install dm-sonnet\n",
|
|
"!pip install dopamine-rl==3.1.2\n",
|
|
"!pip install atari-py\n",
|
|
"!git clone https://github.com/deepmind/deepmind-research.git\n",
|
|
"%cd deepmind-research"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "c-H2d6UZi7Sf"
|
|
},
|
|
"source": [
|
|
"## Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "both",
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "HJ74Id-8MERq"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import copy\n",
|
|
"\n",
|
|
"import acme\n",
|
|
"from acme.agents.tf import actors\n",
|
|
"from acme.agents.tf.dqn import learning as dqn\n",
|
|
"from acme.tf import utils as acme_utils\n",
|
|
"from acme.utils import loggers\n",
|
|
"from rl_unplugged import atari\n",
|
|
"import sonnet as snt\n",
|
|
"import tensorflow as tf"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "JrOSnoWiY4Xl"
|
|
},
|
|
"source": [
|
|
"## Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "Vi3_H_h1zy_0"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"game = 'Pong' #@param\n",
|
|
"run = 1 #@param\n",
|
|
"\n",
|
|
"tmp_path = '/tmp/atari'\n",
|
|
"gs_path = 'gs://rl_unplugged/atari'\n",
|
|
"\n",
|
|
"!mkdir -p {tmp_path}/{game}\n",
|
|
"\n",
|
|
"src = f'{gs_path}/{game}/run_{run}-00000-of-00100'\n",
|
|
"dest = f'{tmp_path}/{game}/run_{run}-00000-of-00001'\n",
|
|
"!gsutil cp {src} {dest}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "a9vF7LtYvLzy"
|
|
},
|
|
"source": [
|
|
"## Dataset and environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "01AHHNd9cEX2"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_size = 10 #@param\n",
|
|
"\n",
|
|
"def discard_extras(sample):\n",
|
|
" return sample._replace(data=sample.data[:5])\n",
|
|
"\n",
|
|
"dataset = atari.dataset(path=tmp_path, game='Pong', run=1, num_shards=1)\n",
|
|
"# Small batch size, experiments in the paper were run with batch size 256.\n",
|
|
"dataset = dataset.map(discard_extras).batch(batch_size)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "KoYBhjPtI_N6"
|
|
},
|
|
"source": [
|
|
""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "4b4_rHwCmQg-"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"environment = atari.environment(game='Pong')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "BukOfOsmtSQn"
|
|
},
|
|
"source": [
|
|
"## DQN learner"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"executionInfo": {
|
|
"elapsed": 83,
|
|
"status": "ok",
|
|
"timestamp": 1593614657342,
|
|
"user": {
|
|
"displayName": "",
|
|
"photoUrl": "",
|
|
"userId": ""
|
|
},
|
|
"user_tz": -60
|
|
},
|
|
"id": "3Jcjk1w6oHVX",
|
|
"outputId": "1746b0bb-5a5c-45dd-b5a1-c77852545e12"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"TensorSpec(shape=(6,), dtype=tf.float32, name=None)"
|
|
]
|
|
},
|
|
"execution_count": 20,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Get total number of actions.\n",
|
|
"num_actions = environment.action_spec().num_values\n",
|
|
"\n",
|
|
"# Create the Q network.\n",
|
|
"network = snt.Sequential([\n",
|
|
" lambda x: tf.image.convert_image_dtype(x, tf.float32),\n",
|
|
" snt.Conv2D(32, [8, 8], [4, 4]),\n",
|
|
" tf.nn.relu,\n",
|
|
" snt.Conv2D(64, [4, 4], [2, 2]),\n",
|
|
" tf.nn.relu,\n",
|
|
" snt.Conv2D(64, [3, 3], [1, 1]),\n",
|
|
" tf.nn.relu,\n",
|
|
" snt.Flatten(),\n",
|
|
" snt.nets.MLP([512, num_actions])\n",
|
|
"])\n",
|
|
"acme_utils.create_variables(network, [environment.observation_spec()])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "9CD2sNK-oA9S"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a logger.\n",
|
|
"logger = loggers.TerminalLogger(label='learner', time_delta=1.)\n",
|
|
"\n",
|
|
"# Create the DQN learner.\n",
|
|
"learner = dqn.DQNLearner(\n",
|
|
" network=network,\n",
|
|
" target_network=copy.deepcopy(network),\n",
|
|
" discount=0.99,\n",
|
|
" learning_rate=3e-4,\n",
|
|
" importance_sampling_exponent=0.2,\n",
|
|
" target_update_period=2500,\n",
|
|
" dataset=dataset,\n",
|
|
" logger=logger)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "oKeGQxzitXYC"
|
|
},
|
|
"source": [
|
|
"## Training loop"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 51
|
|
},
|
|
"colab_type": "code",
|
|
"executionInfo": {
|
|
"elapsed": 4694,
|
|
"status": "ok",
|
|
"timestamp": 1593614662237,
|
|
"user": {
|
|
"displayName": "",
|
|
"photoUrl": "",
|
|
"userId": ""
|
|
},
|
|
"user_tz": -60
|
|
},
|
|
"id": "VWZd5N-Qoz82",
|
|
"outputId": "5ee2ce7c-b3fe-483b-8893-5a6e13519f48"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[Learner] Loss = 0.003 | Steps = 1 | Walltime = 0\n",
|
|
"[Learner] Loss = 0.004 | Steps = 54 | Walltime = 1.126\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for _ in range(100):\n",
|
|
" learner.step()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "qFQDrp0CgIzU"
|
|
},
|
|
"source": [
|
|
"## Evaluation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 102
|
|
},
|
|
"colab_type": "code",
|
|
"executionInfo": {
|
|
"elapsed": 15099,
|
|
"status": "ok",
|
|
"timestamp": 1593614677360,
|
|
"user": {
|
|
"displayName": "",
|
|
"photoUrl": "",
|
|
"userId": ""
|
|
},
|
|
"user_tz": -60
|
|
},
|
|
"id": "DWYHBalygIDF",
|
|
"outputId": "4ec412c3-810a-4208-b521-919a8ece40df"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[Evaluation] Episode Length = 842 | Episode Return = -20.000 | Episodes = 1 | Steps = 842 | Steps Per Second = 265.850\n",
|
|
"[Evaluation] Episode Length = 792 | Episode Return = -21.000 | Episodes = 2 | Steps = 1634 | Steps Per Second = 270.043\n",
|
|
"[Evaluation] Episode Length = 812 | Episode Return = -21.000 | Episodes = 3 | Steps = 2446 | Steps Per Second = 274.792\n",
|
|
"[Evaluation] Episode Length = 812 | Episode Return = -21.000 | Episodes = 4 | Steps = 3258 | Steps Per Second = 270.967\n",
|
|
"[Evaluation] Episode Length = 812 | Episode Return = -21.000 | Episodes = 5 | Steps = 4070 | Steps Per Second = 274.253\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Create a logger.\n",
|
|
"logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)\n",
|
|
"\n",
|
|
"# Create an environment loop.\n",
|
|
"policy_network = snt.Sequential([\n",
|
|
" network,\n",
|
|
" lambda q: tf.argmax(q, axis=-1),\n",
|
|
"])\n",
|
|
"loop = acme.EnvironmentLoop(\n",
|
|
" environment=environment,\n",
|
|
" actor=actors.FeedForwardActor(policy_network=policy_network),\n",
|
|
" logger=logger)\n",
|
|
"\n",
|
|
"loop.run(5)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"last_runtime": {
|
|
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
|
|
"kind": "private"
|
|
},
|
|
"name": "RL Unplugged: Offline DQN - Atari",
|
|
"provenance": [
|
|
{
|
|
"file_id": "1g9yTbTuk9aeERxWflOWqUGpx2M3osx0l",
|
|
"timestamp": 1593685504110
|
|
}
|
|
]
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|