diff --git a/causal_reasoning/Causal_Reasoning_in_Probability_Trees.ipynb b/causal_reasoning/Causal_Reasoning_in_Probability_Trees.ipynb new file mode 100644 index 0000000..1867df5 --- /dev/null +++ b/causal_reasoning/Causal_Reasoning_in_Probability_Trees.ipynb @@ -0,0 +1,4117 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "N1EHYIqxd80m" + }, + "source": [ + "\u003e Copyright 2020 DeepMind Technologies Limited.\n", + "\u003e\n", + "\u003e Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "\u003e you may not use this file except in compliance with the License.\n", + "\u003e \n", + "\u003e You may obtain a copy of the License at\n", + "\u003e https://www.apache.org/licenses/LICENSE-2.0\n", + "\u003e \n", + "\u003e Unless required by applicable law or agreed to in writing, software\n", + "\u003e distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "\u003e WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "\u003e See the License for the specific language governing permissions and\n", + "\u003e limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qHP8zlWs1OUN" + }, + "source": [ + "# **Tutorial: Causal Reasoning in Probability Trees**\n", + "\n", + "*By the AGI Safety Analysis Team @ DeepMind.*\n", + "\n", + "**Summary:** This is the companion tutorial for the paper \"Algorithms\n", + "for Causal Reasoning in Probability trees\" by Genewein T. et al. (2020).\n", + "\n", + "Probability trees are one of the simplest models of causal\n", + "generative processes.They possess clean semantics and are strictly more general\n", + "than causal Bayesian networks, being able to e.g. represent causal relations\n", + "that causal Bayesian networks can’t. Even so, they have received little\n", + "attention from the AI and ML community.\n", + "\n", + "In this tutorial we present new algorithms for causal reasoning in discrete\n", + "probability trees that cover the entire causal hierarchy (association,\n", + "intervention, and counterfactuals), operating on arbitrary logical and causal\n", + "events." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "XUQVMuc2_VlG" + }, + "source": [ + "\n", + "# Part I: Basics\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "P4cjtJaR_VlH" + }, + "source": [ + "### Setup\n", + "\n", + "First we install the `graphviz` package:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "JlbLG8sy_VlI" + }, + "outputs": [], + "source": [ + "!apt-get install graphviz\n", + "!pip install graphviz" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "KxQFfJio_VlM" + }, + "source": [ + "### Imports and data structures\n", + "\n", + "We import Numpy and Pyplot, and then we define the basic data structures for\n", + "this tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": {}, + "colab_type": "code", + "id": "HFs6Q37E_VlM" + }, + "outputs": [], + "source": [ + "#@title Imports\n", + "\n", + "import numpy as np\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "yTTsE6f1_oIA" + }, + "outputs": [], + "source": [ + "#@title Data structures\n", + "\n", + "import graphviz\n", + "import copy\n", + "from random import random\n", + "\n", + "\n", + "class MinCut:\n", + " \"\"\"A representation of an event in a probability tree.\"\"\"\n", + "\n", + " def __init__(self, root, t=frozenset(), f=frozenset()):\n", + " self._root = root\n", + " self.t = t\n", + " self.f = f\n", + "\n", + " def __str__(self):\n", + "\n", + " true_elements = ', '.join([str(id) for id in sorted(self.t)])\n", + " false_elements = ', '.join([str(id) for id in sorted(self.f)])\n", + " return '{true: {' + true_elements + '}, false: {' + false_elements + '}}'\n", + "\n", + " def __reptr__(self):\n", + " return self.__str__()\n", + "\n", + " # Proposition\n", + " def prop(root, statement):\n", + " cond_lst = Node._parse_statements(statement)\n", + "\n", + " # Complain if more than one statement.\n", + " if len(cond_lst) != 1:\n", + " raise Exception('\\'prop\\' takes one and only one statement.')\n", + " return None\n", + "\n", + " # Remove list envelope.\n", + " cond = cond_lst[0]\n", + "\n", + " # Recurse.\n", + " return MinCut._prop(root, root, cond)\n", + "\n", + " def _prop(root, node, cond):\n", + " # Take var and val of condition.\n", + " condvar, condval = cond\n", + "\n", + " # Search for variable.\n", + " for var, val in node.assign:\n", + " if condvar == var:\n", + " if condval == val:\n", + " return MinCut(root, frozenset([node.id]), frozenset())\n", + " else:\n", + " return MinCut(root, frozenset(), frozenset([node.id]))\n", + "\n", + " # If we reach a leaf node and the variable isn't resolved,\n", + " # raise an exception.\n", + " if not node.children:\n", + " msg = 'Node ' + str(node.id) + ': ' \\\n", + " + 'min-cut for condition \"' + condvar + ' = ' \\\n", + " + condval + '\" is undefined.'\n", + " raise Exception(msg)\n", + "\n", + " # Variable not found, recurse.\n", + " t_set = frozenset()\n", + " f_set = frozenset()\n", + " for child in node.children:\n", + " _, subnode = child\n", + " subcut = MinCut._prop(root, subnode, cond)\n", + " t_set = t_set.union(subcut.t)\n", + " f_set = f_set.union(subcut.f)\n", + "\n", + " # Consolidate into node if children are either only true or false nodes.\n", + " cut = MinCut(root, t_set, f_set)\n", + " if not cut.f:\n", + " cut.t = frozenset([node.id])\n", + " elif not cut.t:\n", + " cut.f = frozenset([node.id])\n", + "\n", + " return cut\n", + "\n", + " # Negation\n", + " def neg(self):\n", + " return MinCut(self._root, t=self.f, f=self.t)\n", + "\n", + " def __invert__(self):\n", + " return self.neg()\n", + "\n", + " # Conjunction\n", + " def conj(root, cut1, cut2):\n", + " return MinCut._conj(root, root, cut1, cut2, False, False)\n", + "\n", + " def _conj(root, node, cut1, cut2, end1=False, end2=False):\n", + " # Base case.\n", + " if (node.id in cut1.f) or (node.id in cut2.f):\n", + " return MinCut(root, frozenset(), frozenset([node.id]))\n", + " if node.id in cut1.t:\n", + " end1 = True\n", + " if node.id in cut2.t:\n", + " end2 = True\n", + " if end1 and end2:\n", + " return MinCut(root, frozenset([node.id]), frozenset())\n", + "\n", + " # Recurse.\n", + " t_set = frozenset()\n", + " f_set = frozenset()\n", + " for _, subnode in node.children:\n", + " subcut = MinCut._conj(root, subnode, cut1, cut2, end1, end2)\n", + " t_set = t_set.union(subcut.t)\n", + " f_set = f_set.union(subcut.f)\n", + "\n", + " # Consolidate into node if children are either only true or false nodes.\n", + " cut = MinCut(root, t_set, f_set)\n", + " if not cut.f:\n", + " cut.t = frozenset([node.id])\n", + " elif not cut.t:\n", + " cut.f = frozenset([node.id])\n", + " return cut\n", + "\n", + " def __and__(self, operand):\n", + " return MinCut.conj(self._root, self, operand)\n", + "\n", + " # Disjunction\n", + " def disj(root, cut1, cut2):\n", + " return MinCut.neg(MinCut.conj(root, MinCut.neg(cut1), MinCut.neg(cut2)))\n", + "\n", + " def __or__(self, operand):\n", + " return MinCut.disj(self._root, self, operand)\n", + "\n", + " # Causal dependence\n", + " def precedes(root, cut_c, cut_e):\n", + " return MinCut._precedes(root, root, cut_c, cut_e, False)\n", + "\n", + " def _precedes(root, node, cut_c, cut_e, found_c):\n", + " # Base case.\n", + " if not found_c:\n", + " if (node.id in cut_e.t or node.id in cut_e.f or node.id in cut_c.f):\n", + " return MinCut(root, frozenset(), frozenset([node.id]))\n", + " if (node.id in cut_c.t):\n", + " found_c = True\n", + " if found_c:\n", + " if (node.id in cut_e.t):\n", + " return MinCut(root, frozenset([node.id]), frozenset())\n", + " if (node.id in cut_e.f):\n", + " return MinCut(root, frozenset(), frozenset([node.id]))\n", + "\n", + " # Recursion.\n", + " t_set = frozenset()\n", + " f_set = frozenset()\n", + " for _, subnode in node.children:\n", + " subcut = MinCut._precedes(root, subnode, cut_c, cut_e, found_c)\n", + " t_set = t_set.union(subcut.t)\n", + " f_set = f_set.union(subcut.f)\n", + "\n", + " # Consolidate into node if children are either only true or false nodes.\n", + " cut = MinCut(root, t_set, f_set)\n", + " if not cut.f:\n", + " cut.t = frozenset([node.id])\n", + " elif not cut.t:\n", + " cut.f = frozenset([node.id])\n", + " return cut\n", + "\n", + " def __lt__(self, operand):\n", + " return MinCut.precedes(self._root, self, operand)\n", + "\n", + "\n", + "class Critical:\n", + " \"\"\"A representation of the critical set associated to an event.\"\"\"\n", + "\n", + " # Constructor\n", + " def __init__(self, s=frozenset()):\n", + " self.s = s\n", + "\n", + " def __str__(self):\n", + " elements = ', '.join([str(id) for id in sorted(self.s)])\n", + " return '{' + elements + '}'\n", + "\n", + " def __reptr__(self):\n", + " return self.__str__()\n", + "\n", + " def critical(root, cut):\n", + " _, crit = Critical._critical(root, cut)\n", + " return crit\n", + "\n", + " def _critical(node, cut):\n", + " # Base case.\n", + " if node.id in cut.t:\n", + " return (False, Critical(frozenset()))\n", + " if node.id in cut.f:\n", + " return (True, Critical(frozenset()))\n", + " # Recurse.\n", + " s = frozenset()\n", + " for _, subnode in node.children:\n", + " incut, subcrit = Critical._critical(subnode, cut)\n", + " if incut:\n", + " s = s.union(frozenset([node.id]))\n", + " else:\n", + " s = s.union(subcrit.s)\n", + "\n", + " return (False, Critical(s))\n", + "\n", + "\n", + "class Node:\n", + " \"\"\"A node in probability tree.\"\"\"\n", + "\n", + " # Constructor.\n", + " def __init__(self, uid, statements, children=None):\n", + " # Automatically assigned ID.\n", + " self.id = uid\n", + "\n", + " # Assignments.\n", + " if isinstance(statements, str):\n", + " self.assign = Node._parse_statements(statements)\n", + " else:\n", + " self.assign = statements\n", + "\n", + " # Children.\n", + " if children is None:\n", + " self.children = []\n", + " else:\n", + " self.children = children\n", + "\n", + " # Parse statements.\n", + " def _parse_statements(statements):\n", + " statement_list = statements.split(',')\n", + " pair_list = [x.split('=') for x in statement_list]\n", + " assign = [(var.strip(), val.strip()) for var, val in pair_list]\n", + " return assign\n", + "\n", + " # Sample.\n", + " def sample(self):\n", + " return self._sample(dict())\n", + "\n", + " def _sample(self, smp):\n", + " # Add new assignments.\n", + " newsmp = {var: val for var, val in self.assign}\n", + " smp = dict(smp, **newsmp)\n", + "\n", + " # Base case.\n", + " if not self.children:\n", + " return smp\n", + "\n", + " # Recurse.\n", + " rnum = random()\n", + " for child in self.children:\n", + " subprob, subnode = child\n", + " rnum -= subprob\n", + " if rnum \u003c= 0:\n", + " return subnode._sample(smp)\n", + "\n", + " # Something went wrong: probabilities aren't normalized.\n", + " msg = 'Node ' + str(self.id) + ': ' \\\n", + " + 'probabilities of transitions do not add up to one.'\n", + " raise Exception(msg)\n", + "\n", + " # Insert.\n", + " def insert(self, prob, node):\n", + " self.children.append((prob, node))\n", + "\n", + " # Compute probability of cut.\n", + " def prob(self, cut):\n", + " return self._prob(cut, 1.0)\n", + "\n", + " def _prob(self, cut, prob):\n", + " # Base case.\n", + " if self.id in cut.t:\n", + " return prob\n", + " if self.id in cut.f:\n", + " return 0.0\n", + "\n", + " # Recurse.\n", + " probsum = 0.0\n", + " for child in self.children:\n", + " subprob, subnode = child\n", + " resprob = subnode._prob(cut, prob * subprob)\n", + " probsum += resprob\n", + "\n", + " return probsum\n", + "\n", + " # Return a dictionary with all the random variables and their values.\n", + " def rvs(self):\n", + " sts = dict()\n", + " return self._rvs(sts)\n", + "\n", + " def _rvs(self, sts):\n", + " for var, val in self.assign:\n", + " if not (var in sts):\n", + " sts[var] = list()\n", + " if not (val in sts[var]):\n", + " sts[var].append(val)\n", + "\n", + " for _, subnode in self.children:\n", + " sts = subnode._rvs(sts)\n", + "\n", + " return sts\n", + "\n", + " # Auxiliary function for computing the list of children.\n", + " def _normalize_children(children, probsum, logsum):\n", + " newchildren = None\n", + " if probsum \u003e 0.0:\n", + " newchildren = [\n", + " (subprob / probsum, subnode) for _, subprob, subnode in children\n", + " ]\n", + " else:\n", + " newchildren = [\n", + " (sublog / logsum, subnode) for sublog, _, subnode in children\n", + " ]\n", + " return newchildren\n", + "\n", + " # Conditioning\n", + " def see(self, cut):\n", + " root = copy.deepcopy(self)\n", + " root._see(cut, 1.0)\n", + " return root\n", + "\n", + " def _see(self, cut, prob):\n", + " # Base case.\n", + " if self.id in cut.t:\n", + " newnode = Node(self.id, self.assign)\n", + " return (1.0, prob)\n", + " if self.id in cut.f:\n", + " newnode = Node(self.id, self.assign)\n", + " return (0.0, 0.0)\n", + "\n", + " # Recurse.\n", + " newchildren = []\n", + " probsum = 0.0\n", + " logsum = 0.0\n", + " for subprob, subnode in self.children:\n", + " reslog, resprob = subnode._see(cut, prob * subprob)\n", + "\n", + " newchildren.append((reslog, resprob, subnode))\n", + " logsum += reslog\n", + " probsum += resprob\n", + "\n", + " # Normalize.\n", + " self.children = Node._normalize_children(newchildren, probsum, logsum)\n", + "\n", + " return (1.0, probsum)\n", + "\n", + " # Causal intervention\n", + " def do(self, cut):\n", + " root = copy.deepcopy(self)\n", + " root._do(cut)\n", + " return root\n", + "\n", + " def _do(self, cut):\n", + " # Base case.\n", + " if self.id in cut.t:\n", + " return True\n", + " if self.id in cut.f:\n", + " return False\n", + "\n", + " # Recurse.\n", + " newchildren = []\n", + " probsum = 0.0\n", + " logsum = 0.0\n", + " for subprob, subnode in self.children:\n", + " resdo = subnode._do(cut)\n", + "\n", + " if resdo:\n", + " newchildren.append((1.0, subprob, subnode))\n", + " probsum += subprob\n", + " logsum += 1.0\n", + " else:\n", + " newchildren.append((0.0, 0.0, subnode))\n", + "\n", + " # Normalize.\n", + " self.children = Node._normalize_children(newchildren, probsum, logsum)\n", + "\n", + " return (1.0, probsum)\n", + "\n", + " # Counterfactual/subjunctive conditional\n", + " def cf(self, root_prem, cut_subj):\n", + " root_subj = self.do(cut_subj)\n", + " root_subj._cf(root_prem, cut_subj)\n", + " return root_subj\n", + "\n", + " def _cf(self, prem, cut):\n", + " # Base case.\n", + " if self.id in cut.t:\n", + " return True\n", + " if self.id in cut.f:\n", + " return False\n", + "\n", + " # Recurse.\n", + " critical = False\n", + "\n", + " for child, child_prem in zip(self.children, prem.children):\n", + " (_, subnode) = child\n", + " (_, subnode_prem) = child_prem\n", + " in_do = subnode._cf(subnode_prem, cut)\n", + " if not in_do:\n", + " critical = True\n", + " continue\n", + "\n", + " # Pick children if node is critical.\n", + " if not critical:\n", + " self.children = [\n", + " (subprob, subnode)\n", + " for (_, subnode), (subprob, _) in zip(self.children, prem.children)\n", + " ]\n", + "\n", + " return True\n", + "\n", + " # Show probability tree.\n", + " def show(self, show_id=False, show_prob=False, cut=None, crit=None):\n", + " # Initialize Digraph.\n", + " graph_attr = {\n", + " 'bgcolor': 'White',\n", + " 'rankdir': 'LR',\n", + " 'nodesep': '0.1',\n", + " 'ranksep': '0.3',\n", + " 'sep': '0'\n", + " }\n", + " node_attr = {\n", + " 'style': 'rounded',\n", + " 'shape': 'box',\n", + " 'height': '0.1',\n", + " 'width': '0.5',\n", + " 'fontsize': '10',\n", + " 'margin': '0.1, 0.02'\n", + " }\n", + " edge_attr = {'fontsize': '10'}\n", + " g = graphviz.Digraph(\n", + " 'g',\n", + " format='svg',\n", + " graph_attr=graph_attr,\n", + " node_attr=node_attr,\n", + " edge_attr=edge_attr)\n", + "\n", + " # Recursion.\n", + " return self._show(\n", + " g, 1.0, show_id=show_id, show_prob=show_prob, cut=cut, crit=crit)\n", + "\n", + " def _show(self, g, prob, show_id=False, show_prob=False, cut=None, crit=None):\n", + " # Create label.\n", + " labels = [name + ' = ' + value for name, value in self.assign]\n", + " node_label = '\\n'.join(labels)\n", + " if show_id:\n", + " node_label = str(self.id) + '\\n' + node_label\n", + " if show_prob:\n", + " node_label = node_label + '\\np = ' + '{0:.3g}'.format(prob)\n", + "\n", + " # Decorate node.\n", + " attr = {'style': 'filled, rounded', 'fillcolor': 'WhiteSmoke'}\n", + " if not (cut is None):\n", + " if self.id in cut.t:\n", + " attr = {'style': 'filled, rounded', 'fillcolor': 'AquaMarine'}\n", + " elif self.id in cut.f:\n", + " attr = {'style': 'filled, rounded', 'fillcolor': 'LightCoral'}\n", + " if not (crit is None):\n", + " if self.id in crit.s:\n", + " attr = {'style': 'filled, rounded', 'fillcolor': 'Plum'}\n", + " g.node(str(self.id), label=node_label, **attr)\n", + "\n", + " # Recurse.\n", + " for child in self.children:\n", + " subprob, subnode = child\n", + " subnode._show(\n", + " g,\n", + " prob * subprob,\n", + " show_id=show_id,\n", + " show_prob=show_prob,\n", + " cut=cut,\n", + " crit=crit)\n", + " g.edge(str(self.id), str(subnode.id), label='{0:.3g}'.format(subprob))\n", + "\n", + " return g\n", + "\n", + " def find(self, uid):\n", + " if self.id == uid:\n", + " return self\n", + "\n", + " for child in self.children:\n", + " subprob, subnode = child\n", + " found_node = subnode.find(uid)\n", + " if found_node is not None:\n", + " return found_node\n", + "\n", + " return None\n", + "\n", + "\n", + "class PTree:\n", + " \"\"\"A probability tree.\"\"\"\n", + "\n", + " def __init__(self):\n", + " \"\"\"Create a probability tree.\"\"\"\n", + " self._root = None\n", + " self._count = 0\n", + "\n", + " def root(self, statements, children=None):\n", + " \"\"\"Sets the root node.\n", + "\n", + " Parameters\n", + " ----------\n", + " statements : str\n", + " A string containing a comma-separated list of statements of\n", + " the form \"var = val\", such as \"X=1, Y=0\". These are the\n", + " values resolved by the root node.\n", + " children : list((float, Node)), (default: None)\n", + " A list of (probability, child node) pairs. These are the root\n", + " node's children and their transition probabilities.\n", + "\n", + " Returns\n", + " -------\n", + " Node\n", + " the root node of the probability tree.\n", + " \"\"\"\n", + " self._count += 1\n", + " self._root = Node(self._count, statements, children)\n", + " return self._root\n", + "\n", + " def child(self, prob, statements, children=None):\n", + " \"\"\"Create a child node and its transition probability.\n", + "\n", + " Parameters\n", + " ----------\n", + " prob : float\n", + " The probability of the transition\n", + " statements : str\n", + " A string containing a comma-separated list of statements of\n", + " the form \"var = val\", such as \"X=1, Y=0\". These are the\n", + " values resolved by the child node.\n", + " children : list((float, Node)), (default: None)\n", + " A list of (probability, child node) pairs to be set as the\n", + " children of the node.\n", + "\n", + " Returns\n", + " -------\n", + " Node\n", + " the created node.\n", + " \"\"\"\n", + " self._count += 1\n", + " return (prob, Node(self._count, statements, children))\n", + "\n", + " def get_root(self):\n", + " \"\"\"Return the root node.\n", + "\n", + " Returns\n", + " -------\n", + " Node\n", + " the root node of the probability tree.\n", + " \"\"\"\n", + " return self._root\n", + "\n", + " def show(self, show_id=False, show_prob=False, cut=None, crit=None):\n", + " \"\"\"Returns a graph of the probability tree.\n", + "\n", + " Parameters\n", + " ----------\n", + " show_id: Bool (default: False)\n", + " If true, display the unique id's.\n", + " show_prob : Bool (default: False)\n", + " If true, display the node probabilities.\n", + " cut : MinCut (default: None)\n", + " If a MinCut is given, then display it.\n", + " crit : Critical (default: None)\n", + " If a Critical set is given, then show it.\n", + "\n", + " Returns\n", + " -------\n", + " Node\n", + " the created node.\n", + " \"\"\"\n", + " return self._root.show(\n", + " show_id=show_id, show_prob=show_prob, cut=cut, crit=crit)\n", + "\n", + " def rvs(self):\n", + " \"\"\"Return a dictionary with all the random variables and their values.\n", + "\n", + " Returns\n", + " -------\n", + " dict(str: list)\n", + " A dictionary with all the random variables pointing at lists\n", + " containing their possible values.\n", + " \"\"\"\n", + " return self._root.rvs()\n", + "\n", + " def rv(self, var):\n", + " \"\"\"Return a probability distribution for a given random variable.\n", + "\n", + " Parameters\n", + " ----------\n", + " var: str\n", + " A string containing the name of the random variable.\n", + "\n", + " Returns\n", + " -------\n", + " list((float, str))\n", + " A list with pairs (prob, val), where prob and val are the\n", + " probability\n", + " and the value of the random variable.\n", + " \"\"\"\n", + " return [(self.prob(self.prop(var + ' = ' + val)), val)\n", + " for val in self.rvs()[var]]\n", + "\n", + " def expect(self, var):\n", + " \"\"\"Return the expected value of a random variable.\n", + "\n", + " Parameters\n", + " ----------\n", + " var: str\n", + " A string containing the name of the random variable.\n", + "\n", + " Returns\n", + " -------\n", + " float\n", + " The expected value of the random variable.\n", + " \"\"\"\n", + " e = 0.0\n", + " for prob, val in self.rv(var):\n", + " e += prob * float(val)\n", + " return e\n", + "\n", + " def find(self, uid):\n", + " \"\"\"Return a node with given unique identifier.\n", + "\n", + " Parameters\n", + " ----------\n", + " uid: int\n", + " Identifier of the node to be returned.\n", + "\n", + " Returns\n", + " -------\n", + " Node or None\n", + " Returns the node if found, otherwise None.\n", + " \"\"\"\n", + " return self._root.find(uid)\n", + "\n", + " def prop(self, statement):\n", + " \"\"\"Returns min-cut of a statement.\n", + "\n", + " Parameters\n", + " ----------\n", + " statement: str\n", + " A single statement of the form \"var = val\", such as \"X = 1\".\n", + "\n", + " Returns\n", + " -------\n", + " MinCut\n", + " the min-cut of the event corresponding to the statement.\n", + " \"\"\"\n", + " return MinCut.prop(self._root, statement)\n", + "\n", + " def critical(self, cut):\n", + " \"\"\"Returns critical set of a min-cut.\n", + "\n", + " Parameters\n", + " ----------\n", + " cut: MinCut\n", + " A min-cuts.\n", + "\n", + " Returns\n", + " -------\n", + " Critical\n", + " the critical set for the min-cut.\n", + " \"\"\"\n", + " return Critical.critical(self._root, cut)\n", + "\n", + " def sample(self):\n", + " \"\"\"Sample a realization.\n", + " \n", + " Returns\n", + " -------\n", + " dict((str:str))\n", + " A dictionary of bound random variables such as\n", + "\n", + " { 'X': '1', 'Y': '0' }.\n", + " \"\"\"\n", + " return self._root.sample()\n", + "\n", + " def prob(self, cut):\n", + " \"\"\"Compute probability of a min-cut.\n", + "\n", + " Parameters\n", + " ----------\n", + " cut: MinCut\n", + " A min-cut for an event.\n", + "\n", + " Returns\n", + " -------\n", + " float\n", + " The probability of the event of the min-cut.\n", + " \"\"\"\n", + " return self._root.prob(cut)\n", + "\n", + " def see(self, cut):\n", + " \"\"\"Return a probability tree conditioned on a cut.\n", + "\n", + " Parameters\n", + " ----------\n", + " cut: MinCut\n", + " A min-cut for an event.\n", + "\n", + " Returns\n", + " -------\n", + " PTree\n", + " A new probability tree.\n", + " \"\"\"\n", + " newptree = PTree()\n", + " newptree._root = self._root.see(cut)\n", + " return newptree\n", + "\n", + " def do(self, cut):\n", + " \"\"\"Intervene on a cut.\n", + "\n", + " Parameters\n", + " ----------\n", + " cut: MinCut\n", + " A min-cut for an event.\n", + "\n", + " Returns\n", + " -------\n", + " float\n", + " A new probability tree.\n", + " \"\"\"\n", + " newptree = PTree()\n", + " newptree._root = self._root.do(cut)\n", + " return newptree\n", + "\n", + " def cf(self, tree_prem, cut_subj):\n", + " \"\"\"Return a subjunctive conditional tree.\n", + "\n", + " Parameters\n", + " ----------\n", + " tree_prem: PTree\n", + " A probality tree representing the premises for the subjunctive\n", + " evaluation.\n", + " This probability tree must have been obtained through operations on\n", + " the\n", + " base probability tree.\n", + " cut_do: MinCut\n", + " A min-cut for an event. This min-cut is the subjunctive condition of\n", + " the\n", + " counterfactual.\n", + "\n", + " Returns\n", + " -------\n", + " float\n", + " A new probability tree.\n", + " \"\"\"\n", + " newptree = PTree()\n", + " newptree._root = self._root.cf(tree_prem._root, cut_subj)\n", + " return newptree\n", + "\n", + " def fromFunc(func, root_statement=None):\n", + " \"\"\"Build a probability tree from a factory function.\n", + "\n", + " Building probability trees can be difficult, especially when we have\n", + " to manually specify all its nodes. To simplify this, `fromFunc` allows\n", + " building a probability tree using a factory function. A factory\n", + " function is a function that:\n", + "\n", + " - receives a dictionary of bound random variables, such as\n", + "\n", + " { 'X': '1', 'Y': '0' }\n", + "\n", + " - and returns either `None` if a leaf has been reached, or a list\n", + " of transitions and their statements, such as\n", + "\n", + " [(0.3, 'Z = 0'), (0.2, 'Z = 1'), (0.5, 'Z = 2')].\n", + "\n", + " Such a factory function contains all the necessary information for\n", + " building a probability tree.\n", + "\n", + " The advantage of using a factory function is that we can exploit\n", + " symmetries (such as conditional independencies) to code a much\n", + " more compact description of the probability tree.\n", + "\n", + "\n", + " Parameters\n", + " ----------\n", + " func: Function: dict((str: str)) -\u003e list((float, str))\n", + " A probality tree factory function.\n", + "\n", + " root_statement: str (default: None)\n", + " A string containing the statement (e.g. 'root = 0')\n", + " for the root node. If `None`, 'Ω = 1' is used.\n", + "\n", + " Returns\n", + " -------\n", + " PTree\n", + " A new probability tree.\n", + " \"\"\"\n", + " if not root_statement:\n", + " root_statement = 'O = 1'\n", + "\n", + " tree = PTree()\n", + " bvars = dict(Node._parse_statements(root_statement))\n", + " tree.root(root_statement, tree._fromFunc(func, bvars))\n", + " return tree\n", + "\n", + " def _fromFunc(self, func, bvars):\n", + " \"\"\"Auxiliary method for PTree.fromFunc().\"\"\"\n", + "\n", + " transition_list = func(bvars)\n", + " if not transition_list:\n", + " return None\n", + " children = []\n", + " for prob, statement in transition_list:\n", + " add_vars = dict(Node._parse_statements(statement))\n", + " new_bvars = {**bvars, **add_vars}\n", + " res = self._fromFunc(func, new_bvars)\n", + " children.append(self.child(prob, statement, res))\n", + " return children" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "lcEgcO7C_VlQ" + }, + "source": [ + "## 1. Probability trees\n", + "\n", + "A **probability tree** is a representation of a random experiment or process.\n", + "Starting from the **root node**, the process iteratively takes **random\n", + "transitions** to **child nodes**, terminating at a **leaf node**. A path from\n", + "the root node to a node is a **(partial) realization**, and a path from the root\n", + "node to a leaf node is a **total realization**. Every node in the tree has one\n", + "or more **statements** associated with it. When a realization reaches a node,\n", + "the statements indicate which values are bound to random variables.\n", + "\n", + "Let's create our first probability tree. It shows a random variable $X$ where: -\n", + "$X = 0$ with probability $0.5$; - $X = 1$ with probability $0.3$; - and $X = 2$\n", + "with probability $0.2$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "IKiO-SgJ_VlQ" + }, + "outputs": [], + "source": [ + "# Creata a blank probability tree.\n", + "pt = PTree()\n", + "\n", + "# Add a root node and the children.\n", + "pt.root(\n", + " 'O = 1',\n", + " [pt.child(0.5, 'X = 1'),\n", + " pt.child(0.3, 'X = 2'),\n", + " pt.child(0.2, 'X = 3')])\n", + "\n", + "# Display it.\n", + "display(pt.show())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "pU9B9TJf_VlV" + }, + "source": [ + "We'll typically call the root node $O$, standing for \"**O**mega\" ($\\Omega$),\n", + "which is a common name for the sample space in the literature.\n", + "\n", + "After creating a probability tree, we can ask it to return: \n", + "- the list of random variables and their values using the method `rvs()`; \n", + "- the probability distribution for a given random variable using \n", + "`rv(varname)`; \n", + "- the expected value of a *numerical* random variable with \n", + "`expected(varname)`;\n", + "- and obtain a random sample from the tree with `sample()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "a-MsyV_d_VlV" + }, + "outputs": [], + "source": [ + "rvs = pt.rvs()\n", + "print('Random variables:', rvs)\n", + "\n", + "pdist = pt.rv('X')\n", + "print('P(X) =', pdist)\n", + "\n", + "expect = pt.expect('X')\n", + "print('E(X) =', expect)\n", + "\n", + "smp = pt.sample()\n", + "print('Sample =', smp)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BscAftLn_VlZ" + }, + "source": [ + "### Causal dependencies\n", + "\n", + "In a probability tree, a causal dependency $X \\rightarrow Y$ is expressed\n", + "through a node $X$ having a descendent node $Y$. For instance, consider the next\n", + "probability tree:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "0DDBKwNs_Vla" + }, + "outputs": [], + "source": [ + "# Create a blank probability tree.\n", + "pt = PTree()\n", + "\n", + "# Add a root node and the children.\n", + "pt.root('O = 1', [\n", + " pt.child(0.3, 'X = 0', [\n", + " pt.child(0.2, 'Y = 0'),\n", + " pt.child(0.8, 'Y = 1'),\n", + " ]),\n", + " pt.child(0.7, 'X = 1', [\n", + " pt.child(0.8, 'Y = 0'),\n", + " pt.child(0.2, 'Y = 1'),\n", + " ]),\n", + "])\n", + "\n", + "# Display it.\n", + "display(pt.show())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "YenMXlr6_Vld" + }, + "source": [ + "Here $Y$ is a descendant of $X$ and therefore $X \\rightarrow Y$. This means that\n", + "we can affect the value of $Y$ by choosing $X$ but not viceversa. The exact\n", + "semantics of this requires **interventions**, which we'll review later. Notice\n", + "how the value of $X$ changes the distribution over $Y$: - $P(Y=1|X=0) \u003e\n", + "P(Y=0|X=0)$, - $P(Y=1|X=1) \u003c P(Y=0|X=1)$.\n", + "\n", + "If we want to express that neither $X \\rightarrow Y$ nor $Y \\rightarrow X$ are\n", + "the case, then we need to combine both random variables into the same nodes as\n", + "follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Pot5uw-R_Vld" + }, + "outputs": [], + "source": [ + "# Creata a blank probability tree.\n", + "pt = PTree()\n", + "\n", + "# Add a root node and the children.\n", + "pt.root('O = 1', [\n", + " pt.child(0.3 * 0.2, 'X = 0, Y = 0'),\n", + " pt.child(0.3 * 0.8, 'X = 0, Y = 1'),\n", + " pt.child(0.7 * 0.8, 'X = 1, Y = 0'),\n", + " pt.child(0.7 * 0.2, 'X = 1, Y = 1')\n", + "])\n", + "\n", + "# Display it.\n", + "display(pt.show())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "t9ntPBZd_Vlk" + }, + "source": [ + "### Another tree: drug testing\n", + "\n", + "Let's build another example. Here we have a drug testing situation:\n", + "\n", + "- A patient has a probability of being ill ($D = 1$).\n", + "- If the patient takes the drug ($T = 1$) when she is ill, she will likely\n", + " feel better ($R = 1$), otherwise she will likely feel worse ($R = 0$).\n", + "- However, if she takes the drug when she is not ill, the situation is\n", + " inverted: the drug might make her feel worse ($R = 0$).\n", + "\n", + "![Drug Testing CBN](http://www.adaptiveagents.org/_media/wiki/drug-testing.png)\n", + "\n", + "This tree can also be represented as the above causal Bayesian graph. This is\n", + "always the case when the causal ordering of the random variables is the same, no\n", + "matter which realization path is taken in the tree." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "AxK8DRP3_Vlk" + }, + "outputs": [], + "source": [ + "med = PTree()\n", + "med.root('O = 1', [\n", + " med.child(0.4, 'D = 0', [\n", + " med.child(0.5, 'T = 0',\n", + " [med.child(0.2, 'R = 0'),\n", + " med.child(0.8, 'R = 1')]),\n", + " med.child(0.5, 'T = 1',\n", + " [med.child(0.8, 'R = 0'),\n", + " med.child(0.2, 'R = 1')])\n", + " ]),\n", + " med.child(0.6, 'D = 1', [\n", + " med.child(0.5, 'T = 0',\n", + " [med.child(0.8, 'R = 0'),\n", + " med.child(0.2, 'R = 1')]),\n", + " med.child(0.5, 'T = 1',\n", + " [med.child(0.2, 'R = 0'),\n", + " med.child(0.8, 'R = 1')])\n", + " ])\n", + "])\n", + "\n", + "print('Random variables:', med.rvs())\n", + "display(med.show())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Fq63jb0u_Vln" + }, + "source": [ + "### A tree that cannot be represented as a Bayesian graph: Weather-Barometer Worlds\n", + "\n", + "We can also build a tree where the different realization paths have different\n", + "causal dependencies. For instance, imagine we have two possible worlds: - Our\n", + "world ($A = 0$) where the weather ($W$) influences the barometer reading\n", + "($B$); - An alien world ($A = 1$) where the barometer influences the weather.\n", + "\n", + "Such a situation with multiple causal dependencies cannot be captured in a\n", + "single graphical model:\n", + "\n", + "![Weather-Barometer Worlds](http://www.adaptiveagents.org/_media/wiki/weather-barometer-worlds.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "er74BN7n_Vlo" + }, + "source": [ + "However, we can represent it using a probability tree:\n", + "\n", + "![Weather-Barometer Worlds Probability Tree](http://www.adaptiveagents.org/_media/wiki/wb_tree.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "afVSi5xo_Vlo" + }, + "source": [ + "### Exercise 1\n", + "\n", + "Now it's your turn to create a probability tree. Create the \"weather-barometer\n", + "worlds\" probability tree and name it `wb`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "bCeTXeTe7WDM" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ehQ37yPg7Wil" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Q7W7exq4_Vlp" + }, + "outputs": [], + "source": [ + "# Create blank tree.\n", + "wb = PTree()\n", + "\n", + "# Set the root node and its sub-nodes.\n", + "wb.root('O = 1', [\n", + " wb.child(0.5, 'A = 0', [\n", + " wb.child(0.5, 'W = 0',\n", + " [wb.child(0.75, 'B = 0'),\n", + " wb.child(0.25, 'B = 1')]),\n", + " wb.child(0.5, 'W = 1',\n", + " [wb.child(0.25, 'B = 0'),\n", + " wb.child(0.75, 'B = 1')])\n", + " ]),\n", + " wb.child(0.5, 'A = 1', [\n", + " wb.child(0.5, 'B = 0',\n", + " [wb.child(0.75, 'W = 0'),\n", + " wb.child(0.25, 'W = 1')]),\n", + " wb.child(0.5, 'B = 1',\n", + " [wb.child(0.25, 'W = 0'),\n", + " wb.child(0.75, 'W = 1')])\n", + " ])\n", + "])\n", + "\n", + "# Display it.\n", + "display(wb.show())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9TvtFpta_Vls" + }, + "source": [ + "### Remember:\n", + "\n", + "- A node can contain more than one statement.\n", + "- The tree doesn't have to be balanced.\n", + "\n", + "See the next example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "e4RxAStn_Vls" + }, + "outputs": [], + "source": [ + "pt = PTree()\n", + "pt.root('O = 1', [\n", + " pt.child(0.2, 'X = 0, Y = 0'),\n", + " pt.child(0.8, 'X = 1', [pt.child(0.3, 'Y = 1'),\n", + " pt.child(0.7, 'Y = 2')])\n", + "])\n", + "\n", + "display(pt.show())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "tZijtGvW_Vlw" + }, + "source": [ + "### Displaying additional information\n", + "\n", + "We can display additional information about probability trees: \n", + "- **Unique identifiers**: Each node has an automatically assigned \n", + "unique identifier. Use `show_id = True` to display it. \n", + "- **Probability**: Each node has a probability of being realized. \n", + "Use `show_prob = True` to display this information." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "xZRlhlq__Vlw" + }, + "outputs": [], + "source": [ + "display(med.show(show_prob=True, show_id=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "CRsPVDcg_Vlz" + }, + "source": [ + "### Exercise 2\n", + "\n", + "For the probability tree `wb`: \n", + "- list all the random variables; \n", + "- compute the probability distribution of the barometer ($B$); \n", + "- display the probability tree with the unique ids and probabilities\n", + "of every node." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "qBqyFExK7eor" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "HHmLu4w_7e-S" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "cHoN3mQp_Vl0" + }, + "outputs": [], + "source": [ + "print(wb.rvs())\n", + "print(wb.rv('B'))\n", + "display(wb.show(show_id=True, show_prob=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "rqcodRJr_Vl3" + }, + "source": [ + "## 2. Propositions and min-cuts\n", + "\n", + "We've seen that a probability tree is a simple way of representing all the\n", + "possible realizations and their causal dependencies. We now investigate the\n", + "possible **events** in a probability tree.\n", + "\n", + "An **event** is a collection of full realizations. We can **describe** events\n", + "using propositions about random variables (e.g. $W = 0$, $B = 1$) and the\n", + "logical connectives of negation, conjunction (AND), and disjunction (OR). The\n", + "connectives allow us to state composite events, such as $\\neg(W = 1 \\wedge B =\n", + "0)$. For instance, the event $B = 0$ is the set of all realizations, i.e. paths\n", + "from the root to a leaf, that **pass through a node** with the statement $B=0$.\n", + "\n", + "We can **represent** events using cuts, and in particular, **min-cuts**. A\n", + "**min-cut** is a minimal representation of an event in terms of the nodes of a\n", + "probability tree. The min-cut of an event collects the smallest number of nodes\n", + "in the probability tree that resolves whether an event has occurred or not. In\n", + "other words, if a realization hits a node in the min-cut, then we know for sure\n", + "whether the event has occurred or not. (In measure theory, a similar notion to\n", + "min-cut would be the algebra that renders the event measurable.)\n", + "\n", + "Our implementation of min-cuts furthermore distinguishes between the nodes that\n", + "render the event true from the nodes that render the event false.\n", + "\n", + "Let's start by constructing a min-cut for a setting of a random variable in our\n", + "drug testing example. Verify that the min-cut is correct for the setting of the\n", + "random variable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "CmjmqUAW_Vl3" + }, + "outputs": [], + "source": [ + "# Build a cut for the proposition 'R = 1'.\n", + "cut = med.prop('R=1')\n", + "\n", + "# The result is of type MinCut:\n", + "print('Type of a cut:', type(cut))\n", + "\n", + "# Print the min-cut. Note that the elements in the\n", + "# true and false sets refer to the ids of the prob tree.\n", + "print('Min-cut for \"R = 1\":', cut)\n", + "\n", + "# Render the probability tree with a cut.\n", + "display(med.show(cut=cut, show_id=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9Bb9yBqP_Vl6" + }, + "source": [ + "Let's do a min-cut for not taking the treatment ($T = 0$)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "6-zNSWmL_Vl7" + }, + "outputs": [], + "source": [ + "# Build a cut for the proposition 'T = 0'.\n", + "cut = med.prop('T=0')\n", + "\n", + "# Print the min-cut. Note that the elements in the\n", + "# true and false sets refer to the ids of the prob tree.\n", + "print('Min-cut for \"T = 0\":', cut)\n", + "\n", + "# Render the probability tree with a cut.\n", + "display(med.show(cut=cut, show_id=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "IQxIGXeu_Vl-" + }, + "source": [ + "We can build negative events too using the `~` unary operator. As an example,\n", + "let's negate the previous event. Compare the two cuts. Notice that a negation\n", + "simply inverts the nodes that are true and false." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "hMApKX3O_Vl_" + }, + "outputs": [], + "source": [ + "cut = ~med.prop('T = 0')\n", + "print('Min-cut for \"T = 0\":', med.prop('T = 0'))\n", + "print('Min-cut for \"not T = 0\":', ~med.prop('T = 0'))\n", + "display(med.show(cut=cut, show_id=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DGP5kr7K_VmC" + }, + "source": [ + "Now let's build more complex events using conjunctions (`\u0026`) and disjunctions\n", + "(`|`). Make sure these min-cuts make sense to you. Notice that the conjunction\n", + "of two events pick out the earliest occurrence of false nodes and the last\n", + "occurence of true nodes, whereas the disjunction does the opposite." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "vOrfMXyA_VmD" + }, + "outputs": [], + "source": [ + "# Recovery\n", + "cut1 = med.prop('R=1')\n", + "print('Cut for \"R = 1\":')\n", + "display(med.show(cut=cut1))\n", + "\n", + "# Taking the treatment\n", + "cut2 = med.prop('T=1')\n", + "print('Cut for \"T=1\":')\n", + "display(med.show(cut=cut2))\n", + "\n", + "# Conjunction: taking the treatment and recovery\n", + "cut_and = cut1 \u0026 cut2\n", + "print('Cut for \"T=1 and R=1\":')\n", + "display(med.show(cut=cut_and))\n", + "\n", + "# Disjunction: taking the treatment or recovery\n", + "cut_or = cut1 | cut2\n", + "print('Cut for \"T=1 or R=1\":')\n", + "display(med.show(cut=cut_or))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1AxBz2Ml_VmG" + }, + "source": [ + "### The precedence relation\n", + "\n", + "In addition to the Boolean operators, we can also use a causal connective which\n", + "cannot stated in logical terms: the **precedence relation** $\\prec$. This\n", + "relation allows building min-cuts for events where one event $A$ precedes\n", + "another event $B$, written $A \\prec B$, and thus requires the additional\n", + "information provided by the probability tree's structure.\n", + "\n", + "Let's try one example. We want to build the min-cut where having the disease\n", + "($D=1$) precedes feeling better ($R=1$), and vice-versa." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "2WvNpbQG_VmH" + }, + "outputs": [], + "source": [ + "# Disease and recovery min-cuts.\n", + "cut1 = med.prop('D=1') \u003c med.prop('R=1')\n", + "cut2 = med.prop('R=1') \u003c med.prop('D=1')\n", + "\n", + "# Display.\n", + "print('Cut for D=1 \u003c R=1:')\n", + "display(med.show(cut=cut1))\n", + "\n", + "print('Cut for R=1 \u003c D=1:')\n", + "display(med.show(cut=cut2))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "YbYnVRiM_VmJ" + }, + "source": [ + "### Requirement: random variables must be measurable\n", + "\n", + "If we try to build a min-cut using a variable that is not measurable, then an\n", + "exception is raised. For instance, the random variable $X$ below is not\n", + "measurable within the probability tree, because the realization starting at the\n", + "root and reaching the leaf $Y = 2$ never sets the value for $X$.\n", + "\n", + "Attempting to build a min-cut for an event involving $X$ will throw an error." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ujFLeAMN_VmK" + }, + "outputs": [], + "source": [ + "pt = PTree()\n", + "pt.root('O = 1', [\n", + " pt.child(0.1, 'X = 0, Y = 0'),\n", + " pt.child(0.2, 'X = 1, Y = 1'),\n", + " pt.child(0.7, 'Y = 2')\n", + "])\n", + "\n", + "display(pt.show())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zyP7gSPl_VmN" + }, + "source": [ + "### Special case: probabilistic truth versus logical truth\n", + "\n", + "Let's have a look at one special case. Our definitions make a distinction\n", + "between **logical** and **probabilistic truth**. This is best seen in the\n", + "example below.\n", + "\n", + "In this example, we have a probability tree with three outcomes: $X = 1, 2$, and\n", + "$3$. - $X = 1$ occurs with probability one. \n", + "- Hence, probabilistically, the event $X=1$ is resolved at the level of the\n", + "root node. \n", + "- However, it isn't resolved at the logical level, since $X = 2$ or $X = 3$ \n", + "can happen logically, although with probability zero.\n", + "\n", + "Distinguishing between logical truth and probabilistic truth is important for\n", + "stating counterfactuals. This will become clearer later." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "8IXq-jVl_VmN" + }, + "outputs": [], + "source": [ + "# First we add all the nodes.\n", + "pt = PTree()\n", + "pt.root('O = 1',\n", + " [pt.child(1, 'X = 1'),\n", + " pt.child(0, 'X = 2'),\n", + " pt.child(0, 'X = 3')])\n", + "\n", + "# Show the cut for 'X = 0'\n", + "cut = pt.prop('X = 1')\n", + "print('While the root node \"O=1\" does resolve the event \"X=1\"\\n' +\n", + " 'probabilistically, it does not resolve the event logically.')\n", + "display(pt.show(cut=cut))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "K7OJKlcZ_VmQ" + }, + "source": [ + "### Exercise 3\n", + "\n", + "For the `wb` probability tree, build the min-cuts for the following events:\n", + "- the world is alien ($A = 1$); \n", + "- the weather is sunny ($W = 1$); \n", + "- the barometer goes down and the weather is sunny ($B = 0 \\wedge W = 1$); \n", + "- the negation of \"barometer does not go down or weather is not sunny\", \n", + "$\\neg(\\neg(B = 0) \\vee \\neg(W = 1))$.\n", + "\n", + "Display every min-cut. In particular, compare the last two. What do you observe?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "HlKSQi1l7rDJ" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "bD3LT8NV7rRV" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3ttu0mlR_VmQ" + }, + "outputs": [], + "source": [ + "# Exercise.\n", + "\n", + "# A = 1.\n", + "cut = wb.prop('A=1')\n", + "print('Cut for \"A=1\":')\n", + "display(wb.show(cut=cut))\n", + "\n", + "# W = 1.\n", + "cut = wb.prop('W=1')\n", + "print('Cut for \"W=1\":')\n", + "display(wb.show(cut=cut))\n", + "\n", + "# B = 0 and W = 1.\n", + "cut = wb.prop('B=0') \u0026 wb.prop('W=1')\n", + "print('Cut for \"B=0 and W=1\":')\n", + "display(wb.show(cut=cut))\n", + "\n", + "# not( not(B = 0) or not(W = 1) ).\n", + "cut = ~(~wb.prop('B=0') | ~wb.prop('W=1'))\n", + "print('Cut for \"not( not(B=0) or not(W=1) )\":')\n", + "display(wb.show(cut=cut))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ehK-sadN_VmT" + }, + "source": [ + "### Exercise 4\n", + "\n", + "For the `wb` probability tree, determine the min-cut for whenever the weather\n", + "($W$) affects the value of the barometer ($B$). This min-cut should coincide\n", + "with the min-cut for the event ($A=0$).\n", + "\n", + "Hint: enumerate all the 4 cases (values for $W$ and $B$) and combine them using\n", + "disjunctions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "rZe9JCMB7wo9" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ZtQffc1P7wxk" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "bX6AxxDk_VmT" + }, + "outputs": [], + "source": [ + "# Build the min-cut.\n", + "cut = (wb.prop('W=0') \u003c wb.prop('B=0')) \\\n", + " | (wb.prop('W=0') \u003c wb.prop('B=1')) \\\n", + " | (wb.prop('W=1') \u003c wb.prop('B=0')) \\\n", + " | (wb.prop('W=1') \u003c wb.prop('B=1'))\n", + "\n", + "# Display.\n", + "display(wb.show(cut=cut))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "5PCmW5ol_VmX" + }, + "source": [ + "## 3. Critical sets\n", + "\n", + "Min-cuts correspond to the smallest set of nodes where it becomes clear whether\n", + "an event has occurred or not. Every min-cut has an associated **critical set**:\n", + "the set of nodes that **determines** whether an event won't occur. Given an\n", + "event, the associated **critical set** is defined as the set of parents of the\n", + "event's false set in the min-cut.\n", + "\n", + "Together, a critical set and a min-cut form the set of **mechanisms** that\n", + "determine the occurrence of the event.\n", + "\n", + "Let's have a look at a simple example. Here, the critical set is the singleton\n", + "containing the root node. Critical sets are computed using the function\n", + "`PTree.critical(cut)`, where `cut` is an event's min-cut. We can display the\n", + "critical set by providing the optional argument `crit` to the `PTree.show()`\n", + "function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "9g2F9NKF_VmX" + }, + "outputs": [], + "source": [ + "# First we add all the nodes.\n", + "pt = PTree()\n", + "pt.root('O = 1',\n", + " [pt.child(1, 'X = 1'),\n", + " pt.child(0, 'X = 2'),\n", + " pt.child(0, 'X = 3')])\n", + "\n", + "# Get the critical set for a min-cut.\n", + "cut = pt.prop('X = 1')\n", + "crit = pt.critical(cut)\n", + "\n", + "# Show the critical set.\n", + "print('Min-cut for \"X=1\":', cut)\n", + "print('Critical set for \"X=1\":', crit)\n", + "display(pt.show(show_id=True, cut=cut, crit=crit))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "wE7voIju_Vma" + }, + "source": [ + "Let's work out another example. Consider the following probability tree.\n", + "\n", + "![Min-Cuts and Critical Sets](http://www.adaptiveagents.org/_media/wiki/mincut-critical.png)\n", + "\n", + "Try to predict the min-cut and the critical set of the events $X=1$, $Y=1$, and\n", + "$Y=0$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "4YvLDtzB_Vma" + }, + "outputs": [], + "source": [ + "pt = PTree()\n", + "pt.root('O = 1', [\n", + " pt.child(0.2, 'X = 0, Y = 0'),\n", + " pt.child(0.8, 'X = 1', [pt.child(0.3, 'Y = 1'),\n", + " pt.child(0.7, 'Y = 0')])\n", + "])\n", + "\n", + "# Original tree.\n", + "print('Original tree:')\n", + "display(pt.show(show_id=True))\n", + "\n", + "# 'X=1'\n", + "cut = pt.prop('X=1')\n", + "crit = pt.critical(cut)\n", + "print('Min-cut and critical set for \"X=1\":')\n", + "display(pt.show(show_id=True, cut=cut, crit=crit))\n", + "\n", + "# 'Y=1'\n", + "cut = pt.prop('Y=1')\n", + "crit = pt.critical(cut)\n", + "print('Min-cut and critical set for \"Y=1\":')\n", + "display(pt.show(show_id=True, cut=cut, crit=crit))\n", + "\n", + "# 'Y=0'\n", + "cut = pt.prop('Y=0')\n", + "crit = pt.critical(cut)\n", + "print('Min-cut and critical set for \"Y=0\":')\n", + "display(pt.show(show_id=True, cut=cut, crit=crit))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "CtfxGITj_Vmd" + }, + "source": [ + "### Exercise 5\n", + "\n", + "For the `wb` tree, compute and display the mechanisms (i.e. the min-cut and the\n", + "critical set) for the following events: \n", + "- the world is alien ($A = 1$); \n", + "- the barometer goes down ($B = 0$); \n", + "- the weather is sunny ($W = 1$); \n", + "- the barometer goes down and weather is sunny ($B = 0 \\wedge W = 1$)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "pw-AG0Ma76Al" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-1fUlg1O76bp" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Wqkni_k4_Vmd" + }, + "outputs": [], + "source": [ + "# Exercise.\n", + "\n", + "# A = 1.\n", + "cut = wb.prop('A=1')\n", + "crit = wb.critical(cut)\n", + "print('Mechanism for \"A=1\":')\n", + "display(wb.show(cut=cut, crit=crit))\n", + "\n", + "# B = 0.\n", + "cut = wb.prop('B=0')\n", + "crit = wb.critical(cut)\n", + "print('Mechanisms for \"B=0\":')\n", + "display(wb.show(cut=cut, crit=crit))\n", + "\n", + "# W = 1.\n", + "cut = wb.prop('W=1')\n", + "crit = wb.critical(cut)\n", + "print('Mechanisms for \"W=1\":')\n", + "display(wb.show(cut=cut, crit=crit))\n", + "\n", + "# B = 0 and W = 1.\n", + "cut = wb.prop('B=0') \u0026 wb.prop('W=1')\n", + "crit = wb.critical(cut)\n", + "print('Mechanisms for \"B=0 and W=1\":')\n", + "display(wb.show(cut=cut, crit=crit))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FPMH5CQ-_Vmg" + }, + "source": [ + "We'll return later to critical sets, as they are important for determining the\n", + "operations of conditioning and intervening on probability trees." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9ggdcdjg_Vmg" + }, + "source": [ + "## 4. Evaluating probabilities\n", + "\n", + "We can also evaluate probabilities of events. For instance, you may ask:\n", + "\n", + "- \"$P(R=1)$: What is the probability of recovery?\"\n", + "- \"$P(R=0)$: What is the probability of not recovering?\"\n", + "- \"$P(D=1)$: What is the probability of having the disease?\"\n", + "- \"$P(D=1 \\wedge R=0)$: What is the probability of taking the drug and not\n", + " recovering?\"\n", + "- \"$P(D=1 \\vee R=0)$: What is the probability of taking the drug or not\n", + " recovering?\"\n", + "- \"$P(D=1 \\prec R=1)$: What is the probability of taking the drug preceding\n", + " the recovery?\"\n", + "\n", + "To do so, we use the min-cut of the event.\n", + "\n", + "Let's have a look at some of them. Compare to the graph of the probability tree." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "lkH9LPzB_Vmh" + }, + "outputs": [], + "source": [ + "# Min-cuts for some events\n", + "cut1 = med.prop('R=1')\n", + "cut2 = med.prop('D=1')\n", + "cut1_neg = ~cut1\n", + "cut_and = cut2 \u0026 cut1_neg\n", + "cut_or = cut2 | cut1_neg\n", + "cut_prec = cut2 \u003c cut1\n", + "\n", + "print('P(R=1) =', med.prob(cut1))\n", + "print('P(R=0) =', med.prob(cut1_neg))\n", + "print('P(D=1) =', med.prob(cut2))\n", + "print('P(D=1 and R=0) =', med.prob(cut_and))\n", + "print('P(D=1 or R=0) =', med.prob(cut_or))\n", + "print('P(D=1 precedes R=1) =', med.prob(cut_prec))\n", + "\n", + "display(med.show(show_prob=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Bwc00gKT_Vmj" + }, + "source": [ + "### Exercise 6\n", + "\n", + "For the `wb` tree, evaluate the probability of the following events: \n", + "- the world is ours ($A = 0$) and the barometer goes down ($B = 0$); \n", + "- it is not the case that the barometer goes down or the weather \n", + "is sunny ($\\neg(B = 0 \\vee W = 1)$).\n", + "\n", + "Print the probabilities and display the probability trees." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "5NWdMrLW7_cj" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NUy8SH-O7_jT" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "xTbTw3Ky_Vmk" + }, + "outputs": [], + "source": [ + "# Exercise.\n", + "\n", + "# A = 0 and B = 0\n", + "cut = wb.prop('A=0') \u0026 wb.prop('B=0')\n", + "print('P(A=0 and B=0) =', wb.prob(cut))\n", + "display(wb.show(cut=cut))\n", + "\n", + "# not(B = 0 or W = 1)\n", + "cut = ~(wb.prop('B=0') | wb.prop('W=1'))\n", + "print('P(not(B=0 or W=1)) =', wb.prob(cut))\n", + "display(wb.show(cut=cut))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "6Y8wjMCi_Vmm" + }, + "source": [ + "## 5. Conditioning\n", + "\n", + "We have learned how to represent events using min-cuts. Now we can use min-cuts\n", + "to **condition** probability trees **on events**. Conditioning allows asking\n", + "questions after making **observations**, such as:\n", + "\n", + "- \"$P(R=1|T=1)$: What is the probability of recovery given that a patient has\n", + " taken the treatment?\"\n", + "- \"$P(D=1|R=1)$: What is the probability of having had the disease given that\n", + " a patient has recovered/felt better?\"\n", + "\n", + "### How to compute conditions\n", + "\n", + "Conditioning takes a probability tree and produces a new probability tree with\n", + "modified transition probabilities. These are obtained by removing all the total\n", + "realizations that are **incompatible with the condition**, and then\n", + "renormalizing, as illustrated below.\n", + "\n", + "\u003cimg src=\"http://www.adaptiveagents.org/_media/wiki/see.png\" alt=\"Seeing\" width=\"700\"/\u003e\n", + "\n", + "In the example, we compute the result of seeing $Y= 1$. \n", + "Conditioning on an event proceeds in two steps: \n", + "- first, we remove the probability mass of the realizations\n", + "passing through the false set of the event’s min-cut \n", + "(hihglighted in dark, bottom row); \n", + "- then we renormalize the probabilities. \n", + "\n", + "We can do this recursively by aggregating the original probabilities \n", + "of the true set. The top row shows the result of conditioning a\n", + "probability tree on the event $Y= 1$, which also highlights the modified\n", + "transition probabilities in red. The bottom row shows the same \n", + "operation in a probability mass diagram, which is a representation of a\n", + "probability tree that emphasizes the probabilities." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "USe_1o89_Vmn" + }, + "source": [ + "Let's have a look at the drug testing example. We will condition on $R=1$.\n", + "Observe how the probabilities change." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "-Is4k0IR_Vmn" + }, + "outputs": [], + "source": [ + "# Now we condition.\n", + "cut = med.prop('R=1')\n", + "med_see = med.see(cut)\n", + "\n", + "# Critical set.\n", + "crit = med.critical(cut)\n", + "\n", + "# Compare probabilities of events.\n", + "print('Before conditioning: P(R=1) =', med.prob(cut))\n", + "print('After conditioning: P(R=1 | R=1) =', med_see.prob(cut))\n", + "\n", + "# Display both trees for comparison.\n", + "print('\\nOriginal tree:')\n", + "display(med.show(show_prob=True))\n", + "\n", + "print('Tree after conditioning on \"R=1\":')\n", + "display(med_see.show(cut=cut, crit=crit, show_prob=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "KF9E83_8_Vmq" + }, + "source": [ + "We can condition on composite events too and evaluate the probability of events.\n", + "\n", + "Assume you observe that the drug was taken and a recovery is observed. Then, it\n", + "is very likely that the patient had the disease." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "L_ULii7K_Vmq" + }, + "outputs": [], + "source": [ + "# Min-cuts.\n", + "cut_r = med.prop('R=1')\n", + "cut_tr = med.prop('T=1') \u0026 med.prop('R=1')\n", + "cut_disease = med.prop('D=1')\n", + "\n", + "# Critical set.\n", + "crit = med.critical(cut_tr)\n", + "\n", + "# Condition.\n", + "med_see_r = med.see(cut_r)\n", + "med_see_tr = med.see(cut_tr)\n", + "\n", + "# Now we evaluate the posterior probability of having a disease.\n", + "print('P(D = 1) =', med.prob(cut_disease))\n", + "print('P(D = 1 | R = 1) =', med_see_r.prob(cut_disease))\n", + "print('P(D = 1 | T = 1, R = 1) =', med_see_tr.prob(cut_disease))\n", + "\n", + "# Display prob tree.\n", + "print('\\nProbability tree after conditioning on \"T=1 and R=1\":')\n", + "display(med_see_tr.show(cut=cut_tr, show_id=True, crit=crit))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RVtd_Pov_Vmt" + }, + "source": [ + "### Special case: conditioning on trivial events\n", + "\n", + "Let's have a look at a special case: conditioning on **trivial events**, namely\n", + "the **sure event** and the **impossible event**.\n", + "\n", + "Observe that conditioning on trivial events does not change the probability\n", + "tree." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "R9Zcg6lH_Vmt" + }, + "outputs": [], + "source": [ + "# Create a simple tree.\n", + "pt = PTree()\n", + "pt.root('O = 1', [\n", + " pt.child(0.6, 'X = 0, Y = 0'),\n", + " pt.child(0.6, 'X = 1', [pt.child(0.3, 'Y = 0'),\n", + " pt.child(0.7, 'Y = 0')]),\n", + "])\n", + "\n", + "# Show tree.\n", + "print('Original tree:')\n", + "display(pt.show())\n", + "\n", + "# Condition on Y = 0.\n", + "cut = pt.prop('Y=0')\n", + "pt_see_sure = pt.see(cut)\n", + "print('Conditioning on \"Y = 0\":')\n", + "display(pt_see_sure.show(cut=cut))\n", + "\n", + "# Condiiton on not Y = 0.\n", + "neg_cut = ~cut\n", + "pt_see_impossible = pt.see(neg_cut)\n", + "print('Conditioning on \"not Y = 0\":')\n", + "display(pt_see_impossible.show(cut=neg_cut))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RLijoy-__Vmy" + }, + "source": [ + "### Special case: conditioning on an event with probability zero\n", + "\n", + "Let's return to our simple example with tree outcomes. Assume we're conditioning\n", + "on an event with **probability zero**, which can happen **logically but not\n", + "probabilistically**. Using the measure-theoretic definition of conditional\n", + "probabilities, we are required to pick a so-called **version** of the\n", + "conditional distribution. There are infinite choices.\n", + "\n", + "Here, we have settled on the following. If we condition on an event with\n", + "probability zero, then we assign uniform probability over all the possible\n", + "transitions. This is just one arbitrary way of solving this problem.\n", + "\n", + "See the example below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "61sYhPTI_Vmz" + }, + "outputs": [], + "source": [ + "# Create a simple tree.\n", + "pt = PTree()\n", + "pt.root(\n", + " 'O = 1',\n", + " [pt.child(1.0, 'X = 1'),\n", + " pt.child(0.0, 'X = 2'),\n", + " pt.child(0.0, 'X = 3')])\n", + "\n", + "# Let's pick the negative event for our minimal prob tree.\n", + "cut = ~pt.prop('X = 1')\n", + "display(pt.show(cut=cut))\n", + "\n", + "pt_see = pt.see(cut)\n", + "display(pt_see.show(cut=cut))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "2fuYExoQ_Vm1" + }, + "source": [ + "### Exercise 7\n", + "\n", + "For the `wb` tree, print the probability distribution of \n", + "- the weather $W$ \n", + "- and the barometer $B$.\n", + "\n", + "Do this for the following probability trees: \n", + "- the original tree \n", + "- the probability tree conditioned on it being an alien world ($\\theta = 1$) \n", + "- the probability tree conditioned on the weather being sunny ($W = 1$).\n", + "\n", + "What do you observe? Does observing (conditioning) give you any additional\n", + "information? If no, why? If yes, why is that?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "D3WqIjaK8FSA" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "13s9eGWF8FZA" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "mMryXm_d_Vm2" + }, + "outputs": [], + "source": [ + "# Exercise\n", + "\n", + "# No condition.\n", + "print('P(W) =', wb.rv('W'))\n", + "print('P(B) =', wb.rv('B'))\n", + "\n", + "# Condition on \"A = 1\"\n", + "cut = wb.prop('A=1')\n", + "print('P(W | A=1) =', wb.see(cut).rv('W'))\n", + "print('P(B | A=1) =', wb.see(cut).rv('B'))\n", + "\n", + "# Condition on \"W = 1\"\n", + "cut = wb.prop('W=1')\n", + "print('P(W | W=1) =', wb.see(cut).rv('W'))\n", + "print('P(B | W=1) =', wb.see(cut).rv('B'))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9Dmexb79_Vm4" + }, + "source": [ + "## 6. Interventions\n", + "\n", + "Interventions are at the heart of causal reasoning.\n", + "\n", + "We have seen how to filter probability trees using observational data through\n", + "the use of conditioning. Now we investigate how a probability tree transforms\n", + "when it is intervened. An **intervention** is a change to the random process\n", + "itself to make something happen, as opposed to a filtration. We can ask\n", + "questions like:\n", + "\n", + "- \"$P(R=1|T \\leftarrow 1)$: What is the probability of recovery given that **I\n", + " take the drug**?\"\n", + "- \"$P(D=1|T \\leftarrow 1 \\wedge R=1)$: What is the probability of having the\n", + " disease given **that I take the drug** and that I observe a recovery?\"\n", + "\n", + "Here, the notation $T \\leftarrow 1$ is a shorthand for the more common notation\n", + "$\\mathrm{do}(T = 1)$.\n", + "\n", + "### How to compute interventions\n", + "\n", + "Interventions differ from conditioning in the following: \n", + "- they change the transition probabilities **minimally**, \n", + "so as to make a desired event happen; \n", + "- they **do not filter** the total realizations of the probability tree; \n", + "- they are **easier to execute** than conditions, because they only \n", + "change the transition probabilities that leave the critical set, \n", + "and they do not require the backward induction of probabilities. \n", + "\n", + "See the illustration below.\n", + "\n", + "\u003cimg src=\"http://www.adaptiveagents.org/_media/wiki/do.png\" alt=\"Doing\" width=\"700\"/\u003e\n", + "\n", + "Example intervention on $Y \\leftarrow 1$. An intervention proceeds in two steps:\n", + "- first, it selects the partial realizations starting in a critical node \n", + "and ending in a leaf that traverse the false set of the event’s min-cut; \n", + "- then it removes their probability mass, renormalizing the probabilities\n", + "from the transitions leaving the critical set. \n", + "\n", + "The top row shows the result of intervening a probability tree\n", + "on $Y \\leftarrow 1$. The bottom row show the same procedure on \n", + "the corresponding probability mass diagram.\n", + "\n", + "Let's start with a simple comparison to illustrate the difference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "z4PLcgN5_Vm5" + }, + "outputs": [], + "source": [ + "pt = PTree()\n", + "pt.root('O = 1', [\n", + " pt.child(0.2, 'X = 0, Y = 0'),\n", + " pt.child(0.8, 'X = 1', [pt.child(0.3, 'Y = 1'),\n", + " pt.child(0.7, 'Y = 0')])\n", + "])\n", + "\n", + "print('Original:')\n", + "display(pt.show(show_prob=True, cut=cut, crit=crit))\n", + "\n", + "# 'Y=1'\n", + "cut = pt.prop('Y = 1')\n", + "crit = pt.critical(cut)\n", + "pt_see = pt.see(cut)\n", + "pt_do = pt.do(cut)\n", + "print('Condition on \"Y=1\":')\n", + "display(pt_see.show(cut=cut, crit=crit))\n", + "print('Intervention on \"Y\u003c-1\":')\n", + "display(pt_do.show(cut=cut, crit=crit))\n", + "\n", + "# 'Y=0'\n", + "cut = pt.prop('Y = 0')\n", + "crit = pt.critical(cut)\n", + "pt_see = pt.see(cut)\n", + "pt_do = pt.do(cut)\n", + "print('Condition on \"Y = 0\":')\n", + "display(pt_see.show(cut=cut, crit=crit))\n", + "print('Intervention on \"Y \u003c- 0\":')\n", + "display(pt_do.show(cut=cut, crit=crit))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xJ-_TtAY_Vm7" + }, + "source": [ + "Notice that the mechanisms for $Y=0$ and $Y=1$ are different. In general, a\n", + "single random variable can have **multiple mechanism** for setting their\n", + "individual values." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LUvv_FNz_Vm8" + }, + "source": [ + "Let's return to our drug testing example. We investigate the effect of taking\n", + "the treatment, that is, by intervening on $T \\leftarrow 1$. How do the\n", + "probabilities of: \n", + "- having the disease ($D = 1$); \n", + "- taking the treatment ($T = 1$); \n", + "- and recovering ($R = 1$)\n", + "\n", + "change after taking the treatment ($T \\leftarrow 1$)?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "6iFDVc7Z_Vm8" + }, + "outputs": [], + "source": [ + "# Min-Cuts.\n", + "cut_dis = med.prop('D = 1')\n", + "cut_arg = med.prop('R = 1')\n", + "cut_do = med.prop('T = 1')\n", + "\n", + "# Critical set.\n", + "crit_do = med.critical(cut_do)\n", + "\n", + "# Perform intervention.\n", + "med_do = med.do(cut_do)\n", + "\n", + "# Display original tree.\n", + "print('Original tree:')\n", + "print('P(D = 1) =', med.prob(cut_dis))\n", + "print('P(T = 1) =', med.prob(cut_do))\n", + "print('P(R = 1) =', med.prob(cut_arg))\n", + "display(med.show(cut=cut_do, show_prob=True, crit=crit_do))\n", + "\n", + "# Display tree after invervention.\n", + "print('Tree after intervening on \"T \u003c- 1\":')\n", + "print('P(D = 1 | T \u003c- 1) =', med_do.prob(cut_dis))\n", + "print('P(T = 1 | T \u003c- 1) =', med_do.prob(cut_do))\n", + "print('P(R = 1 | T \u003c- 1) =', med_do.prob(cut_arg))\n", + "display(med_do.show(cut=cut_do, show_prob=True, crit=crit_do))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9auv9QLe_Vm_" + }, + "source": [ + "In other words, for the example above, taking the treatment increases the\n", + "chances of recovery. This is due to the base rates (i.e. the probability of\n", + "having a disease). The base rates are not affected by the decision of taking the\n", + "treatment." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RXjEiNWa_VnA" + }, + "source": [ + "### Special case: intervening on an event with probability zero\n", + "\n", + "Assume we're intervening on an event with **probability zero**. Recall that this\n", + "is possible **logically**, but **not probabilistically**. How do we set the\n", + "transition probabilities leaving the critical set? Here again we settle on\n", + "assigning uniform probabilities over all the transitions affected by the\n", + "intervention.\n", + "\n", + "See the example below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "NUgHl3mT_VnA" + }, + "outputs": [], + "source": [ + "# Create a simple tree.\n", + "pt = PTree()\n", + "pt.root(\n", + " 'O = 1',\n", + " [pt.child(1.0, 'X = 1'),\n", + " pt.child(0.0, 'X = 2'),\n", + " pt.child(0.0, 'X = 3')])\n", + "\n", + "# Let's pick the negative event for our minimal prob tree.\n", + "cut = ~pt.prop('X = 1')\n", + "crit = pt.critical(cut)\n", + "\n", + "# Intervene.\n", + "pt_do = pt.do(cut)\n", + "\n", + "# Show results.\n", + "print('Before the intervention:')\n", + "display(pt.show(cut=cut, crit=crit))\n", + "print('After the invention on \"not X \u003c- 1\":')\n", + "display(pt_do.show(cut=cut, crit=crit))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qNqHBq19_VnD" + }, + "source": [ + "### Exercise 8\n", + "\n", + "For the `wb` tree, print the probability distribution of \n", + "- the weather $W$ \n", + "- and the barometer $B$.\n", + "\n", + "Do this for the following probability trees: \n", + "- the original tree \n", + "- the probability tree resulting from enforcing it to being \n", + "an alien world ($A \\leftarrow 1$) \n", + "- the probability tree resulting from setting the weather to\n", + "being sunny ($W \\leftarrow 1$).\n", + "\n", + "What do you observe? Compare these results with your previous exercise, where\n", + "you conditioned on the same events. Why are the probabilities different when you\n", + "condition and when you intervene? How is this related to the different causal\n", + "dependencies in both worlds?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "pgP-6OwU8LBq" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9JFZjAra8LJd" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "8Fui2W_l_VnE" + }, + "outputs": [], + "source": [ + "# Exercise\n", + "\n", + "# No intervention.\n", + "print('P(W) =', wb.rv('W'))\n", + "print('P(B) =', wb.rv('B'))\n", + "\n", + "# Intervention on \"A \u003c- 1\"\n", + "cut = wb.prop('A=1')\n", + "print('P(W|A \u003c- 1) =', wb.do(cut).rv('W'))\n", + "print('P(B|A \u003c- 1) =', wb.do(cut).rv('B'))\n", + "\n", + "# Condition on \"W \u003c- 1\"\n", + "cut = wb.prop('W=1')\n", + "print('P(W|W \u003c- 1) =', wb.do(cut).rv('W'))\n", + "print('P(B|W \u003c- 1) =', wb.do(cut).rv('B'))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "SCZyxF5l_VnG" + }, + "source": [ + "### Exercise 9\n", + "\n", + "Next, evaluate the following probabilities:\n", + "\n", + "- What is the probability of being in our world ($A=0$), given that you\n", + " observe a sunny weather ($W=1$) and the barometer going up ($B=1$)?\n", + "- What is the probability of being in our world ($A=0$), given that you first\n", + " observe a sunny weather ($W=1$) and then **you force** the barometer to go\n", + " up ($B\\leftarrow 1$)?\n", + "- What is the probability of being in our world ($A=0$), given that you first\n", + " **force** the barometer to go up ($B\\leftarrow 1$) and then observe a sunny\n", + " weather ($W=1$)?\n", + "\n", + "Answer the following questions:\n", + "\n", + "- Does conditioning give different results from intervening? If so, why?\n", + "- When you mix conditions and interventions, does the order matter? If so,\n", + " why?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "uB3igmF58OyH" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "k1vLg3F_8O7c" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "i2B5UhDT_VnJ" + }, + "outputs": [], + "source": [ + "# Exercise\n", + "cutw = wb.prop('W=1')\n", + "cutb = wb.prop('B=1')\n", + "cuttheta = wb.prop('A=0')\n", + "\n", + "# Question 1\n", + "print('P(A = 0 | W = 1 and B = 1) =', wb.see(cutw).see(cutb).prob(cuttheta))\n", + "\n", + "# Question 2\n", + "print('P(A = 0 | W = 1 then B \u003c- 1) =', wb.see(cutw).do(cutb).prob(cuttheta))\n", + "\n", + "# Question 3\n", + "print('P(A = 0 | B \u003c- 1 then W = 1) =', wb.do(cutb).see(cutw).prob(cuttheta))\n", + "\n", + "display(wb.show())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "3OMlUFKF_VnM" + }, + "source": [ + "## 7. Counterfactuals\n", + "\n", + "Finally, we have counterfactuals. Counterfactuals are questions about how the\n", + "experiment could have gone if something about it were different. For instance:\n", + "\n", + "- \"What is the probability of having the disease **had I not recovered**,\n", + " given that I have recovered?\"\n", + "- \"Given that I have taken the treatment and recovered, what is the\n", + " probability of recovery **had I not taken the treatment**?\"\n", + "\n", + "These are tricky questions because they mix two moods:\n", + "\n", + "- **indicative statements** - things that have actually happened;\n", + "- **subjunctive statements** - things that could have happened \n", + "in an alternate reality/possible world.\n", + "\n", + "Because of this, counterfactuals spawn a new scope of random variables:\n", + "\n", + "\u003cimg src=\"http://www.adaptiveagents.org/_media/wiki/counterfactual.png\" alt=\"Counterfactual\" width=\"400\"/\u003e\n", + "\n", + "These two questions above are spelled as follows:\n", + "\n", + "- $P(D^\\ast=1|R=1)$, where $D^\\ast=D_{R \\leftarrow 0}$\n", + "- $P(R^\\ast=1|T\\leftarrow 1; R=1)$, where $R^\\ast=R_{T\\leftarrow 0}$\n", + "\n", + "Here the random variables with an asterisk $D^\\ast, R^\\ast$ are copies of the\n", + "original random variables $D, R$ that ocurr in an alternate reality. The\n", + "notation $D_{T \\leftarrow 0}$ means that the random variable $D$ is in the new\n", + "scope spawned by the intervention on $T\\leftarrow 0$.\n", + "\n", + "### Computing a counterfactual\n", + "\n", + "The next figure shows how to obtain a counterfactual:\n", + "\n", + "\u003cimg src=\"http://www.adaptiveagents.org/_media/wiki/cf.png\" alt=\"Computing a counterfactual\" width=\"700\"/\u003e\n", + "\n", + "The example shows a counterfactual probability tree generated by imposing $Y\n", + "\\leftarrow 1$, given the factual premise $Z = 1$. Starting from a **reference\n", + "probability tree**, we first derive two additional trees: a **factual premise**,\n", + "capturing the current state of affairs; and a **counterfactual premise**,\n", + "represented as an intervention on the reference tree.\n", + "\n", + "To form the counterfactual we proceed as follows:\n", + "- We slice both derived trees along the critical set\n", + "of the counterfactual premise. \n", + "- Then, we compose the counterfactual tree by\n", + "taking the transition probabilities **upstream of the slice**\n", + "from the factual premise, and those **downstream of the slice** \n", + "from the counterfactual premise. \n", + "\n", + "The events downstream then span a new scope containing copies \n", + "of the original random variables (marked with \"∗\"), ready to \n", + "adopt new values. \n", + "\n", + "In particular note that $Z^\\ast = 0$ can happen in our alternate \n", + "reality, even though we know that $Z = 1$." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4Xq_CBmU_VnM" + }, + "source": [ + "Let's have a look at a minimal example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "wgfuKqZr_VnN" + }, + "outputs": [], + "source": [ + "pt = PTree()\n", + "pt.root('O = 1', [\n", + " pt.child(0.25, 'X = 0', [\n", + " pt.child(0.25, 'Y = 0',\n", + " [pt.child(0.1, 'Z = 0'),\n", + " pt.child(0.9, 'Z = 1')]),\n", + " pt.child(0.75, 'Y = 1',\n", + " [pt.child(0.2, 'Z = 0'),\n", + " pt.child(0.8, 'Z = 1')]),\n", + " ]),\n", + " pt.child(0.75, 'X = 1',\n", + " [pt.child(0.75, 'Y = 0, Z = 0'),\n", + " pt.child(0.25, 'Y = 1, Z = 0')])\n", + "])\n", + "\n", + "print('Original:')\n", + "display(pt.show())\n", + "\n", + "# Condition on 'Y=0', do 'Y=1'\n", + "cut_see = pt.prop('Y=0')\n", + "cut_do = pt.prop('Y=1')\n", + "\n", + "# Critical set.\n", + "crit = pt.critical(cut_do)\n", + "\n", + "# Evaluate conditional, intervention, and counterfactual.\n", + "pt_see = pt.see(cut_see)\n", + "pt_do = pt.do(cut_do)\n", + "pt_cf = pt.cf(pt_see, cut_do)\n", + "\n", + "# Display results.\n", + "print('Condition on \"Y = 0\":')\n", + "display(pt_see.show(cut=cut_see, crit=crit))\n", + "print('Intervention on \"Y \u003c- 1\":')\n", + "display(pt_do.show(cut=cut_do, crit=crit))\n", + "print('Counterfactual with premise \"Y = 0\" and subjunctive \"Y = 1\":')\n", + "display(pt_cf.show(cut=cut_do, crit=crit))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "u6L7VDH9_VnP" + }, + "source": [ + "Now we return to our drug testing example. Let's ask the two questions we asked\n", + "before. We start with the question: \"What is the probability of having the\n", + "disease **had I not recovered**, given that I have recovered?\", that is\n", + "$$P(D^\\ast=1|R=1), \\qquad D^\\ast=D_{R \\leftarrow 0}.$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "1_dAJkjj_VnQ" + }, + "outputs": [], + "source": [ + "# Cuts.\n", + "cut_disease = med.prop('D = 1')\n", + "cut_recovery = med.prop('R = 1')\n", + "cut_not_recovery = ~cut_recovery\n", + "\n", + "# Critical.\n", + "crit = med.critical(cut_not_recovery)\n", + "\n", + "# Compute counterfactual:\n", + "# - compute factual premise,\n", + "# - use factual premise and subjunctive premise to compute counterfactual.\n", + "med_factual_prem = med.see(cut_recovery)\n", + "med_cf = med.cf(med_factual_prem, cut_not_recovery)\n", + "\n", + "print('Baseline:')\n", + "print('P(D = 1) =', med.prob(cut_disease))\n", + "display(med.show())\n", + "print('Premise:')\n", + "print('P(D = 1 | R = 1) =', med_factual_prem.prob(cut_disease))\n", + "display(med_factual_prem.show())\n", + "print('Counterfactual:')\n", + "print('P(D* = 1 | R = 1) =', med_cf.prob(cut_disease), ', D* = D[R \u003c- 0]')\n", + "display(med_cf.show(crit=crit, cut=cut_not_recovery))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eAPyY5O8_VnS" + }, + "source": [ + "As we can see, the probability of the disease in the indicative and the\n", + "counterfactual aren't different. This is because the recovery $R$ is independent\n", + "of the disease $D$, and because the disease is upstream of the critical set.\n", + "\n", + "Let's have a look at the second question: $$P(R^\\ast=1|T\\leftarrow 1;\n", + "R=1), \\qquad R^\\ast=R_{T\\leftarrow 0}$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Nm-h3YPh_VnS" + }, + "outputs": [], + "source": [ + "# Cuts.\n", + "cut_treatment = med.prop('T = 1')\n", + "cut_not_treatment = ~cut_treatment\n", + "cut_recovery = med.prop('R = 1')\n", + "\n", + "# Critical.\n", + "crit = med.critical(cut_not_treatment)\n", + "\n", + "# Compute counterfactual:\n", + "# - compute factual premise,\n", + "# - use factual premise and counterfactual premise to compute counterfactual.\n", + "med_factual_prem = med.do(cut_treatment).see(cut_recovery)\n", + "med_cf = med.cf(med_factual_prem, cut_not_treatment)\n", + "\n", + "# Display results.\n", + "print('Baseline:')\n", + "print('P(R = 1) =', med.prob(cut_recovery))\n", + "display(med.show())\n", + "\n", + "print('Premise:')\n", + "print('P(R = 1 | T \u003c- 1 and R = 1) =', med_factual_prem.prob(cut_recovery))\n", + "display(med_factual_prem.show())\n", + "\n", + "print('Counterfactual:')\n", + "print('P(R* = 1 | T \u003c- 1 and R = 1) =', med_cf.prob(cut_recovery),\n", + " ', R* = R[T \u003c- 0]')\n", + "display(med_cf.show(cut=cut_not_treatment, crit=crit))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "bCsw7tTM_VnV" + }, + "source": [ + "Hence, if I had not taken the treatment, then the probability of recovery would\n", + "have been lower. Why is that? \n", + "- In our premise, I have taken the treatment and\n", + "then observed a recovery. \n", + "- This implies that, most likely, I had the disease,\n", + "since taking the treatment when I don't have the disease is risky and can lead\n", + "to illness. \n", + "- Thus, knowing that I probably have the disease, I know that, had I\n", + "not taken the treatment, I would most likely not have recovered." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "B3DdaUXJ_VnV" + }, + "source": [ + "### Exercise 10\n", + "\n", + "Consider the drug testing probability tree `med`.\n", + "\n", + "- Assume you take the drug ($T \\leftarrow 1$) and you feel bad afterwards\n", + " ($R = 0$).\n", + "- Given this information, what is the probability of recovery ($R = 1$) had\n", + " you not taken the drug ($T = 0$)?\n", + "\n", + "Compute the **regret**, i.e. the difference: $$ \\mathbb{E}[ R^\\ast | T\n", + "\\leftarrow 1; R = 0 ] - \\mathbb{E}[ R | T \\leftarrow 1; R = 0 ], $$ where\n", + "$R^\\ast = R_{T \\leftarrow 0}$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "JAoYBJcn8gve" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "yxuqdgvA8g-J" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "23ziOdxr_VnV" + }, + "outputs": [], + "source": [ + "# Exercise\n", + "\n", + "med_prem = med.do(med.prop('T=1')).see(med.prop('R=0'))\n", + "med_cf = med.cf(med_prem, med.prop('T=0'))\n", + "\n", + "print('P(R* = 1 | T \u003c- 1, R = 0) =', med_cf.prob(med.prop('R=1')))\n", + "\n", + "regret = med_cf.expect('R') - med_prem.expect('R')\n", + "print('Regret = ', regret)\n", + "\n", + "display(med_prem.show(cut=med.prop('R=0')))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ioUtMVOs_VnX" + }, + "source": [ + "### Exercise 11\n", + "\n", + "Take the probability tree `wb`. Evaluate the following counterfactuals:\n", + "\n", + "1. Assume that you set the world to ours ($A \\leftarrow 0$) and the weather to\n", + " sunny ($W \\leftarrow 1$). What is the probability distribution of observing\n", + " a high barometer value ($B = 1$) had you set the weather to rainy ($W\n", + " \\leftarrow 0$)? Does the fact that you set the world and the weather affect\n", + " the value of the counterfactual?\n", + "\n", + "2. Assume that you set the barometer to a high value ($B \\leftarrow 1$), and\n", + " you observe that the weather is sunny ($W=1$). What is the probability of\n", + " observing a sunny weather ($W=1$) had you set the barometer to a low value\n", + " ($B=0$)?\n", + "\n", + "These are highly non-trivial questions. What do you observe? Do the results make\n", + "sense to you?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "6dh5MySa8j83" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "jWpZ0riB8kG3" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "KGzDgfzs_VnY" + }, + "outputs": [], + "source": [ + "# Question 1.\n", + "wb_prem = wb.do(wb.prop('A=0')).do(wb.prop('W=1'))\n", + "wb_cf = wb.cf(wb_prem, wb.prop('W=0'))\n", + "print('P(B*| A \u003c- 0, W \u003c- 1) =', wb_cf.rv('B'), ' where B* = B[W \u003c- 0]')\n", + "display(wb_cf.show(show_prob=True, cut=wb.prop('B=1')))\n", + "\n", + "# Question 2.\n", + "wb_prem = wb.do(wb.prop('B=1')).see(wb.prop('W=1'))\n", + "wb_cf = wb.cf(wb_prem, wb.prop('B=0'))\n", + "print('P(W* | B \u003c- 1 then W \u003c- 1) =', wb_cf.rv('W'), ' where W* = W[B \u003c- 0]')\n", + "display(wb_cf.show(show_prob=True, cut=wb.prop('W=1')))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ZB-JZdum16ZJ" + }, + "source": [ + "# Part II: Examples" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "jw5OiWWC_Vna" + }, + "source": [ + "## Construction of probability trees using factory functions\n", + "\n", + "Building probability trees can be difficult, especially when we have to manually\n", + "specify all its nodes. \n", + "\n", + "To simplify this, we could design a function `factory(bvar)` which: \n", + "- receives a dictionary `bvar` of bound random variables, such as \n", + "`{ 'X': '1', 'Y': '0' }` \n", + "- and returns a list of transitions and their statements, such as \n", + "`[(0.3, 'Z = 0'), (0.2, 'Z = 1'), (0.5, 'Z = 2')].` If all relevant\n", + "events have been defined already, return `None`.\n", + "\n", + "Such a function contains all the necessary information for building a\n", + "probability tree. We call this a **probability tree factory**. We can pass a\n", + "description function to the method `PTree.fromFunc()` to build a probability\n", + "tree.\n", + "\n", + "The advantage of using this method is that we can exploit symmetries (e.g.\n", + "conditional independencies) to code a much more compact description of the\n", + "probability tree. Essentially, it is like specifying a probabilistic program.\n", + "\n", + "Let's experiment with this." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ce_kzARO_Vna" + }, + "source": [ + "## Burglar, Earthquake, Alarm\n", + "\n", + "Let's start with a classical example: a burglar alarm. The alarm gets \n", + "triggered by a burglar breaking into our home. However, the alarm can \n", + "also be set off by an earthquake. \n", + "\n", + "Let's define the factory function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "94KpJKbm_Vnb" + }, + "outputs": [], + "source": [ + "def alarm(bvar):\n", + " # Define the burglar and earthquake events.\n", + " if 'Burglar' not in bvar:\n", + " pb = 0.1 # Probability of burglar\n", + " pe = 0.001 # Probability of earthquake\n", + " return [((1 - pb) * (1 - pe), 'Burglar=0, Earthquake=0'),\n", + " ((1 - pb) * pe, 'Burglar=0, Earthquake=1'),\n", + " (pb * (1 - pe), 'Burglar=1, Earthquake=0'),\n", + " (pb * pe, 'Burglar=1, Earthquake=1')]\n", + "\n", + " # Define the alarm event.\n", + " if 'Alarm' not in bvar:\n", + " if bvar['Burglar'] == '0' and bvar['Earthquake'] == '0':\n", + " return [(0.999, 'Alarm=0'), (0.001, 'Alarm=1')]\n", + " if bvar['Burglar'] == '0' and bvar['Earthquake'] == '1':\n", + " return [(0.01, 'Alarm=0'), (0.99, 'Alarm=1')]\n", + " if bvar['Burglar'] == '1' and bvar['Earthquake'] == '0':\n", + " return [(0.1, 'Alarm=0'), (0.9, 'Alarm=1')]\n", + " else:\n", + " return [(0.001, 'Alarm=0'), (0.999, 'Alarm=1')]\n", + "\n", + " # All the events defined.\n", + " return None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "HYjrsERo_Vnd" + }, + "source": [ + "Now, let's create the probability tree." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "JOvj1N5X_Vnd" + }, + "outputs": [], + "source": [ + "# Create the probability tree.\n", + "al = PTree.fromFunc(alarm, 'Root = 1')\n", + "\n", + "# Print all the random variables.\n", + "print('Random variables:', al.rvs())\n", + "print('\\nP(Alarm) =', al.rv('Alarm'))\n", + "\n", + "print('\\nOriginal probability tree:')\n", + "display(al.show())\n", + "\n", + "print('\\nSome samples from the probability tree:')\n", + "for k in range(5):\n", + " print(al.sample())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "z6PQk0Du_Vnh" + }, + "source": [ + "Assume now you hear the alarm. Which explanation is more likely:\n", + "did the earthquake or the burglar trigger the alarm?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "bG3qSQz7_Vni" + }, + "outputs": [], + "source": [ + "# Condition on the alarm going off.\n", + "cut = al.prop('Alarm=1')\n", + "crit = al.critical(cut)\n", + "al_see = al.see(cut)\n", + "\n", + "# Compute probability distributions for earthquake and burglar.\n", + "print('P(Earthquake = 1 | Alarm = 1) =', al_see.prob(al.prop('Earthquake=1')))\n", + "print('P(Burglar = 1 | Alarm = 1) =', al_see.prob(al.prop('Burglar=1')))\n", + "\n", + "# Display the conditional probability tree.\n", + "\n", + "print('\\nConditional probability tree:')\n", + "display(al_see.show(show_prob=True, cut=cut, crit=crit))\n", + "\n", + "print('\\nSome samples from the conditional probability tree:')\n", + "for k in range(5):\n", + " print(al_see.sample())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "v8jmncNUF5kS" + }, + "source": [ + "As we can see, it is far more likely that the burglar set off the alarm.\n", + "\n", + "If we now tamper with the alarm, setting it off, then what is the probability\n", + "that there was a burglar or an earthquake?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "pl3Ps_PAJB89" + }, + "outputs": [], + "source": [ + "# Intervene on the alarm going off.\n", + "cut = al.prop('Alarm=1')\n", + "crit = al.critical(cut)\n", + "al_do = al.do(cut)\n", + "\n", + "# Compute probability distributions for earthquake and burglar.\n", + "print('P(Earthquake = 1 | Alarm \u003c- 1) =', al_do.prob(al.prop('Earthquake=1')))\n", + "print('P(Burglar = 1 | Alarm \u003c- 1) =', al_do.prob(al.prop('Burglar=1')))\n", + "\n", + "# Display the intervened probability tree.\n", + "\n", + "print('\\nIntervened probability tree:')\n", + "display(al_do.show(show_prob=True, cut=cut, crit=crit))\n", + "\n", + "print('\\nSome samples from the intervened probability tree:')\n", + "for k in range(5):\n", + " print(al_do.sample())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xYUfpc-2KJ27" + }, + "source": [ + "Now we observe that the probabilities of the burglar and earthquake\n", + "events are exactly as the base rates - we have severed the\n", + "causal dependencies connecting those events with the alarm." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "_UO1GC7T_Vnl" + }, + "source": [ + "## Coin toss prediction\n", + "\n", + "Let's build another probability tree. This is a discrete approximation to a\n", + "process having a continuous random variable: a **Beta-Bernoulli process**. \n", + "This problem was first studied by Rev. Thomas Bayes (\"An Essay towards\n", + "solving a Problem in the Doctrine of Chances\", 1763) .\n", + "\n", + "The story goes as follows. Someone picks a coin with an unknown bias and then throws it repeatedly. Our goal is to infer the next outcome based only on the observed outcomes (and not on the latent bias). The unknown bias is drawn\n", + "uniformly from the interval [0, 1].\n", + "\n", + "Let's start by coding the factory function for the discretized Beta-Bernoulli\n", + "process. Here we assume that the prior distribution over the bias is uniform,\n", + "and discretized into `divtheta = 40` bins. Then `T = 5` coin tosses follow. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": {}, + "colab_type": "code", + "id": "X00sYFKB_Vnm" + }, + "outputs": [], + "source": [ + "#@title Beta-Bernoulli factory function.\n", + "\n", + "def betaBernoulli(bvar, divtheta=41, T=5):\n", + " # Root: defined.\n", + " # Define biases Bias=0, 1/divtheta, 2/divtheta, ... , 1\n", + " if 'Bias' not in bvar:\n", + " ptheta = 1.0 / divtheta\n", + " biases = [(ptheta, 'Bias=' + str(theta))\n", + " for theta in np.linspace(0.0, 1.0, divtheta, endpoint=True)]\n", + " return biases\n", + "\n", + " # Biases: defined.\n", + " # Now create Bernoulli observations X_1, X_2, ... , X_T,\n", + " # where X_t=0 or X_t=1.\n", + " t = 1\n", + " for var in bvar:\n", + " if '_' not in var:\n", + " continue\n", + " t += 1\n", + " if t \u003c= T:\n", + " theta = float(bvar['Bias'])\n", + " varstr = 'X_' + str(t)\n", + " return [(1 - theta, varstr + '=0'), (theta, varstr + '=1')]\n", + "\n", + " # All the events defined.\n", + " return None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4e3ZM5Vl_Vnr" + }, + "source": [ + "We now build the probability tree. Let's also print the \n", + "random variables and get a few samples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "zN3TzCnQHoz6" + }, + "outputs": [], + "source": [ + "# Create tree.\n", + "bb = PTree.fromFunc(betaBernoulli)\n", + "\n", + "# Show random variables.\n", + "print('Random variables:')\n", + "print(bb.rvs())\n", + "\n", + "# Get sample.\n", + "print('\\nSamples from the process:')\n", + "for n in range(10):\n", + " print(bb.sample())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "R4yVyNboW84N" + }, + "source": [ + "The tree itself is quite large (over 1000 nodes). \n", + "Normally such trees are too large to\n", + "display, for instance when `T` is large.\n", + "\n", + "Let's display it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "EvV6Gv4XaTxZ" + }, + "outputs": [], + "source": [ + "bb.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "JEuxhMkKHlUm" + }, + "source": [ + "\n", + "### Exercise \n", + "\n", + "Let's do some inference now.\n", + "\n", + "Assume you observe the first four toin cosses. They are\n", + "```\n", + "observations = ['X_1=1', 'X_2=1', 'X_3=0', 'X_4=1']\n", + "```\n", + "\n", + "Answer the following questions:\n", + "1. What is the prior distribution over the unknown bias?\n", + "2. What is the probability of the next outcome being Heads (`X_5=1`)?\n", + "3. Given the observations, what is the distribution over the\n", + "latent bias?\n", + "4. Rather than observing the four outcomes, assume instead\n", + "that you enforce the outcomes. What is the probability of\n", + "the next outcome being Heads?\n", + "5. What is the distribution over the latent bias if you enforce\n", + "the data?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "cwdfiRj-bNBR" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7fTawNEfbNZ1" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "yPG6AbQo_Vnr" + }, + "outputs": [], + "source": [ + "# Prepare the cut forthe data.\n", + "observations = ['X_1=1', 'X_2=1', 'X_3=0', 'X_4=1']\n", + "cut_data = None\n", + "for s in observations:\n", + " if cut_data is None:\n", + " cut_data = bb.prop(s)\n", + " else:\n", + " cut_data \u0026= bb.prop(s)\n", + "\n", + "# Prepare the cut for the query.\n", + "cut_query = bb.prop('X_5=1')\n", + "\n", + "# Question 1\n", + "bias = bb.rv('Bias')\n", + "print('P(Bias) :\\n' + str(bias))\n", + "\n", + "# Question 2\n", + "bb_cond = bb.see(cut_data)\n", + "print('\\nP(X_5 = 1 | Data) = ' + str(bb_cond.prob(cut_query)))\n", + "\n", + "# Question 3\n", + "bias_cond = bb_cond.rv('Bias')\n", + "print('\\nP(Bias | Data) :\\n' + str(bias_cond))\n", + "\n", + "# Question 4\n", + "bb_int = bb.do(cut_data)\n", + "print('\\nP(X_5 = 1 | do(Data)) = ' + str(bb_int.prob(cut_query)))\n", + "\n", + "# Question 5\n", + "bias_int = bb_int.rv('Bias')\n", + "print('\\nP(Bias | do(Data)) :' + str(bias_int))\n", + "\n", + "# Display distribution over bias.\n", + "print('\\nDistribution over biases for the three settings:')\n", + "\n", + "fig = plt.figure(figsize=(15, 5))\n", + "\n", + "# Show prior.\n", + "plt.subplot(131)\n", + "res = bb.rv('Bias')\n", + "theta = np.array([theta for _, theta in res], dtype=np.float)\n", + "prob = np.array([prob for prob, _ in res])\n", + "plt.fill_between(theta, prob, 0)\n", + "plt.title('P(Bias)')\n", + "plt.ylim([-0.005, 0.1])\n", + "plt.xlabel('Bias')\n", + "\n", + "# Show posterior after conditioning.\n", + "plt.subplot(132)\n", + "res = bb.see(cut).rv('Bias')\n", + "theta = np.array([theta for _, theta in res], dtype=np.float)\n", + "prob = np.array([prob for prob, _ in res])\n", + "plt.fill_between(theta, prob, 0)\n", + "plt.title('P(Bias|D)')\n", + "plt.ylim([-0.005, 0.1])\n", + "plt.xlabel('Bias')\n", + "\n", + "# Show posterior after intervening.\n", + "plt.subplot(133)\n", + "res = bb.do(cut).rv('Bias')\n", + "theta = np.array([theta for _, theta in res], dtype=np.float)\n", + "prob = np.array([prob for prob, _ in res])\n", + "plt.fill_between(theta, prob, 0)\n", + "plt.title('P(Bias|do(D))')\n", + "plt.ylim([-0.005, 0.1])\n", + "plt.xlabel('Bias')\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eoNn0c5rj5U1" + }, + "source": [ + "## Who's in charge?\n", + "\n", + "In this problem we will look at causal induction. Alice and Bob play a game\n", + "where both of them shout either 'chicken' or 'egg'.\n", + "\n", + "At the beginning of the game, one of them is chosen to be the leader, and\n", + "the other, the follower. The follower will always attempt to match the\n", + "leader: so if Alice is the leader and Bob the follower, and Alice\n", + "shouts 'chicken', then Bob will attempt to shout 'chicken' too (with\n", + "60% success rate).\n", + "\n", + "A typical game would look like this:\n", + "\n", + "- Round 1: Alice shouts 'egg', Bob shouts 'chicken'.\n", + "- Round 2: Alice shouts 'chicken', Bob shouts 'chicken'.\n", + "- Round 3: Alice shouts 'chicken', Bob shouts 'chicken'.\n", + "- Round 4: Alice shouts 'egg', Bob shouts 'egg'.\n", + "\n", + "Note that you hear both of them shouting simultaneously.\n", + "\n", + "Our goal is to discover who's the leader. This is a **causal induction\n", + "problem**, because we want to figure out whether:\n", + "- hypothesis `Leader = Alice`: Alice $\\rightarrow$ Bob;\n", + "- or hypothesis `Leader = Bob`: Bob $\\rightarrow$ Alice.\n", + "\n", + "Let's start by defining the factory function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": {}, + "colab_type": "code", + "id": "pRDAaQwW_zuF" + }, + "outputs": [], + "source": [ + "#@title Leader factory function.\n", + "\n", + "def leader(bvar, T=2):\n", + " p = 0.75 # Probability of match.\n", + "\n", + " # Define leader.\n", + " if 'Leader' not in bvar:\n", + " return [(0.5, 'Leader=Alice'), (0.5, 'Leader=Bob')]\n", + "\n", + " # Now create the shouts.\n", + "\n", + " # Figure out the leader.\n", + " if bvar['Leader'] == 'Alice':\n", + " leader = 'Alice'\n", + " follower = 'Bob'\n", + " else:\n", + " leader = 'Bob'\n", + " follower = 'Alice'\n", + "\n", + " # Define random variables of shouts.\n", + " for t in range(1, T+1):\n", + " leader_str = leader + '_' + str(t)\n", + " if leader_str not in bvar:\n", + " return [(0.5, leader_str + '=chicken'), (0.5, leader_str + '=egg')]\n", + "\n", + " follower_str = follower + '_' + str(t)\n", + " if follower_str not in bvar:\n", + " if bvar[leader_str] == 'chicken':\n", + " return [(p, follower_str + '=chicken'), (1-p, follower_str + '=egg')]\n", + " else:\n", + " return [(1-p, follower_str + '=chicken'), (p, follower_str + '=egg')]\n", + "\n", + " # We're done.\n", + " return None\n", + "\n", + "# Create true environment.\n", + "class ChickenEggGame:\n", + "\n", + " def __init__(self, T=2):\n", + " self.T = T\n", + " self.pt = PTree.fromFunc(lambda bvar:leader(bvar, T=T))\n", + " smp = self.pt.sample()\n", + " self.pt.do(self.pt.prop('Leader=' + smp['Leader']))\n", + " self.time = 0\n", + "\n", + " def step(self, name, word):\n", + " # Check whether parameters are okay.\n", + " if name != 'Alice' and name != 'Bob':\n", + " raise Exception('\"name\" has to be either \"Alice\" or \"Bob\".')\n", + " if word != 'chicken' and word != 'egg':\n", + " raise Exception('\"word\" has to be either \"chicken\" or \"egg\".')\n", + " if self.time \u003e self.T -1:\n", + " raise Exception('The game has only ' + str(self.T) + ' rounds.')\n", + "\n", + " # Enforce instruction.\n", + " self.time = self.time + 1\n", + " cut_do = self.pt.prop(name + '_' + str(self.time) + '=' + word)\n", + " self.pt = self.pt.do(cut_do)\n", + "\n", + " # Produce next sample.\n", + " smp = self.pt.sample()\n", + " if name == 'Alice':\n", + " varname = 'Bob_' + str(self.time)\n", + " else:\n", + " varname = 'Alice_' + str(self.time)\n", + " response = smp[varname]\n", + " cut_see = self.pt.prop(varname + '=' + response) \n", + " self.pt = self.pt.see(cut_see)\n", + "\n", + " return varname + '=' + response\n", + "\n", + " def reveal(self):\n", + " smp = self.pt.sample()\n", + " return smp['Leader']" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0IwyuPLrsE0v" + }, + "source": [ + "The factory function is called `leader()`.\n", + "\n", + "Let's first have a look at how the probability tree would\n", + "look like for `T = 2` rounds." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "gveaR8W2sSJv" + }, + "outputs": [], + "source": [ + "ld = PTree.fromFunc(lambda bvar: leader(bvar, T=2), root_statement='Root = 1')\n", + "display(ld.show())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LU6ZQ0e1cUw3" + }, + "source": [ + "Notice how the transition probabilities of `Alice_n, Bob_n`,\n", + "`n = 1, 2`, are identical within the subtree rooted at \n", + "`Leader = Alice`. The same is true for the transitions probabilities\n", + "within the subtree rooted at `Leader = Bob`.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ZtXmvPbIsqBI" + }, + "source": [ + "Now, let's create a new probability tree for a slightly longer game,\n", + "namely `T = 5`. **This tree is too big to display** (over 2K nodes)\n", + "but we can still sample from it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "lekPv0r1Egti" + }, + "outputs": [], + "source": [ + "T = 5\n", + "ld = PTree.fromFunc(lambda bvar: leader(bvar, T=T), root_statement='Root = 1')\n", + "\n", + "print('Samples from the probability tree:')\n", + "for n in range(T):\n", + " print(ld.sample())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "dNdejpEjR6WI" + }, + "source": [ + "Let's first figure out the joint distribution over Alice's and Bob's shouts\n", + "in the first round (remember, rounds are i.i.d.) when Alice is the leader,\n", + "and compare this to the situation when Bob is the leader.\n", + "\n", + "We can do this by setting `Leader` to whoever we want to be the leader,\n", + "and then enumerate the joint probabilities over the combinations of\n", + "shouts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "X1lMShnYEqDz" + }, + "outputs": [], + "source": [ + "import itertools\n", + "\n", + "# Define cuts for both leaders.\n", + "cut_leader_a = ld.prop('Leader = Alice')\n", + "cut_leader_b = ld.prop('Leader = Bob')\n", + "\n", + "# The words they can say.\n", + "words = ['chicken', 'egg']\n", + "\n", + "# Print the joint distribution over\n", + "# shouts when Alice is the leader.\n", + "print('Leader = Alice')\n", + "for word_a, word_b in itertools.product(words, words):\n", + " cut = ld.prop('Alice_1 = ' + word_a) \u0026 ld.prop('Bob_1 = ' + word_b)\n", + " prob = ld.do(cut_leader_a).prob(cut)\n", + " fmt = 'P( Alice_1 = {}, Bob_1 = {} | Leader \u003c- Alice) = {:.2f}'\n", + " print(fmt.format(word_a, word_b, prob))\n", + "\n", + "# Print the joint distribution over\n", + "# shouts when Bob is the leader.\n", + "print('\\nLeader = Bob')\n", + "for word_a, word_b in itertools.product(words, words):\n", + " cut = ld.prop('Alice_1 = ' + word_a) \u0026 ld.prop('Bob_1 = ' + word_b)\n", + " prob = ld.do(cut_leader_b).prob(cut)\n", + " fmt = 'P( Alice_1 = {}, Bob_1 = {} | Leader \u003c- Bob) = {:.2f}'\n", + " print(fmt.format(word_a, word_b, prob))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "IxkzXrhDawc1" + }, + "source": [ + "Looking at the joint probabilities, **we realize that they are identical**.\n", + "This means that we cannot identify who's the leader by conditioning on\n", + "our observations. Let's try this with the following observations:\n", + "```\n", + "obs = [\n", + " 'Alice_1=chicken', 'Bob_1=egg', \n", + " 'Alice_2=egg', 'Bob_2=egg', \n", + " 'Alice_3=egg', 'Bob_3=egg'\n", + " ]\n", + "```\n", + "\n", + "We now compare the prior and posterior probabilities of Bob being the leader." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "skor-95PGNpa" + }, + "outputs": [], + "source": [ + "import functools \n", + "\n", + "obs = [\n", + " 'Alice_1=chicken', 'Bob_1=egg', \n", + " 'Alice_2=egg', 'Bob_2=egg', \n", + " 'Alice_3=egg', 'Bob_3=egg'\n", + " ]\n", + "cuts_data = [ld.prop(data) for data in obs]\n", + "cut_data = functools.reduce(lambda x, y: x \u0026 y, cuts_data)\n", + "cut_query = ld.prop('Leader=Bob')\n", + "\n", + "prob_prior = ld.prob(cut_query)\n", + "prob_post = ld.see(cut_data).prob(cut_query)\n", + "print('Prior and posterior probabilities:')\n", + "print('P( Leader = Bob ) = {:.2f}'.format(prob_prior))\n", + "print('P( Leader = Bob | Data ) = {:.2f}'.format(prob_post))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "dR2U7_0quNYm" + }, + "source": [ + "As you can see, this doesn't work - we can't disentangle the two hypotheses\n", + "just by looking at the data.\n", + "\n", + "Intuitively, we could figure out whether Alice or Bob is the leader by\n", + "intervening the game - for instance, by instructing Bob to say what\n", + "we want and observe Alice's reaction:\n", + "- if Alice matches Bob many times, then she's probably the follower;\n", + "- instead if Alice does not attempt to match Bob, then we can conclude \n", + "that Alice is the leader.\n", + "\n", + "Crucially, we need to **interact** in order to collect the data. \n", + "It's not enough to passively observe. For this, we'll use\n", + "an implementation of the game (`ChickenEggGame`) that allows\n", + "us to instruct either Alice or Bob to shout the word we want. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "f7rPh-f6UKWM" + }, + "outputs": [], + "source": [ + "T = 5\n", + "game = ChickenEggGame(T=T)\n", + "\n", + "# Do T rounds.\n", + "for n in range(T):\n", + " reply = game.step('Alice', 'chicken')\n", + " print(reply)\n", + "\n", + "# Reveal.\n", + "print('The true leader is:' + game.reveal())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "po_krgRljGmp" + }, + "source": [ + "### Exercise\n", + "\n", + "Using `ChickenEggGame`, play `T=5` rounds giving an instruction.\n", + "Use a copy of the probability tree `ld` to record the results,\n", + "appropriately distinguishing between conditions and interventions.\n", + "Finally, compute the posterior probability of Alice being the\n", + "leader and compare with ground truth (using the `reveal` method).\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "5-5dNAcdk0Nz" + }, + "outputs": [], + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "z4icvktpk0v0" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "m9DZhCjcVn4-" + }, + "outputs": [], + "source": [ + "import copy\n", + "\n", + "T = 5\n", + "game = ChickenEggGame(T=T)\n", + "\n", + "# Do T rounds.\n", + "print('Game:')\n", + "ldg = copy.deepcopy(ld)\n", + "for t in range(1, T+1):\n", + " reply = game.step('Alice', 'chicken')\n", + " instruction = 'Alice_' + str(t) + '=chicken'\n", + " ldg = ldg.do(ldg.prop(instruction))\n", + " ldg = ldg.see(ldg.prop(reply))\n", + " print(instruction + ', ' + reply)\n", + "\n", + "# Prediction.\n", + "print('\\nPrediction:')\n", + "cut_query = ldg.prop('Leader=Alice')\n", + "prob_post = ldg.prob(cut_query)\n", + "print('P(Leader = Alice | Data) = {:.5f}'.format(prob_post))\n", + "\n", + "# Reveal ground truth.\n", + "print('\\nGround truth:')\n", + "print('Leader = ' + game.reveal())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "QMvQGn9KclPU" + }, + "outputs": [], + "source": [ + "" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "Causality Tutorial.ipynb", + "provenance": [ + { + "file_id": "1uLGieQXt93jX0ASo-qSCvnjEV67_4ZOO", + "timestamp": 1591027877583 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/causal_reasoning/README.md b/causal_reasoning/README.md new file mode 100644 index 0000000..eb35dee --- /dev/null +++ b/causal_reasoning/README.md @@ -0,0 +1,21 @@ +# Algorithms for Causal Reasoning in Probability Trees +*By the AGI Safety Analysis Team @ DeepMind* + +Probability trees are one of the simplest models of causal generative processes. +They possess clean semantics and are strictly more general than causal Bayesian +networks, as they are able to e.g. represent causal relations that causal Bayesian +networks cannot. Yet, they have received little attention from the AI and ML +community. Here we present new algorithms for causal reasoning in discrete +probability trees that cover the entire causal hierarchy (association, intervention, +and counterfactuals), and operate on arbitrary propositional and causal events. Our +work expands the domain of causal reasoning to a very general class of discrete +stochastic processes. + +For details, see our paper [Algorithms for Causal Reasoning in Probability Trees](). + +To launch the accompanying notebook in Google colab, [click here](https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/causal_reasoning/Causal_Reasoning_in_Probability_Trees.ipynb). + +If you use the code here please cite this paper. + +> Tim Genewein*, Tom McGrath*, Grégoire Delétang*, Vladimir Mikulik*, Miljan Martic, Shane Legg, Pedro A. Ortega. NeurIPS 2020. [\[arXiv\]]() + diff --git a/learned_free_energy_estimation/README.md b/learned_free_energy_estimation/README.md index 2afca08..c2c4c3f 100644 --- a/learned_free_energy_estimation/README.md +++ b/learned_free_energy_estimation/README.md @@ -1,7 +1,8 @@ # Targeted free energy estimation via learned mappings This repository contains supporting data for our publication -([arXiv](https://arxiv.org/abs/2002.04913)). Here, we provide +([journal](https://doi.org/10.1063/5.0018903), [arXiv](https://arxiv.org/abs/2002.04913)). +Here, we provide - molecular dynamics (MD) datasets underlying the results reported in our paper, - a LAMMPS input script to generate these datasets, and - the data plotted in Fig. 5 of our paper to facilitate comparison. @@ -96,9 +97,11 @@ If you find this repository helpful for your research, please cite our publicati title={Targeted free energy estimation via learned mappings}, author={Wirnsberger, Peter and Ballard, Andrew J and Papamakarios, George and Abercrombie, Stuart and Racanière, Sébastien and Pritzel, Alexander and - Jimenez Rezende, Danilo and Blundell, Charles} - journal={Journal of Chemical Physics}, - vol={153}, + Jimenez Rezende, Danilo and Blundell, Charles}, + journal={J. Chem. Phys.}, + volume={153}, + number={14}, + pages={144112}, year={2020}, doi={10.1063/5.0018903} }