mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
add UVFA approximation for relative reachability penalty for the future tasks paper
PiperOrigin-RevId: 336851625
This commit is contained in:
committed by
Diego de Las Casas
parent
bc398d8004
commit
0b9372d5e6
@@ -26,10 +26,13 @@ import abc
|
||||
import collections
|
||||
import copy
|
||||
import enum
|
||||
import random
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import range
|
||||
from six.moves import zip
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class Actions(enum.IntEnum):
|
||||
@@ -282,7 +285,8 @@ class Reachability(ReachabilityMixin, DeviationMeasure):
|
||||
self._reachability = collections.defaultdict(
|
||||
lambda: collections.defaultdict(lambda: 0))
|
||||
|
||||
def update(self, prev_state, current_state):
|
||||
def update(self, prev_state, current_state, action=None):
|
||||
del action # Unused.
|
||||
self._reachability[prev_state][prev_state] = 1
|
||||
self._reachability[current_state][current_state] = 1
|
||||
if self._reachability[prev_state][current_state] < self._value_discount:
|
||||
@@ -300,6 +304,195 @@ class Reachability(ReachabilityMixin, DeviationMeasure):
|
||||
return self._discount
|
||||
|
||||
|
||||
class UVFAReachability(ReachabilityMixin, DeviationMeasure):
|
||||
"""Approximate relative reachability deviation measure using UVFA.
|
||||
|
||||
We approximate reachability using a neural network only trained on state
|
||||
transitions that have been observed. For each (s0, action, s1) transition,
|
||||
we update the reachability estimate for (s0, action, s) towards the
|
||||
reachability estimate between s1 and s, for each s in a random sample of
|
||||
size update_sample_size. In particular, the loss for the neural network
|
||||
reachability estimate (NN) is
|
||||
sum_s(max_a(NN(s1, a, s)) * value_discount - NN(s0, action, s)),
|
||||
where the sum is over all sampled s, the max is taken over all actions a.
|
||||
|
||||
At evaluation time, the reachability difference is calculated with respect
|
||||
to a randomly sampled set of states of size calc_sample_size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value_discount=0.95,
|
||||
dev_fun=None,
|
||||
discount=0.95,
|
||||
state_size=36, # Sokoban default
|
||||
num_actions=5,
|
||||
update_sample_size=10,
|
||||
calc_sample_size=10,
|
||||
hidden_size=50,
|
||||
representation_size=5,
|
||||
num_layers=1,
|
||||
base_loss_coeff=0.1,
|
||||
num_stored=100):
|
||||
|
||||
# Create networks to generate state representations. To get a reachability
|
||||
# estimate, take the dot product of the origin network output and the goal
|
||||
# network output, then pass it through a sigmoid function to constrain it to
|
||||
# between 0 and 1.
|
||||
output_sizes = [hidden_size] * num_layers + [representation_size]
|
||||
self._origin_network = snt.nets.MLP(
|
||||
output_sizes=output_sizes,
|
||||
activation=tf.nn.relu,
|
||||
activate_final=False,
|
||||
name='origin_network')
|
||||
self._goal_network = snt.nets.MLP(
|
||||
output_sizes=output_sizes,
|
||||
activation=tf.nn.relu,
|
||||
activate_final=False,
|
||||
name='goal_network')
|
||||
|
||||
self._value_discount = value_discount
|
||||
self._dev_fun = dev_fun
|
||||
self._discount = discount
|
||||
self._state_size = state_size
|
||||
self._num_actions = num_actions
|
||||
self._update_sample_size = update_sample_size
|
||||
self._calc_sample_size = calc_sample_size
|
||||
self._num_stored = num_stored
|
||||
self._stored_states = set()
|
||||
|
||||
self._state_0_placeholder = tf.placeholder(tf.float32, shape=(state_size))
|
||||
self._state_1_placeholder = tf.placeholder(tf.float32, shape=(state_size))
|
||||
self._action_placeholder = tf.placeholder(tf.float32, shape=(num_actions))
|
||||
self._update_sample_placeholder = tf.placeholder(
|
||||
tf.float32, shape=(update_sample_size, state_size))
|
||||
self._calc_sample_placeholder = tf.placeholder(
|
||||
tf.float32, shape=(calc_sample_size, state_size))
|
||||
|
||||
# Trained to estimate reachability = value_discount ^ distance.
|
||||
self._sample_loss = self._get_state_action_loss(
|
||||
self._state_0_placeholder,
|
||||
self._state_1_placeholder,
|
||||
self._action_placeholder,
|
||||
self._update_sample_placeholder)
|
||||
|
||||
# Add additional loss to force observed transitions towards value_discount.
|
||||
self._base_reachability = self._get_state_sample_reachability(
|
||||
self._state_0_placeholder,
|
||||
tf.expand_dims(self._state_1_placeholder, axis=0),
|
||||
action=self._action_placeholder)
|
||||
self._base_case_loss = tf.keras.losses.MSE(self._value_discount,
|
||||
self._base_reachability)
|
||||
|
||||
self._opt = tf.train.AdamOptimizer().minimize(self._sample_loss +
|
||||
base_loss_coeff *
|
||||
self._base_case_loss)
|
||||
|
||||
current_state_reachability = self._get_state_sample_reachability(
|
||||
self._state_0_placeholder, self._calc_sample_placeholder)
|
||||
baseline_state_reachability = self._get_state_sample_reachability(
|
||||
self._state_1_placeholder, self._calc_sample_placeholder)
|
||||
|
||||
self._reachability_calculation = [
|
||||
tf.reshape(baseline_state_reachability, [-1]),
|
||||
tf.reshape(current_state_reachability, [-1])
|
||||
]
|
||||
|
||||
init = tf.global_variables_initializer()
|
||||
self._sess = tf.Session()
|
||||
self._sess.run(init)
|
||||
|
||||
def calculate(self, current_state, baseline_state, rollout_func=None):
|
||||
"""Compute the reachability penalty between two states."""
|
||||
current_state = np.array(current_state).flatten()
|
||||
baseline_state = np.array(baseline_state).flatten()
|
||||
sample = self._sample_n_states(self._calc_sample_size)
|
||||
# Run if there are enough states to draw a correctly-sized sample from.
|
||||
if sample:
|
||||
base, curr = self._sess.run(
|
||||
self._reachability_calculation,
|
||||
feed_dict={
|
||||
self._state_0_placeholder: current_state,
|
||||
self._state_1_placeholder: baseline_state,
|
||||
self._calc_sample_placeholder: sample
|
||||
})
|
||||
return sum(map(self._dev_fun, base - curr)) / self._calc_sample_size
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _sample_n_states(self, n):
|
||||
try:
|
||||
return random.sample(self._stored_states, n)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def update(self, prev_state, current_state, action):
|
||||
prev_state = np.array(prev_state).flatten()
|
||||
current_state = np.array(current_state).flatten()
|
||||
one_hot_action = np.zeros(self._num_actions)
|
||||
one_hot_action[action] = 1
|
||||
|
||||
sample = self._sample_n_states(self._update_sample_size)
|
||||
if self._num_stored is None or len(self._stored_states) < self._num_stored:
|
||||
self._stored_states.add(tuple(prev_state))
|
||||
self._stored_states.add(tuple(current_state))
|
||||
elif (np.random.random() < 0.01 and
|
||||
tuple(current_state) not in self._stored_states):
|
||||
self._stored_states.pop()
|
||||
self._stored_states.add(tuple(current_state))
|
||||
|
||||
# If there aren't enough states to get a full sample, do nothing.
|
||||
if sample:
|
||||
self._sess.run([self._opt], feed_dict={
|
||||
self._state_0_placeholder: prev_state,
|
||||
self._state_1_placeholder: current_state,
|
||||
self._action_placeholder: one_hot_action,
|
||||
self._update_sample_placeholder: sample
|
||||
})
|
||||
|
||||
def _get_state_action_loss(self, prev_state, current_state, action, sample):
|
||||
"""Get the loss from differences in state reachability estimates."""
|
||||
# Calculate NN(s0, action, s) for all s in sample.
|
||||
prev_state_reachability = self._get_state_sample_reachability(
|
||||
prev_state, sample, action=action)
|
||||
# Calculate max_a(NN(s1, a, s)) for all s in sample and all actions a.
|
||||
current_state_reachability = tf.stop_gradient(
|
||||
self._get_state_sample_reachability(current_state, sample))
|
||||
# Combine to return loss.
|
||||
return tf.keras.losses.MSE(
|
||||
current_state_reachability * self._value_discount,
|
||||
prev_state_reachability)
|
||||
|
||||
def _get_state_sample_reachability(self, state, sample, action=None):
|
||||
"""Calculate reachability from a state to each item in a sample."""
|
||||
if action is None:
|
||||
state_options = self._tile_with_all_actions(state)
|
||||
else:
|
||||
state_options = tf.expand_dims(tf.concat([state, action], axis=0), axis=0)
|
||||
goal_representations = self._goal_network(sample)
|
||||
# Reachability of sampled states by taking actions
|
||||
reach_result = tf.sigmoid(
|
||||
tf.reduce_max(
|
||||
tf.matmul(
|
||||
goal_representations,
|
||||
self._origin_network(state_options),
|
||||
transpose_b=True),
|
||||
axis=1))
|
||||
if action is None:
|
||||
# Return 1 if sampled state is already reached (equal to state)
|
||||
reach_no_action = tf.cast(tf.reduce_all(tf.equal(sample, state), axis=1),
|
||||
dtype=tf.float32)
|
||||
reach_result = tf.maximum(reach_result, reach_no_action)
|
||||
return reach_result
|
||||
|
||||
def _tile_with_all_actions(self, state):
|
||||
"""Returns tensor with all state/action combinations."""
|
||||
state_tiled = tf.tile(tf.expand_dims(state, axis=0), [self._num_actions, 1])
|
||||
all_actions_tiled = tf.one_hot(
|
||||
tf.range(self._num_actions), depth=self._num_actions)
|
||||
return tf.concat([state_tiled, all_actions_tiled], axis=1)
|
||||
|
||||
|
||||
class AttainableUtilityMixin(object):
|
||||
"""Class for computing attainable utility measure.
|
||||
|
||||
@@ -385,8 +578,9 @@ class AttainableUtility(AttainableUtilityMixin, DeviationMeasure):
|
||||
# predecessors[s] = set of states known to lead, by some action, to s
|
||||
self._predecessors = collections.defaultdict(set)
|
||||
|
||||
def update(self, prev_state, current_state):
|
||||
def update(self, prev_state, current_state, action=None):
|
||||
"""Update predecessors and attainable utility estimates."""
|
||||
del action # Unused.
|
||||
self._predecessors[current_state].add(prev_state)
|
||||
seen = set()
|
||||
queue = [current_state]
|
||||
@@ -437,7 +631,7 @@ class SideEffectPenalty(object):
|
||||
def calculate(self, prev_state, action, current_state):
|
||||
"""Calculate the penalty associated with a transition, and update models."""
|
||||
if current_state:
|
||||
self._dev_measure.update(prev_state, current_state)
|
||||
self._dev_measure.update(prev_state, current_state, action)
|
||||
baseline_state = self._baseline.calculate(prev_state, action,
|
||||
current_state)
|
||||
if self._use_inseparable_rollout:
|
||||
|
||||
Reference in New Issue
Block a user