mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Added flags and parameters for the UVFA approximation to relative reachability (for the future tasks paper)
PiperOrigin-RevId: 336881122
This commit is contained in:
committed by
Diego de Las Casas
parent
2e48a73ee4
commit
fefa95eb1f
@@ -30,7 +30,8 @@ class QLearningSE(agent.QLearning):
|
||||
self, actions, alpha=0.1, epsilon=0.1, q_initialisation=0.0,
|
||||
baseline='start', dev_measure='none', dev_fun='truncation',
|
||||
discount=0.99, value_discount=1.0, beta=1.0, num_util_funs=10,
|
||||
exact_baseline=False, baseline_env=None, start_timestep=None):
|
||||
exact_baseline=False, baseline_env=None, start_timestep=None,
|
||||
state_size=None, nonterminal_weight=0.01):
|
||||
"""Create a Q-learning agent with a side effects penalty.
|
||||
|
||||
Args:
|
||||
@@ -53,6 +54,8 @@ class QLearningSE(agent.QLearning):
|
||||
exact_baseline: whether to use an exact or approximate baseline.
|
||||
baseline_env: copy of environment (with noops) for the exact baseline.
|
||||
start_timestep: copy of starting timestep for the baseline.
|
||||
state_size: the size of each state (flattened) for NN reachability.
|
||||
nonterminal_weight: penalty weight on nonterminal states.
|
||||
|
||||
Raises:
|
||||
ValueError: for incorrect baseline, dev_measure, or dev_fun
|
||||
@@ -62,9 +65,9 @@ class QLearningSE(agent.QLearning):
|
||||
discount)
|
||||
|
||||
# Impact penalty: set dev_fun (f)
|
||||
if dev_measure in {'rel_reach', 'att_util'}:
|
||||
if 'rel_reach' in dev_measure or 'att_util' in dev_measure:
|
||||
if dev_fun == 'truncation':
|
||||
dev_fun = lambda diff: max(0, diff)
|
||||
dev_fun = lambda diff: np.maximum(0, diff)
|
||||
elif dev_fun == 'absolute':
|
||||
dev_fun = np.abs
|
||||
else:
|
||||
@@ -76,6 +79,9 @@ class QLearningSE(agent.QLearning):
|
||||
# Impact penalty: create deviation measure
|
||||
if dev_measure in {'reach', 'rel_reach'}:
|
||||
deviation = sep.Reachability(value_discount, dev_fun, discount)
|
||||
elif dev_measure == 'uvfa_rel_reach':
|
||||
deviation = sep.UVFAReachability(value_discount, dev_fun, discount,
|
||||
state_size)
|
||||
elif dev_measure == 'att_util':
|
||||
deviation = sep.AttainableUtility(value_discount, dev_fun, num_util_funs,
|
||||
discount)
|
||||
@@ -99,8 +105,8 @@ class QLearningSE(agent.QLearning):
|
||||
else:
|
||||
raise ValueError('Baseline not recognized')
|
||||
|
||||
self._impact_penalty = sep.SideEffectPenalty(baseline, deviation, beta,
|
||||
use_inseparable_rollout)
|
||||
self._impact_penalty = sep.SideEffectPenalty(
|
||||
baseline, deviation, beta, nonterminal_weight, use_inseparable_rollout)
|
||||
|
||||
def begin_episode(self):
|
||||
"""Perform episode initialisation."""
|
||||
|
||||
@@ -31,11 +31,13 @@ from side_effects_penalties.file_loading import filename
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
if __name__ == '__main__': # Avoid defining flags when used as a library.
|
||||
flags.DEFINE_enum('baseline', 'stepwise',
|
||||
# Side effects penalty settings
|
||||
flags.DEFINE_enum('baseline', 'inaction',
|
||||
['start', 'inaction', 'stepwise', 'step_noroll'],
|
||||
'Baseline.')
|
||||
flags.DEFINE_enum('dev_measure', 'rel_reach',
|
||||
['none', 'reach', 'rel_reach', 'att_util'],
|
||||
['none', 'reach', 'rel_reach',
|
||||
'uvfa_rel_reach', 'att_util'],
|
||||
'Deviation measure.')
|
||||
flags.DEFINE_enum('dev_fun', 'truncation', ['truncation', 'absolute'],
|
||||
'Summary function for the deviation measure.')
|
||||
@@ -43,45 +45,65 @@ if __name__ == '__main__': # Avoid defining flags when used as a library.
|
||||
flags.DEFINE_float('value_discount', 0.99,
|
||||
'Discount factor for deviation measure value function.')
|
||||
flags.DEFINE_float('beta', 30.0, 'Weight for side effects penalty.')
|
||||
flags.DEFINE_string('nonterminal', 'disc',
|
||||
'Penalty for nonterminal states relative to terminal'
|
||||
'states: none (0), full (1), or disc (1-discount).')
|
||||
flags.DEFINE_bool('exact_baseline', False,
|
||||
'Compute the exact baseline using an environment copy.')
|
||||
# Agent settings
|
||||
flags.DEFINE_bool('anneal', True,
|
||||
'Whether to anneal the exploration rate from 1 to 0.')
|
||||
flags.DEFINE_integer('num_episodes', 10000, 'Number of episodes.')
|
||||
flags.DEFINE_integer('num_episodes_noexp', 0,
|
||||
'Number of episodes with no exploration.')
|
||||
flags.DEFINE_integer('seed', 1, 'Random seed.')
|
||||
# Environment settings
|
||||
flags.DEFINE_string('env_name', 'box', 'Environment name.')
|
||||
flags.DEFINE_bool('noops', True, 'Whether the environment includes noops.')
|
||||
flags.DEFINE_integer('movement_reward', 0, 'Movement reward.')
|
||||
flags.DEFINE_integer('goal_reward', 1, 'Reward for reaching a goal state.')
|
||||
flags.DEFINE_integer('side_effect_reward', -1,
|
||||
'Hidden reward for causing side effects.')
|
||||
flags.DEFINE_bool('exact_baseline', False,
|
||||
'Compute the exact baseline using an environment copy.')
|
||||
# Settings for outputting results
|
||||
flags.DEFINE_enum('mode', 'save', ['print', 'save'],
|
||||
'Print results or save to file.')
|
||||
flags.DEFINE_string('path', '', 'File path.')
|
||||
flags.DEFINE_string('suffix', '', 'Filename suffix.')
|
||||
|
||||
|
||||
def run_experiment(baseline, dev_measure, dev_fun, discount, value_discount,
|
||||
beta, anneal, num_episodes, num_episodes_noexp, seed,
|
||||
env_name, noops, exact_baseline, mode, path, suffix,
|
||||
movement_reward, goal_reward, side_effect_reward):
|
||||
def run_experiment(
|
||||
baseline, dev_measure, dev_fun, discount, value_discount, beta, nonterminal,
|
||||
exact_baseline, anneal, num_episodes, num_episodes_noexp, seed,
|
||||
env_name, noops, movement_reward, goal_reward, side_effect_reward,
|
||||
mode, path, suffix):
|
||||
"""Run agent and save or print the results."""
|
||||
performances = []
|
||||
rewards = []
|
||||
seeds = []
|
||||
episodes = []
|
||||
if dev_measure not in ['rel_reach', 'att_util']:
|
||||
if 'rel_reach' not in dev_measure and 'att_util' not in dev_measure:
|
||||
dev_fun = 'none'
|
||||
nonterminal_weights = {'none': 0.0, 'disc': 1.0-discount, 'full': 1.0}
|
||||
nonterminal_weight = nonterminal_weights[nonterminal]
|
||||
reward, performance = training.run_agent(
|
||||
baseline=baseline, dev_measure=dev_measure, dev_fun=dev_fun,
|
||||
discount=discount, value_discount=value_discount, beta=beta,
|
||||
anneal=anneal, num_episodes=num_episodes,
|
||||
num_episodes_noexp=num_episodes_noexp, seed=seed, env_name=env_name,
|
||||
noops=noops, agent_class=agent_with_penalties.QLearningSE,
|
||||
exact_baseline=exact_baseline, movement_reward=movement_reward,
|
||||
goal_reward=goal_reward, side_effect_reward=side_effect_reward)
|
||||
baseline=baseline,
|
||||
dev_measure=dev_measure,
|
||||
dev_fun=dev_fun,
|
||||
discount=discount,
|
||||
value_discount=value_discount,
|
||||
beta=beta,
|
||||
nonterminal_weight=nonterminal_weight,
|
||||
exact_baseline=exact_baseline,
|
||||
anneal=anneal,
|
||||
num_episodes=num_episodes,
|
||||
num_episodes_noexp=num_episodes_noexp,
|
||||
seed=seed,
|
||||
env_name=env_name,
|
||||
noops=noops,
|
||||
movement_reward=movement_reward,
|
||||
goal_reward=goal_reward,
|
||||
side_effect_reward=side_effect_reward,
|
||||
agent_class=agent_with_penalties.QLearningSE)
|
||||
rewards.extend(reward)
|
||||
performances.extend(performance)
|
||||
seeds.extend([seed] * (num_episodes + num_episodes_noexp))
|
||||
@@ -117,6 +139,8 @@ def main(unused_argv):
|
||||
discount=FLAGS.discount,
|
||||
value_discount=FLAGS.value_discount,
|
||||
beta=FLAGS.beta,
|
||||
nonterminal=FLAGS.nonterminal,
|
||||
exact_baseline=FLAGS.exact_baseline,
|
||||
anneal=FLAGS.anneal,
|
||||
num_episodes=FLAGS.num_episodes,
|
||||
num_episodes_noexp=FLAGS.num_episodes_noexp,
|
||||
@@ -126,7 +150,6 @@ def main(unused_argv):
|
||||
movement_reward=FLAGS.movement_reward,
|
||||
goal_reward=FLAGS.goal_reward,
|
||||
side_effect_reward=FLAGS.side_effect_reward,
|
||||
exact_baseline=FLAGS.exact_baseline,
|
||||
mode=FLAGS.mode,
|
||||
path=FLAGS.path,
|
||||
suffix=FLAGS.suffix)
|
||||
|
||||
@@ -87,6 +87,10 @@ class Baseline(object):
|
||||
def rollout_func(self):
|
||||
"""Function to compute a rollout chain, or None if n/a."""
|
||||
|
||||
@property
|
||||
def baseline_state(self):
|
||||
return self._baseline_state
|
||||
|
||||
|
||||
class StartBaseline(Baseline):
|
||||
"""Starting state baseline."""
|
||||
@@ -609,31 +613,32 @@ class NoDeviation(DeviationMeasure):
|
||||
class SideEffectPenalty(object):
|
||||
"""Impact penalty."""
|
||||
|
||||
def __init__(self, baseline, dev_measure, beta=1.0,
|
||||
use_inseparable_rollout=False):
|
||||
def __init__(
|
||||
self, baseline, dev_measure, beta=1.0, nonterminal_weight=0.01,
|
||||
use_inseparable_rollout=False):
|
||||
"""Make an object to calculate the impact penalty.
|
||||
|
||||
Args:
|
||||
baseline: object for calculating the baseline state
|
||||
dev_measure: object for calculating the deviation between states
|
||||
beta: weight (scaling factor) for the impact penalty
|
||||
use_inseparable_rollout: whether to compute the penalty as the average of
|
||||
deviations over parallel inaction rollouts from the current and
|
||||
baselines states (True) otherwise just between the current state and
|
||||
baseline state (or by whatever rollout value is provided in the
|
||||
baseline) (False)
|
||||
nonterminal_weight: penalty weight on nonterminal states.
|
||||
use_inseparable_rollout:
|
||||
whether to compute the penalty as the average of deviations over
|
||||
parallel inaction rollouts from the current and baseline states (True)
|
||||
otherwise just between the current state and baseline state (or by
|
||||
whatever rollout value is provided in the baseline) (False)
|
||||
"""
|
||||
self._baseline = baseline
|
||||
self._dev_measure = dev_measure
|
||||
self._beta = beta
|
||||
self._nonterminal_weight = nonterminal_weight
|
||||
self._use_inseparable_rollout = use_inseparable_rollout
|
||||
|
||||
def calculate(self, prev_state, action, current_state):
|
||||
"""Calculate the penalty associated with a transition, and update models."""
|
||||
if current_state:
|
||||
self._dev_measure.update(prev_state, current_state, action)
|
||||
baseline_state = self._baseline.calculate(prev_state, action,
|
||||
current_state)
|
||||
def compute_penalty(current_state, baseline_state):
|
||||
"""Compute penalty."""
|
||||
if self._use_inseparable_rollout:
|
||||
penalty = self._rollout_value(current_state, baseline_state,
|
||||
self._dev_measure.discount,
|
||||
@@ -642,8 +647,15 @@ class SideEffectPenalty(object):
|
||||
penalty = self._dev_measure.calculate(current_state, baseline_state,
|
||||
self._baseline.rollout_func)
|
||||
return self._beta * penalty
|
||||
else:
|
||||
return 0
|
||||
if current_state: # not a terminal state
|
||||
self._dev_measure.update(prev_state, current_state, action)
|
||||
baseline_state =\
|
||||
self._baseline.calculate(prev_state, action, current_state)
|
||||
penalty = compute_penalty(current_state, baseline_state)
|
||||
return self._nonterminal_weight * penalty
|
||||
else: # terminal state
|
||||
penalty = compute_penalty(prev_state, self._baseline.baseline_state)
|
||||
return penalty
|
||||
|
||||
def reset(self):
|
||||
"""Signal start of new episode."""
|
||||
|
||||
@@ -72,9 +72,9 @@ def run_loop(agent, env, number_episodes, anneal):
|
||||
|
||||
|
||||
def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
|
||||
exact_baseline, anneal, num_episodes, num_episodes_noexp, seed,
|
||||
env_name, noops, movement_reward, goal_reward, side_effect_reward,
|
||||
agent_class):
|
||||
nonterminal_weight, exact_baseline, anneal, num_episodes,
|
||||
num_episodes_noexp, seed, env_name, noops, movement_reward,
|
||||
goal_reward, side_effect_reward, agent_class):
|
||||
"""Run agent.
|
||||
|
||||
Create an agent with the given parameters for the side effects penalty.
|
||||
@@ -89,6 +89,7 @@ def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
|
||||
discount: discount factor
|
||||
value_discount: discount factor for deviation measure value function.
|
||||
beta: weight for side effects penalty
|
||||
nonterminal_weight: penalty weight for nonterminal states.
|
||||
exact_baseline: whether to use an exact or approximate baseline
|
||||
anneal: whether to anneal the exploration rate from 1 to 0 or use a constant
|
||||
exploration rate
|
||||
@@ -108,11 +109,11 @@ def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
|
||||
performances: safety performance for each episode
|
||||
"""
|
||||
np.random.seed(seed)
|
||||
env, _ = get_env(env_name=env_name,
|
||||
noops=noops,
|
||||
movement_reward=movement_reward,
|
||||
goal_reward=goal_reward,
|
||||
side_effect_reward=side_effect_reward)
|
||||
env, state_size = get_env(env_name=env_name,
|
||||
noops=noops,
|
||||
movement_reward=movement_reward,
|
||||
goal_reward=goal_reward,
|
||||
side_effect_reward=side_effect_reward)
|
||||
start_timestep = env.reset()
|
||||
if exact_baseline:
|
||||
baseline_env, _ = get_env(env_name=env_name,
|
||||
@@ -123,10 +124,11 @@ def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
|
||||
else:
|
||||
baseline_env = None
|
||||
agent = agent_class(
|
||||
actions=env.action_spec(), baseline=baseline,
|
||||
dev_measure=dev_measure, dev_fun=dev_fun, discount=discount,
|
||||
value_discount=value_discount, beta=beta, exact_baseline=exact_baseline,
|
||||
baseline_env=baseline_env, start_timestep=start_timestep)
|
||||
actions=env.action_spec(), baseline=baseline, dev_measure=dev_measure,
|
||||
dev_fun=dev_fun, discount=discount, value_discount=value_discount,
|
||||
beta=beta, exact_baseline=exact_baseline, baseline_env=baseline_env,
|
||||
start_timestep=start_timestep, state_size=state_size,
|
||||
nonterminal_weight=nonterminal_weight)
|
||||
returns, performances = run_loop(
|
||||
agent, env, number_episodes=num_episodes, anneal=anneal)
|
||||
if num_episodes_noexp > 0:
|
||||
|
||||
Reference in New Issue
Block a user