diff --git a/side_effects_penalties/side_effects_penalty.py b/side_effects_penalties/side_effects_penalty.py index ecbacaf..1894c68 100644 --- a/side_effects_penalties/side_effects_penalty.py +++ b/side_effects_penalties/side_effects_penalty.py @@ -26,10 +26,13 @@ import abc import collections import copy import enum +import random import numpy as np import six from six.moves import range from six.moves import zip +import sonnet as snt +import tensorflow.compat.v1 as tf class Actions(enum.IntEnum): @@ -282,7 +285,8 @@ class Reachability(ReachabilityMixin, DeviationMeasure): self._reachability = collections.defaultdict( lambda: collections.defaultdict(lambda: 0)) - def update(self, prev_state, current_state): + def update(self, prev_state, current_state, action=None): + del action # Unused. self._reachability[prev_state][prev_state] = 1 self._reachability[current_state][current_state] = 1 if self._reachability[prev_state][current_state] < self._value_discount: @@ -300,6 +304,195 @@ class Reachability(ReachabilityMixin, DeviationMeasure): return self._discount +class UVFAReachability(ReachabilityMixin, DeviationMeasure): + """Approximate relative reachability deviation measure using UVFA. + + We approximate reachability using a neural network only trained on state + transitions that have been observed. For each (s0, action, s1) transition, + we update the reachability estimate for (s0, action, s) towards the + reachability estimate between s1 and s, for each s in a random sample of + size update_sample_size. In particular, the loss for the neural network + reachability estimate (NN) is + sum_s(max_a(NN(s1, a, s)) * value_discount - NN(s0, action, s)), + where the sum is over all sampled s, the max is taken over all actions a. + + At evaluation time, the reachability difference is calculated with respect + to a randomly sampled set of states of size calc_sample_size. + """ + + def __init__( + self, + value_discount=0.95, + dev_fun=None, + discount=0.95, + state_size=36, # Sokoban default + num_actions=5, + update_sample_size=10, + calc_sample_size=10, + hidden_size=50, + representation_size=5, + num_layers=1, + base_loss_coeff=0.1, + num_stored=100): + + # Create networks to generate state representations. To get a reachability + # estimate, take the dot product of the origin network output and the goal + # network output, then pass it through a sigmoid function to constrain it to + # between 0 and 1. + output_sizes = [hidden_size] * num_layers + [representation_size] + self._origin_network = snt.nets.MLP( + output_sizes=output_sizes, + activation=tf.nn.relu, + activate_final=False, + name='origin_network') + self._goal_network = snt.nets.MLP( + output_sizes=output_sizes, + activation=tf.nn.relu, + activate_final=False, + name='goal_network') + + self._value_discount = value_discount + self._dev_fun = dev_fun + self._discount = discount + self._state_size = state_size + self._num_actions = num_actions + self._update_sample_size = update_sample_size + self._calc_sample_size = calc_sample_size + self._num_stored = num_stored + self._stored_states = set() + + self._state_0_placeholder = tf.placeholder(tf.float32, shape=(state_size)) + self._state_1_placeholder = tf.placeholder(tf.float32, shape=(state_size)) + self._action_placeholder = tf.placeholder(tf.float32, shape=(num_actions)) + self._update_sample_placeholder = tf.placeholder( + tf.float32, shape=(update_sample_size, state_size)) + self._calc_sample_placeholder = tf.placeholder( + tf.float32, shape=(calc_sample_size, state_size)) + + # Trained to estimate reachability = value_discount ^ distance. + self._sample_loss = self._get_state_action_loss( + self._state_0_placeholder, + self._state_1_placeholder, + self._action_placeholder, + self._update_sample_placeholder) + + # Add additional loss to force observed transitions towards value_discount. + self._base_reachability = self._get_state_sample_reachability( + self._state_0_placeholder, + tf.expand_dims(self._state_1_placeholder, axis=0), + action=self._action_placeholder) + self._base_case_loss = tf.keras.losses.MSE(self._value_discount, + self._base_reachability) + + self._opt = tf.train.AdamOptimizer().minimize(self._sample_loss + + base_loss_coeff * + self._base_case_loss) + + current_state_reachability = self._get_state_sample_reachability( + self._state_0_placeholder, self._calc_sample_placeholder) + baseline_state_reachability = self._get_state_sample_reachability( + self._state_1_placeholder, self._calc_sample_placeholder) + + self._reachability_calculation = [ + tf.reshape(baseline_state_reachability, [-1]), + tf.reshape(current_state_reachability, [-1]) + ] + + init = tf.global_variables_initializer() + self._sess = tf.Session() + self._sess.run(init) + + def calculate(self, current_state, baseline_state, rollout_func=None): + """Compute the reachability penalty between two states.""" + current_state = np.array(current_state).flatten() + baseline_state = np.array(baseline_state).flatten() + sample = self._sample_n_states(self._calc_sample_size) + # Run if there are enough states to draw a correctly-sized sample from. + if sample: + base, curr = self._sess.run( + self._reachability_calculation, + feed_dict={ + self._state_0_placeholder: current_state, + self._state_1_placeholder: baseline_state, + self._calc_sample_placeholder: sample + }) + return sum(map(self._dev_fun, base - curr)) / self._calc_sample_size + else: + return 0 + + def _sample_n_states(self, n): + try: + return random.sample(self._stored_states, n) + except ValueError: + return None + + def update(self, prev_state, current_state, action): + prev_state = np.array(prev_state).flatten() + current_state = np.array(current_state).flatten() + one_hot_action = np.zeros(self._num_actions) + one_hot_action[action] = 1 + + sample = self._sample_n_states(self._update_sample_size) + if self._num_stored is None or len(self._stored_states) < self._num_stored: + self._stored_states.add(tuple(prev_state)) + self._stored_states.add(tuple(current_state)) + elif (np.random.random() < 0.01 and + tuple(current_state) not in self._stored_states): + self._stored_states.pop() + self._stored_states.add(tuple(current_state)) + + # If there aren't enough states to get a full sample, do nothing. + if sample: + self._sess.run([self._opt], feed_dict={ + self._state_0_placeholder: prev_state, + self._state_1_placeholder: current_state, + self._action_placeholder: one_hot_action, + self._update_sample_placeholder: sample + }) + + def _get_state_action_loss(self, prev_state, current_state, action, sample): + """Get the loss from differences in state reachability estimates.""" + # Calculate NN(s0, action, s) for all s in sample. + prev_state_reachability = self._get_state_sample_reachability( + prev_state, sample, action=action) + # Calculate max_a(NN(s1, a, s)) for all s in sample and all actions a. + current_state_reachability = tf.stop_gradient( + self._get_state_sample_reachability(current_state, sample)) + # Combine to return loss. + return tf.keras.losses.MSE( + current_state_reachability * self._value_discount, + prev_state_reachability) + + def _get_state_sample_reachability(self, state, sample, action=None): + """Calculate reachability from a state to each item in a sample.""" + if action is None: + state_options = self._tile_with_all_actions(state) + else: + state_options = tf.expand_dims(tf.concat([state, action], axis=0), axis=0) + goal_representations = self._goal_network(sample) + # Reachability of sampled states by taking actions + reach_result = tf.sigmoid( + tf.reduce_max( + tf.matmul( + goal_representations, + self._origin_network(state_options), + transpose_b=True), + axis=1)) + if action is None: + # Return 1 if sampled state is already reached (equal to state) + reach_no_action = tf.cast(tf.reduce_all(tf.equal(sample, state), axis=1), + dtype=tf.float32) + reach_result = tf.maximum(reach_result, reach_no_action) + return reach_result + + def _tile_with_all_actions(self, state): + """Returns tensor with all state/action combinations.""" + state_tiled = tf.tile(tf.expand_dims(state, axis=0), [self._num_actions, 1]) + all_actions_tiled = tf.one_hot( + tf.range(self._num_actions), depth=self._num_actions) + return tf.concat([state_tiled, all_actions_tiled], axis=1) + + class AttainableUtilityMixin(object): """Class for computing attainable utility measure. @@ -385,8 +578,9 @@ class AttainableUtility(AttainableUtilityMixin, DeviationMeasure): # predecessors[s] = set of states known to lead, by some action, to s self._predecessors = collections.defaultdict(set) - def update(self, prev_state, current_state): + def update(self, prev_state, current_state, action=None): """Update predecessors and attainable utility estimates.""" + del action # Unused. self._predecessors[current_state].add(prev_state) seen = set() queue = [current_state] @@ -437,7 +631,7 @@ class SideEffectPenalty(object): def calculate(self, prev_state, action, current_state): """Calculate the penalty associated with a transition, and update models.""" if current_state: - self._dev_measure.update(prev_state, current_state) + self._dev_measure.update(prev_state, current_state, action) baseline_state = self._baseline.calculate(prev_state, action, current_state) if self._use_inseparable_rollout: