add UVFA approximation for relative reachability penalty for the future tasks paper

PiperOrigin-RevId: 336851625
This commit is contained in:
Victoria Krakovna
2020-10-13 12:38:01 +01:00
committed by Diego de Las Casas
parent bc398d8004
commit 0b9372d5e6
+197 -3
View File
@@ -26,10 +26,13 @@ import abc
import collections
import copy
import enum
import random
import numpy as np
import six
from six.moves import range
from six.moves import zip
import sonnet as snt
import tensorflow.compat.v1 as tf
class Actions(enum.IntEnum):
@@ -282,7 +285,8 @@ class Reachability(ReachabilityMixin, DeviationMeasure):
self._reachability = collections.defaultdict(
lambda: collections.defaultdict(lambda: 0))
def update(self, prev_state, current_state):
def update(self, prev_state, current_state, action=None):
del action # Unused.
self._reachability[prev_state][prev_state] = 1
self._reachability[current_state][current_state] = 1
if self._reachability[prev_state][current_state] < self._value_discount:
@@ -300,6 +304,195 @@ class Reachability(ReachabilityMixin, DeviationMeasure):
return self._discount
class UVFAReachability(ReachabilityMixin, DeviationMeasure):
"""Approximate relative reachability deviation measure using UVFA.
We approximate reachability using a neural network only trained on state
transitions that have been observed. For each (s0, action, s1) transition,
we update the reachability estimate for (s0, action, s) towards the
reachability estimate between s1 and s, for each s in a random sample of
size update_sample_size. In particular, the loss for the neural network
reachability estimate (NN) is
sum_s(max_a(NN(s1, a, s)) * value_discount - NN(s0, action, s)),
where the sum is over all sampled s, the max is taken over all actions a.
At evaluation time, the reachability difference is calculated with respect
to a randomly sampled set of states of size calc_sample_size.
"""
def __init__(
self,
value_discount=0.95,
dev_fun=None,
discount=0.95,
state_size=36, # Sokoban default
num_actions=5,
update_sample_size=10,
calc_sample_size=10,
hidden_size=50,
representation_size=5,
num_layers=1,
base_loss_coeff=0.1,
num_stored=100):
# Create networks to generate state representations. To get a reachability
# estimate, take the dot product of the origin network output and the goal
# network output, then pass it through a sigmoid function to constrain it to
# between 0 and 1.
output_sizes = [hidden_size] * num_layers + [representation_size]
self._origin_network = snt.nets.MLP(
output_sizes=output_sizes,
activation=tf.nn.relu,
activate_final=False,
name='origin_network')
self._goal_network = snt.nets.MLP(
output_sizes=output_sizes,
activation=tf.nn.relu,
activate_final=False,
name='goal_network')
self._value_discount = value_discount
self._dev_fun = dev_fun
self._discount = discount
self._state_size = state_size
self._num_actions = num_actions
self._update_sample_size = update_sample_size
self._calc_sample_size = calc_sample_size
self._num_stored = num_stored
self._stored_states = set()
self._state_0_placeholder = tf.placeholder(tf.float32, shape=(state_size))
self._state_1_placeholder = tf.placeholder(tf.float32, shape=(state_size))
self._action_placeholder = tf.placeholder(tf.float32, shape=(num_actions))
self._update_sample_placeholder = tf.placeholder(
tf.float32, shape=(update_sample_size, state_size))
self._calc_sample_placeholder = tf.placeholder(
tf.float32, shape=(calc_sample_size, state_size))
# Trained to estimate reachability = value_discount ^ distance.
self._sample_loss = self._get_state_action_loss(
self._state_0_placeholder,
self._state_1_placeholder,
self._action_placeholder,
self._update_sample_placeholder)
# Add additional loss to force observed transitions towards value_discount.
self._base_reachability = self._get_state_sample_reachability(
self._state_0_placeholder,
tf.expand_dims(self._state_1_placeholder, axis=0),
action=self._action_placeholder)
self._base_case_loss = tf.keras.losses.MSE(self._value_discount,
self._base_reachability)
self._opt = tf.train.AdamOptimizer().minimize(self._sample_loss +
base_loss_coeff *
self._base_case_loss)
current_state_reachability = self._get_state_sample_reachability(
self._state_0_placeholder, self._calc_sample_placeholder)
baseline_state_reachability = self._get_state_sample_reachability(
self._state_1_placeholder, self._calc_sample_placeholder)
self._reachability_calculation = [
tf.reshape(baseline_state_reachability, [-1]),
tf.reshape(current_state_reachability, [-1])
]
init = tf.global_variables_initializer()
self._sess = tf.Session()
self._sess.run(init)
def calculate(self, current_state, baseline_state, rollout_func=None):
"""Compute the reachability penalty between two states."""
current_state = np.array(current_state).flatten()
baseline_state = np.array(baseline_state).flatten()
sample = self._sample_n_states(self._calc_sample_size)
# Run if there are enough states to draw a correctly-sized sample from.
if sample:
base, curr = self._sess.run(
self._reachability_calculation,
feed_dict={
self._state_0_placeholder: current_state,
self._state_1_placeholder: baseline_state,
self._calc_sample_placeholder: sample
})
return sum(map(self._dev_fun, base - curr)) / self._calc_sample_size
else:
return 0
def _sample_n_states(self, n):
try:
return random.sample(self._stored_states, n)
except ValueError:
return None
def update(self, prev_state, current_state, action):
prev_state = np.array(prev_state).flatten()
current_state = np.array(current_state).flatten()
one_hot_action = np.zeros(self._num_actions)
one_hot_action[action] = 1
sample = self._sample_n_states(self._update_sample_size)
if self._num_stored is None or len(self._stored_states) < self._num_stored:
self._stored_states.add(tuple(prev_state))
self._stored_states.add(tuple(current_state))
elif (np.random.random() < 0.01 and
tuple(current_state) not in self._stored_states):
self._stored_states.pop()
self._stored_states.add(tuple(current_state))
# If there aren't enough states to get a full sample, do nothing.
if sample:
self._sess.run([self._opt], feed_dict={
self._state_0_placeholder: prev_state,
self._state_1_placeholder: current_state,
self._action_placeholder: one_hot_action,
self._update_sample_placeholder: sample
})
def _get_state_action_loss(self, prev_state, current_state, action, sample):
"""Get the loss from differences in state reachability estimates."""
# Calculate NN(s0, action, s) for all s in sample.
prev_state_reachability = self._get_state_sample_reachability(
prev_state, sample, action=action)
# Calculate max_a(NN(s1, a, s)) for all s in sample and all actions a.
current_state_reachability = tf.stop_gradient(
self._get_state_sample_reachability(current_state, sample))
# Combine to return loss.
return tf.keras.losses.MSE(
current_state_reachability * self._value_discount,
prev_state_reachability)
def _get_state_sample_reachability(self, state, sample, action=None):
"""Calculate reachability from a state to each item in a sample."""
if action is None:
state_options = self._tile_with_all_actions(state)
else:
state_options = tf.expand_dims(tf.concat([state, action], axis=0), axis=0)
goal_representations = self._goal_network(sample)
# Reachability of sampled states by taking actions
reach_result = tf.sigmoid(
tf.reduce_max(
tf.matmul(
goal_representations,
self._origin_network(state_options),
transpose_b=True),
axis=1))
if action is None:
# Return 1 if sampled state is already reached (equal to state)
reach_no_action = tf.cast(tf.reduce_all(tf.equal(sample, state), axis=1),
dtype=tf.float32)
reach_result = tf.maximum(reach_result, reach_no_action)
return reach_result
def _tile_with_all_actions(self, state):
"""Returns tensor with all state/action combinations."""
state_tiled = tf.tile(tf.expand_dims(state, axis=0), [self._num_actions, 1])
all_actions_tiled = tf.one_hot(
tf.range(self._num_actions), depth=self._num_actions)
return tf.concat([state_tiled, all_actions_tiled], axis=1)
class AttainableUtilityMixin(object):
"""Class for computing attainable utility measure.
@@ -385,8 +578,9 @@ class AttainableUtility(AttainableUtilityMixin, DeviationMeasure):
# predecessors[s] = set of states known to lead, by some action, to s
self._predecessors = collections.defaultdict(set)
def update(self, prev_state, current_state):
def update(self, prev_state, current_state, action=None):
"""Update predecessors and attainable utility estimates."""
del action # Unused.
self._predecessors[current_state].add(prev_state)
seen = set()
queue = [current_state]
@@ -437,7 +631,7 @@ class SideEffectPenalty(object):
def calculate(self, prev_state, action, current_state):
"""Calculate the penalty associated with a transition, and update models."""
if current_state:
self._dev_measure.update(prev_state, current_state)
self._dev_measure.update(prev_state, current_state, action)
baseline_state = self._baseline.calculate(prev_state, action,
current_state)
if self._use_inseparable_rollout: