diff --git a/rl_unplugged/dm_control_suite.py b/rl_unplugged/dm_control_suite.py index 5360e03..59bb0fe 100644 --- a/rl_unplugged/dm_control_suite.py +++ b/rl_unplugged/dm_control_suite.py @@ -33,6 +33,7 @@ import os from typing import Dict, Optional, Tuple, Set from acme import wrappers +from acme.adders import reverb as adders from dm_control import composer from dm_control import suite from dm_control.composer.variation import colors @@ -726,16 +727,19 @@ def _parse_seq_tf_example(example, uint8_features, shapes): def _build_sequence_example(sequences): """Convert raw sequences into a Reverb sequence sample.""" - o = sequences['observation'] - a = sequences['action'] - r = sequences['reward'] - p = sequences['discount'] + data = adders.Step( + observation=sequences['observation'], + action=sequences['action'], + reward=sequences['reward'], + discount=sequences['discount'], + start_of_episode=(), + extras=()) info = reverb.SampleInfo(key=tf.constant(0, tf.uint64), probability=tf.constant(1.0, tf.float64), table_size=tf.constant(0, tf.int64), priority=tf.constant(1.0, tf.float64)) - return reverb.ReplaySample(info=info, data=(o, a, r, p)) + return reverb.ReplaySample(info=info, data=data) def _build_sarsa_example(sequences): diff --git a/rl_unplugged/dm_control_suite_crr.ipynb b/rl_unplugged/dm_control_suite_crr.ipynb new file mode 100644 index 0000000..d44e01a --- /dev/null +++ b/rl_unplugged/dm_control_suite_crr.ipynb @@ -0,0 +1,514 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "KDiJzbb8KFvP" + }, + "source": [ + "Copyright 2020 DeepMind Technologies Limited.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n", + "this file except in compliance with the License. You may obtain a copy of the\n", + "License at\n", + "\n", + "[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n", + "\n", + "Unless required by applicable law or agreed to in writing, software distributed\n", + "under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n", + "CONDITIONS OF ANY KIND, either express or implied. See the License for the\n", + "specific language governing permissions and limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zzJlIvx4tnrM" + }, + "source": [ + "# RL Unplugged: CRR agent with GPU/TPU support - DM control\n", + "\n", + "## Guide to training an Acme CRR agent on DM control data.\n", + "# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/rl_unplugged/dm_control_suite_crr.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "o1eig5zGEL4y" + }, + "source": [ + "## Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WbpMoLbgEL41" + }, + "outputs": [], + "source": [ + "!pip install git+https://github.com/deepmind/acme.git#egg=dm-acme\n", + "!pip install dm-acme[reverb]\n", + "!pip install dm-acme[tf]\n", + "!pip install dm-sonnet\n", + "!git clone https://github.com/deepmind/deepmind-research.git\n", + "%cd deepmind-research" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "04bMANoeEPM3" + }, + "source": [ + "### dm_control\n", + "\n", + "More detailed instructions in [this tutorial](https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb#scrollTo=YvyGCsgSCxHQ)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "VEEj3Qw60y73" + }, + "source": [ + "#### Institutional MuJoCo license." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": {}, + "colab_type": "code", + "id": "IbZxYDxzoz5R" + }, + "outputs": [], + "source": [ + "#@title Edit and run\n", + "mjkey = \"\"\"\n", + "\n", + "REPLACE THIS LINE WITH YOUR MUJOCO LICENSE KEY\n", + "\n", + "\"\"\".strip()\n", + "\n", + "mujoco_dir = \"$HOME/.mujoco\"\n", + "\n", + "# Install OpenGL deps\n", + "!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n", + " libgl1-mesa-glx libosmesa6 libglew2.0\n", + "\n", + "# Fetch MuJoCo binaries from Roboti\n", + "!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n", + "!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n", + "\n", + "# Copy over MuJoCo license\n", + "!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n", + "\n", + "\n", + "# Configure dm_control to use the OSMesa rendering backend\n", + "%env MUJOCO_GL=osmesa\n", + "\n", + "# Install dm_control\n", + "!pip install dm_control" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-_7tVg-zzjzW" + }, + "source": [ + "#### Machine-locked MuJoCo license." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "OvMLEDE-D9oF" + }, + "outputs": [], + "source": [ + "#@title Add your MuJoCo License and run\n", + "mjkey = \"\"\"\n", + "\"\"\".strip()\n", + "\n", + "mujoco_dir = \"$HOME/.mujoco\"\n", + "\n", + "# Install OpenGL dependencies\n", + "!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n", + " libgl1-mesa-glx libosmesa6 libglew2.0\n", + "\n", + "# Get MuJoCo binaries\n", + "!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n", + "!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n", + "\n", + "# Copy over MuJoCo license\n", + "!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n", + "\n", + "# Install dm_control\n", + "!pip install dm_control[locomotion_mazes]\n", + "\n", + "# Configure dm_control to use the OSMesa rendering backend\n", + "%env MUJOCO_GL=osmesa" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "IE2nV9Hivnv5" + }, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "RI7NgnJIvs4s" + }, + "outputs": [], + "source": [ + "import copy\n", + "from typing import Sequence\n", + "import acme\n", + "from acme import specs\n", + "from acme.agents.tf import actors\n", + "from acme.agents.tf import crr\n", + "from acme.tf import networks as acme_networks\n", + "from acme.tf import utils as tf2_utils\n", + "from acme.utils import loggers\n", + "import numpy as np\n", + "from rl_unplugged import dm_control_suite\n", + "from rl_unplugged import networks\n", + "import sonnet as snt\n", + "import tensorflow as tf" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "a2PCwF3bwBII" + }, + "source": [ + "## Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": {}, + "colab_type": "code", + "id": "VaEJbXjampPy" + }, + "outputs": [], + "source": [ + "task_name = 'cartpole_swingup' #@param\n", + "gs_path = 'gs://rl_unplugged/dm_control_suite'\n", + "\n", + "num_shards_str, = !gsutil ls {gs_path}/{task_name}/* | wc -l\n", + "num_shards = int(num_shards_str)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "mQ1as51Mww7X" + }, + "source": [ + "## Dataset and environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "5kHzJpfcw306" + }, + "outputs": [], + "source": [ + "batch_size = 256 #@param\n", + "\n", + "task = dm_control_suite.ControlSuite(task_name)\n", + "\n", + "environment = task.environment\n", + "environment_spec = specs.make_environment_spec(environment)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "T2wd9sHeGrD-" + }, + "source": [ + "## Networks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Gc2Qsdi7GpCI" + }, + "outputs": [], + "source": [ + "def make_networks(\n", + " action_spec: specs.BoundedArray,\n", + " policy_lstm_sizes: Sequence[int] = None,\n", + " critic_lstm_sizes: Sequence[int] = None,\n", + " num_components: int = 5,\n", + " vmin: float = 0.,\n", + " vmax: float = 100.,\n", + " num_atoms: int = 21,\n", + "):\n", + " \"\"\"Creates recurrent networks with GMM head used by the agents.\"\"\"\n", + "\n", + " action_size = np.prod(action_spec.shape, dtype=int)\n", + " actor_head = acme_networks.MultivariateGaussianMixture(\n", + " num_components=num_components, num_dimensions=action_size)\n", + "\n", + " if policy_lstm_sizes is None:\n", + " policy_lstm_sizes = [1024, 1024]\n", + " if critic_lstm_sizes is None:\n", + " critic_lstm_sizes = [1024, 1024]\n", + "\n", + " actor_neck = acme_networks.LayerNormAndResidualMLP(hidden_size=1024,\n", + " num_blocks=4)\n", + " actor_encoder = networks.ControlNetwork(\n", + " proprio_encoder_size=300,\n", + " activation=tf.nn.relu)\n", + "\n", + " policy_lstms = [snt.LSTM(s) for s in policy_lstm_sizes]\n", + "\n", + " policy_network = snt.DeepRNN([actor_encoder, actor_neck] + policy_lstms +\n", + " [actor_head])\n", + "\n", + " critic_encoder = networks.ControlNetwork(\n", + " proprio_encoder_size=400,\n", + " activation=tf.nn.relu)\n", + " critic_neck = acme_networks.LayerNormAndResidualMLP(\n", + " hidden_size=1024, num_blocks=4)\n", + " distributional_head = acme_networks.DiscreteValuedHead(\n", + " vmin=vmin, vmax=vmax, num_atoms=num_atoms)\n", + " critic_lstms = [snt.LSTM(s) for s in critic_lstm_sizes]\n", + " critic_network = acme_networks.CriticDeepRNN([critic_encoder, critic_neck] +\n", + " critic_lstms + [\n", + " distributional_head,\n", + " ])\n", + "\n", + " return {\n", + " 'policy': policy_network,\n", + " 'critic': critic_network,\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zL1fAYN8GvCf" + }, + "source": [ + "## Set up TPU if present" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "V3fQxmeiGtrQ" + }, + "outputs": [], + "source": [ + "try:\n", + " tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection\n", + " tf.config.experimental_connect_to_cluster(tpu)\n", + " tf.tpu.experimental.initialize_tpu_system(tpu)\n", + " accelerator_strategy = snt.distribute.TpuReplicator()\n", + " print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])\n", + "except ValueError:\n", + " print('Running on CPU or GPU (no TPUs available)')\n", + " accelerator_strategy = snt.distribute.Replicator()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "adb0cyE5qu9G" + }, + "source": [ + "## CRR learner" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "83naOY7a_A4I" + }, + "outputs": [], + "source": [ + "action_spec = environment_spec.actions\n", + "action_size = np.prod(action_spec.shape, dtype=int)\n", + "\n", + "with accelerator_strategy.scope():\n", + " dataset = dm_control_suite.dataset(\n", + " 'gs://rl_unplugged/',\n", + " data_path=task.data_path,\n", + " shapes=task.shapes,\n", + " uint8_features=task.uint8_features,\n", + " num_threads=1,\n", + " batch_size=batch_size,\n", + " num_shards=num_shards,\n", + " sarsa=False)\n", + " # CRR learner assumes that the dataset samples don't have metadata,\n", + " # so let's remove it here.\n", + " dataset = dataset.map(lambda sample: sample.data)\n", + " nets = make_networks(action_spec)\n", + " policy_network, critic_network = nets['policy'], nets['critic']\n", + "\n", + " # Create the target networks\n", + " target_policy_network = copy.deepcopy(policy_network)\n", + " target_critic_network = copy.deepcopy(critic_network)\n", + "\n", + " # Create variables.\n", + " tf2_utils.create_variables(network=policy_network,\n", + " input_spec=[environment_spec.observations])\n", + " tf2_utils.create_variables(network=critic_network,\n", + " input_spec=[environment_spec.observations,\n", + " environment_spec.actions])\n", + " tf2_utils.create_variables(network=target_policy_network,\n", + " input_spec=[environment_spec.observations])\n", + " tf2_utils.create_variables(network=target_critic_network,\n", + " input_spec=[environment_spec.observations,\n", + " environment_spec.actions])\n", + "\n", + "# The learner updates the parameters (and initializes them).\n", + "learner = crr.RCRRLearner(\n", + " policy_network=policy_network,\n", + " critic_network=critic_network,\n", + " accelerator_strategy=accelerator_strategy,\n", + " target_policy_network=target_policy_network,\n", + " target_critic_network=target_critic_network,\n", + " dataset=dataset,\n", + " discount=0.99,\n", + " target_update_period=100)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PYkjKaduy_xj" + }, + "source": [ + "## Training loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "HbQOyCG4zCwa" + }, + "outputs": [], + "source": [ + "# Run\n", + "# tf.config.run_functions_eagerly(True)\n", + "# if you want to debug the code in eager mode.\n", + "\n", + "for _ in range(100):\n", + " learner.step()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LJ_XsuQSzFSV" + }, + "source": [ + "## Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "blvNCANKb22J" + }, + "outputs": [], + "source": [ + "# Create a logger.\n", + "logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)\n", + "\n", + "# Create an environment loop.\n", + "loop = acme.EnvironmentLoop(\n", + " environment=environment,\n", + " actor=actors.RecurrentActor(policy_network),\n", + " logger=logger)\n", + "\n", + "loop.run(5)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "name": "RL Unplugged: Offline CRR - DM control", + "provenance": [ + { + "file_id": "1OerSIsTjv4d3rQCjAsi0ljPaLan87juJ", + "timestamp": 1593080049369 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/rl_unplugged/networks.py b/rl_unplugged/networks.py new file mode 100644 index 0000000..10d1263 --- /dev/null +++ b/rl_unplugged/networks.py @@ -0,0 +1,93 @@ +# Lint as: python3 +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Networks used for training agents. +""" + +from acme.tf import networks as acme_networks +from acme.tf import utils as tf2_utils +import numpy as np +import sonnet as snt +import tensorflow as tf + + +def instance_norm_and_elu(x): + mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True) + x_ = x - mean + var = tf.reduce_mean(x_**2, axis=[1, 2], keepdims=True) + x_norm = x_ / (var + 1e-6) + return tf.nn.elu(x_norm) + + +class ControlNetwork(snt.Module): + """Image, proprio and optionally action encoder used for actors and critics. + """ + + def __init__(self, + proprio_encoder_size: int, + proprio_keys=None, + activation=tf.nn.elu): + """Creates a ControlNetwork. + + Args: + proprio_encoder_size: Size of the linear layer for the proprio encoder. + proprio_keys: Optional list of names of proprioceptive observations. + Defaults to all observations. Note that if this is specified, any + observation not contained in proprio_keys will be ignored by the agent. + activation: Linear layer activation function. + """ + super().__init__(name='control_network') + self._activation = activation + self._proprio_keys = proprio_keys + + self._proprio_encoder = acme_networks.LayerNormMLP([proprio_encoder_size]) + + def __call__(self, inputs, action: tf.Tensor = None, task=None): + """Evaluates the ControlNetwork. + + Args: + inputs: A dictionary of agent observation tensors. + action: Agent actions. + task: Optional encoding of the task. + + Raises: + ValueError: if neither proprio_input is provided. + ValueError: if some proprio input looks suspiciously like pixel inputs. + + Returns: + Processed network output. + """ + if not isinstance(inputs, dict): + inputs = {'inputs': inputs} + + proprio_input = [] + # By default, treat all observations as proprioceptive. + if self._proprio_keys is None: + self._proprio_keys = list(sorted(inputs.keys())) + for key in self._proprio_keys: + proprio_input.append(snt.Flatten()(inputs[key])) + if np.prod(inputs[key].shape[1:]) > 32*32*3: + raise ValueError( + 'This input does not resemble a proprioceptive ' + 'state: {} with shape {}'.format( + key, inputs[key].shape)) + + # Append optional action input (i.e. for critic networks). + if action is not None: + proprio_input.append(action) + + proprio_input = tf2_utils.batch_concat(proprio_input) + proprio_state = self._proprio_encoder(proprio_input) + + return proprio_state