Added flags and parameters for the UVFA approximation to relative reachability (for the future tasks paper)

PiperOrigin-RevId: 336881122
This commit is contained in:
Victoria Krakovna
2020-10-13 16:30:51 +01:00
committed by Diego de Las Casas
parent 2e48a73ee4
commit fefa95eb1f
4 changed files with 90 additions and 47 deletions

View File

@@ -30,7 +30,8 @@ class QLearningSE(agent.QLearning):
self, actions, alpha=0.1, epsilon=0.1, q_initialisation=0.0,
baseline='start', dev_measure='none', dev_fun='truncation',
discount=0.99, value_discount=1.0, beta=1.0, num_util_funs=10,
exact_baseline=False, baseline_env=None, start_timestep=None):
exact_baseline=False, baseline_env=None, start_timestep=None,
state_size=None, nonterminal_weight=0.01):
"""Create a Q-learning agent with a side effects penalty.
Args:
@@ -53,6 +54,8 @@ class QLearningSE(agent.QLearning):
exact_baseline: whether to use an exact or approximate baseline.
baseline_env: copy of environment (with noops) for the exact baseline.
start_timestep: copy of starting timestep for the baseline.
state_size: the size of each state (flattened) for NN reachability.
nonterminal_weight: penalty weight on nonterminal states.
Raises:
ValueError: for incorrect baseline, dev_measure, or dev_fun
@@ -62,9 +65,9 @@ class QLearningSE(agent.QLearning):
discount)
# Impact penalty: set dev_fun (f)
if dev_measure in {'rel_reach', 'att_util'}:
if 'rel_reach' in dev_measure or 'att_util' in dev_measure:
if dev_fun == 'truncation':
dev_fun = lambda diff: max(0, diff)
dev_fun = lambda diff: np.maximum(0, diff)
elif dev_fun == 'absolute':
dev_fun = np.abs
else:
@@ -76,6 +79,9 @@ class QLearningSE(agent.QLearning):
# Impact penalty: create deviation measure
if dev_measure in {'reach', 'rel_reach'}:
deviation = sep.Reachability(value_discount, dev_fun, discount)
elif dev_measure == 'uvfa_rel_reach':
deviation = sep.UVFAReachability(value_discount, dev_fun, discount,
state_size)
elif dev_measure == 'att_util':
deviation = sep.AttainableUtility(value_discount, dev_fun, num_util_funs,
discount)
@@ -99,8 +105,8 @@ class QLearningSE(agent.QLearning):
else:
raise ValueError('Baseline not recognized')
self._impact_penalty = sep.SideEffectPenalty(baseline, deviation, beta,
use_inseparable_rollout)
self._impact_penalty = sep.SideEffectPenalty(
baseline, deviation, beta, nonterminal_weight, use_inseparable_rollout)
def begin_episode(self):
"""Perform episode initialisation."""

View File

@@ -31,11 +31,13 @@ from side_effects_penalties.file_loading import filename
FLAGS = flags.FLAGS
if __name__ == '__main__': # Avoid defining flags when used as a library.
flags.DEFINE_enum('baseline', 'stepwise',
# Side effects penalty settings
flags.DEFINE_enum('baseline', 'inaction',
['start', 'inaction', 'stepwise', 'step_noroll'],
'Baseline.')
flags.DEFINE_enum('dev_measure', 'rel_reach',
['none', 'reach', 'rel_reach', 'att_util'],
['none', 'reach', 'rel_reach',
'uvfa_rel_reach', 'att_util'],
'Deviation measure.')
flags.DEFINE_enum('dev_fun', 'truncation', ['truncation', 'absolute'],
'Summary function for the deviation measure.')
@@ -43,45 +45,65 @@ if __name__ == '__main__': # Avoid defining flags when used as a library.
flags.DEFINE_float('value_discount', 0.99,
'Discount factor for deviation measure value function.')
flags.DEFINE_float('beta', 30.0, 'Weight for side effects penalty.')
flags.DEFINE_string('nonterminal', 'disc',
'Penalty for nonterminal states relative to terminal'
'states: none (0), full (1), or disc (1-discount).')
flags.DEFINE_bool('exact_baseline', False,
'Compute the exact baseline using an environment copy.')
# Agent settings
flags.DEFINE_bool('anneal', True,
'Whether to anneal the exploration rate from 1 to 0.')
flags.DEFINE_integer('num_episodes', 10000, 'Number of episodes.')
flags.DEFINE_integer('num_episodes_noexp', 0,
'Number of episodes with no exploration.')
flags.DEFINE_integer('seed', 1, 'Random seed.')
# Environment settings
flags.DEFINE_string('env_name', 'box', 'Environment name.')
flags.DEFINE_bool('noops', True, 'Whether the environment includes noops.')
flags.DEFINE_integer('movement_reward', 0, 'Movement reward.')
flags.DEFINE_integer('goal_reward', 1, 'Reward for reaching a goal state.')
flags.DEFINE_integer('side_effect_reward', -1,
'Hidden reward for causing side effects.')
flags.DEFINE_bool('exact_baseline', False,
'Compute the exact baseline using an environment copy.')
# Settings for outputting results
flags.DEFINE_enum('mode', 'save', ['print', 'save'],
'Print results or save to file.')
flags.DEFINE_string('path', '', 'File path.')
flags.DEFINE_string('suffix', '', 'Filename suffix.')
def run_experiment(baseline, dev_measure, dev_fun, discount, value_discount,
beta, anneal, num_episodes, num_episodes_noexp, seed,
env_name, noops, exact_baseline, mode, path, suffix,
movement_reward, goal_reward, side_effect_reward):
def run_experiment(
baseline, dev_measure, dev_fun, discount, value_discount, beta, nonterminal,
exact_baseline, anneal, num_episodes, num_episodes_noexp, seed,
env_name, noops, movement_reward, goal_reward, side_effect_reward,
mode, path, suffix):
"""Run agent and save or print the results."""
performances = []
rewards = []
seeds = []
episodes = []
if dev_measure not in ['rel_reach', 'att_util']:
if 'rel_reach' not in dev_measure and 'att_util' not in dev_measure:
dev_fun = 'none'
nonterminal_weights = {'none': 0.0, 'disc': 1.0-discount, 'full': 1.0}
nonterminal_weight = nonterminal_weights[nonterminal]
reward, performance = training.run_agent(
baseline=baseline, dev_measure=dev_measure, dev_fun=dev_fun,
discount=discount, value_discount=value_discount, beta=beta,
anneal=anneal, num_episodes=num_episodes,
num_episodes_noexp=num_episodes_noexp, seed=seed, env_name=env_name,
noops=noops, agent_class=agent_with_penalties.QLearningSE,
exact_baseline=exact_baseline, movement_reward=movement_reward,
goal_reward=goal_reward, side_effect_reward=side_effect_reward)
baseline=baseline,
dev_measure=dev_measure,
dev_fun=dev_fun,
discount=discount,
value_discount=value_discount,
beta=beta,
nonterminal_weight=nonterminal_weight,
exact_baseline=exact_baseline,
anneal=anneal,
num_episodes=num_episodes,
num_episodes_noexp=num_episodes_noexp,
seed=seed,
env_name=env_name,
noops=noops,
movement_reward=movement_reward,
goal_reward=goal_reward,
side_effect_reward=side_effect_reward,
agent_class=agent_with_penalties.QLearningSE)
rewards.extend(reward)
performances.extend(performance)
seeds.extend([seed] * (num_episodes + num_episodes_noexp))
@@ -117,6 +139,8 @@ def main(unused_argv):
discount=FLAGS.discount,
value_discount=FLAGS.value_discount,
beta=FLAGS.beta,
nonterminal=FLAGS.nonterminal,
exact_baseline=FLAGS.exact_baseline,
anneal=FLAGS.anneal,
num_episodes=FLAGS.num_episodes,
num_episodes_noexp=FLAGS.num_episodes_noexp,
@@ -126,7 +150,6 @@ def main(unused_argv):
movement_reward=FLAGS.movement_reward,
goal_reward=FLAGS.goal_reward,
side_effect_reward=FLAGS.side_effect_reward,
exact_baseline=FLAGS.exact_baseline,
mode=FLAGS.mode,
path=FLAGS.path,
suffix=FLAGS.suffix)

View File

@@ -87,6 +87,10 @@ class Baseline(object):
def rollout_func(self):
"""Function to compute a rollout chain, or None if n/a."""
@property
def baseline_state(self):
return self._baseline_state
class StartBaseline(Baseline):
"""Starting state baseline."""
@@ -609,31 +613,32 @@ class NoDeviation(DeviationMeasure):
class SideEffectPenalty(object):
"""Impact penalty."""
def __init__(self, baseline, dev_measure, beta=1.0,
use_inseparable_rollout=False):
def __init__(
self, baseline, dev_measure, beta=1.0, nonterminal_weight=0.01,
use_inseparable_rollout=False):
"""Make an object to calculate the impact penalty.
Args:
baseline: object for calculating the baseline state
dev_measure: object for calculating the deviation between states
beta: weight (scaling factor) for the impact penalty
use_inseparable_rollout: whether to compute the penalty as the average of
deviations over parallel inaction rollouts from the current and
baselines states (True) otherwise just between the current state and
baseline state (or by whatever rollout value is provided in the
baseline) (False)
nonterminal_weight: penalty weight on nonterminal states.
use_inseparable_rollout:
whether to compute the penalty as the average of deviations over
parallel inaction rollouts from the current and baseline states (True)
otherwise just between the current state and baseline state (or by
whatever rollout value is provided in the baseline) (False)
"""
self._baseline = baseline
self._dev_measure = dev_measure
self._beta = beta
self._nonterminal_weight = nonterminal_weight
self._use_inseparable_rollout = use_inseparable_rollout
def calculate(self, prev_state, action, current_state):
"""Calculate the penalty associated with a transition, and update models."""
if current_state:
self._dev_measure.update(prev_state, current_state, action)
baseline_state = self._baseline.calculate(prev_state, action,
current_state)
def compute_penalty(current_state, baseline_state):
"""Compute penalty."""
if self._use_inseparable_rollout:
penalty = self._rollout_value(current_state, baseline_state,
self._dev_measure.discount,
@@ -642,8 +647,15 @@ class SideEffectPenalty(object):
penalty = self._dev_measure.calculate(current_state, baseline_state,
self._baseline.rollout_func)
return self._beta * penalty
else:
return 0
if current_state: # not a terminal state
self._dev_measure.update(prev_state, current_state, action)
baseline_state =\
self._baseline.calculate(prev_state, action, current_state)
penalty = compute_penalty(current_state, baseline_state)
return self._nonterminal_weight * penalty
else: # terminal state
penalty = compute_penalty(prev_state, self._baseline.baseline_state)
return penalty
def reset(self):
"""Signal start of new episode."""

View File

@@ -72,9 +72,9 @@ def run_loop(agent, env, number_episodes, anneal):
def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
exact_baseline, anneal, num_episodes, num_episodes_noexp, seed,
env_name, noops, movement_reward, goal_reward, side_effect_reward,
agent_class):
nonterminal_weight, exact_baseline, anneal, num_episodes,
num_episodes_noexp, seed, env_name, noops, movement_reward,
goal_reward, side_effect_reward, agent_class):
"""Run agent.
Create an agent with the given parameters for the side effects penalty.
@@ -89,6 +89,7 @@ def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
discount: discount factor
value_discount: discount factor for deviation measure value function.
beta: weight for side effects penalty
nonterminal_weight: penalty weight for nonterminal states.
exact_baseline: whether to use an exact or approximate baseline
anneal: whether to anneal the exploration rate from 1 to 0 or use a constant
exploration rate
@@ -108,11 +109,11 @@ def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
performances: safety performance for each episode
"""
np.random.seed(seed)
env, _ = get_env(env_name=env_name,
noops=noops,
movement_reward=movement_reward,
goal_reward=goal_reward,
side_effect_reward=side_effect_reward)
env, state_size = get_env(env_name=env_name,
noops=noops,
movement_reward=movement_reward,
goal_reward=goal_reward,
side_effect_reward=side_effect_reward)
start_timestep = env.reset()
if exact_baseline:
baseline_env, _ = get_env(env_name=env_name,
@@ -123,10 +124,11 @@ def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
else:
baseline_env = None
agent = agent_class(
actions=env.action_spec(), baseline=baseline,
dev_measure=dev_measure, dev_fun=dev_fun, discount=discount,
value_discount=value_discount, beta=beta, exact_baseline=exact_baseline,
baseline_env=baseline_env, start_timestep=start_timestep)
actions=env.action_spec(), baseline=baseline, dev_measure=dev_measure,
dev_fun=dev_fun, discount=discount, value_discount=value_discount,
beta=beta, exact_baseline=exact_baseline, baseline_env=baseline_env,
start_timestep=start_timestep, state_size=state_size,
nonterminal_weight=nonterminal_weight)
returns, performances = run_loop(
agent, env, number_episodes=num_episodes, anneal=anneal)
if num_episodes_noexp > 0: