Files
deepmind-research/affordances_theory/AffordancesInContinuousEnvironment.ipynb

1747 lines
140 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ya9j9pyzkyBZ"
},
"source": [
"Copyright 2020 The \"What Can I do Here? A Theory of Affordances In Reinforcement Learning\" Authors. All rights reserved.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "LbWb35G9UHLO",
"outputId": "280cac1e-76e0-4960-a271-24d351f249bc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"%tensorflow_version 2.x\n",
"%pylab inline\n",
"\n",
"# System imports\n",
"import copy\n",
"import dataclasses\n",
"import enum\n",
"import itertools\n",
"import numpy as np\n",
"import operator\n",
"import random\n",
"import time\n",
"from typing import Optional, List, Tuple, Any, Dict, Union, Callable\n",
"\n",
"\n",
"# Library imports.\n",
"from google.colab import files\n",
"from matplotlib import colors\n",
"import matplotlib.animation as animation\n",
"import matplotlib.pylab as plt\n",
"import tensorflow as tf\n",
"\n",
"import tensorflow_probability as tfp"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "UZV6OS_BUklD"
},
"source": [
"# Environment Specification"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "Mz4KtBOVUpOV"
},
"outputs": [],
"source": [
"#@title Point Class\n",
"@dataclasses.dataclass(order=True, frozen=True)\n",
"class Point:\n",
" \"\"\"A class representing a point in 2D space.\n",
"\n",
" Comes with some convenience functions.\n",
" \"\"\"\n",
" x: float\n",
" y: float\n",
"\n",
" def sum(self):\n",
" return self.x + self.y\n",
"\n",
" def l2norm(self):\n",
" \"\"\"Computes the L2 norm of the point.\"\"\"\n",
" return np.sqrt(self.x * self.x + self.y * self.y)\n",
"\n",
" def __add__(self, other: 'Point'):\n",
" return Point(self.x + other.x, self.y + other.y)\n",
"\n",
" def __sub__(self, other: 'Point'):\n",
" return Point(self.x - other.x, self.y - other.y)\n",
"\n",
" def normal_sample_around(self, scale: float):\n",
" \"\"\"Samples a point around the current point based on some noise.\"\"\"\n",
" new_coords = np.random.normal(dataclasses.astuple(self), scale)\n",
" new_coords = new_coords.astype(np.float32)\n",
" return Point(*new_coords)\n",
"\n",
" def is_close_to(self, other: 'Point', diff: float = 1e-4):\n",
" \"\"\"Determines if one point is close to another.\"\"\"\n",
" point_diff = self - other\n",
" if abs(point_diff.x) \u003c diff and abs(point_diff.y) \u003c diff:\n",
" return True\n",
" else:\n",
" return False\n",
"\n",
"# Test the points.\n",
"z1 = Point(0.4, 0.1)\n",
"assert z1.is_close_to(z1)\n",
"assert z1.is_close_to(Point(0.5, 0.0), 1.0)\n",
"assert not z1.is_close_to(Point(5.0, 0.0), 1.0)\n",
"z2 = Point(0.1, 0.1)\n",
"z3 = z1 - z2\n",
"assert isinstance(z3, Point)\n",
"assert z3.is_close_to(Point(0.3, 0.0))\n",
"assert isinstance(z3.normal_sample_around(0.1), Point)\n",
"\n",
"class Force(Point):\n",
" pass\n",
"\n",
"\n",
"# # Intersection code.\n",
"# See Sedgewick, Robert, and Kevin Wayne. Algorithms. , 2011.\n",
"# Chapter 6.1 on Geometric Primitives\n",
"# https://algs4.cs.princeton.edu/91primitives/\n",
"def _check_counter_clockwise(a: Point, b: Point, c: Point):\n",
" \"\"\"Checks if 3 points are counter clockwise to each other.\"\"\"\n",
" slope_AB_numerator = (b.y - a.y)\n",
" slope_AB_denominator = (b.x - a.x)\n",
" slope_AC_numerator = (c.y - a.y)\n",
" slope_AC_denominator = (c.x - a.x)\n",
" return (slope_AC_numerator * slope_AB_denominator \u003e= \\\n",
" slope_AB_numerator * slope_AC_denominator)\n",
"\n",
"def intersect(segment_1: Tuple[Point, Point], segment_2: Tuple[Point, Point]):\n",
" \"\"\"Checks if two line segments intersect.\"\"\"\n",
" a, b = segment_1\n",
" c, d = segment_2\n",
"\n",
" # Checking if there is an intersection is equivalent to:\n",
" # Exactly one counter clockwise path to D (from A or B) via C.\n",
" AC_ccw_CD = _check_counter_clockwise(a, c, d)\n",
" BC_ccw_CD = _check_counter_clockwise(b, c, d)\n",
" toD_via_C = AC_ccw_CD != BC_ccw_CD\n",
"\n",
" # AND\n",
" # Exactly one counterclockwise path from A (to C or D) via B.\n",
" AB_ccw_BC = _check_counter_clockwise(a, b, c)\n",
" AB_ccw_BD = _check_counter_clockwise(a, b, d)\n",
"\n",
" fromA_via_B = AB_ccw_BC != AB_ccw_BD\n",
"\n",
" return toD_via_C and fromA_via_B\n",
"\n",
"# Some simple tests to ensure everything is working.\n",
"assert not intersect((Point(1, 0), Point(1, 1)), (Point(0,0), Point(0, 1))), \\\n",
" 'Parallel lines detected as intersecting.'\n",
"assert not intersect((Point(0, 0), Point(1, 0)), (Point(0,1), Point(1, 1))), \\\n",
" 'Parallel lines detected as intersecting.'\n",
"assert intersect((Point(3, 5), Point(1, 1)), (Point(2, 2), Point(0, 1))), \\\n",
" 'Lines that intersect not detected.'\n",
"assert not intersect((Point(0, 0), Point(2, 2)), (Point(3, 3), Point(5, 1))), \\\n",
" 'Lines that do not intersect detected as intersecting'\n",
"assert intersect((Point(0, .5), Point(0, -.5)), (Point(.5, 0), Point(-.5, 0.))), \\\n",
" 'Lines that intersect not detected.'"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "IaC8khoBVZ2a"
},
"outputs": [],
"source": [
"#@title ContinuousWorld environment.\n",
"\n",
"class ContinuousWorld(object):\n",
" r\"\"\"The ContinuousWorld Environment.\n",
"\n",
" An agent can be anywhere in the grid. The agent provides Forces to move. When\n",
" the agent provides a force, it is applied and the final position is jittered.\n",
"\n",
" When the agent is reset, its location is drawn from a global start position\n",
" given by `drift_between`. This start position is non-stationary and drifts\n",
" toward the target start position as the environment resets with the speed\n",
" `drift_speed`.\n",
"\n",
" For example the start position is (0., 0.). After reseting once, the start\n",
" positon might drift toward (0.5, 0.5). After resetting again it may drift\n",
" again to (0., 0.). This happens smoothly according to the drifting speed.\n",
"\n",
" Walls can be specified in this environment. Detection works by checking if the\n",
" agents action forces it to go in a direction which collides with a wall.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" size: float,\n",
" wall_pairs: Optional[List[Tuple[Point, Point]]] = None,\n",
" drift_between: Optional[List[Tuple[Point, Point]]] = None,\n",
" movement_noise: float = 0.1,\n",
" seed: int = 1,\n",
" drift_speed: float = 0.5,\n",
" reset_noise: Optional[float] = None,\n",
" max_episode_length: int = 10,\n",
" max_action_force: float = 0.5,\n",
" verbose_reset: bool = False\n",
" ):\n",
" \"\"\"Initializes the Continuous World Environment.\n",
"\n",
" Args:\n",
" size: The size of the world.\n",
" wall_pairs: A list of tuple of points representing the start and end\n",
" positions of the wall.\n",
" drift_between: A list of tuple of points representing how the starting\n",
" distrubiton should change. If None, it will drift between the four\n",
" corners of the room.\n",
" movement_noise: The noise around each position after movement.\n",
" seed: The seed for the random number generator.\n",
" drift_speed: How quickly to move in the drift direction.\n",
" reset_noise: The noise around the reset position. Defaults to\n",
" movement_noise if not specified.\n",
" max_episode_length: The maximum length of the episode before resetting.\n",
" max_action_force: If using random_step() this will be the maximum random\n",
" force applied in the x and y direction.\n",
" verbose_reset: Prints out every time the global starting position is\n",
" reset.\n",
" \"\"\"\n",
" self._size = size\n",
" self._wall_pairs = wall_pairs or []\n",
" self._verbose_reset = verbose_reset\n",
"\n",
" # Points to drift the start position between.\n",
" if drift_between is None:\n",
" self._drift_between = (\n",
" Point((1/4) * size, (1/4) * size),\n",
" Point((1/4) * size, (3/4) * size),\n",
" Point((3/4) * size, (1/4) * size),\n",
" Point((3/4) * size, (3/4) * size),\n",
" )\n",
" else:\n",
" self._drift_between = drift_between\n",
"\n",
" self._noise = movement_noise\n",
" self._reset_noise = reset_noise or movement_noise\n",
" self._rng = np.random.RandomState(seed)\n",
" random.seed(seed)\n",
"\n",
" # The current and target starting positions.\n",
" # Internal to this class mu is used to refer to mean \"start position\".\n",
" # Therefore mu = current start position and end_mu is the target start\n",
" # position.\n",
" self._mu, self._end_mu = random.sample(self._drift_between, 2)\n",
" # The speed at which we will move toward the target position.\n",
" self._drift_speed = drift_speed\n",
" self.update_agent_position()\n",
" self._decide_new_target_mu()\n",
" self._max_episode_length = max_episode_length\n",
" self._current_episode_length = 0\n",
" self._terminated = True\n",
" self._max_action_force = max_action_force\n",
" self._recent_mu_updated = False\n",
"\n",
" def _decide_new_target_mu(self):\n",
" \"\"\"Decide a new target direction to move toward.\"\"\"\n",
" # The direction should be toward the \"target ending mu.\"\n",
" (new_end_mu,) = random.sample(self._drift_between, 1)\n",
" while new_end_mu == self._end_mu:\n",
" (new_end_mu,) = random.sample(self._drift_between, 1)\n",
"\n",
" self._end_mu = new_end_mu\n",
" self._decide_drift_direction()\n",
" if self._verbose_reset:\n",
" print(f'Target mu has been updated to: {self._end_mu}')\n",
" self._recent_mu_updated = True\n",
"\n",
" def _decide_drift_direction(self):\n",
" \"\"\"Decide the drifting direction to move in.\"\"\"\n",
" direction = self._end_mu - self._mu\n",
" l2 = direction.l2norm()\n",
" drift_direction = Point(direction.x / l2, direction.y / l2)\n",
" self._drift_direction = Point(\n",
" drift_direction.x * self._drift_speed,\n",
" drift_direction.y * self._drift_speed\n",
" )\n",
"\n",
" def _should_update_target_mu(self) -\u003e bool:\n",
" \"\"\"Decide if the drift direction should change.\"\"\"\n",
" # Condition 1: We are past the edge of the environment.\n",
" if self._past_edge(self._mu.x)[0] or self._past_edge(self._mu.y)[0]:\n",
" return True\n",
"\n",
" # Condition 2: Check if the current mu is close to the end mu.\n",
" return self._mu.is_close_to(self._end_mu, self._drift_speed)\n",
"\n",
" def update_current_start_position(self):\n",
" \"\"\"Update the current mu to drift toward mu_end. Change mu_end if needed.\"\"\"\n",
" if self._should_update_target_mu():\n",
" self._decide_new_target_mu()\n",
" self._decide_drift_direction()\n",
" proposed_mu = self._mu + self._drift_direction\n",
" self._mu = self._wrap_coordinate(proposed_mu)\n",
"\n",
" def _past_edge(self, x: float) -\u003e Tuple[bool, float]:\n",
" \"\"\"Checks if coordinate is beyond the edges.\"\"\"\n",
" if x \u003e= self._size:\n",
" return True, self._size\n",
" elif x \u003c= 0.0:\n",
" return True, 0.0\n",
" else:\n",
" return False, x\n",
"\n",
" def _wrap_coordinate(self, point: Point) -\u003e Point:\n",
" \"\"\"Wraps coordinates that are beyond edges.\"\"\"\n",
" wrapped_coordinates = map(self._past_edge, dataclasses.astuple(point))\n",
" return Point(*map(operator.itemgetter(1), wrapped_coordinates))\n",
"\n",
" def update_agent_position(self):\n",
" self._current_position = self._wrap_coordinate(\n",
" self._mu.normal_sample_around(self._noise))\n",
"\n",
" def set_agent_position(self, new_position: Point):\n",
" self._current_position = self._wrap_coordinate(new_position)\n",
"\n",
" def reset(self) -\u003e Tuple[float, float]:\n",
" \"\"\"Reset the current position of the agent and move the global mu.\"\"\"\n",
" self.update_current_start_position()\n",
" self.update_agent_position()\n",
" self._current_episode_length = 0\n",
" self._terminated = False\n",
" return self._current_position\n",
"\n",
" def get_random_force(self) -\u003e Force:\n",
" return Force(*self._rng.uniform(\n",
" -self._max_action_force, self._max_action_force, 2))\n",
"\n",
" def random_step(self):\n",
" random_action = self.get_random_force()\n",
" to_be_returned = self.step(random_action)\n",
" to_be_returned[-1]['action_taken'] = random_action\n",
" return to_be_returned\n",
"\n",
" @property\n",
" def agent_position(self):\n",
" return dataclasses.astuple(self._current_position)\n",
"\n",
" @property\n",
" def start_position(self):\n",
" return dataclasses.astuple(self._mu)\n",
"\n",
" @property\n",
" def size(self):\n",
" return self._size\n",
"\n",
" @property\n",
" def walls(self):\n",
" return self._wall_pairs\n",
"\n",
" def _check_goes_through_wall(self, start: Point, end: Point):\n",
" if not self._wall_pairs: return False\n",
"\n",
" for pair in self._wall_pairs:\n",
" if intersect((start, end), pair):\n",
" return True\n",
" return False\n",
"\n",
" def step(\n",
" self,\n",
" action: Force\n",
" ) -\u003e Tuple[Tuple[float, float], Optional[float], bool, Dict[str, Any]]:\n",
" \"\"\"Does a step in the environment using the action.\n",
"\n",
" Args:\n",
" action: Force applied by the agent.\n",
"\n",
" Returns:\n",
" Agent position: A tuple of two floats.\n",
" The reward.\n",
" An indicator if the episode terminated.\n",
" A dictionary containing any information about the step.\n",
" \"\"\"\n",
" if self._terminated:\n",
" raise ValueError('Episode is over. Please reset the environment.')\n",
" perturbed_action = action.normal_sample_around(self._noise)\n",
"\n",
" proposed_position = self._wrap_coordinate(\n",
" self._current_position + perturbed_action)\n",
"\n",
" goes_through_wall = self._check_goes_through_wall(\n",
" self._current_position, proposed_position)\n",
"\n",
" if not goes_through_wall:\n",
" self._current_position = proposed_position\n",
"\n",
" self._current_episode_length += 1\n",
"\n",
" if self._current_episode_length \u003e self._max_episode_length:\n",
" self._terminated = True\n",
"\n",
" recent_mu_updated = self._recent_mu_updated\n",
" self._recent_mu_updated = False\n",
" return (\n",
" self._current_position,\n",
" None,\n",
" self._terminated,\n",
" {\n",
" 'goes_through_wall': goes_through_wall,\n",
" 'proposed_position': proposed_position,\n",
" 'recent_start_position_updated': recent_mu_updated\n",
" }\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "O43gq5h_YS2I"
},
"outputs": [],
"source": [
"#@title Visualization suite.\n",
"\n",
"def visualize_environment(\n",
" world,\n",
" ax,\n",
" scaling=1.0,\n",
" agent_color='r',\n",
" agent_size=0.2,\n",
" start_color='g',\n",
" draw_agent=True,\n",
" draw_start_mu=True,\n",
" draw_target_mu=True,\n",
" draw_walls=True,\n",
" write_text=True):\n",
" \"\"\"Visualize the continuous grid world.\n",
"\n",
" The agent will be drawn as a circle. The start and target\n",
" locations will be drawn by a cross. Walls will be drawn in\n",
" black.\n",
"\n",
" Args:\n",
" world: The continuous gridworld to visualize.\n",
" ax: The matplotlib axes to draw the gridworld.\n",
" scaling: Scale the plot by this factor.\n",
" agent_color: Color of the agent.\n",
" agent_size: Size of the agent in the world.\n",
" start_color: Color of the start marker.\n",
" draw_agent: Boolean that controls drawing agent.\n",
" draw_start_mu: Boolean that controls drawing starting position.\n",
" draw_target_mu: Boolean that controls drawing ending position.\n",
" draw_walls: Boolean that controls drawing walls.\n",
" write_text: Boolean to write text for each component being drawn.\n",
" \"\"\"\n",
" scaled_size = scaling * world.size\n",
"\n",
" # Draw the outer walls.\n",
" ax.hlines(0, 0, scaled_size)\n",
" ax.hlines(scaled_size, 0, scaled_size)\n",
" ax.vlines(scaled_size, 0, scaled_size)\n",
" ax.vlines(0, 0, scaled_size)\n",
"\n",
" for wall_pair in world.walls:\n",
" ax.plot(\n",
" [p.x * scaling for p in wall_pair],\n",
" [p.y * scaling for p in wall_pair],\n",
" color='k')\n",
"\n",
" if draw_start_mu:\n",
" # Draw the position of the start dist.\n",
" x, y = [p * scaling for p in world.mu_start_position]\n",
" ax.scatter([x], [y], marker='x', c=start_color)\n",
" if write_text: ax.text(x, y, 'Starting position.')\n",
"\n",
" if draw_target_mu:\n",
" # Draw the target position.\n",
" x, y = [p * scaling for p in dataclasses.astuple(world._end_mu)]\n",
" ax.scatter([x], [y], marker='x', c='k')\n",
" if write_text: ax.text(x, y,'Target position.')\n",
"\n",
" if draw_agent:\n",
" # Draw the position of the agent as a circle.\n",
" x, y = [scaling * p for p in world.agent_position]\n",
" agent_circle = plt.Circle((x, y), agent_size, color=agent_color)\n",
" ax.add_artist(agent_circle)\n",
" if write_text: ax.text(x, y, 'Agent position.')\n",
"\n",
" return ax\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "eSfUwRtoY_aN"
},
"source": [
"# Affordance specification"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "IVTYi4YVZAfR"
},
"outputs": [],
"source": [
"#@title Intent detection and plotting code.\n",
"\n",
"IntentName = enum.IntEnum(\n",
" 'IntentName', 'delta_pos_x delta_neg_x delta_pos_y delta_neg_y')\n",
"\n",
"class IntentStatus(enum.IntEnum):\n",
" complete = 1\n",
" incomplete = 0\n",
"\n",
"@dataclasses.dataclass(eq=False)\n",
"class Intent:\n",
" name: 'IntentName'\n",
" status: 'IntentStatus'\n",
"\n",
"\n",
"PointOrFloatTuple = Union[Point, Tuple[float, float]]\n",
"\n",
"def _get_intent_completed(\n",
" s_t: PointOrFloatTuple,\n",
" a_t: Force,\n",
" s_tp1: PointOrFloatTuple,\n",
" intent_name: IntentName,\n",
" threshold: float = 0.0):\n",
" r\"\"\"Determines if the intent was completed in the transition.\n",
"\n",
" The available intents are based on significant movement on the x-y plane:\n",
"\n",
" Intent is 1 if:\n",
" `s_tp1.{{x,y}} - s_t.{{x,y}} {{\u003e,\u003c}} threshold`\n",
" else: 0.\n",
"\n",
" Args:\n",
" s_t: The current position of the agent.\n",
" a_t: The force for the action.\n",
" s_tp1: The position after executing action of the agent.\n",
" intent_name: The intent that needs to be detected.\n",
" threshold: The significance threshold for the intent to be detected.\n",
" \"\"\"\n",
" if not isinstance(s_t, Point):\n",
" s_t = Point(*s_t)\n",
" if not isinstance(s_tp1, Point):\n",
" s_tp1 = Point(*s_tp1)\n",
" IntentName(intent_name) # Check if valid intent_name.\n",
"\n",
" diff = s_tp1 - s_t # Find the positional difference.\n",
"\n",
" if intent_name == IntentName.delta_pos_x:\n",
" if diff.x \u003e threshold:\n",
" return IntentStatus.complete\n",
" if intent_name == IntentName.delta_pos_y:\n",
" if diff.y \u003e threshold:\n",
" return IntentStatus.complete\n",
" if intent_name == IntentName.delta_neg_x:\n",
" if diff.x \u003c -threshold:\n",
" return IntentStatus.complete\n",
" if intent_name == IntentName.delta_neg_y:\n",
" if diff.y \u003c -threshold:\n",
" return IntentStatus.complete\n",
"\n",
" return IntentStatus.incomplete\n",
"\n",
"# Some simple test cases.\n",
"assert not _get_intent_completed(\n",
" Point(0, 0), None, Point(0.5, 0.0), IntentName.delta_neg_y)\n",
"assert not _get_intent_completed(\n",
" Point(0, 0), None, Point(0.5, 0.0), IntentName.delta_pos_y)\n",
"assert _get_intent_completed(\n",
" Point(0, 0), None, Point(0.5, 0.0), IntentName.delta_pos_x)\n",
"assert not _get_intent_completed(\n",
" Point(0, 0), None, Point(0.5, 0.0), IntentName.delta_neg_x)\n",
"assert _get_intent_completed(\n",
" Point(0, 0), None, Point(0.5, 0.5), IntentName.delta_pos_x)\n",
"assert _get_intent_completed(\n",
" Point(0, 0), None, Point(0.5, 0.5), IntentName.delta_pos_y)\n",
"assert not _get_intent_completed(\n",
" Point(0, 0), None, Point(0.5, 0.5), IntentName.delta_pos_y, 0.6)\n",
"assert not _get_intent_completed(\n",
" Point(0, 0), None, Point(-0.5, -0.5), IntentName.delta_neg_x, 0.6)\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "OtWqqse4Zutx"
},
"outputs": [],
"source": [
"#@title Data Collection code\n",
"def get_transitions(\n",
" world: ContinuousWorld,\n",
" max_num_transitions: int = 500,\n",
" max_trajectory_length: Optional[int] = None,\n",
" policy: Optional[Callable[[np.ndarray], int]] = None,\n",
" intent_threshold: float = 0.0):\n",
" \"\"\"Samples transitions from an environment.\n",
"\n",
" Args:\n",
" world: The environment to collect trajectories from.\n",
" max_num_transitions: The total number of transitions to sample.\n",
" max_trajectory_length: The maximum length of the trajectory. If None\n",
" trajectories will naturally reset during episode end.\n",
" policy: The data collection policy. If None is given a random policy\n",
" is used. The policy must take a single argument, the one hot\n",
" representation of the state. If using a tensorflow function make sure to\n",
" handle batching within the policy itself.\n",
" intent_threshold: The threshold to use for the intent.\n",
"\n",
" Returns:\n",
" The transitions collected from the environment:\n",
" This is a 4-tuple containing the batch of state, action, state' and intent\n",
" target.\n",
" Human Readable transitions:\n",
" A set containing the unique transitions in the batch and if the intent was\n",
" completed.\n",
" Infos:\n",
" A list containing the info dicts sampled during the batch.\n",
" \"\"\"\n",
" max_trajectory_length = max_trajectory_length or float('inf')\n",
" trajectory = []\n",
" s_t = world.reset()\n",
" trajectory_length = 0\n",
" human_readable = set()\n",
" if policy is None:\n",
" def policy(_):\n",
" return world.get_random_force()\n",
"\n",
" infos = []\n",
"\n",
" for _ in range(max_num_transitions):\n",
" action = policy(s_t)\n",
" s_tp1, _, done, info = world.step(action)\n",
" infos.append(info)\n",
" reward = 0\n",
"\n",
" all_intents = []\n",
" intent_status_only = []\n",
" for intent_name in IntentName:\n",
" intent_status = _get_intent_completed(\n",
" s_t, action, s_tp1, intent_name, intent_threshold)\n",
" all_intents.append((intent_name, intent_status))\n",
" intent_status_only.append(intent_status)\n",
"\n",
" # Human readable vesion:\n",
" human_readable.add((s_t, action, s_tp1, tuple(all_intents)))\n",
"\n",
" # Prepare things for tensorflow:\n",
" s_t_tf = tf.constant(dataclasses.astuple(s_t), dtype=tf.float32)\n",
" s_tp1_tf = tf.constant(dataclasses.astuple(s_tp1), dtype=tf.float32)\n",
" a_t_tf = tf.constant(dataclasses.astuple(action), dtype=tf.float32)\n",
" intent_statuses_tf = tf.constant(intent_status_only)\n",
" trajectory.append((s_t_tf, a_t_tf, s_tp1_tf, reward, intent_statuses_tf))\n",
"\n",
" trajectory_length += 1\n",
" if done or trajectory_length \u003e max_trajectory_length:\n",
" s_t = world.reset()\n",
" trajectory_length = 0\n",
" else:\n",
" s_t = s_tp1\n",
"\n",
" batch = list(map(tf.stack, zip(*trajectory)))\n",
" return batch, human_readable, infos\n",
"\n",
"# Integration test.\n",
"world = ContinuousWorld(\n",
" size=2,\n",
" drift_speed=0.1,\n",
" max_action_force=2.0,\n",
" max_episode_length=100)\n",
"data, _, _ = get_transitions(world, max_num_transitions=2)\n",
"assert data is not None"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "bvYwGLyubcpa"
},
"outputs": [],
"source": [
"#@title Probabilistic transition model\n",
"\n",
"hidden_nodes = 32\n",
"input_size = 2\n",
"\n",
"class TransitionModel(tf.keras.Model):\n",
" def __init__(self, hidden_nodes, output_size):\n",
" super().__init__()\n",
" self._net1 = tf.keras.layers.Dense(\n",
" hidden_nodes, activation=tf.keras.activations.relu)\n",
" self._net2 = tf.keras.layers.Dense(\n",
" hidden_nodes, activation=tf.keras.activations.relu)\n",
" # Multiply by 2 for means and variances.\n",
" self._output = tf.keras.layers.Dense(2*output_size)\n",
"\n",
" def __call__(self, st, at):\n",
" net_inputs = tf.concat((st, at), axis=1)\n",
" means_logstd = self._output(self._net2(self._net1(net_inputs)))\n",
" means, logstd = tf.split(means_logstd, 2, axis=1)\n",
" std = tf.exp(logstd)\n",
" return tfp.distributions.Normal(loc=means, scale=std)\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "5S6uqTVAdoKD"
},
"outputs": [],
"source": [
"#@title Training algorithm.\n",
"\n",
"MACHINE_EPSILON = np.finfo(float).eps.item()\n",
"\n",
"def train_networks(\n",
" world: ContinuousWorld,\n",
" model_network: Optional[tf.keras.Model] = None,\n",
" model_optimizer: Optional[tf.keras.optimizers.Optimizer] = None,\n",
" affordance_network: Optional[tf.keras.Model] = None,\n",
" affordance_optimizer: Optional[tf.keras.optimizers.Optimizer] = None,\n",
" use_affordance_to_mask_model: bool = False,\n",
" affordance_mask_threshold: float = 0.9,\n",
" num_train_steps: int =10,\n",
" fresh_data: bool = True,\n",
" max_num_transitions: int = 1,\n",
" max_trajectory_length: Optional[int] = None,\n",
" optimize_performance: bool = False,\n",
" intent_threshold: float = 1.0,\n",
" debug: bool = False,\n",
" print_losses: bool = False,\n",
" print_every: int = 10):\n",
" \"\"\"Trains an affordance network.\n",
"\n",
" Args:\n",
" world: The gridworld to collect training data from.\n",
" model_network: The network for the transition model.\n",
" model_optimizer: The optimizer for the transition model.\n",
" affordance_network: The affordance network.\n",
" affordance_optimizer: The optimizer for the affordance network.\n",
" use_affordance_to_mask_model: Uses affordances to mask the losses of the\n",
" transition model.\n",
" affordance_mask_threshold: The threshold at which the mask should be\n",
" applied.\n",
" num_train_steps: The total number of training steps.\n",
" fresh_data: Collect fresh data before every training step.\n",
" max_num_transitions: The number of rollout trajectories per training step.\n",
" max_trajectory_length: The maximum length of each trajectory. If None then\n",
" there is no artifically truncated trajectory length.\n",
" optimizer_performance: Use `tf.function` to speed up training steps.\n",
" intent_threshold: The threshold to consider as a signficant completion of\n",
" the intent.\n",
" debug: Debug mode prints out the human readable transitions and disables\n",
" tf.function.\n",
" print_losses: Prints out the losses during training.\n",
" print_every: Indicates how often things should be printed out.\n",
" \"\"\"\n",
" all_aff_losses = []\n",
" all_model_losses = []\n",
"\n",
" # Error checking to make sure the correct combinations of model/affordance\n",
" # nets and optimizers are given or none at all.\n",
" if (affordance_network is None) != (affordance_optimizer is None):\n",
" raise ValueError('Both affordance network and optimizer have to be given.')\n",
" else:\n",
" use_affordances = affordance_network is not None\n",
"\n",
" if (model_network is None) != (model_optimizer is None):\n",
" raise ValueError('Both model network and optimizer have to be given.')\n",
" else:\n",
" use_model = model_network is not None\n",
"\n",
" # At least one of affordance network or model network must be specified.\n",
" if model_network is None and (\n",
" (model_network is None) == (affordance_network is None)):\n",
" raise ValueError(\n",
" 'This code does not do anything without models or affordances.')\n",
"\n",
" # Check if both are specified if use_affordance_to_mask_model is True.\n",
" if use_affordance_to_mask_model and (\n",
" model_network is None and affordance_network is None):\n",
" raise ValueError(\n",
" 'Cannot use_affordance_to_mask model if affordance and model networks'\n",
" ' are not given!')\n",
"\n",
" # User friendly print outs indicate what is happening.\n",
" print(\n",
" f'Using model? {use_model}. Using affordances? {use_affordances}. Using'\n",
" f' affordances to mask model? {use_affordance_to_mask_model}.')\n",
"\n",
" def _train_step_affordances(trajectory):\n",
" \"\"\"Train affordance network.\"\"\"\n",
" # Note: Please make sure you understand the shapes here before editing to\n",
" # prevent accidental broadcast.\n",
" with tf.GradientTape() as tape:\n",
" s_t, a_t, _, _, intent_target = trajectory\n",
" concat_input = tf.concat((s_t, a_t), axis=1)\n",
" preds = affordance_network(concat_input)\n",
"\n",
" intent_target = tf.reshape(intent_target, (-1, 1))\n",
" unshaped_preds = preds\n",
" preds = tf.reshape(preds, (-1, 1))\n",
"\n",
" loss = tf.keras.losses.binary_crossentropy(intent_target, preds)\n",
" total_loss = tf.reduce_mean(loss)\n",
" grads = tape.gradient(total_loss, affordance_network.trainable_variables)\n",
" affordance_optimizer.apply_gradients(\n",
" zip(grads, affordance_network.trainable_variables))\n",
"\n",
" return total_loss, unshaped_preds\n",
"\n",
" def _train_step_model(trajectory, affordances):\n",
" \"\"\"Train model network.\"\"\"\n",
" with tf.GradientTape() as tape:\n",
" s_t, a_t, s_tp1, _, _ = trajectory\n",
" transition_model = model_network(s_t, a_t)\n",
" log_prob = tf.reduce_sum(transition_model.log_prob(s_tp1), -1)\n",
" num_examples = s_t.shape[0]\n",
"\n",
" if use_affordance_to_mask_model:\n",
" # Check if at least one intent is affordable.\n",
" masks_per_intent = tf.math.greater_equal(\n",
" affordances, affordance_mask_threshold)\n",
" masks_per_transition = tf.reduce_any(masks_per_intent, 1)\n",
" # Explicit reshape to prevent accidental broadcasting.\n",
" batch_size = len(s_t)\n",
" log_prob = tf.reshape(log_prob, (batch_size, 1))\n",
" masks_per_transition = tf.reshape(masks_per_transition, (batch_size, 1))\n",
" log_prob = log_prob * tf.cast(masks_per_transition, dtype=tf.float32)\n",
" # num_examples changes if there is masking so take that into account:\n",
" num_examples = tf.reduce_sum(\n",
" tf.cast(masks_per_transition, dtype=tf.float32))\n",
" num_examples = tf.math.maximum(num_examples, tf.constant(1.0))\n",
"\n",
" # Negate log_prob here because we want to maximize this via minimization.\n",
" total_loss = -tf.reduce_sum(log_prob) / num_examples\n",
" grads = tape.gradient(total_loss, model_network.trainable_variables)\n",
" model_optimizer.apply_gradients(\n",
" zip(grads, model_network.trainable_variables))\n",
"\n",
" return total_loss\n",
"\n",
" # Optimize performance using tf.function.\n",
" if optimize_performance and not debug:\n",
" _train_step_affordances = tf.function(_train_step_affordances)\n",
" _train_step_model = tf.function(_train_step_model)\n",
" print('Training step has been optimized.')\n",
"\n",
" initial_data_collected = False\n",
" infos = []\n",
" for i in range(num_train_steps):\n",
" # Step 1: Collect data.\n",
" if not initial_data_collected or fresh_data:\n",
" initial_data_collected = True\n",
" running_time = time.time()\n",
" trajectories, unique_transitions, infos_i = get_transitions(\n",
" world,\n",
" max_num_transitions=max_num_transitions,\n",
" max_trajectory_length=max_trajectory_length,\n",
" intent_threshold=intent_threshold)\n",
" collection_running_time = time.time() - running_time\n",
" if debug: print('unique_transitions:', unique_transitions)\n",
" running_time = time.time()\n",
"\n",
" # Check if the start state was updated:\n",
" infos.append(\n",
" any([info['recent_start_position_updated'] for info in infos_i]))\n",
"\n",
" # Step 2: Train affordance model.\n",
" if use_affordances:\n",
" aff_loss, affordance_predictions = _train_step_affordances(trajectories)\n",
" aff_loss = aff_loss.numpy().item()\n",
" else:\n",
" affordance_predictions = tf.constant(0.0) # Basically a none.\n",
" aff_loss = None\n",
" all_aff_losses.append(aff_loss)\n",
"\n",
" # Step 3: Train transition model and mask predictions if necessary.\n",
" if use_model:\n",
" model_loss = _train_step_model(trajectories, affordance_predictions)\n",
" model_loss = model_loss.numpy().item()\n",
" else:\n",
" model_loss = None\n",
" all_model_losses.append(model_loss)\n",
"\n",
" if debug or print_losses:\n",
" if i % print_every == 0:\n",
" train_loop_time = time.time() - running_time\n",
" print(f'i: {i}, aff_loss: {aff_loss}, model_loss: {model_loss}, '\n",
" f'collection_loop_time: {collection_running_time:.2f}, '\n",
" f'train_loop_time: {train_loop_time:.2f}')\n",
"\n",
" return all_model_losses, all_aff_losses, infos"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "6hXi7Iy0b50_"
},
"source": [
"# Plotting utilities"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "gMNIsCLEb3Ie"
},
"outputs": [],
"source": [
"#@title Learning curve smoothing\n",
"# From https://github.com/google-research/policy-learning-landscape/blob/master/analysis_tools/data_processing.py#L82\n",
"\n",
"DEFAULT_SMOOTHING_WEIGHT = 0.9\n",
"def apply_linear_smoothing(data, smoothing_weight=DEFAULT_SMOOTHING_WEIGHT):\n",
" \"\"\"Smooth curves using a exponential linear weight.\n",
"\n",
" This smoothing algorithm is the same as the one used in tensorboard.\n",
"\n",
" Args:\n",
" data: The sequence or list containing the data to smooth.\n",
" smoothing_weight: A float representing the weight to place on the moving\n",
" average.\n",
"\n",
" Returns:\n",
" A list containing the smoothed data.\n",
" \"\"\"\n",
" if len(data) == 0: # pylint: disable=g-explicit-length-test\n",
" raise ValueError('No data to smooth.') \n",
" if smoothing_weight \u003c= 0:\n",
" return data\n",
" last = data[0]\n",
" smooth_data = []\n",
" for x in data:\n",
" if not np.isfinite(last):\n",
" smooth_data.append(x)\n",
" else:\n",
" smooth_data.append(last * smoothing_weight + (1 - smoothing_weight) * x)\n",
" last = smooth_data[-1]\n",
" return smooth_data\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "T5cvk1XUb-M0"
},
"outputs": [],
"source": [
"#@title Intent plotting code.\n",
"\n",
"def plot_intents(\n",
" world: ContinuousWorld,\n",
" affordance_predictions: np.ndarray,\n",
" eval_action: Tuple[float, float],\n",
" num_world_ticks: int = 3,\n",
" intent_collection: IntentName = IntentName,\n",
" subplot_configuration: Tuple[int, int] = (2, 2),\n",
" figsize: Tuple[int, int] = (5, 5)):\n",
" \"\"\"Plots the intents as a heatmap.\n",
"\n",
" Given the predictions from the affordance network, we plot a heatmap for each\n",
" intent indicating how likely the `eval_action` can be used to complete it.\n",
"\n",
" Args:\n",
" world: The gridworld to use.\n",
" affordance_predictions: Predictions from the affordance classifier. The last\n",
" dimension should be of the same len as intent_collection.\n",
" eval_action: The eval action being used (For plotting the title).\n",
" num_world_ticks: The number of ticks on the axes of the world.\n",
" subplot_configuration: The arrangement of the subplots on the plot.\n",
" figsize: The size of the matplotlib figure.\n",
" \"\"\"\n",
" fig = plt.figure(figsize=figsize)\n",
"\n",
" # Since we are predicting probabilities, normalize between 0 and 1.\n",
" norm = mpl.colors.Normalize(vmin=0.0, vmax=1.0)\n",
"\n",
" # The colorbar axes.\n",
" cax = fig.add_axes([1.0, 0.1, 0.075, 0.8])\n",
"\n",
" for intent in intent_collection:\n",
" ax = fig.add_subplot(*subplot_configuration, intent)\n",
" afford_sliced = affordance_predictions[:, :, intent-1]\n",
" afford_sliced = np.transpose(afford_sliced)\n",
" ax_ = ax.imshow(afford_sliced, origin='lower')\n",
"\n",
" # This code will handle num_world_ticks=0 gracefully.\n",
" ax.set_xticks(np.linspace(0, afford_sliced.shape[0], num_world_ticks))\n",
" ax.set_yticks(np.linspace(0, afford_sliced.shape[0], num_world_ticks))\n",
" ax.set_xticklabels(\n",
" np.linspace(0, world.size, num_world_ticks), fontsize='x-small')\n",
" ax.set_yticklabels(\n",
" np.linspace(0, world.size, num_world_ticks), fontsize='x-small')\n",
"\n",
" ax.set_xlabel('x')\n",
" ax.set_ylabel('y', rotation=0)\n",
" plt.title('Intent: {}'.format(intent.__repr__()[-10:-2]))\n",
" ax_.set_norm(norm)\n",
" if intent == len(intent_collection):\n",
" plt.colorbar(ax_, cax)\n",
" cax.set_ylabel('Probability of intent completion')\n",
"\n",
" plt.suptitle('Evaluating Action: {}'.format(eval_action))\n",
" plt.tight_layout(rect=[0, 0.03, 1, 0.95])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "pHarhdiSdixz"
},
"source": [
"# Main Experiment (Training)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Zq9TXAyDfaP_"
},
"outputs": [],
"source": [
"# Storing the losses and models in a global list.\n",
"all_losses_global = []\n",
"all_models_global = []\n",
"all_affordance_global = []"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"colab_type": "code",
"id": "OP50CV_YdkLi",
"outputId": "58a4677b-5dcf-4015-d795-2e7a4b9384fc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Experiments that will be run: [False, True]\n",
"Resetting seed to 0.\n",
"Target mu has been updated to: Point(x=0.5, y=0.5)\n",
"Using model? True. Using affordances? False. Using affordances to mask model? False.\n",
"Training step has been optimized.\n",
"Target mu has been updated to: Point(x=1.5, y=1.5)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:107: RuntimeWarning: invalid value encountered in double_scalars\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"i: 0, aff_loss: None, model_loss: 3.079850912094116, collection_loop_time: 0.20, train_loop_time: 0.28\n",
"i: 1000, aff_loss: None, model_loss: -1.3993345499038696, collection_loop_time: 0.17, train_loop_time: 0.00\n",
"Target mu has been updated to: Point(x=0.5, y=0.5)\n",
"i: 2000, aff_loss: None, model_loss: -1.4746335744857788, collection_loop_time: 0.17, train_loop_time: 0.00\n",
"Target mu has been updated to: Point(x=1.5, y=1.5)\n",
"i: 3000, aff_loss: None, model_loss: -1.926070213317871, collection_loop_time: 0.29, train_loop_time: 0.00\n",
"i: 4000, aff_loss: None, model_loss: -2.129803419113159, collection_loop_time: 0.17, train_loop_time: 0.00\n",
"Target mu has been updated to: Point(x=0.5, y=0.5)\n",
"i: 5000, aff_loss: None, model_loss: -2.212693452835083, collection_loop_time: 0.18, train_loop_time: 0.00\n",
"Target mu has been updated to: Point(x=1.5, y=1.5)\n",
"i: 6000, aff_loss: None, model_loss: -2.104724168777466, collection_loop_time: 0.18, train_loop_time: 0.00\n",
"i: 7000, aff_loss: None, model_loss: -2.153459310531616, collection_loop_time: 0.19, train_loop_time: 0.00\n",
"Target mu has been updated to: Point(x=0.5, y=0.5)\n",
"Resetting seed to 0.\n",
"Target mu has been updated to: Point(x=0.5, y=0.5)\n",
"Using model? True. Using affordances? True. Using affordances to mask model? True.\n",
"Training step has been optimized.\n",
"Target mu has been updated to: Point(x=1.5, y=1.5)\n",
"i: 0, aff_loss: 0.7048008441925049, model_loss: 3.4398090839385986, collection_loop_time: 0.17, train_loop_time: 0.66\n",
"i: 1000, aff_loss: 0.19042730331420898, model_loss: -1.7565302848815918, collection_loop_time: 0.19, train_loop_time: 0.00\n",
"Target mu has been updated to: Point(x=0.5, y=0.5)\n",
"i: 2000, aff_loss: 0.23882852494716644, model_loss: -1.7699742317199707, collection_loop_time: 0.18, train_loop_time: 0.00\n",
"Target mu has been updated to: Point(x=1.5, y=1.5)\n",
"i: 3000, aff_loss: 0.207383394241333, model_loss: -1.983075737953186, collection_loop_time: 0.17, train_loop_time: 0.00\n",
"i: 4000, aff_loss: 0.1946447789669037, model_loss: -2.051856517791748, collection_loop_time: 0.19, train_loop_time: 0.00\n",
"Target mu has been updated to: Point(x=0.5, y=0.5)\n",
"i: 5000, aff_loss: 0.21516098082065582, model_loss: -1.9648913145065308, collection_loop_time: 0.17, train_loop_time: 0.00\n",
"Target mu has been updated to: Point(x=1.5, y=1.5)\n",
"i: 6000, aff_loss: 0.20555077493190765, model_loss: -1.8827911615371704, collection_loop_time: 0.18, train_loop_time: 0.00\n",
"i: 7000, aff_loss: 0.17789226770401, model_loss: -2.0646157264709473, collection_loop_time: 0.28, train_loop_time: 0.00\n",
"Target mu has been updated to: Point(x=0.5, y=0.5)\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"\u003cFigure size 360x360 with 1 Axes\u003e"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"\u003cFigure size 360x360 with 1 Axes\u003e"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"#@title Train affordance and transition model.\n",
"# Trains num_repeats affordance and model networks.\n",
"#@markdown Experiments to run.\n",
"run_model = True #@param {type:\"boolean\"}\n",
"run_model_with_affordances = True #@param {type:\"boolean\"}\n",
"num_repeats = 1#@param {type:\"integer\"}\n",
"\n",
"#@markdown Training arguments\n",
"# use_affordance_to_mask_model = True #@param {type:\"boolean\"}\n",
"optimize_performance = True #@param {type:\"boolean\"}\n",
"model_learning_rate = 1e-2 #@param {type:\"number\"}\n",
"affordance_learning_rate = 1e-1 #@param {type:\"number\"}\n",
"max_num_transitions = 1000 #@param {type:\"integer\"}\n",
"num_train_steps = 8000 #@param {type:\"integer\"}\n",
"affordance_mask_threshold = 0.5 #@param {type:\"number\"}\n",
"seed = 0 #@param {type:\"integer\"}\n",
"intent_threshold = 0.05 #@param {type:\"number\"}\n",
"\n",
"#@markdown Environment arguments\n",
"drift_speed = 0.001 #@param {type:\"number\"}\n",
"max_action_force = 0.5 #@param {type:\"number\"}\n",
"movement_noise = 0.1 #@param {type:\"number\"}\n",
"max_episode_length = 100000 #@param {type:\"integer\"}\n",
"\n",
"input_size = 2\n",
"action_size = 2\n",
"intent_size = len(IntentName)\n",
"hidden_nodes = 32\n",
"world_size = 2\n",
"\n",
"affordance_mask_params = []\n",
"if run_model:\n",
" affordance_mask_params.append(False)\n",
"if run_model_with_affordances:\n",
" affordance_mask_params.append(True)\n",
"\n",
"print(f'Experiments that will be run: {affordance_mask_params}')\n",
"\n",
"for repeat_number in range(num_repeats):\n",
" all_losses = {}\n",
" model_networks = {}\n",
" affordance_networks = {}\n",
" new_seed = seed + repeat_number\n",
" for use_affordance_to_mask_model in affordance_mask_params:\n",
" print(f'Resetting seed to {new_seed}.')\n",
" np.random.seed(new_seed)\n",
" random.seed(new_seed)\n",
" tf.random.set_seed(new_seed)\n",
"\n",
" affordance_network = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(\n",
" hidden_nodes, activation=tf.keras.activations.relu),\n",
" tf.keras.layers.Dense(\n",
" hidden_nodes, activation=tf.keras.activations.relu),\n",
" tf.keras.layers.Dense(\n",
" intent_size, activation=tf.keras.activations.sigmoid),\n",
" ])\n",
"\n",
" affordance_sgd = tf.keras.optimizers.Adam(\n",
" learning_rate=affordance_learning_rate)\n",
" model_sgd = tf.keras.optimizers.Adam(learning_rate=model_learning_rate)\n",
" model_network = TransitionModel(hidden_nodes, input_size)\n",
"\n",
" # Store models for later use.\n",
" model_networks[use_affordance_to_mask_model] = model_network\n",
" affordance_networks[use_affordance_to_mask_model] = affordance_network\n",
"\n",
" world = ContinuousWorld(\n",
" size=world_size,\n",
" # Slow drift speed to make the transition from L -\u003e R slow.\n",
" drift_speed=drift_speed,\n",
" drift_between=(\n",
" # Drift between the two sides around the wall.\n",
" Point((1 / 4) * world_size, (1 / 4) * world_size),\n",
" Point((3 / 4) * world_size, (3 / 4) * world_size),\n",
" ),\n",
" max_action_force=max_action_force,\n",
" max_episode_length=max_episode_length,\n",
" movement_noise=movement_noise,\n",
" wall_pairs=[\n",
" (Point(1.0, 0.0), Point(1.0, 2.0)),\n",
" ],\n",
" verbose_reset=True)\n",
"\n",
" fig = plt.figure(figsize=(5, 5))\n",
" ax = fig.add_subplot(1, 1, 1)\n",
"\n",
" visualize_environment(\n",
" world, ax, scaling=1.0, draw_start_mu=False, draw_target_mu=False)\n",
"\n",
" def _use_affordance_or_none(model):\n",
" if use_affordance_to_mask_model:\n",
" return model\n",
" else:\n",
" return None\n",
"\n",
" model_loss, aff_loss, infos = train_networks(\n",
" world,\n",
" model_network=model_network,\n",
" model_optimizer=model_sgd,\n",
" affordance_network=_use_affordance_or_none(affordance_network),\n",
" affordance_optimizer=_use_affordance_or_none(affordance_sgd),\n",
" print_losses=True,\n",
" fresh_data=True,\n",
" affordance_mask_threshold=affordance_mask_threshold,\n",
" use_affordance_to_mask_model=use_affordance_to_mask_model,\n",
" max_num_transitions=max_num_transitions,\n",
" max_trajectory_length=None,\n",
" optimize_performance=optimize_performance,\n",
" num_train_steps=num_train_steps,\n",
" intent_threshold=intent_threshold,\n",
" print_every=1000)\n",
"\n",
" all_losses[use_affordance_to_mask_model] = (model_loss, aff_loss, infos)\n",
"\n",
" all_models_global.append(model_networks)\n",
" all_affordance_global.append(affordance_networks)\n",
" all_losses_global.append(all_losses)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "90KLuKumiRcB"
},
"outputs": [],
"source": [
"#@title Save weights\n",
"\n",
"for seed, model_networks in enumerate(all_models_global):\n",
" model_networks[True].save_weights(\n",
" f'./affordances/seed_{seed}_model_networks_True/keras.weights')\n",
" model_networks[False].save_weights(\n",
" f'./affordances/seed_{seed}_model_networks_False/keras.weights')\n",
"for seed, affordance_networks in enumerate(all_models_global):\n",
" affordance_networks[True].save_weights(\n",
" f'./affordances/seed_{seed}_affordance_networks_true/keras.weights')\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "g1iO3h6Ffucu"
},
"source": [
"# Visualizations\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "OQtxS_ovhJxt"
},
"source": [
"## Learning curves"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "06NPUpXDfyFf"
},
"outputs": [],
"source": [
"#@title Code to collect results from a list of lists into a single array.\n",
"def _collect_results(\n",
" all_losses_g,\n",
" using_affordances,\n",
" save_to_disk=False,\n",
" smooth_weight=0.99,\n",
" skip_first=10):\n",
" \"\"\"Collects results from the list of losses.\"\"\"\n",
" smoothed_curves = []\n",
" for seed, trace in enumerate(all_losses_g):\n",
" if save_to_disk:\n",
" np.save(f'./affordances/curve_seed_{seed}_{using_affordances}.npy',\n",
" np.array(trace[using_affordances][0]))\n",
" # Smooth the curves for plotting.\n",
" smoothed_curves.append(\n",
" apply_linear_smoothing(\n",
" trace[using_affordances][0][skip_first:], smooth_weight))\n",
" all_curves_stacked = np.stack(smoothed_curves)\n",
" mean_curve = np.mean(all_curves_stacked, 0)\n",
" std_curve = np.std(all_curves_stacked, 0)\n",
" return mean_curve, std_curve"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "R4elUiamgc-I"
},
"outputs": [],
"source": [
"#@title Plot averaged learning curves.\n",
"\n",
"mean_curve_aff, std_curve_aff = _collect_results(\n",
" all_losses_global, True)\n",
"mean_curve_normal, std_curve_normal = _collect_results(\n",
" all_losses_global, False)\n",
"colors = ['r', 'k']\n",
"plt.plot(\n",
" mean_curve_aff, color=colors[0], linewidth=4, label='With Affordance')\n",
"plt.plot(\n",
" mean_curve_normal, color=colors[1], linewidth=4, label='without Affordance')\n",
"\n",
"plt.fill_between(\n",
" range(len(mean_curve_aff)),\n",
" mean_curve_aff+std_curve_aff,\n",
" mean_curve_aff-std_curve_aff,\n",
" alpha=0.25,\n",
" color=colors[0])\n",
"\n",
"\n",
"plt.fill_between(\n",
" range(len(mean_curve_normal)),\n",
" mean_curve_normal+std_curve_normal/np.sqrt(num_repeats),\n",
" mean_curve_normal-std_curve_normal/np.sqrt(num_repeats),\n",
" alpha=0.25,\n",
" color=colors[1])\n",
"\n",
"plt.ylim([-2.2, -1.2])\n",
"plt.xticks(fontsize=15)\n",
"plt.yticks(fontsize=15)\n",
"plt.xticks([0, 2500, 5000, 7500],[0, 2500, 5000, 7500], fontsize=15)\n",
"plt.legend(fontsize=15)\n",
"plt.xlabel('Updates', fontsize=20)\n",
"plt.ylabel(r'$-\\log \\hat{P}(s^\\prime|s,a)$',fontsize=20)\n",
"plt.tight_layout()\n",
"plt.savefig('./affordances/model_learning_avg_plot.pdf')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "OSAm0ukchLqr"
},
"source": [
"## Intent heatmap plots"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 399
},
"colab_type": "code",
"id": "1NeTHiZkhXKj",
"outputId": "2428756a-8685-4312-f3f3-1bf4be01188b"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:56: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"\u003cFigure size 360x360 with 5 Axes\u003e"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"#@title Evaluating action and plotting intent heatmaps.\n",
"#@markdown What is the action?\n",
"action_x_dir = 0.2 #@param {type:\"number\"}\n",
"action_y_dir = 0.2 #@param {type:\"number\"}\n",
"network_seed = 0 #@param {type:\"integer\"}\n",
"\n",
"# Cover the x-y grid.\n",
"xs = np.linspace(0, world.size)\n",
"ys = np.linspace(0, world.size)\n",
"xy_coords = tf.constant(list(itertools.product(xs, ys)), dtype=tf.float32)\n",
"\n",
"eval_action = [action_x_dir, action_y_dir]\n",
"fixed_action = tf.constant([eval_action], dtype=tf.float32)\n",
"fixed_action = tf.repeat(fixed_action, 2500, axis=0)\n",
"\n",
"concat_matrix = tf.concat((xy_coords, fixed_action), axis=1)\n",
"affordance_network = all_affordance_global[network_seed][True]\n",
"afford_predictions = affordance_network(concat_matrix)\n",
"affordance_predictions = tf.reshape(\n",
" afford_predictions,\n",
" (len(xs), len(ys), intent_size)).numpy()\n",
"\n",
"plot_intents(world, affordance_predictions, eval_action)\n",
"\n",
"plt.savefig(\n",
" f'intent_eval_FX{action_x_dir}_FY{action_y_dir}.pdf', bbox_inches='tight')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mOK_e3oqhrjI"
},
"source": [
"## Model Predictions"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "61Gb_tXKh3je"
},
"outputs": [],
"source": [
"#@title Round annotation plotting code.\n",
"ROUND_BOX = dict(boxstyle='round', facecolor='wheat', alpha=1.0)\n",
"\n",
"def add_annotation(\n",
" ax,\n",
" start: Tuple[float, float],\n",
" end: Tuple[float, float],\n",
" connectionstyle, text):\n",
" x1, y1 = start\n",
" x2, y2 = end\n",
"\n",
" # ax.plot([x1, x2], [y1, y2], \".\")\n",
" ax.annotate(\n",
" \"\",\n",
" xy=(x1, y1),\n",
" xycoords='data',\n",
" xytext=(x2 + 0.25, y2),\n",
" textcoords='data',\n",
" size=30.0,\n",
" arrowprops=dict(arrowstyle=\"-\u003e\", color=\"0.0\",\n",
" shrinkA=5, shrinkB=5,\n",
" patchA=None, patchB=None,\n",
" connectionstyle=connectionstyle,),)\n",
"\n",
" ax.text(*end, text, size=15,\n",
" #transform=ax.transAxes,\n",
" ha=\"left\", va=\"top\", bbox=ROUND_BOX)\n",
"\n",
"connection_styles = [\n",
" \"arc3,rad=-0.3\",\n",
" \"arc3,rad=0.3\",\n",
" \"arc3,rad=0.0\",\n",
" \"arc3,rad=0.5\",\n",
" \"arc3,rad=-0.5\"\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "both",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 689
},
"colab_type": "code",
"id": "ILHEEbHShwCb",
"outputId": "9bac0134-d9d1-46da-9283-975ede52250e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Figures show the predicted position of the transition distribution.\n",
"Gray circle shows what would have been predicted but was masked by affordance model. \n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"\u003cFigure size 360x360 with 1 Axes\u003e"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"\u003cFigure size 360x360 with 1 Axes\u003e"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"#@title Plotting Model predictions.\n",
"\n",
"#@markdown Use affordance based model?\n",
"affordance_mask_threshold = 0.5 #@param {type:\"number\"}\n",
"network_seed = 0 #@param {type:\"integer\"}\n",
"\n",
"#@markdown What is the action?\n",
"action_x_dir = +0.5 #@param {type:\"number\"}\n",
"action_y_dir = 0.0 #@param {type:\"number\"}\n",
"\n",
"#@markdown Where is the agent?\n",
"agent_x = 0.75 #@param {type:\"number\"}\n",
"agent_y = 1.0 #@param {type:\"number\"}\n",
"\n",
"action = tf.constant([[action_x_dir, action_y_dir]])\n",
"pos = tf.constant([[agent_x, agent_y]])\n",
"\n",
"affordance_networks = all_affordance_global[network_seed]\n",
"model_networks = all_models_global[network_seed]\n",
"\n",
"scale_scale = 2.0\n",
"\n",
"for i, use_affordance_to_mask_model in enumerate([False, True]):\n",
" fig = plt.figure(figsize=(5, 5))\n",
" ax = fig.add_subplot(1, 1, 1)\n",
" transition_dist = model_networks[use_affordance_to_mask_model](pos, action)\n",
" transition_loc = tuple(transition_dist.loc[0].numpy())\n",
" transition_scale = tuple(transition_dist.scale[0].numpy() * scale_scale)\n",
"\n",
" if use_affordance_to_mask_model:\n",
" aff_network = affordance_networks[use_affordance_to_mask_model]\n",
" AF = aff_network(tf.concat([pos, action], axis=1))\n",
" intents_completable = (AF \u003e affordance_mask_threshold)[0].numpy()\n",
"\n",
" visualize_environment(\n",
" world,\n",
" ax,\n",
" scaling=1.0,\n",
" draw_start_mu=False,\n",
" draw_target_mu=False,\n",
" draw_agent=False,\n",
" agent_size=0.1,\n",
" write_text=False)\n",
" ax.scatter([agent_x], [agent_y], s=150.0, c='green', marker='x')\n",
" ax.arrow(agent_x, agent_y, action_x_dir, action_y_dir, head_width=0.05)\n",
"\n",
" if use_affordance_to_mask_model and not np.any(intents_completable):\n",
" color = 'gray'\n",
" alpha = 0.25\n",
" ellipse_text = '(Masked) '\n",
" else:\n",
" color = None\n",
" alpha = 0.7\n",
" ellipse_text = ''\n",
"\n",
" elipse = mpl.patches.Ellipse(\n",
" transition_loc, *transition_scale, alpha=alpha, color=color)\n",
" ax.add_artist(elipse)\n",
"\n",
" if use_affordance_to_mask_model:\n",
" string_built = ' Intent classificaiton\\n'\n",
" for a in list(zip(IntentName, intents_completable)):\n",
" string_built += ' ' + str(a[0])[-5:] + ':' + str(a[1])\n",
" string_built += '\\n'\n",
" ax.text(\n",
" 0,\n",
" 0,\n",
" string_built,\n",
" )\n",
"\n",
" ax.set_xticks([0.0, 1.0, 2.0])\n",
" ax.set_xticklabels([0, 1.0, 2.0])\n",
"\n",
" ax.set_yticks([0.0, 1.0, 2.0])\n",
" ax.set_yticklabels([0, 1.0, 2.0])\n",
"\n",
" if use_affordance_to_mask_model:\n",
" title = 'Using affordances'\n",
" else:\n",
" title = 'Without affordances'\n",
" ax.set_title(title)\n",
" ax.legend([elipse], [ellipse_text + 'Predicted transition'])\n",
" file_name = (f'./empirical_demo{movement_noise}_P{agent_x}_{agent_y}_'\n",
" f'F{action_x_dir}_{action_x_dir}.pdf')\n",
" fig.savefig(file_name)\n",
"\n",
"print(\n",
" 'Figures show the predicted position of the transition distribution.'\n",
" '\\nGray circle shows what would have been predicted but was masked by '\n",
" 'affordance model. ')"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5GHbTbNClfL-"
},
"outputs": [],
"source": [
""
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "AffordancesInContinuousEnvironment.ipynb",
"provenance": [
{
"file_id": "1W86NFSHwhnx-UEmAY_mhJJzUxPXY3JC4",
"timestamp": 1591715576521
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}