mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
99976cfaa9
PiperOrigin-RevId: 334117027
515 lines
16 KiB
Plaintext
515 lines
16 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "KDiJzbb8KFvP"
|
|
},
|
|
"source": [
|
|
"Copyright 2020 DeepMind Technologies Limited.\n",
|
|
"\n",
|
|
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n",
|
|
"this file except in compliance with the License. You may obtain a copy of the\n",
|
|
"License at\n",
|
|
"\n",
|
|
"[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
|
|
"\n",
|
|
"Unless required by applicable law or agreed to in writing, software distributed\n",
|
|
"under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n",
|
|
"CONDITIONS OF ANY KIND, either express or implied. See the License for the\n",
|
|
"specific language governing permissions and limitations under the License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "zzJlIvx4tnrM"
|
|
},
|
|
"source": [
|
|
"# RL Unplugged: CRR agent with GPU/TPU support - DM control\n",
|
|
"\n",
|
|
"## Guide to training an Acme CRR agent on DM control data.\n",
|
|
"# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/rl_unplugged/dm_control_suite_crr.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "o1eig5zGEL4y"
|
|
},
|
|
"source": [
|
|
"## Installation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "WbpMoLbgEL41"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install git+https://github.com/deepmind/acme.git#egg=dm-acme\n",
|
|
"!pip install dm-acme[reverb]\n",
|
|
"!pip install dm-acme[tf]\n",
|
|
"!pip install dm-sonnet\n",
|
|
"!git clone https://github.com/deepmind/deepmind-research.git\n",
|
|
"%cd deepmind-research"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "04bMANoeEPM3"
|
|
},
|
|
"source": [
|
|
"### dm_control\n",
|
|
"\n",
|
|
"More detailed instructions in [this tutorial](https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb#scrollTo=YvyGCsgSCxHQ)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "VEEj3Qw60y73"
|
|
},
|
|
"source": [
|
|
"#### Institutional MuJoCo license."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "both",
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "IbZxYDxzoz5R"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#@title Edit and run\n",
|
|
"mjkey = \"\"\"\n",
|
|
"\n",
|
|
"REPLACE THIS LINE WITH YOUR MUJOCO LICENSE KEY\n",
|
|
"\n",
|
|
"\"\"\".strip()\n",
|
|
"\n",
|
|
"mujoco_dir = \"$HOME/.mujoco\"\n",
|
|
"\n",
|
|
"# Install OpenGL deps\n",
|
|
"!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
|
|
" libgl1-mesa-glx libosmesa6 libglew2.0\n",
|
|
"\n",
|
|
"# Fetch MuJoCo binaries from Roboti\n",
|
|
"!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
|
|
"!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
|
|
"\n",
|
|
"# Copy over MuJoCo license\n",
|
|
"!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
|
|
"\n",
|
|
"\n",
|
|
"# Configure dm_control to use the OSMesa rendering backend\n",
|
|
"%env MUJOCO_GL=osmesa\n",
|
|
"\n",
|
|
"# Install dm_control\n",
|
|
"!pip install dm_control"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "-_7tVg-zzjzW"
|
|
},
|
|
"source": [
|
|
"#### Machine-locked MuJoCo license."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "OvMLEDE-D9oF"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#@title Add your MuJoCo License and run\n",
|
|
"mjkey = \"\"\"\n",
|
|
"\"\"\".strip()\n",
|
|
"\n",
|
|
"mujoco_dir = \"$HOME/.mujoco\"\n",
|
|
"\n",
|
|
"# Install OpenGL dependencies\n",
|
|
"!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
|
|
" libgl1-mesa-glx libosmesa6 libglew2.0\n",
|
|
"\n",
|
|
"# Get MuJoCo binaries\n",
|
|
"!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
|
|
"!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
|
|
"\n",
|
|
"# Copy over MuJoCo license\n",
|
|
"!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
|
|
"\n",
|
|
"# Install dm_control\n",
|
|
"!pip install dm_control[locomotion_mazes]\n",
|
|
"\n",
|
|
"# Configure dm_control to use the OSMesa rendering backend\n",
|
|
"%env MUJOCO_GL=osmesa"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "IE2nV9Hivnv5"
|
|
},
|
|
"source": [
|
|
"## Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "RI7NgnJIvs4s"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import copy\n",
|
|
"from typing import Sequence\n",
|
|
"import acme\n",
|
|
"from acme import specs\n",
|
|
"from acme.agents.tf import actors\n",
|
|
"from acme.agents.tf import crr\n",
|
|
"from acme.tf import networks as acme_networks\n",
|
|
"from acme.tf import utils as tf2_utils\n",
|
|
"from acme.utils import loggers\n",
|
|
"import numpy as np\n",
|
|
"from rl_unplugged import dm_control_suite\n",
|
|
"from rl_unplugged import networks\n",
|
|
"import sonnet as snt\n",
|
|
"import tensorflow as tf"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "a2PCwF3bwBII"
|
|
},
|
|
"source": [
|
|
"## Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "both",
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "VaEJbXjampPy"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"task_name = 'cartpole_swingup' #@param\n",
|
|
"gs_path = 'gs://rl_unplugged/dm_control_suite'\n",
|
|
"\n",
|
|
"num_shards_str, = !gsutil ls {gs_path}/{task_name}/* | wc -l\n",
|
|
"num_shards = int(num_shards_str)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "mQ1as51Mww7X"
|
|
},
|
|
"source": [
|
|
"## Dataset and environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "5kHzJpfcw306"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_size = 256 #@param\n",
|
|
"\n",
|
|
"task = dm_control_suite.ControlSuite(task_name)\n",
|
|
"\n",
|
|
"environment = task.environment\n",
|
|
"environment_spec = specs.make_environment_spec(environment)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "T2wd9sHeGrD-"
|
|
},
|
|
"source": [
|
|
"## Networks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "Gc2Qsdi7GpCI"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def make_networks(\n",
|
|
" action_spec: specs.BoundedArray,\n",
|
|
" policy_lstm_sizes: Sequence[int] = None,\n",
|
|
" critic_lstm_sizes: Sequence[int] = None,\n",
|
|
" num_components: int = 5,\n",
|
|
" vmin: float = 0.,\n",
|
|
" vmax: float = 100.,\n",
|
|
" num_atoms: int = 21,\n",
|
|
"):\n",
|
|
" \"\"\"Creates recurrent networks with GMM head used by the agents.\"\"\"\n",
|
|
"\n",
|
|
" action_size = np.prod(action_spec.shape, dtype=int)\n",
|
|
" actor_head = acme_networks.MultivariateGaussianMixture(\n",
|
|
" num_components=num_components, num_dimensions=action_size)\n",
|
|
"\n",
|
|
" if policy_lstm_sizes is None:\n",
|
|
" policy_lstm_sizes = [1024, 1024]\n",
|
|
" if critic_lstm_sizes is None:\n",
|
|
" critic_lstm_sizes = [1024, 1024]\n",
|
|
"\n",
|
|
" actor_neck = acme_networks.LayerNormAndResidualMLP(hidden_size=1024,\n",
|
|
" num_blocks=4)\n",
|
|
" actor_encoder = networks.ControlNetwork(\n",
|
|
" proprio_encoder_size=300,\n",
|
|
" activation=tf.nn.relu)\n",
|
|
"\n",
|
|
" policy_lstms = [snt.LSTM(s) for s in policy_lstm_sizes]\n",
|
|
"\n",
|
|
" policy_network = snt.DeepRNN([actor_encoder, actor_neck] + policy_lstms +\n",
|
|
" [actor_head])\n",
|
|
"\n",
|
|
" critic_encoder = networks.ControlNetwork(\n",
|
|
" proprio_encoder_size=400,\n",
|
|
" activation=tf.nn.relu)\n",
|
|
" critic_neck = acme_networks.LayerNormAndResidualMLP(\n",
|
|
" hidden_size=1024, num_blocks=4)\n",
|
|
" distributional_head = acme_networks.DiscreteValuedHead(\n",
|
|
" vmin=vmin, vmax=vmax, num_atoms=num_atoms)\n",
|
|
" critic_lstms = [snt.LSTM(s) for s in critic_lstm_sizes]\n",
|
|
" critic_network = acme_networks.CriticDeepRNN([critic_encoder, critic_neck] +\n",
|
|
" critic_lstms + [\n",
|
|
" distributional_head,\n",
|
|
" ])\n",
|
|
"\n",
|
|
" return {\n",
|
|
" 'policy': policy_network,\n",
|
|
" 'critic': critic_network,\n",
|
|
" }"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "zL1fAYN8GvCf"
|
|
},
|
|
"source": [
|
|
"## Set up TPU if present"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "V3fQxmeiGtrQ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"try:\n",
|
|
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection\n",
|
|
" tf.config.experimental_connect_to_cluster(tpu)\n",
|
|
" tf.tpu.experimental.initialize_tpu_system(tpu)\n",
|
|
" accelerator_strategy = snt.distribute.TpuReplicator()\n",
|
|
" print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])\n",
|
|
"except ValueError:\n",
|
|
" print('Running on CPU or GPU (no TPUs available)')\n",
|
|
" accelerator_strategy = snt.distribute.Replicator()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "adb0cyE5qu9G"
|
|
},
|
|
"source": [
|
|
"## CRR learner"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "83naOY7a_A4I"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"action_spec = environment_spec.actions\n",
|
|
"action_size = np.prod(action_spec.shape, dtype=int)\n",
|
|
"\n",
|
|
"with accelerator_strategy.scope():\n",
|
|
" dataset = dm_control_suite.dataset(\n",
|
|
" 'gs://rl_unplugged/',\n",
|
|
" data_path=task.data_path,\n",
|
|
" shapes=task.shapes,\n",
|
|
" uint8_features=task.uint8_features,\n",
|
|
" num_threads=1,\n",
|
|
" batch_size=batch_size,\n",
|
|
" num_shards=num_shards,\n",
|
|
" sarsa=False)\n",
|
|
" # CRR learner assumes that the dataset samples don't have metadata,\n",
|
|
" # so let's remove it here.\n",
|
|
" dataset = dataset.map(lambda sample: sample.data)\n",
|
|
" nets = make_networks(action_spec)\n",
|
|
" policy_network, critic_network = nets['policy'], nets['critic']\n",
|
|
"\n",
|
|
" # Create the target networks\n",
|
|
" target_policy_network = copy.deepcopy(policy_network)\n",
|
|
" target_critic_network = copy.deepcopy(critic_network)\n",
|
|
"\n",
|
|
" # Create variables.\n",
|
|
" tf2_utils.create_variables(network=policy_network,\n",
|
|
" input_spec=[environment_spec.observations])\n",
|
|
" tf2_utils.create_variables(network=critic_network,\n",
|
|
" input_spec=[environment_spec.observations,\n",
|
|
" environment_spec.actions])\n",
|
|
" tf2_utils.create_variables(network=target_policy_network,\n",
|
|
" input_spec=[environment_spec.observations])\n",
|
|
" tf2_utils.create_variables(network=target_critic_network,\n",
|
|
" input_spec=[environment_spec.observations,\n",
|
|
" environment_spec.actions])\n",
|
|
"\n",
|
|
"# The learner updates the parameters (and initializes them).\n",
|
|
"learner = crr.RCRRLearner(\n",
|
|
" policy_network=policy_network,\n",
|
|
" critic_network=critic_network,\n",
|
|
" accelerator_strategy=accelerator_strategy,\n",
|
|
" target_policy_network=target_policy_network,\n",
|
|
" target_critic_network=target_critic_network,\n",
|
|
" dataset=dataset,\n",
|
|
" discount=0.99,\n",
|
|
" target_update_period=100)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "PYkjKaduy_xj"
|
|
},
|
|
"source": [
|
|
"## Training loop"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "HbQOyCG4zCwa"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Run\n",
|
|
"# tf.config.run_functions_eagerly(True)\n",
|
|
"# if you want to debug the code in eager mode.\n",
|
|
"\n",
|
|
"for _ in range(100):\n",
|
|
" learner.step()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "LJ_XsuQSzFSV"
|
|
},
|
|
"source": [
|
|
"## Evaluation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "blvNCANKb22J"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a logger.\n",
|
|
"logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)\n",
|
|
"\n",
|
|
"# Create an environment loop.\n",
|
|
"loop = acme.EnvironmentLoop(\n",
|
|
" environment=environment,\n",
|
|
" actor=actors.RecurrentActor(policy_network),\n",
|
|
" logger=logger)\n",
|
|
"\n",
|
|
"loop.run(5)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"last_runtime": {
|
|
"build_target": "",
|
|
"kind": "local"
|
|
},
|
|
"name": "RL Unplugged: Offline CRR - DM control",
|
|
"provenance": [
|
|
{
|
|
"file_id": "1OerSIsTjv4d3rQCjAsi0ljPaLan87juJ",
|
|
"timestamp": 1593080049369
|
|
}
|
|
]
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|