mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Added functionality for running on new environment variants in the future task paper
PiperOrigin-RevId: 336680745
This commit is contained in:
committed by
Diego de Las Casas
parent
1d86410c90
commit
bc398d8004
@@ -51,6 +51,10 @@ if __name__ == '__main__': # Avoid defining flags when used as a library.
|
||||
flags.DEFINE_integer('seed', 1, 'Random seed.')
|
||||
flags.DEFINE_string('env_name', 'box', 'Environment name.')
|
||||
flags.DEFINE_bool('noops', True, 'Whether the environment includes noops.')
|
||||
flags.DEFINE_integer('movement_reward', 0, 'Movement reward.')
|
||||
flags.DEFINE_integer('goal_reward', 1, 'Reward for reaching a goal state.')
|
||||
flags.DEFINE_integer('side_effect_reward', -1,
|
||||
'Hidden reward for causing side effects.')
|
||||
flags.DEFINE_bool('exact_baseline', False,
|
||||
'Compute the exact baseline using an environment copy.')
|
||||
flags.DEFINE_enum('mode', 'save', ['print', 'save'],
|
||||
@@ -61,7 +65,8 @@ if __name__ == '__main__': # Avoid defining flags when used as a library.
|
||||
|
||||
def run_experiment(baseline, dev_measure, dev_fun, discount, value_discount,
|
||||
beta, anneal, num_episodes, num_episodes_noexp, seed,
|
||||
env_name, noops, exact_baseline, mode, path, suffix):
|
||||
env_name, noops, exact_baseline, mode, path, suffix,
|
||||
movement_reward, goal_reward, side_effect_reward):
|
||||
"""Run agent and save or print the results."""
|
||||
performances = []
|
||||
rewards = []
|
||||
@@ -75,7 +80,8 @@ def run_experiment(baseline, dev_measure, dev_fun, discount, value_discount,
|
||||
anneal=anneal, num_episodes=num_episodes,
|
||||
num_episodes_noexp=num_episodes_noexp, seed=seed, env_name=env_name,
|
||||
noops=noops, agent_class=agent_with_penalties.QLearningSE,
|
||||
exact_baseline=exact_baseline)
|
||||
exact_baseline=exact_baseline, movement_reward=movement_reward,
|
||||
goal_reward=goal_reward, side_effect_reward=side_effect_reward)
|
||||
rewards.extend(reward)
|
||||
performances.extend(performance)
|
||||
seeds.extend([seed] * (num_episodes + num_episodes_noexp))
|
||||
@@ -117,6 +123,9 @@ def main(unused_argv):
|
||||
seed=FLAGS.seed,
|
||||
env_name=FLAGS.env_name,
|
||||
noops=FLAGS.noops,
|
||||
movement_reward=FLAGS.movement_reward,
|
||||
goal_reward=FLAGS.goal_reward,
|
||||
side_effect_reward=FLAGS.side_effect_reward,
|
||||
exact_baseline=FLAGS.exact_baseline,
|
||||
mode=FLAGS.mode,
|
||||
path=FLAGS.path,
|
||||
|
||||
@@ -44,8 +44,8 @@ class SideEffectsTestCase(parameterized.TestCase):
|
||||
class BaselineTestCase(SideEffectsTestCase):
|
||||
|
||||
def _create_baseline(self, env_name):
|
||||
self._env = training.get_env(env_name, True)
|
||||
self._baseline_env = training.get_env(env_name, True)
|
||||
self._env, _ = training.get_env(env_name, True)
|
||||
self._baseline_env, _ = training.get_env(env_name, True)
|
||||
baseline_class = getattr(side_effects_penalty,
|
||||
self.__class__.__name__[:-4]) # remove 'Test'
|
||||
self._baseline = baseline_class(
|
||||
@@ -84,7 +84,8 @@ class StartBaselineTest(BaselineTestCase):
|
||||
|
||||
class InactionBaselineTest(BaselineTestCase):
|
||||
|
||||
box_action_spec = training.get_env('box', True).action_spec()
|
||||
box_env, _ = training.get_env('box', True)
|
||||
box_action_spec = box_env.action_spec()
|
||||
|
||||
@parameterized.parameters(
|
||||
*list(range(box_action_spec.minimum, box_action_spec.maximum + 1)))
|
||||
@@ -191,7 +192,7 @@ class NoDeviationTest(SideEffectsTestCase):
|
||||
def _random_initial_transition(self):
|
||||
env_name = np.random.choice(environments)
|
||||
noops = np.random.choice([True, False])
|
||||
env = training.get_env(env_name, noops)
|
||||
env, _ = training.get_env(env_name, noops)
|
||||
action_range = self._env_to_action_range(env)
|
||||
action = np.random.choice(action_range)
|
||||
state1 = self._timestep_to_state(env.reset())
|
||||
@@ -216,7 +217,7 @@ class UnreachabilityTest(SideEffectsTestCase):
|
||||
def testUnreachabilityCycle(self, gamma):
|
||||
# Reachability with no dev_fun means unreachability
|
||||
deviation = side_effects_penalty.Reachability(value_discount=gamma)
|
||||
env = training.get_env('box', False)
|
||||
env, _ = training.get_env('box', False)
|
||||
|
||||
state0 = self._timestep_to_state(env.reset())
|
||||
state1 = self._timestep_to_state(env.step(Actions.LEFT))
|
||||
|
||||
@@ -23,16 +23,25 @@ import numpy as np
|
||||
from six.moves import range
|
||||
|
||||
|
||||
def get_env(env_name, noops):
|
||||
def get_env(env_name, noops,
|
||||
movement_reward=-1, goal_reward=1, side_effect_reward=-1):
|
||||
"""Get a copy of the environment for simulating the baseline."""
|
||||
if env_name == 'box':
|
||||
env = factory.get_environment_obj('side_effects_sokoban', noops=noops)
|
||||
elif env_name in ['vase', 'sushi', 'sushi_goal']:
|
||||
if env_name == 'box' or 'sokocoin' in env_name:
|
||||
levels = {'box': 0, 'sokocoin1': 1, 'sokocoin2': 2, 'sokocoin3': 3}
|
||||
sizes = {'box': 36, 'sokocoin1': 100, 'sokocoin2': 72, 'sokocoin3': 100}
|
||||
env = factory.get_environment_obj(
|
||||
'conveyor_belt', variant=env_name, noops=noops)
|
||||
'side_effects_sokoban', noops=noops, movement_reward=movement_reward,
|
||||
goal_reward=goal_reward, wall_reward=side_effect_reward,
|
||||
corner_reward=side_effect_reward, level=levels[env_name])
|
||||
size = sizes[env_name]
|
||||
elif 'sushi' in env_name or env_name == 'vase':
|
||||
env = factory.get_environment_obj(
|
||||
'conveyor_belt', variant=env_name, noops=noops, goal_reward=goal_reward)
|
||||
size = 49
|
||||
else:
|
||||
env = factory.get_environment_obj(env_name)
|
||||
return env
|
||||
size = None
|
||||
return env, size
|
||||
|
||||
|
||||
def run_loop(agent, env, number_episodes, anneal):
|
||||
@@ -63,8 +72,9 @@ def run_loop(agent, env, number_episodes, anneal):
|
||||
|
||||
|
||||
def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
|
||||
anneal, seed, env_name, noops, num_episodes, num_episodes_noexp,
|
||||
exact_baseline, agent_class):
|
||||
exact_baseline, anneal, num_episodes, num_episodes_noexp, seed,
|
||||
env_name, noops, movement_reward, goal_reward, side_effect_reward,
|
||||
agent_class):
|
||||
"""Run agent.
|
||||
|
||||
Create an agent with the given parameters for the side effects penalty.
|
||||
@@ -79,14 +89,17 @@ def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
|
||||
discount: discount factor
|
||||
value_discount: discount factor for deviation measure value function.
|
||||
beta: weight for side effects penalty
|
||||
exact_baseline: whether to use an exact or approximate baseline
|
||||
anneal: whether to anneal the exploration rate from 1 to 0 or use a constant
|
||||
exploration rate
|
||||
num_episodes: number of episodes
|
||||
num_episodes_noexp: number of episodes with no exploration
|
||||
seed: random seed
|
||||
env_name: environment name
|
||||
noops: whether the environment has noop actions
|
||||
num_episodes: number of episodes
|
||||
num_episodes_noexp: number of episodes with no exploration
|
||||
exact_baseline: whether to use an exact or approximate baseline
|
||||
movement_reward: movement reward
|
||||
goal_reward: reward for reaching a goal state
|
||||
side_effect_reward: hidden reward for causing side effects
|
||||
agent_class: Q-learning agent class: QLearning (regular) or QLearningSE
|
||||
(with side effects penalty)
|
||||
|
||||
@@ -95,10 +108,18 @@ def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
|
||||
performances: safety performance for each episode
|
||||
"""
|
||||
np.random.seed(seed)
|
||||
env = get_env(env_name, noops)
|
||||
env, _ = get_env(env_name=env_name,
|
||||
noops=noops,
|
||||
movement_reward=movement_reward,
|
||||
goal_reward=goal_reward,
|
||||
side_effect_reward=side_effect_reward)
|
||||
start_timestep = env.reset()
|
||||
if exact_baseline:
|
||||
baseline_env = get_env(env_name, True)
|
||||
baseline_env, _ = get_env(env_name=env_name,
|
||||
noops=True,
|
||||
movement_reward=movement_reward,
|
||||
goal_reward=goal_reward,
|
||||
side_effect_reward=side_effect_reward)
|
||||
else:
|
||||
baseline_env = None
|
||||
agent = agent_class(
|
||||
|
||||
Reference in New Issue
Block a user