Files
deepmind-research/rl_unplugged/dm_control_suite_crr.ipynb
T
Alexander Novikov 99976cfaa9 An example of training CRR agent with TPU support
PiperOrigin-RevId: 334117027
2020-10-12 16:03:27 +01:00

515 lines
16 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "KDiJzbb8KFvP"
},
"source": [
"Copyright 2020 DeepMind Technologies Limited.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n",
"this file except in compliance with the License. You may obtain a copy of the\n",
"License at\n",
"\n",
"[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
"\n",
"Unless required by applicable law or agreed to in writing, software distributed\n",
"under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n",
"CONDITIONS OF ANY KIND, either express or implied. See the License for the\n",
"specific language governing permissions and limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "zzJlIvx4tnrM"
},
"source": [
"# RL Unplugged: CRR agent with GPU/TPU support - DM control\n",
"\n",
"## Guide to training an Acme CRR agent on DM control data.\n",
"# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/rl_unplugged/dm_control_suite_crr.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "o1eig5zGEL4y"
},
"source": [
"## Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "WbpMoLbgEL41"
},
"outputs": [],
"source": [
"!pip install git+https://github.com/deepmind/acme.git#egg=dm-acme\n",
"!pip install dm-acme[reverb]\n",
"!pip install dm-acme[tf]\n",
"!pip install dm-sonnet\n",
"!git clone https://github.com/deepmind/deepmind-research.git\n",
"%cd deepmind-research"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "04bMANoeEPM3"
},
"source": [
"### dm_control\n",
"\n",
"More detailed instructions in [this tutorial](https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb#scrollTo=YvyGCsgSCxHQ)."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "VEEj3Qw60y73"
},
"source": [
"#### Institutional MuJoCo license."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "both",
"colab": {},
"colab_type": "code",
"id": "IbZxYDxzoz5R"
},
"outputs": [],
"source": [
"#@title Edit and run\n",
"mjkey = \"\"\"\n",
"\n",
"REPLACE THIS LINE WITH YOUR MUJOCO LICENSE KEY\n",
"\n",
"\"\"\".strip()\n",
"\n",
"mujoco_dir = \"$HOME/.mujoco\"\n",
"\n",
"# Install OpenGL deps\n",
"!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
" libgl1-mesa-glx libosmesa6 libglew2.0\n",
"\n",
"# Fetch MuJoCo binaries from Roboti\n",
"!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
"!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
"\n",
"# Copy over MuJoCo license\n",
"!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
"\n",
"\n",
"# Configure dm_control to use the OSMesa rendering backend\n",
"%env MUJOCO_GL=osmesa\n",
"\n",
"# Install dm_control\n",
"!pip install dm_control"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "-_7tVg-zzjzW"
},
"source": [
"#### Machine-locked MuJoCo license."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "OvMLEDE-D9oF"
},
"outputs": [],
"source": [
"#@title Add your MuJoCo License and run\n",
"mjkey = \"\"\"\n",
"\"\"\".strip()\n",
"\n",
"mujoco_dir = \"$HOME/.mujoco\"\n",
"\n",
"# Install OpenGL dependencies\n",
"!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
" libgl1-mesa-glx libosmesa6 libglew2.0\n",
"\n",
"# Get MuJoCo binaries\n",
"!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
"!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
"\n",
"# Copy over MuJoCo license\n",
"!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
"\n",
"# Install dm_control\n",
"!pip install dm_control[locomotion_mazes]\n",
"\n",
"# Configure dm_control to use the OSMesa rendering backend\n",
"%env MUJOCO_GL=osmesa"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "IE2nV9Hivnv5"
},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "RI7NgnJIvs4s"
},
"outputs": [],
"source": [
"import copy\n",
"from typing import Sequence\n",
"import acme\n",
"from acme import specs\n",
"from acme.agents.tf import actors\n",
"from acme.agents.tf import crr\n",
"from acme.tf import networks as acme_networks\n",
"from acme.tf import utils as tf2_utils\n",
"from acme.utils import loggers\n",
"import numpy as np\n",
"from rl_unplugged import dm_control_suite\n",
"from rl_unplugged import networks\n",
"import sonnet as snt\n",
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "a2PCwF3bwBII"
},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "both",
"colab": {},
"colab_type": "code",
"id": "VaEJbXjampPy"
},
"outputs": [],
"source": [
"task_name = 'cartpole_swingup' #@param\n",
"gs_path = 'gs://rl_unplugged/dm_control_suite'\n",
"\n",
"num_shards_str, = !gsutil ls {gs_path}/{task_name}/* | wc -l\n",
"num_shards = int(num_shards_str)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mQ1as51Mww7X"
},
"source": [
"## Dataset and environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5kHzJpfcw306"
},
"outputs": [],
"source": [
"batch_size = 256 #@param\n",
"\n",
"task = dm_control_suite.ControlSuite(task_name)\n",
"\n",
"environment = task.environment\n",
"environment_spec = specs.make_environment_spec(environment)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "T2wd9sHeGrD-"
},
"source": [
"## Networks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Gc2Qsdi7GpCI"
},
"outputs": [],
"source": [
"def make_networks(\n",
" action_spec: specs.BoundedArray,\n",
" policy_lstm_sizes: Sequence[int] = None,\n",
" critic_lstm_sizes: Sequence[int] = None,\n",
" num_components: int = 5,\n",
" vmin: float = 0.,\n",
" vmax: float = 100.,\n",
" num_atoms: int = 21,\n",
"):\n",
" \"\"\"Creates recurrent networks with GMM head used by the agents.\"\"\"\n",
"\n",
" action_size = np.prod(action_spec.shape, dtype=int)\n",
" actor_head = acme_networks.MultivariateGaussianMixture(\n",
" num_components=num_components, num_dimensions=action_size)\n",
"\n",
" if policy_lstm_sizes is None:\n",
" policy_lstm_sizes = [1024, 1024]\n",
" if critic_lstm_sizes is None:\n",
" critic_lstm_sizes = [1024, 1024]\n",
"\n",
" actor_neck = acme_networks.LayerNormAndResidualMLP(hidden_size=1024,\n",
" num_blocks=4)\n",
" actor_encoder = networks.ControlNetwork(\n",
" proprio_encoder_size=300,\n",
" activation=tf.nn.relu)\n",
"\n",
" policy_lstms = [snt.LSTM(s) for s in policy_lstm_sizes]\n",
"\n",
" policy_network = snt.DeepRNN([actor_encoder, actor_neck] + policy_lstms +\n",
" [actor_head])\n",
"\n",
" critic_encoder = networks.ControlNetwork(\n",
" proprio_encoder_size=400,\n",
" activation=tf.nn.relu)\n",
" critic_neck = acme_networks.LayerNormAndResidualMLP(\n",
" hidden_size=1024, num_blocks=4)\n",
" distributional_head = acme_networks.DiscreteValuedHead(\n",
" vmin=vmin, vmax=vmax, num_atoms=num_atoms)\n",
" critic_lstms = [snt.LSTM(s) for s in critic_lstm_sizes]\n",
" critic_network = acme_networks.CriticDeepRNN([critic_encoder, critic_neck] +\n",
" critic_lstms + [\n",
" distributional_head,\n",
" ])\n",
"\n",
" return {\n",
" 'policy': policy_network,\n",
" 'critic': critic_network,\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "zL1fAYN8GvCf"
},
"source": [
"## Set up TPU if present"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "V3fQxmeiGtrQ"
},
"outputs": [],
"source": [
"try:\n",
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection\n",
" tf.config.experimental_connect_to_cluster(tpu)\n",
" tf.tpu.experimental.initialize_tpu_system(tpu)\n",
" accelerator_strategy = snt.distribute.TpuReplicator()\n",
" print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])\n",
"except ValueError:\n",
" print('Running on CPU or GPU (no TPUs available)')\n",
" accelerator_strategy = snt.distribute.Replicator()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "adb0cyE5qu9G"
},
"source": [
"## CRR learner"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "83naOY7a_A4I"
},
"outputs": [],
"source": [
"action_spec = environment_spec.actions\n",
"action_size = np.prod(action_spec.shape, dtype=int)\n",
"\n",
"with accelerator_strategy.scope():\n",
" dataset = dm_control_suite.dataset(\n",
" 'gs://rl_unplugged/',\n",
" data_path=task.data_path,\n",
" shapes=task.shapes,\n",
" uint8_features=task.uint8_features,\n",
" num_threads=1,\n",
" batch_size=batch_size,\n",
" num_shards=num_shards,\n",
" sarsa=False)\n",
" # CRR learner assumes that the dataset samples don't have metadata,\n",
" # so let's remove it here.\n",
" dataset = dataset.map(lambda sample: sample.data)\n",
" nets = make_networks(action_spec)\n",
" policy_network, critic_network = nets['policy'], nets['critic']\n",
"\n",
" # Create the target networks\n",
" target_policy_network = copy.deepcopy(policy_network)\n",
" target_critic_network = copy.deepcopy(critic_network)\n",
"\n",
" # Create variables.\n",
" tf2_utils.create_variables(network=policy_network,\n",
" input_spec=[environment_spec.observations])\n",
" tf2_utils.create_variables(network=critic_network,\n",
" input_spec=[environment_spec.observations,\n",
" environment_spec.actions])\n",
" tf2_utils.create_variables(network=target_policy_network,\n",
" input_spec=[environment_spec.observations])\n",
" tf2_utils.create_variables(network=target_critic_network,\n",
" input_spec=[environment_spec.observations,\n",
" environment_spec.actions])\n",
"\n",
"# The learner updates the parameters (and initializes them).\n",
"learner = crr.RCRRLearner(\n",
" policy_network=policy_network,\n",
" critic_network=critic_network,\n",
" accelerator_strategy=accelerator_strategy,\n",
" target_policy_network=target_policy_network,\n",
" target_critic_network=target_critic_network,\n",
" dataset=dataset,\n",
" discount=0.99,\n",
" target_update_period=100)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PYkjKaduy_xj"
},
"source": [
"## Training loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "HbQOyCG4zCwa"
},
"outputs": [],
"source": [
"# Run\n",
"# tf.config.run_functions_eagerly(True)\n",
"# if you want to debug the code in eager mode.\n",
"\n",
"for _ in range(100):\n",
" learner.step()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "LJ_XsuQSzFSV"
},
"source": [
"## Evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "blvNCANKb22J"
},
"outputs": [],
"source": [
"# Create a logger.\n",
"logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)\n",
"\n",
"# Create an environment loop.\n",
"loop = acme.EnvironmentLoop(\n",
" environment=environment,\n",
" actor=actors.RecurrentActor(policy_network),\n",
" logger=logger)\n",
"\n",
"loop.run(5)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"last_runtime": {
"build_target": "",
"kind": "local"
},
"name": "RL Unplugged: Offline CRR - DM control",
"provenance": [
{
"file_id": "1OerSIsTjv4d3rQCjAsi0ljPaLan87juJ",
"timestamp": 1593080049369
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}