mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Open sourcing side effects code
PiperOrigin-RevId: 272089371
This commit is contained in:
committed by
Diego de Las Casas
parent
60655a2797
commit
9e3d04c867
@@ -0,0 +1,79 @@
|
||||
# Side effects penalties
|
||||
|
||||
This is the code for the paper [Penalizing side effects using stepwise relative
|
||||
reachability](https://arxiv.org/abs/1806.01186) by Krakovna et al (2019). It
|
||||
implements a tabular Q-learning agent with different penalties for side effects.
|
||||
Each side effects penalty consists of a deviation measure (none, unreachability,
|
||||
relative reachability, or attainable utility) and a baseline (starting state,
|
||||
inaction, or stepwise inaction).
|
||||
|
||||
## Instructions
|
||||
|
||||
Clone the repository:
|
||||
|
||||
`git clone https://github.com/deepmind/deepmind-research/side_effects_penalties.git`
|
||||
|
||||
### Running an agent with a side effects penalty
|
||||
|
||||
Run the agent with a given penalty on an AI Safety Gridworlds environment:
|
||||
|
||||
`python -m side_effects_penalties.run_experiment -baseline <X> -dev_measure <Y> -env_name <Z> -suffix <S>`
|
||||
|
||||
The following parameters can be specified for the side effects penalty:
|
||||
* Baseline state (`-baseline`): starting state (`start`), inaction (`inaction`),
|
||||
stepwise inaction with rollouts (`stepwise`), stepwise inaction without
|
||||
rollouts (`step_noroll`)
|
||||
* Deviation measure (`-dev_measure`): none (`none`), unreachability (`reach`),
|
||||
relative reachability (`rel_reach`), attainable utility (`att_util`)
|
||||
* Discount factor for the deviation measure value function (`-value_discount`)
|
||||
* Summary function to apply to the relative reachability or attainable utility
|
||||
deviation measure (`-dev_fun`): max (0, x) (`truncation`) or |x| (`absolute`)
|
||||
* Weight for the side effects penalty relative to the reward (`-beta`)
|
||||
|
||||
Other arguments:
|
||||
* AI Safety Gridworlds environment name (`-env_name`)
|
||||
* Number of episodes (`-num_episodes`)
|
||||
* Filename suffix for saving result files (`-suffix`)
|
||||
|
||||
### Plotting the results
|
||||
|
||||
Make a summary data frame from the result files generated by `run_experiment`:
|
||||
|
||||
`python -m side_effects_penalties.results_summary -compare_penalties -input_suffix <S>`
|
||||
|
||||
Arguments:
|
||||
* -bar_plot: make a data frame for a bar plot (True) or learning curve plot (False)
|
||||
* -compare_penalties: compare different penalties using the best beta value for
|
||||
each penalty (True), or compare different beta values for a given penalty (False)
|
||||
* If compare_penalties=False, specify the penalty parameters (`-dev_measure`,
|
||||
`-dev_fun` and `-value_discount`)
|
||||
* Environment name (`-env_name`)
|
||||
* Filename suffix for loading result files (`-input_suffix`)
|
||||
* Filename suffix for the summary data frame (`-output_suffix`)
|
||||
|
||||
Import the summary data frame into `plot_results.ipynb` and make a bar plot or
|
||||
learning curve plot.
|
||||
|
||||
## Dependencies
|
||||
|
||||
* Python 2.7 or 3 (tested with Python 2.7.15 and 3.6.7)
|
||||
* [AI Safety Gridworlds](https://github.com/deepmind/ai-safety-gridworlds) suite
|
||||
of safety environments
|
||||
* [Abseil](https://github.com/abseil/abseil-py) Python common libraries
|
||||
* Numpy
|
||||
* Pandas
|
||||
* Six
|
||||
* Matplotlib
|
||||
* Seaborn
|
||||
|
||||
## Citing this work
|
||||
|
||||
If you use this code in your work, please cite the accompanying paper:
|
||||
|
||||
`@article{srr2019,
|
||||
title = {Penalizing Side Effects using Stepwise Relative Reachability},
|
||||
author = {Victoria Krakovna and Laurent Orseau and Ramana Kumar and Miljan Martic and Shane Legg},
|
||||
journal = {CoRR},
|
||||
volume = {abs/1806.01186},
|
||||
year = {2019}
|
||||
}`
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
@@ -0,0 +1,150 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Vanilla Q-Learning agent."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
|
||||
|
||||
class EpsilonGreedyPolicy(object):
|
||||
"""Epsilon greedy policy for table value function lookup."""
|
||||
|
||||
def __init__(self, value_function, actions):
|
||||
"""Construct an epsilon greedy policy object.
|
||||
|
||||
Args:
|
||||
value_function: agent value function as a dict.
|
||||
actions: list of possible actions.
|
||||
|
||||
Raises:
|
||||
ValueError: if `actions` agument is not an iterable.
|
||||
"""
|
||||
if not isinstance(actions, collections.Iterable):
|
||||
raise ValueError('`actions` argument must be an iterable.')
|
||||
|
||||
self._value_function = value_function
|
||||
self._actions = actions
|
||||
|
||||
def get_action(self, epsilon, state):
|
||||
"""Get action following the e-greedy policy.
|
||||
|
||||
Args:
|
||||
epsilon: probability of selecting a random action
|
||||
state: current state of the game as a state/action tuple.
|
||||
|
||||
Returns:
|
||||
Chosen action.
|
||||
"""
|
||||
if np.random.random() < epsilon:
|
||||
return np.random.choice(self._actions)
|
||||
else:
|
||||
values = [self._value_function[(state, action)]
|
||||
for action in self._actions]
|
||||
|
||||
max_value = max(values)
|
||||
max_indices = [i for i, value in enumerate(values) if value == max_value]
|
||||
|
||||
return self._actions[np.random.choice(max_indices)]
|
||||
|
||||
|
||||
class QLearning(object):
|
||||
"""Q-learning agent."""
|
||||
|
||||
def __init__(self, actions, alpha=0.1, epsilon=0.1, q_initialisation=0.0,
|
||||
discount=0.99):
|
||||
"""Create a Q-learning agent.
|
||||
|
||||
Args:
|
||||
actions: a BoundedArraySpec that specifes full discrete action spec.
|
||||
alpha: agent learning rate.
|
||||
epsilon: agent exploration rate.
|
||||
q_initialisation: float, used to initialise the value function.
|
||||
discount: discount factor for rewards.
|
||||
"""
|
||||
|
||||
self._value_function = collections.defaultdict(lambda: q_initialisation)
|
||||
self._valid_actions = list(range(actions.minimum, actions.maximum + 1))
|
||||
self._policy = EpsilonGreedyPolicy(self._value_function,
|
||||
self._valid_actions)
|
||||
|
||||
# Hyperparameters.
|
||||
self.alpha = alpha
|
||||
self.epsilon = epsilon
|
||||
self.discount = discount
|
||||
|
||||
# Episode internal variables.
|
||||
self._current_action = None
|
||||
self._current_state = None
|
||||
|
||||
def begin_episode(self):
|
||||
"""Perform episode initialisation."""
|
||||
self._current_state = None
|
||||
self._current_action = None
|
||||
|
||||
def _timestep_to_state(self, timestep):
|
||||
return tuple(map(tuple, np.copy(timestep.observation['board'])))
|
||||
|
||||
def step(self, timestep):
|
||||
"""Perform a single step in the environment."""
|
||||
# Get state observations.
|
||||
state = self._timestep_to_state(timestep)
|
||||
|
||||
# This is one of the follow up states (i.e. not the initial state).
|
||||
if self._current_state is not None:
|
||||
self._update(timestep, state)
|
||||
|
||||
self._current_state = state
|
||||
# Determine action.
|
||||
self._current_action = self._policy.get_action(self.epsilon, state)
|
||||
# Emit action.
|
||||
return self._current_action
|
||||
|
||||
def _calculate_reward(self, timestep, unused_state):
|
||||
"""Calculate reward: to be extended when impact penalty is added."""
|
||||
reward = timestep.reward
|
||||
return reward
|
||||
|
||||
def _update(self, timestep, state):
|
||||
"""Perform value function update."""
|
||||
|
||||
reward = self._calculate_reward(timestep, state)
|
||||
|
||||
# Terminal state.
|
||||
if not state:
|
||||
delta = (reward - self._value_function[(self._current_state,
|
||||
self._current_action)])
|
||||
# Non-terminal state.
|
||||
else:
|
||||
max_action = self._policy.get_action(0, state)
|
||||
delta = (
|
||||
reward + self.discount * self._value_function[(state, max_action)] -
|
||||
self._value_function[(self._current_state, self._current_action)])
|
||||
|
||||
self._value_function[(self._current_state,
|
||||
self._current_action)] += self.alpha * delta
|
||||
|
||||
def end_episode(self, timestep):
|
||||
"""Performs episode cleanup."""
|
||||
# Update for the terminal state.
|
||||
self._update(timestep, None)
|
||||
|
||||
@property
|
||||
def value_function(self):
|
||||
return self._value_function
|
||||
@@ -0,0 +1,113 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Q-learning with side effects penalties."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from side_effects_penalties import agent
|
||||
from side_effects_penalties import side_effects_penalty as sep
|
||||
|
||||
|
||||
class QLearningSE(agent.QLearning):
|
||||
"""Q-learning agent with side-effects penalties."""
|
||||
|
||||
def __init__(
|
||||
self, actions, alpha=0.1, epsilon=0.1, q_initialisation=0.0,
|
||||
baseline='start', dev_measure='none', dev_fun='truncation',
|
||||
discount=0.99, value_discount=1.0, beta=1.0, num_util_funs=10,
|
||||
exact_baseline=False, baseline_env=None, start_timestep=None):
|
||||
"""Create a Q-learning agent with a side effects penalty.
|
||||
|
||||
Args:
|
||||
actions: full discrete action spec.
|
||||
alpha: agent learning rate.
|
||||
epsilon: agent exploration rate.
|
||||
q_initialisation: float, used to initialise the value function.
|
||||
baseline: which baseline state to use ('start', 'inaction', 'stepwise').
|
||||
dev_measure: deviation measure:
|
||||
- "none" for no penalty,
|
||||
- "reach" for unreachability,
|
||||
- "rel_reach" for relative reachability,
|
||||
- "att_util" for attainable utility,
|
||||
dev_fun: what function to apply in the deviation measure ('truncation' or
|
||||
'absolute' (for 'rel_reach' and 'att_util'), or 'none' (otherwise)).
|
||||
discount: discount factor for rewards.
|
||||
value_discount: discount factor for value functions in penalties.
|
||||
beta: side effects penalty weight.
|
||||
num_util_funs: number of random utility functions for attainable utility.
|
||||
exact_baseline: whether to use an exact or approximate baseline.
|
||||
baseline_env: copy of environment (with noops) for the exact baseline.
|
||||
start_timestep: copy of starting timestep for the baseline.
|
||||
|
||||
Raises:
|
||||
ValueError: for incorrect baseline, dev_measure, or dev_fun
|
||||
"""
|
||||
|
||||
super(QLearningSE, self).__init__(actions, alpha, epsilon, q_initialisation,
|
||||
discount)
|
||||
|
||||
# Impact penalty: set dev_fun (f)
|
||||
if dev_measure in {'rel_reach', 'att_util'}:
|
||||
if dev_fun == 'truncation':
|
||||
dev_fun = lambda diff: max(0, diff)
|
||||
elif dev_fun == 'absolute':
|
||||
dev_fun = np.abs
|
||||
else:
|
||||
raise ValueError('Deviation function not recognized')
|
||||
else:
|
||||
assert dev_fun == 'none'
|
||||
dev_fun = None
|
||||
|
||||
# Impact penalty: create deviation measure
|
||||
if dev_measure in {'reach', 'rel_reach'}:
|
||||
deviation = sep.Reachability(value_discount, dev_fun, discount)
|
||||
elif dev_measure == 'att_util':
|
||||
deviation = sep.AttainableUtility(value_discount, dev_fun, num_util_funs,
|
||||
discount)
|
||||
elif dev_measure == 'none':
|
||||
deviation = sep.NoDeviation()
|
||||
else:
|
||||
raise ValueError('Deviation measure not recognized')
|
||||
|
||||
use_inseparable_rollout = (
|
||||
dev_measure == 'reach' and baseline == 'stepwise')
|
||||
|
||||
# Impact penalty: create baseline
|
||||
if baseline in {'start', 'inaction', 'stepwise'}:
|
||||
baseline_class = getattr(sep, baseline.capitalize() + 'Baseline')
|
||||
baseline = baseline_class(start_timestep, exact_baseline, baseline_env,
|
||||
self._timestep_to_state)
|
||||
elif baseline == 'step_noroll':
|
||||
baseline_class = getattr(sep, 'StepwiseBaseline')
|
||||
baseline = baseline_class(start_timestep, exact_baseline, baseline_env,
|
||||
self._timestep_to_state, False)
|
||||
else:
|
||||
raise ValueError('Baseline not recognized')
|
||||
|
||||
self._impact_penalty = sep.SideEffectPenalty(baseline, deviation, beta,
|
||||
use_inseparable_rollout)
|
||||
|
||||
def begin_episode(self):
|
||||
"""Perform episode initialisation."""
|
||||
super(QLearningSE, self).begin_episode()
|
||||
self._impact_penalty.reset()
|
||||
|
||||
def _calculate_reward(self, timestep, state):
|
||||
reward = super(QLearningSE, self)._calculate_reward(timestep, state)
|
||||
return (reward - self._impact_penalty.calculate(
|
||||
self._current_state, self._current_action, state))
|
||||
@@ -0,0 +1,61 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Helper functions for loading files."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def filename(env_name, noops, dev_measure, dev_fun, baseline, beta,
|
||||
value_discount, seed, path='', suffix=''):
|
||||
"""Generate filename for the given set of parameters."""
|
||||
noop_str = 'noops' if noops else 'nonoops'
|
||||
seed_str = '_' + str(seed) if seed else ''
|
||||
filename_template = ('{env_name}_{noop_str}_{dev_measure}_{dev_fun}' +
|
||||
'_{baseline}_beta_{beta}_vd_{value_discount}' +
|
||||
'{suffix}{seed_str}.csv')
|
||||
full_path = os.path.join(path, filename_template.format(
|
||||
env_name=env_name, noop_str=noop_str, dev_measure=dev_measure,
|
||||
dev_fun=dev_fun, baseline=baseline, beta=beta,
|
||||
value_discount=value_discount, suffix=suffix, seed_str=seed_str))
|
||||
return full_path
|
||||
|
||||
|
||||
def load_files(baseline, dev_measure, dev_fun, value_discount, beta, env_name,
|
||||
noops, path, suffix, seed_list, final=True):
|
||||
"""Load result files generated by run_experiment with the given parameters."""
|
||||
def try_loading(f, final):
|
||||
if os.path.isfile(f):
|
||||
df = pd.read_csv(f, index_col=0)
|
||||
if final:
|
||||
last_episode = max(df['episode'])
|
||||
return df[df.episode == last_episode]
|
||||
else:
|
||||
return df
|
||||
else:
|
||||
return pd.DataFrame()
|
||||
dataframes = []
|
||||
for seed in seed_list:
|
||||
f = filename(baseline=baseline, dev_measure=dev_measure, dev_fun=dev_fun,
|
||||
value_discount=value_discount, beta=beta, env_name=env_name,
|
||||
noops=noops, path=path, suffix=suffix, seed=int(seed))
|
||||
df_part = try_loading(f, final)
|
||||
dataframes.append(df_part)
|
||||
df = pd.concat(dataframes)
|
||||
return df
|
||||
@@ -0,0 +1,241 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "rINYEKJlYpQU"
|
||||
},
|
||||
"source": [
|
||||
"Copyright 2019 DeepMind Technologies Limited.\n",
|
||||
"\n",
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"you may not use this file except in compliance with the License.\n",
|
||||
"You may obtain a copy of the License at\n",
|
||||
"\n",
|
||||
"https://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"\n",
|
||||
"Unless required by applicable law or agreed to in writing, software\n",
|
||||
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"See the License for the specific language governing permissions and\n",
|
||||
"limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "KbCarv91XChI"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import absolute_import\n",
|
||||
"from __future__ import division\n",
|
||||
"from __future__ import print_function\n",
|
||||
"\n",
|
||||
"from google.colab import files\n",
|
||||
"import io\n",
|
||||
"import pandas as pd\n",
|
||||
"import seaborn as sns"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "irahOycBZM1E"
|
||||
},
|
||||
"source": [
|
||||
"### Plot parameters (edit as needed)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "VeaB-Y9TYWCP"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Make a bar plot for average results from the final 100 episodes (True),\n",
|
||||
"# or make a learning curve plot (False)\n",
|
||||
"bar_plot = False\n",
|
||||
"\n",
|
||||
"# Compare different penalties using the best beta value for each penalty (True),\n",
|
||||
"# or compare different beta values for the same penalty (False):\n",
|
||||
"compare_penalties = False\n",
|
||||
"\n",
|
||||
"# If compare_penalties is False, specify the penalty parameters:\n",
|
||||
"dev_measure = 'rel_reach'\n",
|
||||
"dev_fun = 'truncation'\n",
|
||||
"value_discount = 0.99\n",
|
||||
"\n",
|
||||
"# Environment name\n",
|
||||
"env_name = 'box'\n",
|
||||
"\n",
|
||||
"# Filename suffix\n",
|
||||
"suffix = '' "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "t8hYO3f2ZaHq"
|
||||
},
|
||||
"source": [
|
||||
"### Plot settings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "foyP_qrUeTsx"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"final_str = '_final' if bar_plot else ''\n",
|
||||
"if compare_penalties:\n",
|
||||
" var = 'label'\n",
|
||||
" x_label = 'deviation_measure'\n",
|
||||
" legend_title = 'penalty'\n",
|
||||
" palette = sns.color_palette()\n",
|
||||
" filename = ('df_summary_penalties_' + env_name + final_str + suffix\n",
|
||||
" + '.csv')\n",
|
||||
"else:\n",
|
||||
" var = 'beta'\n",
|
||||
" x_label = 'beta'\n",
|
||||
" legend_title = 'beta'\n",
|
||||
" palette = sns.cubehelix_palette()\n",
|
||||
" filename = ('df_summary_betas_' + env_name + '_' + dev_measure + '_' + dev_fun \n",
|
||||
" + '_' + str(value_discount) + final_str + suffix + '.csv')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "5RakwQsFZc0V"
|
||||
},
|
||||
"source": [
|
||||
"### Load summary data output by results_summary.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "PQ615uuGYLF3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"uploaded = files.upload()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "-FoI7u8BtYol"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = pd.read_csv(io.BytesIO(uploaded[filename]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "A7BGierpZi6t"
|
||||
},
|
||||
"source": [
|
||||
"### Make bar plots"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "ch06DzzeXlGK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot = sns.catplot(data=df, col='baseline', x=var, y='performance_smooth',\n",
|
||||
" kind='bar', height=4, aspect=1.3)\n",
|
||||
"axes = plot.axes.flatten()\n",
|
||||
"for ax in axes:\n",
|
||||
" title = ax.get_title().split()\n",
|
||||
" ax.set_title(title[2] + ' baseline')\n",
|
||||
" ax.set_ylabel('performance')\n",
|
||||
" ax.set_xlabel(x_label)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "BU9PYyzOZlRu"
|
||||
},
|
||||
"source": [
|
||||
"### Make learning curve plots"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "RRfU_2iIX-jo"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot = sns.FacetGrid(df, col='baseline', size=5, aspect=1.3,\n",
|
||||
" sharey=False, sharex=False)\n",
|
||||
"plot.map_dataframe(sns.tsplot, time='episode', unit='seed', condition=var,\n",
|
||||
" value='performance_smooth', n_boot=100, color=palette,\n",
|
||||
" alpha=1.0, linewidth=1)\n",
|
||||
"plot.add_legend(title=legend_title)\n",
|
||||
"axes = plot.axes.flatten()\n",
|
||||
"for ax in axes:\n",
|
||||
" title = ax.get_title().split()\n",
|
||||
" ax.set_title(title[2] + ' baseline')\n",
|
||||
" ax.set_ylabel('performance')\n",
|
||||
" ax.set_xlabel('episode')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "plot_results.ipynb",
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1a8ub19XYD4M-r5mGm0lKYTrNwTo1zF7Z",
|
||||
"timestamp": 1569850224175
|
||||
}
|
||||
]
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 2",
|
||||
"name": "python2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
absl-py==0.7.1
|
||||
activity-log-manager==0.8.0
|
||||
apt-xapian-index==0.49
|
||||
asn1crypto==0.24.0
|
||||
attrs==18.2.0
|
||||
Automat==0.6.0
|
||||
backports.functools-lru-cache==1.5
|
||||
bcrypt==3.1.6
|
||||
beautifulsoup4==4.6.3
|
||||
blinker==1.4
|
||||
ccsm==0.9.13.1
|
||||
certifi==2018.8.24
|
||||
chardet==3.0.4
|
||||
Click==7.0
|
||||
colorama==0.3.7
|
||||
compizconfig-python==0.9.13.1
|
||||
configparser==3.5.0b2
|
||||
constantly==15.1.0
|
||||
CredentialKit==0.7
|
||||
cryptography==2.3
|
||||
cycler==0.10.0
|
||||
defer==1.0.6
|
||||
defusedxml==0.5.0
|
||||
dirspec==13.10
|
||||
duplicity==0.7.18.2
|
||||
entrypoints==0.3
|
||||
enum34==1.1.6
|
||||
fanotify==0.1
|
||||
fasteners==0.12.0
|
||||
fpconst==0.7.2
|
||||
future==0.16.0
|
||||
glinux-identity==1
|
||||
goobuntu-config-tools==0.1
|
||||
goobuntu-sso-watcher==0.1
|
||||
goobuntu-welcome==11
|
||||
googlenetworkaccess==0.1
|
||||
gpg==1.12.0
|
||||
gprof2dot==2017.9.19
|
||||
hg-evolve==9.2.0.dev0
|
||||
html5lib==1.0.1
|
||||
httplib2==0.11.3
|
||||
hyperlink==17.3.1
|
||||
idna==2.6
|
||||
incremental==16.10.1
|
||||
inotifyx==0.2.0
|
||||
ipaddress==1.0.17
|
||||
IPy==0.83
|
||||
kernel-pruner==47
|
||||
keyring==17.1.1
|
||||
keyrings.alt==3.1.1
|
||||
kiwisolver==1.1.0
|
||||
lockfile==0.12.2
|
||||
lxml==4.3.2
|
||||
lz4==1.1.0+dfsg
|
||||
matplotlib==2.2.4
|
||||
mercurial==5.1.1+194.5ca351ba2478
|
||||
monotonic==1.0
|
||||
mox==0.5.3
|
||||
numpy==1.16.4
|
||||
oauthlib==2.1.0
|
||||
olefile==0.46
|
||||
PAM==0.4.2
|
||||
pandas==0.24.2
|
||||
paramiko==2.4.2
|
||||
parse==1.6.6
|
||||
pexpect==4.6.0
|
||||
Pillow==4.3.0
|
||||
protobuf==3.6.1
|
||||
psutil==5.5.1
|
||||
pyasn1==0.4.2
|
||||
pyasn1-modules==0.2.1
|
||||
pycairo==1.16.2
|
||||
pycrypto==2.6.1
|
||||
pycups==1.9.73
|
||||
pycurl==7.43.0.2
|
||||
PyGObject==3.30.4
|
||||
pyinotify==0.9.6
|
||||
PyJWT==1.7.0
|
||||
PyKCS11==1.2.4
|
||||
PyNaCl==1.3.0
|
||||
pyOpenSSL==19.0.0
|
||||
pyparsing==2.4.2
|
||||
pyserial==3.4
|
||||
pysmbc==1.0.15.6
|
||||
python-apt==1.8.4
|
||||
python-augeas==0.5.0
|
||||
python-dateutil==2.8.0
|
||||
python-debian==0.1.34
|
||||
python-networkmanager==2.1
|
||||
python2-pythondialog==3.3.0
|
||||
pytz==2019.2
|
||||
pyudev==0.21.0
|
||||
pyxattr==0.6.1
|
||||
pyxdg==0.25
|
||||
PyYAML==3.13
|
||||
rekey==1
|
||||
reportlab==3.5.13
|
||||
requests==2.21.0
|
||||
scipy==1.2.2
|
||||
scour==0.37
|
||||
seaborn==0.9.0
|
||||
SecretStorage==2.3.1
|
||||
service-identity==16.0.0
|
||||
six==1.12.0
|
||||
SOAPpy==0.12.22
|
||||
subprocess32==3.5.4
|
||||
Twisted==18.9.0
|
||||
urllib3==1.24.1
|
||||
webencodings==0.5.1
|
||||
wstools==0.4.3
|
||||
zope.interface==4.3.2
|
||||
@@ -0,0 +1,183 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Plot results for different side effects penalties.
|
||||
|
||||
Loads csv result files generated by `run_experiment' and outputs a summary data
|
||||
frame in a csv file to be used for plotting by plot_results.ipynb.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import pandas as pd
|
||||
from side_effects_penalties.file_loading import load_files
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
if __name__ == '__main__': # Avoid defining flags when used as a library.
|
||||
flags.DEFINE_string('path', '', 'File path.')
|
||||
flags.DEFINE_string('input_suffix', '',
|
||||
'Filename suffix to use when loading data files.')
|
||||
flags.DEFINE_string('output_suffix', '',
|
||||
'Filename suffix to use when saving files.')
|
||||
flags.DEFINE_bool('bar_plot', True,
|
||||
'Make a data frame for a bar plot (True) ' +
|
||||
'or learning curves (False)')
|
||||
flags.DEFINE_string('env_name', 'box', 'Environment name.')
|
||||
flags.DEFINE_bool('noops', True, 'Whether the environment includes noops.')
|
||||
flags.DEFINE_list('beta_list', [0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0],
|
||||
'List of beta values.')
|
||||
flags.DEFINE_list('seed_list', [1], 'List of random seeds.')
|
||||
flags.DEFINE_bool('compare_penalties', True,
|
||||
'Compare different penalties using the best beta value ' +
|
||||
'for each penalty (True), or compare different beta values '
|
||||
+ 'for the same penalty (False).')
|
||||
flags.DEFINE_enum('dev_measure', 'rel_reach',
|
||||
['none', 'reach', 'rel_reach', 'att_util'],
|
||||
'Deviation measure (used if compare_penalties=False).')
|
||||
flags.DEFINE_enum('dev_fun', 'truncation', ['truncation', 'absolute'],
|
||||
'Summary function for the deviation measure ' +
|
||||
'(used if compare_penalties=False)')
|
||||
flags.DEFINE_float('value_discount', 0.99,
|
||||
'Discount factor for deviation measure value function ' +
|
||||
'(used if compare_penalties=False)')
|
||||
|
||||
|
||||
def beta_choice(baseline, dev_measure, dev_fun, value_discount, env_name,
|
||||
beta_list, seed_list, noops=False, path='', suffix=''):
|
||||
"""Choose beta value that gives the highest final performance."""
|
||||
if dev_measure == 'none':
|
||||
return 0.1
|
||||
perf_max = float('-inf')
|
||||
best_beta = 0.0
|
||||
for beta in beta_list:
|
||||
df = load_files(baseline=baseline, dev_measure=dev_measure,
|
||||
dev_fun=dev_fun, value_discount=value_discount, beta=beta,
|
||||
env_name=env_name, noops=noops, path=path, suffix=suffix,
|
||||
seed_list=seed_list)
|
||||
if df.empty:
|
||||
perf = float('-inf')
|
||||
else:
|
||||
perf = df['performance_smooth'].mean()
|
||||
if perf > perf_max:
|
||||
perf_max = perf
|
||||
best_beta = beta
|
||||
return best_beta
|
||||
|
||||
|
||||
def penalty_label(dev_measure, dev_fun, value_discount):
|
||||
"""Penalty label specifying design choices."""
|
||||
dev_measure_labels = {
|
||||
'none': 'None', 'rel_reach': 'RR', 'att_util': 'AU', 'reach': 'UR'}
|
||||
label = dev_measure_labels[dev_measure]
|
||||
disc_lab = 'u' if value_discount == 1.0 else 'd'
|
||||
dev_lab = ''
|
||||
if dev_measure in ['rel_reach', 'att_util']:
|
||||
dev_lab = 't' if dev_fun == 'truncation' else 'a'
|
||||
if dev_measure != 'none':
|
||||
label = label + '(' + disc_lab + dev_lab + ')'
|
||||
return label
|
||||
|
||||
|
||||
def make_summary_data_frame(
|
||||
env_name, beta_list, seed_list, final=True, baseline=None, dev_measure=None,
|
||||
dev_fun=None, value_discount=None, noops=False, compare_penalties=True,
|
||||
path='', input_suffix='', output_suffix=''):
|
||||
"""Make summary dataframe from multiple csv result files and output to csv."""
|
||||
# For each of the penalty parameters (baseline, dev_measure, dev_fun, and
|
||||
# value_discount), compare a list of multiple values if the parameter is None,
|
||||
# or use the provided parameter value if it is not None
|
||||
baseline_list = ['start', 'inaction', 'stepwise', 'step_noroll']
|
||||
if dev_measure is not None:
|
||||
dev_measure_list = [dev_measure]
|
||||
else:
|
||||
dev_measure_list = ['none', 'reach', 'rel_reach', 'att_util']
|
||||
dataframes = []
|
||||
for dev_measure in dev_measure_list:
|
||||
# These deviation measures don't have a deviation function:
|
||||
if dev_measure in ['reach', 'none']:
|
||||
dev_fun_list = ['none']
|
||||
elif dev_fun is not None:
|
||||
dev_fun_list = [dev_fun]
|
||||
else:
|
||||
dev_fun_list = ['truncation', 'absolute']
|
||||
# These deviation measures must be discounted:
|
||||
if dev_measure in ['none', 'att_util']:
|
||||
value_discount_list = [0.99]
|
||||
elif value_discount is not None:
|
||||
value_discount_list = [value_discount]
|
||||
else:
|
||||
value_discount_list = [0.99, 1.0]
|
||||
for baseline in baseline_list:
|
||||
for vd in value_discount_list:
|
||||
for devf in dev_fun_list:
|
||||
# Choose the best beta for this set of penalty parameters if
|
||||
# compare_penalties=True, or compare all betas otherwise
|
||||
if compare_penalties:
|
||||
beta = beta_choice(
|
||||
baseline=baseline, dev_measure=dev_measure, dev_fun=devf,
|
||||
value_discount=vd, env_name=env_name, noops=noops,
|
||||
beta_list=beta_list, seed_list=seed_list, path=path,
|
||||
suffix=input_suffix)
|
||||
betas = [beta]
|
||||
else:
|
||||
betas = beta_list
|
||||
for beta in betas:
|
||||
label = penalty_label(
|
||||
dev_measure=dev_measure, dev_fun=devf, value_discount=vd)
|
||||
df_part = load_files(
|
||||
baseline=baseline, dev_measure=dev_measure, dev_fun=devf,
|
||||
value_discount=vd, beta=beta, env_name=env_name,
|
||||
noops=noops, path=path, suffix=input_suffix, final=final,
|
||||
seed_list=seed_list)
|
||||
df_part = df_part.assign(
|
||||
baseline=baseline, dev_measure=dev_measure, dev_fun=devf,
|
||||
value_discount=vd, beta=beta, env_name=env_name, label=label)
|
||||
dataframes.append(df_part)
|
||||
df = pd.concat(dataframes, sort=False)
|
||||
# Output summary data frame
|
||||
final_str = '_final' if final else ''
|
||||
if compare_penalties:
|
||||
filename = ('df_summary_penalties_' + env_name + final_str +
|
||||
output_suffix + '.csv')
|
||||
else:
|
||||
filename = ('df_summary_betas_' + env_name + '_' + dev_measure + '_' +
|
||||
dev_fun + '_' + str(value_discount) + final_str + output_suffix
|
||||
+ '.csv')
|
||||
f = os.path.join(path, filename)
|
||||
df.to_csv(f)
|
||||
return df
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
compare_penalties = FLAGS.compare_penalties
|
||||
dev_measure = None if compare_penalties else FLAGS.dev_measure
|
||||
dev_fun = None if compare_penalties else FLAGS.dev_fun
|
||||
value_discount = None if compare_penalties else FLAGS.value_discount
|
||||
make_summary_data_frame(
|
||||
compare_penalties=compare_penalties, env_name=FLAGS.env_name,
|
||||
noops=FLAGS.noops, final=FLAGS.bar_plot, dev_measure=dev_measure,
|
||||
value_discount=value_discount, dev_fun=dev_fun, path=FLAGS.path,
|
||||
input_suffix=FLAGS.input_suffix, output_suffix=FLAGS.output_suffix,
|
||||
beta_list=FLAGS.beta_list, seed_list=FLAGS.seed_list)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,130 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Run a Q-learning agent with a side effects penalty."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import pandas as pd
|
||||
from six.moves import range
|
||||
from six.moves import zip
|
||||
from side_effects_penalties import agent_with_penalties
|
||||
from side_effects_penalties import training
|
||||
from side_effects_penalties.file_loading import filename
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
if __name__ == '__main__': # Avoid defining flags when used as a library.
|
||||
flags.DEFINE_enum('baseline', 'stepwise',
|
||||
['start', 'inaction', 'stepwise', 'step_noroll'],
|
||||
'Baseline.')
|
||||
flags.DEFINE_enum('dev_measure', 'rel_reach',
|
||||
['none', 'reach', 'rel_reach', 'att_util'],
|
||||
'Deviation measure.')
|
||||
flags.DEFINE_enum('dev_fun', 'truncation', ['truncation', 'absolute'],
|
||||
'Summary function for the deviation measure.')
|
||||
flags.DEFINE_float('discount', 0.99, 'Discount factor for rewards.')
|
||||
flags.DEFINE_float('value_discount', 0.99,
|
||||
'Discount factor for deviation measure value function.')
|
||||
flags.DEFINE_float('beta', 30.0, 'Weight for side effects penalty.')
|
||||
flags.DEFINE_bool('anneal', True,
|
||||
'Whether to anneal the exploration rate from 1 to 0.')
|
||||
flags.DEFINE_integer('num_episodes', 10000, 'Number of episodes.')
|
||||
flags.DEFINE_integer('num_episodes_noexp', 0,
|
||||
'Number of episodes with no exploration.')
|
||||
flags.DEFINE_integer('seed', 1, 'Random seed.')
|
||||
flags.DEFINE_string('env_name', 'box', 'Environment name.')
|
||||
flags.DEFINE_bool('noops', True, 'Whether the environment includes noops.')
|
||||
flags.DEFINE_bool('exact_baseline', False,
|
||||
'Compute the exact baseline using an environment copy.')
|
||||
flags.DEFINE_enum('mode', 'save', ['print', 'save'],
|
||||
'Print results or save to file.')
|
||||
flags.DEFINE_string('path', '', 'File path.')
|
||||
flags.DEFINE_string('suffix', '', 'Filename suffix.')
|
||||
|
||||
|
||||
def run_experiment(baseline, dev_measure, dev_fun, discount, value_discount,
|
||||
beta, anneal, num_episodes, num_episodes_noexp, seed,
|
||||
env_name, noops, exact_baseline, mode, path, suffix):
|
||||
"""Run agent and save or print the results."""
|
||||
performances = []
|
||||
rewards = []
|
||||
seeds = []
|
||||
episodes = []
|
||||
if dev_measure not in ['rel_reach', 'att_util']:
|
||||
dev_fun = 'none'
|
||||
reward, performance = training.run_agent(
|
||||
baseline=baseline, dev_measure=dev_measure, dev_fun=dev_fun,
|
||||
discount=discount, value_discount=value_discount, beta=beta,
|
||||
anneal=anneal, num_episodes=num_episodes,
|
||||
num_episodes_noexp=num_episodes_noexp, seed=seed, env_name=env_name,
|
||||
noops=noops, agent_class=agent_with_penalties.QLearningSE,
|
||||
exact_baseline=exact_baseline)
|
||||
rewards.extend(reward)
|
||||
performances.extend(performance)
|
||||
seeds.extend([seed] * (num_episodes + num_episodes_noexp))
|
||||
episodes.extend(list(range(num_episodes + num_episodes_noexp)))
|
||||
if mode == 'save':
|
||||
d = {'reward': rewards, 'performance': performances,
|
||||
'seed': seeds, 'episode': episodes}
|
||||
df = pd.DataFrame(d)
|
||||
df1 = add_smoothed_data(df)
|
||||
f = filename(env_name, noops, dev_measure, dev_fun, baseline, beta,
|
||||
value_discount, path=path, suffix=suffix, seed=seed)
|
||||
df1.to_csv(f)
|
||||
return reward, performance
|
||||
|
||||
|
||||
def _smooth(values, window=100):
|
||||
return values.rolling(window,).mean()
|
||||
|
||||
|
||||
def add_smoothed_data(df, groupby='seed', window=100):
|
||||
grouped = df.groupby(groupby)[['reward', 'performance']]
|
||||
grouped = grouped.apply(_smooth, window=window).rename(columns={
|
||||
'performance': 'performance_smooth', 'reward': 'reward_smooth'})
|
||||
temp = pd.concat([df, grouped], axis=1)
|
||||
return temp
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
reward, performance = run_experiment(
|
||||
baseline=FLAGS.baseline,
|
||||
dev_measure=FLAGS.dev_measure,
|
||||
dev_fun=FLAGS.dev_fun,
|
||||
discount=FLAGS.discount,
|
||||
value_discount=FLAGS.value_discount,
|
||||
beta=FLAGS.beta,
|
||||
anneal=FLAGS.anneal,
|
||||
num_episodes=FLAGS.num_episodes,
|
||||
num_episodes_noexp=FLAGS.num_episodes_noexp,
|
||||
seed=FLAGS.seed,
|
||||
env_name=FLAGS.env_name,
|
||||
noops=FLAGS.noops,
|
||||
exact_baseline=FLAGS.exact_baseline,
|
||||
mode=FLAGS.mode,
|
||||
path=FLAGS.path,
|
||||
suffix=FLAGS.suffix)
|
||||
if FLAGS.mode == 'print':
|
||||
print('Performance and reward in the last 10 steps:')
|
||||
print(list(zip(performance, reward))[-10:-1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,476 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Side Effects Penalties.
|
||||
|
||||
Abstract class for implementing a side effects (impact measure) penalty,
|
||||
and various concrete penalties deriving from it.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import copy
|
||||
import enum
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import range
|
||||
from six.moves import zip
|
||||
|
||||
|
||||
class Actions(enum.IntEnum):
|
||||
"""Enum for actions the agent can take."""
|
||||
UP = 0
|
||||
DOWN = 1
|
||||
LEFT = 2
|
||||
RIGHT = 3
|
||||
NOOP = 4
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Baseline(object):
|
||||
"""Base class for baseline states."""
|
||||
|
||||
def __init__(self, start_timestep, exact=False, env=None,
|
||||
timestep_to_state=None):
|
||||
"""Create a baseline.
|
||||
|
||||
Args:
|
||||
start_timestep: starting state timestep
|
||||
exact: whether to use an exact or approximate baseline
|
||||
env: a copy of the environment (used to simulate exact baselines)
|
||||
timestep_to_state: a function that turns timesteps into states
|
||||
"""
|
||||
self._exact = exact
|
||||
self._env = env
|
||||
self._timestep_to_state = timestep_to_state
|
||||
self._start_timestep = start_timestep
|
||||
self._baseline_state = self._timestep_to_state(self._start_timestep)
|
||||
self._inaction_next = collections.defaultdict(
|
||||
lambda: collections.defaultdict(lambda: 0))
|
||||
|
||||
@abc.abstractmethod
|
||||
def calculate(self):
|
||||
"""Update and return the baseline state."""
|
||||
|
||||
def sample(self, state):
|
||||
"""Sample the outcome of a noop in `state`."""
|
||||
d = self._inaction_next[state]
|
||||
counts = np.array(list(d.values()))
|
||||
index = np.random.choice(a=len(counts), p=counts/sum(counts))
|
||||
return list(d.keys())[index]
|
||||
|
||||
def reset(self):
|
||||
"""Signal start of new episode."""
|
||||
self._baseline_state = self._timestep_to_state(self._start_timestep)
|
||||
if self._exact:
|
||||
self._env.reset()
|
||||
|
||||
@abc.abstractproperty
|
||||
def rollout_func(self):
|
||||
"""Function to compute a rollout chain, or None if n/a."""
|
||||
|
||||
|
||||
class StartBaseline(Baseline):
|
||||
"""Starting state baseline."""
|
||||
|
||||
def calculate(self, *unused_args):
|
||||
return self._baseline_state
|
||||
|
||||
@property
|
||||
def rollout_func(self):
|
||||
return None
|
||||
|
||||
|
||||
class InactionBaseline(Baseline):
|
||||
"""Inaction baseline: the state resulting from taking no-ops from start."""
|
||||
|
||||
def calculate(self, prev_state, action, current_state):
|
||||
if self._exact:
|
||||
self._baseline_state = self._timestep_to_state(
|
||||
self._env.step(Actions.NOOP))
|
||||
else:
|
||||
if action == Actions.NOOP:
|
||||
self._inaction_next[prev_state][current_state] += 1
|
||||
if self._baseline_state in self._inaction_next:
|
||||
self._baseline_state = self.sample(self._baseline_state)
|
||||
return self._baseline_state
|
||||
|
||||
@property
|
||||
def rollout_func(self):
|
||||
return None
|
||||
|
||||
|
||||
class StepwiseBaseline(Baseline):
|
||||
"""Stepwise baseline: the state one no-op after the previous state."""
|
||||
|
||||
def __init__(self, start_timestep, exact=False, env=None,
|
||||
timestep_to_state=None, use_rollouts=True):
|
||||
"""Create a stepwise baseline.
|
||||
|
||||
Args:
|
||||
start_timestep: starting state timestep
|
||||
exact: whether to use an exact or approximate baseline
|
||||
env: a copy of the environment (used to simulate exact baselines)
|
||||
timestep_to_state: a function that turns timesteps into states
|
||||
use_rollouts: whether to use inaction rollouts
|
||||
"""
|
||||
super(StepwiseBaseline, self).__init__(
|
||||
start_timestep, exact, env, timestep_to_state)
|
||||
self._rollouts = use_rollouts
|
||||
|
||||
def calculate(self, prev_state, action, current_state):
|
||||
"""Update and return the baseline state.
|
||||
|
||||
Args:
|
||||
prev_state: the state in which `action` was taken
|
||||
action: the action just taken
|
||||
current_state: the state resulting from taking `action`
|
||||
Returns:
|
||||
the baseline state, for computing the penalty for this transition
|
||||
"""
|
||||
if self._exact:
|
||||
if prev_state in self._inaction_next:
|
||||
self._baseline_state = self.sample(prev_state)
|
||||
else:
|
||||
inaction_env = copy.deepcopy(self._env)
|
||||
timestep_inaction = inaction_env.step(Actions.NOOP)
|
||||
self._baseline_state = self._timestep_to_state(timestep_inaction)
|
||||
self._inaction_next[prev_state][self._baseline_state] += 1
|
||||
timestep_action = self._env.step(action)
|
||||
assert current_state == self._timestep_to_state(timestep_action)
|
||||
else:
|
||||
if action == Actions.NOOP:
|
||||
self._inaction_next[prev_state][current_state] += 1
|
||||
if prev_state in self._inaction_next:
|
||||
self._baseline_state = self.sample(prev_state)
|
||||
else:
|
||||
self._baseline_state = prev_state
|
||||
return self._baseline_state
|
||||
|
||||
def _inaction_rollout(self, state):
|
||||
"""Compute an (approximate) inaction rollout from a state."""
|
||||
chain = []
|
||||
st = state
|
||||
while st not in chain:
|
||||
chain.append(st)
|
||||
if st in self._inaction_next:
|
||||
st = self.sample(st)
|
||||
return chain
|
||||
|
||||
def parallel_inaction_rollouts(self, s1, s2):
|
||||
"""Compute (approximate) parallel inaction rollouts from two states."""
|
||||
chain = []
|
||||
states = (s1, s2)
|
||||
while states not in chain:
|
||||
chain.append(states)
|
||||
s1, s2 = states
|
||||
states = (self.sample(s1) if s1 in self._inaction_next else s1,
|
||||
self.sample(s2) if s2 in self._inaction_next else s2)
|
||||
return chain
|
||||
|
||||
@property
|
||||
def rollout_func(self):
|
||||
return self._inaction_rollout if self._rollouts else None
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class DeviationMeasure(object):
|
||||
"""Base class for deviation measures."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def calculate(self):
|
||||
"""Calculate the deviation between two states."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self):
|
||||
"""Update any models after seeing a state transition."""
|
||||
|
||||
|
||||
class ReachabilityMixin(object):
|
||||
"""Class for computing reachability deviation measure.
|
||||
|
||||
Computes the relative/un- reachability given a dictionary of
|
||||
reachability scores for pairs of states.
|
||||
|
||||
Expects _reachability, _discount, and _dev_fun attributes to exist in the
|
||||
inheriting class.
|
||||
"""
|
||||
|
||||
def calculate(self, current_state, baseline_state, rollout_func=None):
|
||||
"""Calculate relative/un- reachability between particular states."""
|
||||
# relative reachability case
|
||||
if self._dev_fun:
|
||||
if rollout_func:
|
||||
curr_values = self._rollout_values(rollout_func(current_state))
|
||||
base_values = self._rollout_values(rollout_func(baseline_state))
|
||||
else:
|
||||
curr_values = self._reachability[current_state]
|
||||
base_values = self._reachability[baseline_state]
|
||||
all_s = set(list(curr_values.keys()) + list(base_values.keys()))
|
||||
total = 0
|
||||
for s in all_s:
|
||||
diff = base_values[s] - curr_values[s]
|
||||
total += self._dev_fun(diff)
|
||||
d = total / len(all_s)
|
||||
# unreachability case
|
||||
else:
|
||||
assert rollout_func is None
|
||||
d = 1 - self._reachability[current_state][baseline_state]
|
||||
return d
|
||||
|
||||
def _rollout_values(self, chain):
|
||||
"""Compute stepwise rollout values for the relative reachability penalty.
|
||||
|
||||
Args:
|
||||
chain: chain of states in an inaction rollout starting with the state for
|
||||
which to compute the rollout values
|
||||
|
||||
Returns:
|
||||
a dictionary of the form:
|
||||
{ s : (1-discount) sum_{k=0}^inf discount^k R_s(S_k) }
|
||||
where S_k is the k-th state in the inaction rollout from 'state',
|
||||
s is a state, and
|
||||
R_s(S_k) is the reachability of s from S_k.
|
||||
"""
|
||||
rollout_values = collections.defaultdict(lambda: 0)
|
||||
coeff = 1
|
||||
for st in chain:
|
||||
for s, rch in six.iteritems(self._reachability[st]):
|
||||
rollout_values[s] += coeff * rch * (1.0 - self._discount)
|
||||
coeff *= self._discount
|
||||
last_state = chain[-1]
|
||||
for s, rch in six.iteritems(self._reachability[last_state]):
|
||||
rollout_values[s] += coeff * rch
|
||||
return rollout_values
|
||||
|
||||
|
||||
class Reachability(ReachabilityMixin, DeviationMeasure):
|
||||
"""Approximate (relative) (un)reachability deviation measure.
|
||||
|
||||
Unreachability (the default, when `dev_fun=None`) uses the length (say, n)
|
||||
of the shortest path (sequence of actions) from the current state to the
|
||||
baseline state. The reachability score is value_discount ** n.
|
||||
Unreachability is then 1.0 - the reachability score.
|
||||
|
||||
Relative reachability (when `dev_fun` is not `None`) considers instead the
|
||||
difference in reachability of all other states from the current state
|
||||
versus from the baseline state.
|
||||
|
||||
We approximate reachability by only considering state transitions
|
||||
that have been observed. Add transitions using the `update` function.
|
||||
"""
|
||||
|
||||
def __init__(self, value_discount=1.0, dev_fun=None, discount=None):
|
||||
self._value_discount = value_discount
|
||||
self._dev_fun = dev_fun
|
||||
self._discount = discount
|
||||
self._reachability = collections.defaultdict(
|
||||
lambda: collections.defaultdict(lambda: 0))
|
||||
|
||||
def update(self, prev_state, current_state):
|
||||
self._reachability[prev_state][prev_state] = 1
|
||||
self._reachability[current_state][current_state] = 1
|
||||
if self._reachability[prev_state][current_state] < self._value_discount:
|
||||
for s1 in self._reachability.keys():
|
||||
if self._reachability[s1][prev_state] > 0:
|
||||
for s2 in self._reachability[current_state].keys():
|
||||
if self._reachability[current_state][s2] > 0:
|
||||
self._reachability[s1][s2] = max(
|
||||
self._reachability[s1][s2],
|
||||
self._reachability[s1][prev_state] * self._value_discount *
|
||||
self._reachability[current_state][s2])
|
||||
|
||||
@property
|
||||
def discount(self):
|
||||
return self._discount
|
||||
|
||||
|
||||
class AttainableUtilityMixin(object):
|
||||
"""Class for computing attainable utility measure.
|
||||
|
||||
Computes attainable utility (averaged over a set of utility functions)
|
||||
given value functions for each utility function.
|
||||
|
||||
Expects _u_values, _discount, _value_discount, and _dev_fun attributes to
|
||||
exist in the inheriting class.
|
||||
"""
|
||||
|
||||
def calculate(self, current_state, baseline_state, rollout_func=None):
|
||||
if rollout_func:
|
||||
current_values = self._rollout_values(rollout_func(current_state))
|
||||
baseline_values = self._rollout_values(rollout_func(baseline_state))
|
||||
else:
|
||||
current_values = [u_val[current_state] for u_val in self._u_values]
|
||||
baseline_values = [u_val[baseline_state] for u_val in self._u_values]
|
||||
penalties = [self._dev_fun(base_val - cur_val) * (1. - self._value_discount)
|
||||
for base_val, cur_val in zip(baseline_values, current_values)]
|
||||
return sum(penalties) / len(penalties)
|
||||
|
||||
def _rollout_values(self, chain):
|
||||
"""Compute stepwise rollout values for the attainable utility penalty.
|
||||
|
||||
Args:
|
||||
chain: chain of states in an inaction rollout starting with the state
|
||||
for which to compute the rollout values
|
||||
|
||||
Returns:
|
||||
a list containing
|
||||
(1-discount) sum_{k=0}^inf discount^k V_u(S_k)
|
||||
for each utility function u,
|
||||
where S_k is the k-th state in the inaction rollout from 'state'.
|
||||
"""
|
||||
rollout_values = [0 for _ in self._u_values]
|
||||
coeff = 1
|
||||
for st in chain:
|
||||
rollout_values = [rv + coeff * u_val[st] * (1.0 - self._discount)
|
||||
for rv, u_val in zip(rollout_values, self._u_values)]
|
||||
coeff *= self._discount
|
||||
last_state = chain[-1]
|
||||
rollout_values = [rv + coeff * u_val[last_state]
|
||||
for rv, u_val in zip(rollout_values, self._u_values)]
|
||||
return rollout_values
|
||||
|
||||
def _set_util_funs(self, util_funs):
|
||||
"""Set up this instance's utility functions.
|
||||
|
||||
Args:
|
||||
util_funs: either a number of functions to generate or a list of
|
||||
pre-defined utility functions, represented as dictionaries
|
||||
over states: util_funs[i][s] = u_i(s), the utility of s
|
||||
according to u_i.
|
||||
"""
|
||||
if isinstance(util_funs, int):
|
||||
self._util_funs = [
|
||||
collections.defaultdict(float) for _ in range(util_funs)
|
||||
]
|
||||
else:
|
||||
self._util_funs = util_funs
|
||||
|
||||
def _utility(self, u, state):
|
||||
"""Apply a random utility function, generating its value if necessary."""
|
||||
if state not in u:
|
||||
u[state] = np.random.random()
|
||||
return u[state]
|
||||
|
||||
|
||||
class AttainableUtility(AttainableUtilityMixin, DeviationMeasure):
|
||||
"""Approximate attainable utility deviation measure."""
|
||||
|
||||
def __init__(self, value_discount=0.99, dev_fun=np.abs, util_funs=10,
|
||||
discount=None):
|
||||
assert value_discount < 1.0 # AU does not converge otherwise
|
||||
self._value_discount = value_discount
|
||||
self._dev_fun = dev_fun
|
||||
self._discount = discount
|
||||
self._set_util_funs(util_funs)
|
||||
# u_values[i][s] = V_{u_i}(s), the (approximate) value of s according to u_i
|
||||
self._u_values = [
|
||||
collections.defaultdict(float) for _ in range(len(self._util_funs))
|
||||
]
|
||||
# predecessors[s] = set of states known to lead, by some action, to s
|
||||
self._predecessors = collections.defaultdict(set)
|
||||
|
||||
def update(self, prev_state, current_state):
|
||||
"""Update predecessors and attainable utility estimates."""
|
||||
self._predecessors[current_state].add(prev_state)
|
||||
seen = set()
|
||||
queue = [current_state]
|
||||
while queue:
|
||||
s_to = queue.pop(0)
|
||||
seen.add(s_to)
|
||||
for u, u_val in zip(self._util_funs, self._u_values):
|
||||
for s_from in self._predecessors[s_to]:
|
||||
v = self._utility(u, s_from) + self._value_discount * u_val[s_to]
|
||||
if u_val[s_from] < v:
|
||||
u_val[s_from] = v
|
||||
if s_from not in seen:
|
||||
queue.append(s_from)
|
||||
|
||||
|
||||
class NoDeviation(DeviationMeasure):
|
||||
"""Dummy deviation measure corresponding to no impact penalty."""
|
||||
|
||||
def calculate(self, *unused_args):
|
||||
return 0
|
||||
|
||||
def update(self, *unused_args):
|
||||
pass
|
||||
|
||||
|
||||
class SideEffectPenalty(object):
|
||||
"""Impact penalty."""
|
||||
|
||||
def __init__(self, baseline, dev_measure, beta=1.0,
|
||||
use_inseparable_rollout=False):
|
||||
"""Make an object to calculate the impact penalty.
|
||||
|
||||
Args:
|
||||
baseline: object for calculating the baseline state
|
||||
dev_measure: object for calculating the deviation between states
|
||||
beta: weight (scaling factor) for the impact penalty
|
||||
use_inseparable_rollout: whether to compute the penalty as the average of
|
||||
deviations over parallel inaction rollouts from the current and
|
||||
baselines states (True) otherwise just between the current state and
|
||||
baseline state (or by whatever rollout value is provided in the
|
||||
baseline) (False)
|
||||
"""
|
||||
self._baseline = baseline
|
||||
self._dev_measure = dev_measure
|
||||
self._beta = beta
|
||||
self._use_inseparable_rollout = use_inseparable_rollout
|
||||
|
||||
def calculate(self, prev_state, action, current_state):
|
||||
"""Calculate the penalty associated with a transition, and update models."""
|
||||
if current_state:
|
||||
self._dev_measure.update(prev_state, current_state)
|
||||
baseline_state = self._baseline.calculate(prev_state, action,
|
||||
current_state)
|
||||
if self._use_inseparable_rollout:
|
||||
penalty = self._rollout_value(current_state, baseline_state,
|
||||
self._dev_measure.discount,
|
||||
self._dev_measure.calculate)
|
||||
else:
|
||||
penalty = self._dev_measure.calculate(current_state, baseline_state,
|
||||
self._baseline.rollout_func)
|
||||
return self._beta * penalty
|
||||
else:
|
||||
return 0
|
||||
|
||||
def reset(self):
|
||||
"""Signal start of new episode."""
|
||||
self._baseline.reset()
|
||||
|
||||
def _rollout_value(self, cur_state, base_state, discount, func):
|
||||
"""Compute stepwise rollout value for unreachability."""
|
||||
# Returns (1-discount) sum_{k=0}^inf discount^k R(S_{t,t+k}, S'_{t,t+k}),
|
||||
# where S_{t,t+k} is k-th state in the inaction rollout from current state,
|
||||
# S'_{t,t+k} is k-th state in the inaction rollout from baseline state,
|
||||
# and R is the reachability function.
|
||||
chain = self._baseline.parallel_inaction_rollouts(cur_state, base_state)
|
||||
coeff = 1
|
||||
rollout_value = 0
|
||||
for states in chain:
|
||||
rollout_value += (coeff * func(states[0], states[1]) * (1.0 - discount))
|
||||
coeff *= discount
|
||||
last_states = chain[-1]
|
||||
rollout_value += coeff * func(last_states[0], last_states[1])
|
||||
return rollout_value
|
||||
|
||||
@property
|
||||
def beta(self):
|
||||
return self._beta
|
||||
@@ -0,0 +1,241 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tests for side_effects_penalty."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
from side_effects_penalties import side_effects_penalty
|
||||
from side_effects_penalties import training
|
||||
from side_effects_penalties.side_effects_penalty import Actions
|
||||
|
||||
|
||||
environments = ['box', 'vase', 'sushi_goal']
|
||||
|
||||
|
||||
class SideEffectsTestCase(parameterized.TestCase):
|
||||
|
||||
def _timestep_to_state(self, timestep):
|
||||
return tuple(map(tuple, np.copy(timestep.observation['board'])))
|
||||
|
||||
def _env_to_action_range(self, env):
|
||||
action_spec = env.action_spec()
|
||||
action_range = list(range(action_spec.minimum, action_spec.maximum + 1))
|
||||
return action_range
|
||||
|
||||
|
||||
class BaselineTestCase(SideEffectsTestCase):
|
||||
|
||||
def _create_baseline(self, env_name):
|
||||
self._env = training.get_env(env_name, True)
|
||||
self._baseline_env = training.get_env(env_name, True)
|
||||
baseline_class = getattr(side_effects_penalty,
|
||||
self.__class__.__name__[:-4]) # remove 'Test'
|
||||
self._baseline = baseline_class(
|
||||
self._env.reset(), True, self._baseline_env, self._timestep_to_state)
|
||||
|
||||
def _test_trajectory(self, actions, key):
|
||||
init_state = self._timestep_to_state(self._env.reset())
|
||||
self._baseline.reset()
|
||||
current_state = init_state
|
||||
for action in actions:
|
||||
timestep = self._env.step(action)
|
||||
next_state = self._timestep_to_state(timestep)
|
||||
baseline_state = self._baseline.calculate(current_state, action,
|
||||
next_state)
|
||||
comparison_dict = {
|
||||
'current_state': current_state,
|
||||
'next_state': next_state,
|
||||
'init_state': init_state
|
||||
}
|
||||
self.assertEqual(baseline_state, comparison_dict[key])
|
||||
current_state = next_state
|
||||
|
||||
|
||||
class StartBaselineTest(BaselineTestCase):
|
||||
|
||||
@parameterized.parameters(*environments)
|
||||
def testInit(self, env_name):
|
||||
self._create_baseline(env_name)
|
||||
self._test_trajectory([Actions.NOOP], 'init_state')
|
||||
|
||||
@parameterized.parameters(*environments)
|
||||
def testTenNoops(self, env_name):
|
||||
self._create_baseline(env_name)
|
||||
self._test_trajectory([Actions.NOOP for _ in range(10)], 'init_state')
|
||||
|
||||
|
||||
class InactionBaselineTest(BaselineTestCase):
|
||||
|
||||
box_action_spec = training.get_env('box', True).action_spec()
|
||||
|
||||
@parameterized.parameters(
|
||||
*list(range(box_action_spec.minimum, box_action_spec.maximum + 1)))
|
||||
def testStaticEnvOneAction(self, action):
|
||||
self._create_baseline('box')
|
||||
self._test_trajectory([action], 'init_state')
|
||||
|
||||
def testStaticEnvRandomActions(self):
|
||||
self._create_baseline('box')
|
||||
num_steps = np.random.randint(low=1, high=20)
|
||||
action_range = self._env_to_action_range(self._env)
|
||||
actions = [np.random.choice(action_range) for _ in range(num_steps)]
|
||||
self._test_trajectory(actions, 'init_state')
|
||||
|
||||
@parameterized.parameters(*environments)
|
||||
def testInactionPolicy(self, env_name):
|
||||
self._create_baseline(env_name)
|
||||
num_steps = np.random.randint(low=1, high=20)
|
||||
self._test_trajectory([Actions.NOOP for _ in range(num_steps)],
|
||||
'next_state')
|
||||
|
||||
|
||||
class StepwiseBaselineTest(BaselineTestCase):
|
||||
|
||||
def testStaticEnvRandomActions(self):
|
||||
self._create_baseline('box')
|
||||
action_range = self._env_to_action_range(self._env)
|
||||
num_steps = np.random.randint(low=1, high=20)
|
||||
actions = [np.random.choice(action_range) for _ in range(num_steps)]
|
||||
self._test_trajectory(actions, 'current_state')
|
||||
|
||||
@parameterized.parameters(*environments)
|
||||
def testInactionPolicy(self, env_name):
|
||||
self._create_baseline(env_name)
|
||||
num_steps = np.random.randint(low=1, high=20)
|
||||
self._test_trajectory([Actions.NOOP for _ in range(num_steps)],
|
||||
'next_state')
|
||||
|
||||
@parameterized.parameters(*environments)
|
||||
def testInactionRollout(self, env_name):
|
||||
self._create_baseline(env_name)
|
||||
init_state = self._timestep_to_state(self._env.reset())
|
||||
self._baseline.reset()
|
||||
action = Actions.NOOP
|
||||
state1 = init_state
|
||||
trajectory = [init_state]
|
||||
for _ in range(10):
|
||||
trajectory.append(self._timestep_to_state(self._env.step(action)))
|
||||
state2 = trajectory[-1]
|
||||
self._baseline.calculate(state1, action, state2)
|
||||
state1 = state2
|
||||
chain = self._baseline.rollout_func(init_state)
|
||||
self.assertEqual(chain, trajectory[:len(chain)])
|
||||
if len(chain) < len(trajectory):
|
||||
self.assertEqual(trajectory[len(chain) - 1], trajectory[len(chain)])
|
||||
|
||||
def testStaticRollouts(self):
|
||||
self._create_baseline('box')
|
||||
action_range = self._env_to_action_range(self._env)
|
||||
num_steps = np.random.randint(low=1, high=20)
|
||||
actions = [np.random.choice(action_range) for _ in range(num_steps)]
|
||||
state1 = self._timestep_to_state(self._env.reset())
|
||||
states = [state1]
|
||||
self._baseline.reset()
|
||||
for action in actions:
|
||||
state2 = self._timestep_to_state(self._env.step(action))
|
||||
states.append(state2)
|
||||
self._baseline.calculate(state1, action, state2)
|
||||
state1 = state2
|
||||
i1, i2 = np.random.choice(len(states), 2)
|
||||
chain = self._baseline.parallel_inaction_rollouts(states[i1], states[i2])
|
||||
self.assertLen(chain, 1)
|
||||
chain1 = self._baseline.rollout_func(states[i1])
|
||||
self.assertLen(chain1, 1)
|
||||
chain2 = self._baseline.rollout_func(states[i2])
|
||||
self.assertLen(chain2, 1)
|
||||
|
||||
@parameterized.parameters(('parallel', 'vase'), ('parallel', 'sushi'),
|
||||
('inaction', 'vase'), ('inaction', 'sushi'))
|
||||
def testConveyorRollouts(self, which_rollout, env_name):
|
||||
self._create_baseline(env_name)
|
||||
init_state = self._timestep_to_state(self._env.reset())
|
||||
self._baseline.reset()
|
||||
action = Actions.NOOP
|
||||
state1 = init_state
|
||||
init_state_next = self._timestep_to_state(self._env.step(action))
|
||||
state2 = init_state_next
|
||||
self._baseline.calculate(state1, action, state2)
|
||||
state1 = state2
|
||||
for _ in range(10):
|
||||
state2 = self._timestep_to_state(self._env.step(action))
|
||||
self._baseline.calculate(state1, action, state2)
|
||||
state1 = state2
|
||||
if which_rollout == 'parallel':
|
||||
chain = self._baseline.parallel_inaction_rollouts(init_state,
|
||||
init_state_next)
|
||||
else:
|
||||
chain = self._baseline.rollout_func(init_state)
|
||||
self.assertLen(chain, 5)
|
||||
|
||||
|
||||
class NoDeviationTest(SideEffectsTestCase):
|
||||
|
||||
def _random_initial_transition(self):
|
||||
env_name = np.random.choice(environments)
|
||||
noops = np.random.choice([True, False])
|
||||
env = training.get_env(env_name, noops)
|
||||
action_range = self._env_to_action_range(env)
|
||||
action = np.random.choice(action_range)
|
||||
state1 = self._timestep_to_state(env.reset())
|
||||
state2 = self._timestep_to_state(env.step(action))
|
||||
return (state1, state2)
|
||||
|
||||
def testNoDeviation(self):
|
||||
deviation = side_effects_penalty.NoDeviation()
|
||||
state1, state2 = self._random_initial_transition()
|
||||
self.assertEqual(deviation.calculate(state1, state2), 0)
|
||||
|
||||
def testNoDeviationUpdate(self):
|
||||
deviation = side_effects_penalty.NoDeviation()
|
||||
state1, state2 = self._random_initial_transition()
|
||||
deviation.update(state1, state2)
|
||||
self.assertEqual(deviation.calculate(state1, state2), 0)
|
||||
|
||||
|
||||
class UnreachabilityTest(SideEffectsTestCase):
|
||||
|
||||
@parameterized.named_parameters(('Discounted', 0.99), ('Undiscounted', 1.0))
|
||||
def testUnreachabilityCycle(self, gamma):
|
||||
# Reachability with no dev_fun means unreachability
|
||||
deviation = side_effects_penalty.Reachability(value_discount=gamma)
|
||||
env = training.get_env('box', False)
|
||||
|
||||
state0 = self._timestep_to_state(env.reset())
|
||||
state1 = self._timestep_to_state(env.step(Actions.LEFT))
|
||||
# deviation should not be calculated before calling update
|
||||
|
||||
deviation.update(state0, state1)
|
||||
self.assertEqual(deviation.calculate(state0, state0), 1.0 - 1.0)
|
||||
self.assertEqual(deviation.calculate(state0, state1), 1.0 - gamma)
|
||||
self.assertEqual(deviation.calculate(state1, state0), 1.0 - 0.0)
|
||||
|
||||
state2 = self._timestep_to_state(env.step(Actions.RIGHT))
|
||||
self.assertEqual(state0, state2)
|
||||
|
||||
deviation.update(state1, state2)
|
||||
self.assertEqual(deviation.calculate(state0, state0), 1.0 - 1.0)
|
||||
self.assertEqual(deviation.calculate(state0, state1), 1.0 - gamma)
|
||||
self.assertEqual(deviation.calculate(state1, state0), 1.0 - gamma)
|
||||
self.assertEqual(deviation.calculate(state1, state1), 1.0 - 1.0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Training loop."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ai_safety_gridworlds.helpers import factory
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
|
||||
|
||||
def get_env(env_name, noops):
|
||||
"""Get a copy of the environment for simulating the baseline."""
|
||||
if env_name == 'box':
|
||||
env = factory.get_environment_obj('side_effects_sokoban', noops=noops)
|
||||
elif env_name in ['vase', 'sushi', 'sushi_goal']:
|
||||
env = factory.get_environment_obj(
|
||||
'conveyor_belt', variant=env_name, noops=noops)
|
||||
else:
|
||||
env = factory.get_environment_obj(env_name)
|
||||
return env
|
||||
|
||||
|
||||
def run_loop(agent, env, number_episodes, anneal):
|
||||
"""Training agent."""
|
||||
episodic_returns = []
|
||||
episodic_performances = []
|
||||
if anneal:
|
||||
agent.epsilon = 1.0
|
||||
eps_unit = 1.0 / number_episodes
|
||||
for episode in range(number_episodes):
|
||||
# Get the initial set of observations from the environment.
|
||||
timestep = env.reset()
|
||||
# Prepare agent for a new episode.
|
||||
agent.begin_episode()
|
||||
while True:
|
||||
action = agent.step(timestep)
|
||||
timestep = env.step(action)
|
||||
if timestep.last():
|
||||
agent.end_episode(timestep)
|
||||
episodic_returns.append(env.episode_return)
|
||||
episodic_performances.append(env.get_last_performance())
|
||||
break
|
||||
if anneal:
|
||||
agent.epsilon = max(0, agent.epsilon - eps_unit)
|
||||
if episode % 500 == 0:
|
||||
print('Episode', episode)
|
||||
return episodic_returns, episodic_performances
|
||||
|
||||
|
||||
def run_agent(baseline, dev_measure, dev_fun, discount, value_discount, beta,
|
||||
anneal, seed, env_name, noops, num_episodes, num_episodes_noexp,
|
||||
exact_baseline, agent_class):
|
||||
"""Run agent.
|
||||
|
||||
Create an agent with the given parameters for the side effects penalty.
|
||||
Run the agent for `num_episodes' episodes with an exploration rate that is
|
||||
either annealed from 1 to 0 (`anneal=True') or constant (`anneal=False').
|
||||
Then run the agent with no exploration for `num_episodes_noexp' episodes.
|
||||
|
||||
Args:
|
||||
baseline: baseline state
|
||||
dev_measure: deviation measure
|
||||
dev_fun: summary function for the deviation measure
|
||||
discount: discount factor
|
||||
value_discount: discount factor for deviation measure value function.
|
||||
beta: weight for side effects penalty
|
||||
anneal: whether to anneal the exploration rate from 1 to 0 or use a constant
|
||||
exploration rate
|
||||
seed: random seed
|
||||
env_name: environment name
|
||||
noops: whether the environment has noop actions
|
||||
num_episodes: number of episodes
|
||||
num_episodes_noexp: number of episodes with no exploration
|
||||
exact_baseline: whether to use an exact or approximate baseline
|
||||
agent_class: Q-learning agent class: QLearning (regular) or QLearningSE
|
||||
(with side effects penalty)
|
||||
|
||||
Returns:
|
||||
returns: return for each episode
|
||||
performances: safety performance for each episode
|
||||
"""
|
||||
np.random.seed(seed)
|
||||
env = get_env(env_name, noops)
|
||||
start_timestep = env.reset()
|
||||
if exact_baseline:
|
||||
baseline_env = get_env(env_name, True)
|
||||
else:
|
||||
baseline_env = None
|
||||
agent = agent_class(
|
||||
actions=env.action_spec(), baseline=baseline,
|
||||
dev_measure=dev_measure, dev_fun=dev_fun, discount=discount,
|
||||
value_discount=value_discount, beta=beta, exact_baseline=exact_baseline,
|
||||
baseline_env=baseline_env, start_timestep=start_timestep)
|
||||
returns, performances = run_loop(
|
||||
agent, env, number_episodes=num_episodes, anneal=anneal)
|
||||
if num_episodes_noexp > 0:
|
||||
agent.epsilon = 0
|
||||
returns_noexp, performances_noexp = run_loop(
|
||||
agent, env, number_episodes=num_episodes_noexp, anneal=False)
|
||||
returns.extend(returns_noexp)
|
||||
performances.extend(performances_noexp)
|
||||
return returns, performances
|
||||
Reference in New Issue
Block a user