Added flags and parameters for the UVFA approximation to relative reachability (for the future tasks paper)

PiperOrigin-RevId: 336881122
This commit is contained in:
Victoria Krakovna
2020-10-13 16:30:51 +01:00
committed by Diego de Las Casas
parent 2e48a73ee4
commit fefa95eb1f
4 changed files with 90 additions and 47 deletions
+11 -5
View File
@@ -30,7 +30,8 @@ class QLearningSE(agent.QLearning):
self, actions, alpha=0.1, epsilon=0.1, q_initialisation=0.0, self, actions, alpha=0.1, epsilon=0.1, q_initialisation=0.0,
baseline='start', dev_measure='none', dev_fun='truncation', baseline='start', dev_measure='none', dev_fun='truncation',
discount=0.99, value_discount=1.0, beta=1.0, num_util_funs=10, discount=0.99, value_discount=1.0, beta=1.0, num_util_funs=10,
exact_baseline=False, baseline_env=None, start_timestep=None): exact_baseline=False, baseline_env=None, start_timestep=None,
state_size=None, nonterminal_weight=0.01):
"""Create a Q-learning agent with a side effects penalty. """Create a Q-learning agent with a side effects penalty.
Args: Args:
@@ -53,6 +54,8 @@ class QLearningSE(agent.QLearning):
exact_baseline: whether to use an exact or approximate baseline. exact_baseline: whether to use an exact or approximate baseline.
baseline_env: copy of environment (with noops) for the exact baseline. baseline_env: copy of environment (with noops) for the exact baseline.
start_timestep: copy of starting timestep for the baseline. start_timestep: copy of starting timestep for the baseline.
state_size: the size of each state (flattened) for NN reachability.
nonterminal_weight: penalty weight on nonterminal states.
Raises: Raises:
ValueError: for incorrect baseline, dev_measure, or dev_fun ValueError: for incorrect baseline, dev_measure, or dev_fun
@@ -62,9 +65,9 @@ class QLearningSE(agent.QLearning):
discount) discount)
# Impact penalty: set dev_fun (f) # Impact penalty: set dev_fun (f)
if dev_measure in {'rel_reach', 'att_util'}: if 'rel_reach' in dev_measure or 'att_util' in dev_measure:
if dev_fun == 'truncation': if dev_fun == 'truncation':
dev_fun = lambda diff: max(0, diff) dev_fun = lambda diff: np.maximum(0, diff)
elif dev_fun == 'absolute': elif dev_fun == 'absolute':
dev_fun = np.abs dev_fun = np.abs
else: else:
@@ -76,6 +79,9 @@ class QLearningSE(agent.QLearning):
# Impact penalty: create deviation measure # Impact penalty: create deviation measure
if dev_measure in {'reach', 'rel_reach'}: if dev_measure in {'reach', 'rel_reach'}:
deviation = sep.Reachability(value_discount, dev_fun, discount) deviation = sep.Reachability(value_discount, dev_fun, discount)
elif dev_measure == 'uvfa_rel_reach':
deviation = sep.UVFAReachability(value_discount, dev_fun, discount,
state_size)
elif dev_measure == 'att_util': elif dev_measure == 'att_util':
deviation = sep.AttainableUtility(value_discount, dev_fun, num_util_funs, deviation = sep.AttainableUtility(value_discount, dev_fun, num_util_funs,
discount) discount)
@@ -99,8 +105,8 @@ class QLearningSE(agent.QLearning):
else: else:
raise ValueError('Baseline not recognized') raise ValueError('Baseline not recognized')
self._impact_penalty = sep.SideEffectPenalty(baseline, deviation, beta, self._impact_penalty = sep.SideEffectPenalty(
use_inseparable_rollout) baseline, deviation, beta, nonterminal_weight, use_inseparable_rollout)
def begin_episode(self): def begin_episode(self):
"""Perform episode initialisation.""" """Perform episode initialisation."""
+40 -17
View File
@@ -31,11 +31,13 @@ from side_effects_penalties.file_loading import filename
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
if __name__ == '__main__': # Avoid defining flags when used as a library. if __name__ == '__main__': # Avoid defining flags when used as a library.
flags.DEFINE_enum('baseline', 'stepwise', # Side effects penalty settings
flags.DEFINE_enum('baseline', 'inaction',
['start', 'inaction', 'stepwise', 'step_noroll'], ['start', 'inaction', 'stepwise', 'step_noroll'],
'Baseline.') 'Baseline.')
flags.DEFINE_enum('dev_measure', 'rel_reach', flags.DEFINE_enum('dev_measure', 'rel_reach',
['none', 'reach', 'rel_reach', 'att_util'], ['none', 'reach', 'rel_reach',
'uvfa_rel_reach', 'att_util'],
'Deviation measure.') 'Deviation measure.')
flags.DEFINE_enum('dev_fun', 'truncation', ['truncation', 'absolute'], flags.DEFINE_enum('dev_fun', 'truncation', ['truncation', 'absolute'],
'Summary function for the deviation measure.') 'Summary function for the deviation measure.')
@@ -43,45 +45,65 @@ if __name__ == '__main__': # Avoid defining flags when used as a library.
flags.DEFINE_float('value_discount', 0.99, flags.DEFINE_float('value_discount', 0.99,
'Discount factor for deviation measure value function.') 'Discount factor for deviation measure value function.')
flags.DEFINE_float('beta', 30.0, 'Weight for side effects penalty.') flags.DEFINE_float('beta', 30.0, 'Weight for side effects penalty.')
flags.DEFINE_string('nonterminal', 'disc',
'Penalty for nonterminal states relative to terminal'
'states: none (0), full (1), or disc (1-discount).')
flags.DEFINE_bool('exact_baseline', False,
'Compute the exact baseline using an environment copy.')
# Agent settings
flags.DEFINE_bool('anneal', True, flags.DEFINE_bool('anneal', True,
'Whether to anneal the exploration rate from 1 to 0.') 'Whether to anneal the exploration rate from 1 to 0.')
flags.DEFINE_integer('num_episodes', 10000, 'Number of episodes.') flags.DEFINE_integer('num_episodes', 10000, 'Number of episodes.')
flags.DEFINE_integer('num_episodes_noexp', 0, flags.DEFINE_integer('num_episodes_noexp', 0,
'Number of episodes with no exploration.') 'Number of episodes with no exploration.')
flags.DEFINE_integer('seed', 1, 'Random seed.') flags.DEFINE_integer('seed', 1, 'Random seed.')
# Environment settings
flags.DEFINE_string('env_name', 'box', 'Environment name.') flags.DEFINE_string('env_name', 'box', 'Environment name.')
flags.DEFINE_bool('noops', True, 'Whether the environment includes noops.') flags.DEFINE_bool('noops', True, 'Whether the environment includes noops.')
flags.DEFINE_integer('movement_reward', 0, 'Movement reward.') flags.DEFINE_integer('movement_reward', 0, 'Movement reward.')
flags.DEFINE_integer('goal_reward', 1, 'Reward for reaching a goal state.') flags.DEFINE_integer('goal_reward', 1, 'Reward for reaching a goal state.')
flags.DEFINE_integer('side_effect_reward', -1, flags.DEFINE_integer('side_effect_reward', -1,
'Hidden reward for causing side effects.') 'Hidden reward for causing side effects.')
flags.DEFINE_bool('exact_baseline', False, # Settings for outputting results
'Compute the exact baseline using an environment copy.')
flags.DEFINE_enum('mode', 'save', ['print', 'save'], flags.DEFINE_enum('mode', 'save', ['print', 'save'],
'Print results or save to file.') 'Print results or save to file.')
flags.DEFINE_string('path', '', 'File path.') flags.DEFINE_string('path', '', 'File path.')
flags.DEFINE_string('suffix', '', 'Filename suffix.') flags.DEFINE_string('suffix', '', 'Filename suffix.')
def run_experiment(baseline, dev_measure, dev_fun, discount, value_discount, def run_experiment(
beta, anneal, num_episodes, num_episodes_noexp, seed, baseline, dev_measure, dev_fun, discount, value_discount, beta, nonterminal,
env_name, noops, exact_baseline, mode, path, suffix, exact_baseline, anneal, num_episodes, num_episodes_noexp, seed,
movement_reward, goal_reward, side_effect_reward): env_name, noops, movement_reward, goal_reward, side_effect_reward,
mode, path, suffix):
"""Run agent and save or print the results.""" """Run agent and save or print the results."""
performances = [] performances = []
rewards = [] rewards = []
seeds = [] seeds = []
episodes = [] episodes = []
if dev_measure not in ['rel_reach', 'att_util']: if 'rel_reach' not in dev_measure and 'att_util' not in dev_measure:
dev_fun = 'none' dev_fun = 'none'
nonterminal_weights = {'none': 0.0, 'disc': 1.0-discount, 'full': 1.0}
nonterminal_weight = nonterminal_weights[nonterminal]
reward, performance = training.run_agent( reward, performance = training.run_agent(
baseline=baseline, dev_measure=dev_measure, dev_fun=dev_fun, baseline=baseline,
discount=discount, value_discount=value_discount, beta=beta, dev_measure=dev_measure,
anneal=anneal, num_episodes=num_episodes, dev_fun=dev_fun,
num_episodes_noexp=num_episodes_noexp, seed=seed, env_name=env_name, discount=discount,
noops=noops, agent_class=agent_with_penalties.QLearningSE, value_discount=value_discount,
exact_baseline=exact_baseline, movement_reward=movement_reward, beta=beta,
goal_reward=goal_reward, side_effect_reward=side_effect_reward) nonterminal_weight=nonterminal_weight,
exact_baseline=exact_baseline,
anneal=anneal,
num_episodes=num_episodes,
num_episodes_noexp=num_episodes_noexp,
seed=seed,
env_name=env_name,
noops=noops,
movement_reward=movement_reward,
goal_reward=goal_reward,
side_effect_reward=side_effect_reward,
agent_class=agent_with_penalties.QLearningSE)
rewards.extend(reward) rewards.extend(reward)
performances.extend(performance) performances.extend(performance)
seeds.extend([seed] * (num_episodes + num_episodes_noexp)) seeds.extend([seed] * (num_episodes + num_episodes_noexp))
@@ -117,6 +139,8 @@ def main(unused_argv):
discount=FLAGS.discount, discount=FLAGS.discount,
value_discount=FLAGS.value_discount, value_discount=FLAGS.value_discount,
beta=FLAGS.beta, beta=FLAGS.beta,
nonterminal=FLAGS.nonterminal,
exact_baseline=FLAGS.exact_baseline,
anneal=FLAGS.anneal, anneal=FLAGS.anneal,
num_episodes=FLAGS.num_episodes, num_episodes=FLAGS.num_episodes,
num_episodes_noexp=FLAGS.num_episodes_noexp, num_episodes_noexp=FLAGS.num_episodes_noexp,
@@ -126,7 +150,6 @@ def main(unused_argv):
movement_reward=FLAGS.movement_reward, movement_reward=FLAGS.movement_reward,
goal_reward=FLAGS.goal_reward, goal_reward=FLAGS.goal_reward,
side_effect_reward=FLAGS.side_effect_reward, side_effect_reward=FLAGS.side_effect_reward,
exact_baseline=FLAGS.exact_baseline,
mode=FLAGS.mode, mode=FLAGS.mode,
path=FLAGS.path, path=FLAGS.path,
suffix=FLAGS.suffix) suffix=FLAGS.suffix)
+25 -13
View File
@@ -87,6 +87,10 @@ class Baseline(object):
def rollout_func(self): def rollout_func(self):
"""Function to compute a rollout chain, or None if n/a.""" """Function to compute a rollout chain, or None if n/a."""
@property
def baseline_state(self):
return self._baseline_state
class StartBaseline(Baseline): class StartBaseline(Baseline):
"""Starting state baseline.""" """Starting state baseline."""
@@ -609,31 +613,32 @@ class NoDeviation(DeviationMeasure):
class SideEffectPenalty(object): class SideEffectPenalty(object):
"""Impact penalty.""" """Impact penalty."""
def __init__(self, baseline, dev_measure, beta=1.0, def __init__(
use_inseparable_rollout=False): self, baseline, dev_measure, beta=1.0, nonterminal_weight=0.01,
use_inseparable_rollout=False):
"""Make an object to calculate the impact penalty. """Make an object to calculate the impact penalty.
Args: Args:
baseline: object for calculating the baseline state baseline: object for calculating the baseline state
dev_measure: object for calculating the deviation between states dev_measure: object for calculating the deviation between states
beta: weight (scaling factor) for the impact penalty beta: weight (scaling factor) for the impact penalty
use_inseparable_rollout: whether to compute the penalty as the average of nonterminal_weight: penalty weight on nonterminal states.
deviations over parallel inaction rollouts from the current and use_inseparable_rollout:
baselines states (True) otherwise just between the current state and whether to compute the penalty as the average of deviations over
baseline state (or by whatever rollout value is provided in the parallel inaction rollouts from the current and baseline states (True)
baseline) (False) otherwise just between the current state and baseline state (or by
whatever rollout value is provided in the baseline) (False)
""" """
self._baseline = baseline self._baseline = baseline
self._dev_measure = dev_measure self._dev_measure = dev_measure
self._beta = beta self._beta = beta
self._nonterminal_weight = nonterminal_weight
self._use_inseparable_rollout = use_inseparable_rollout self._use_inseparable_rollout = use_inseparable_rollout
def calculate(self, prev_state, action, current_state): def calculate(self, prev_state, action, current_state):
"""Calculate the penalty associated with a transition, and update models.""" """Calculate the penalty associated with a transition, and update models."""
if current_state: def compute_penalty(current_state, baseline_state):
self._dev_measure.update(prev_state, current_state, action) """Compute penalty."""
baseline_state = self._baseline.calculate(prev_state, action,
current_state)
if self._use_inseparable_rollout: if self._use_inseparable_rollout:
penalty = self._rollout_value(current_state, baseline_state, penalty = self._rollout_value(current_state, baseline_state,
self._dev_measure.discount, self._dev_measure.discount,
@@ -642,8 +647,15 @@ class SideEffectPenalty(object):
penalty = self._dev_measure.calculate(current_state, baseline_state, penalty = self._dev_measure.calculate(current_state, baseline_state,
self._baseline.rollout_func) self._baseline.rollout_func)
return self._beta * penalty return self._beta * penalty
else: if current_state: # not a terminal state
return 0 self._dev_measure.update(prev_state, current_state, action)
baseline_state =\
self._baseline.calculate(prev_state, action, current_state)
penalty = compute_penalty(current_state, baseline_state)
return self._nonterminal_weight * penalty
else: # terminal state
penalty = compute_penalty(prev_state, self._baseline.baseline_state)
return penalty
def reset(self): def reset(self):
"""Signal start of new episode.""" """Signal start of new episode."""
+14 -12
View File
@@ -72,9 +72,9 @@ def run_loop(agent, env, number_episodes, anneal):
def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta, def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
exact_baseline, anneal, num_episodes, num_episodes_noexp, seed, nonterminal_weight, exact_baseline, anneal, num_episodes,
env_name, noops, movement_reward, goal_reward, side_effect_reward, num_episodes_noexp, seed, env_name, noops, movement_reward,
agent_class): goal_reward, side_effect_reward, agent_class):
"""Run agent. """Run agent.
Create an agent with the given parameters for the side effects penalty. Create an agent with the given parameters for the side effects penalty.
@@ -89,6 +89,7 @@ def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
discount: discount factor discount: discount factor
value_discount: discount factor for deviation measure value function. value_discount: discount factor for deviation measure value function.
beta: weight for side effects penalty beta: weight for side effects penalty
nonterminal_weight: penalty weight for nonterminal states.
exact_baseline: whether to use an exact or approximate baseline exact_baseline: whether to use an exact or approximate baseline
anneal: whether to anneal the exploration rate from 1 to 0 or use a constant anneal: whether to anneal the exploration rate from 1 to 0 or use a constant
exploration rate exploration rate
@@ -108,11 +109,11 @@ def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
performances: safety performance for each episode performances: safety performance for each episode
""" """
np.random.seed(seed) np.random.seed(seed)
env, _ = get_env(env_name=env_name, env, state_size = get_env(env_name=env_name,
noops=noops, noops=noops,
movement_reward=movement_reward, movement_reward=movement_reward,
goal_reward=goal_reward, goal_reward=goal_reward,
side_effect_reward=side_effect_reward) side_effect_reward=side_effect_reward)
start_timestep = env.reset() start_timestep = env.reset()
if exact_baseline: if exact_baseline:
baseline_env, _ = get_env(env_name=env_name, baseline_env, _ = get_env(env_name=env_name,
@@ -123,10 +124,11 @@ def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
else: else:
baseline_env = None baseline_env = None
agent = agent_class( agent = agent_class(
actions=env.action_spec(), baseline=baseline, actions=env.action_spec(), baseline=baseline, dev_measure=dev_measure,
dev_measure=dev_measure, dev_fun=dev_fun, discount=discount, dev_fun=dev_fun, discount=discount, value_discount=value_discount,
value_discount=value_discount, beta=beta, exact_baseline=exact_baseline, beta=beta, exact_baseline=exact_baseline, baseline_env=baseline_env,
baseline_env=baseline_env, start_timestep=start_timestep) start_timestep=start_timestep, state_size=state_size,
nonterminal_weight=nonterminal_weight)
returns, performances = run_loop( returns, performances = run_loop(
agent, env, number_episodes=num_episodes, anneal=anneal) agent, env, number_episodes=num_episodes, anneal=anneal)
if num_episodes_noexp > 0: if num_episodes_noexp > 0: