mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
59c0cf5044
PiperOrigin-RevId: 323321348
475 lines
14 KiB
Plaintext
475 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "KDiJzbb8KFvP"
|
|
},
|
|
"source": [
|
|
"Copyright 2020 DeepMind Technologies Limited.\n",
|
|
"\n",
|
|
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n",
|
|
"this file except in compliance with the License. You may obtain a copy of the\n",
|
|
"License at\n",
|
|
"\n",
|
|
"[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
|
|
"\n",
|
|
"Unless required by applicable law or agreed to in writing, software distributed\n",
|
|
"under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n",
|
|
"CONDITIONS OF ANY KIND, either express or implied. See the License for the\n",
|
|
"specific language governing permissions and limitations under the License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "zzJlIvx4tnrM"
|
|
},
|
|
"source": [
|
|
"# RL Unplugged: Offline D4PG - DM control\n",
|
|
"\n",
|
|
"## Guide to training an Acme D4PG agent on DM control data.\n",
|
|
"# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/rl_unplugged/dm_control_suite_d4pg.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "o1eig5zGEL4y"
|
|
},
|
|
"source": [
|
|
"## Installation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "WbpMoLbgEL41"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install dm-acme\n",
|
|
"!pip install dm-acme[reverb]\n",
|
|
"!pip install dm-acme[tf]\n",
|
|
"!pip install dm-sonnet\n",
|
|
"!git clone https://github.com/deepmind/deepmind-research.git\n",
|
|
"%cd deepmind-research"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "04bMANoeEPM3"
|
|
},
|
|
"source": [
|
|
"### dm_control\n",
|
|
"\n",
|
|
"More detailed instructions in [this tutorial](https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb#scrollTo=YvyGCsgSCxHQ)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "VEEj3Qw60y73"
|
|
},
|
|
"source": [
|
|
"#### Institutional MuJoCo license."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "both",
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "IbZxYDxzoz5R"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#@title Edit and run\n",
|
|
"mjkey = \"\"\"\n",
|
|
"\n",
|
|
"REPLACE THIS LINE WITH YOUR MUJOCO LICENSE KEY\n",
|
|
"\n",
|
|
"\"\"\".strip()\n",
|
|
"\n",
|
|
"mujoco_dir = \"$HOME/.mujoco\"\n",
|
|
"\n",
|
|
"# Install OpenGL deps\n",
|
|
"!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
|
|
" libgl1-mesa-glx libosmesa6 libglew2.0\n",
|
|
"\n",
|
|
"# Fetch MuJoCo binaries from Roboti\n",
|
|
"!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
|
|
"!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
|
|
"\n",
|
|
"# Copy over MuJoCo license\n",
|
|
"!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
|
|
"\n",
|
|
"\n",
|
|
"# Configure dm_control to use the OSMesa rendering backend\n",
|
|
"%env MUJOCO_GL=osmesa\n",
|
|
"\n",
|
|
"# Install dm_control\n",
|
|
"!pip install dm_control"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "-_7tVg-zzjzW"
|
|
},
|
|
"source": [
|
|
"#### Machine-locked MuJoCo license."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "OvMLEDE-D9oF"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#@title Add your MuJoCo License and run\n",
|
|
"mjkey = \"\"\"\n",
|
|
"\"\"\".strip()\n",
|
|
"\n",
|
|
"mujoco_dir = \"$HOME/.mujoco\"\n",
|
|
"\n",
|
|
"# Install OpenGL dependencies\n",
|
|
"!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
|
|
" libgl1-mesa-glx libosmesa6 libglew2.0\n",
|
|
"\n",
|
|
"# Get MuJoCo binaries\n",
|
|
"!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
|
|
"!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
|
|
"\n",
|
|
"# Copy over MuJoCo license\n",
|
|
"!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
|
|
"\n",
|
|
"# Install dm_control\n",
|
|
"!pip install dm_control[locomotion_mazes]\n",
|
|
"\n",
|
|
"# Configure dm_control to use the OSMesa rendering backend\n",
|
|
"%env MUJOCO_GL=osmesa"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "IE2nV9Hivnv5"
|
|
},
|
|
"source": [
|
|
"## Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "RI7NgnJIvs4s"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import collections\n",
|
|
"import copy\n",
|
|
"from typing import Mapping, Sequence\n",
|
|
"\n",
|
|
"import acme\n",
|
|
"from acme import specs\n",
|
|
"from acme.agents.tf import actors\n",
|
|
"from acme.agents.tf import d4pg\n",
|
|
"from acme.tf import networks\n",
|
|
"from acme.tf import utils as tf2_utils\n",
|
|
"from acme.utils import loggers\n",
|
|
"from acme.wrappers import single_precision\n",
|
|
"from acme.tf import utils as tf2_utils\n",
|
|
"import numpy as np\n",
|
|
"from rl_unplugged import dm_control_suite\n",
|
|
"import sonnet as snt\n",
|
|
"import tensorflow as tf"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "a2PCwF3bwBII"
|
|
},
|
|
"source": [
|
|
"## Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "both",
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "VaEJbXjampPy"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"task_name = 'cartpole_swingup' #@param\n",
|
|
"tmp_path = '/tmp/dm_control_suite'\n",
|
|
"gs_path = 'gs://rl_unplugged/dm_control_suite'\n",
|
|
"\n",
|
|
"!mkdir -p {tmp_path}/{task_name}\n",
|
|
"!gsutil cp {gs_path}/{task_name}/* {tmp_path}/{task_name}\n",
|
|
"\n",
|
|
"num_shards_str, = !ls {tmp_path}/{task_name}/* | wc -l\n",
|
|
"num_shards = int(num_shards_str)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "mQ1as51Mww7X"
|
|
},
|
|
"source": [
|
|
"## Dataset and environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "5kHzJpfcw306"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_size = 10 #@param\n",
|
|
"\n",
|
|
"task = dm_control_suite.ControlSuite(task_name)\n",
|
|
"\n",
|
|
"environment = task.environment\n",
|
|
"environment_spec = specs.make_environment_spec(environment)\n",
|
|
"\n",
|
|
"dataset = dm_control_suite.dataset(\n",
|
|
" '/tmp',\n",
|
|
" data_path=task.data_path,\n",
|
|
" shapes=task.shapes,\n",
|
|
" uint8_features=task.uint8_features,\n",
|
|
" num_threads=1,\n",
|
|
" batch_size=batch_size,\n",
|
|
" num_shards=num_shards)\n",
|
|
"\n",
|
|
"def discard_extras(sample):\n",
|
|
" return sample._replace(data=sample.data[:5])\n",
|
|
"\n",
|
|
"dataset = dataset.map(discard_extras).batch(batch_size)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "adb0cyE5qu9G"
|
|
},
|
|
"source": [
|
|
"## D4PG learner"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "83naOY7a_A4I"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create the networks to optimize.\n",
|
|
"action_spec = environment_spec.actions\n",
|
|
"action_size = np.prod(action_spec.shape, dtype=int)\n",
|
|
"\n",
|
|
"policy_network = snt.Sequential([\n",
|
|
" tf2_utils.batch_concat,\n",
|
|
" networks.LayerNormMLP(layer_sizes=(300, 200, action_size)),\n",
|
|
" networks.TanhToSpec(spec=environment_spec.actions)])\n",
|
|
"\n",
|
|
"critic_network = snt.Sequential([\n",
|
|
" networks.CriticMultiplexer(\n",
|
|
" observation_network=tf2_utils.batch_concat,\n",
|
|
" action_network=tf.identity,\n",
|
|
" critic_network=networks.LayerNormMLP(\n",
|
|
" layer_sizes=(400, 300),\n",
|
|
" activate_final=True)),\n",
|
|
" # Value-head gives a 51-atomed delta distribution over state-action values.\n",
|
|
" networks.DiscreteValuedHead(vmin=-150., vmax=150., num_atoms=51)])\n",
|
|
"\n",
|
|
"# Create the target networks\n",
|
|
"target_policy_network = copy.deepcopy(policy_network)\n",
|
|
"target_critic_network = copy.deepcopy(critic_network)\n",
|
|
"\n",
|
|
"# Create variables.\n",
|
|
"tf2_utils.create_variables(network=policy_network,\n",
|
|
" input_spec=[environment_spec.observations])\n",
|
|
"tf2_utils.create_variables(network=critic_network,\n",
|
|
" input_spec=[environment_spec.observations,\n",
|
|
" environment_spec.actions])\n",
|
|
"tf2_utils.create_variables(network=target_policy_network,\n",
|
|
" input_spec=[environment_spec.observations])\n",
|
|
"tf2_utils.create_variables(network=target_critic_network,\n",
|
|
" input_spec=[environment_spec.observations,\n",
|
|
" environment_spec.actions])\n",
|
|
"\n",
|
|
"# The learner updates the parameters (and initializes them).\n",
|
|
"learner = d4pg.D4PGLearner(\n",
|
|
" policy_network=policy_network,\n",
|
|
" critic_network=critic_network,\n",
|
|
" target_policy_network=target_policy_network,\n",
|
|
" target_critic_network=target_critic_network,\n",
|
|
" dataset=dataset,\n",
|
|
" discount=0.99,\n",
|
|
" target_update_period=100)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "PYkjKaduy_xj"
|
|
},
|
|
"source": [
|
|
"## Training loop"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"executionInfo": {
|
|
"elapsed": 3493,
|
|
"status": "ok",
|
|
"timestamp": 1593622068277,
|
|
"user": {
|
|
"displayName": "",
|
|
"photoUrl": "",
|
|
"userId": ""
|
|
},
|
|
"user_tz": -60
|
|
},
|
|
"id": "HbQOyCG4zCwa",
|
|
"outputId": "cfb99d00-da2d-4ce8-e010-034a26e2ada0"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[Learner] Critic Loss = 3.919 | Policy Loss = 0.326 | Steps = 1 | Walltime = 0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for _ in range(100):\n",
|
|
" learner.step()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "LJ_XsuQSzFSV"
|
|
},
|
|
"source": [
|
|
"## Evaluation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 51
|
|
},
|
|
"colab_type": "code",
|
|
"executionInfo": {
|
|
"elapsed": 4197,
|
|
"status": "ok",
|
|
"timestamp": 1593620604870,
|
|
"user": {
|
|
"displayName": "",
|
|
"photoUrl": "",
|
|
"userId": ""
|
|
},
|
|
"user_tz": -60
|
|
},
|
|
"id": "blvNCANKb22J",
|
|
"outputId": "af5ae073-9847-45cc-e51e-a803fc2148b0"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[Evaluation] Episode Length = 1000 | Episode Return = 129.717 | Episodes = 2 | Steps = 2000 | Steps Per Second = 1480.399\n",
|
|
"[Evaluation] Episode Length = 1000 | Episode Return = 34.790 | Episodes = 4 | Steps = 4000 | Steps Per Second = 1449.009\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Create a logger.\n",
|
|
"logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)\n",
|
|
"\n",
|
|
"# Create an environment loop.\n",
|
|
"loop = acme.EnvironmentLoop(\n",
|
|
" environment=environment,\n",
|
|
" actor=actors.FeedForwardActor(policy_network),\n",
|
|
" logger=logger)\n",
|
|
"\n",
|
|
"loop.run(5)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"last_runtime": {
|
|
"build_target": "",
|
|
"kind": "local"
|
|
},
|
|
"name": "RL Unplugged: Offline D4PG - DM control",
|
|
"provenance": [
|
|
{
|
|
"file_id": "1OerSIsTjv4d3rQCjAsi0ljPaLan87juJ",
|
|
"timestamp": 1593080049369
|
|
}
|
|
]
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|