mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
9e3d04c867
PiperOrigin-RevId: 272089371
184 lines
7.7 KiB
Python
184 lines
7.7 KiB
Python
# Copyright 2019 DeepMind Technologies Limited.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Plot results for different side effects penalties.
|
|
|
|
Loads csv result files generated by `run_experiment' and outputs a summary data
|
|
frame in a csv file to be used for plotting by plot_results.ipynb.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os.path
|
|
from absl import app
|
|
from absl import flags
|
|
import pandas as pd
|
|
from side_effects_penalties.file_loading import load_files
|
|
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
if __name__ == '__main__': # Avoid defining flags when used as a library.
|
|
flags.DEFINE_string('path', '', 'File path.')
|
|
flags.DEFINE_string('input_suffix', '',
|
|
'Filename suffix to use when loading data files.')
|
|
flags.DEFINE_string('output_suffix', '',
|
|
'Filename suffix to use when saving files.')
|
|
flags.DEFINE_bool('bar_plot', True,
|
|
'Make a data frame for a bar plot (True) ' +
|
|
'or learning curves (False)')
|
|
flags.DEFINE_string('env_name', 'box', 'Environment name.')
|
|
flags.DEFINE_bool('noops', True, 'Whether the environment includes noops.')
|
|
flags.DEFINE_list('beta_list', [0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0],
|
|
'List of beta values.')
|
|
flags.DEFINE_list('seed_list', [1], 'List of random seeds.')
|
|
flags.DEFINE_bool('compare_penalties', True,
|
|
'Compare different penalties using the best beta value ' +
|
|
'for each penalty (True), or compare different beta values '
|
|
+ 'for the same penalty (False).')
|
|
flags.DEFINE_enum('dev_measure', 'rel_reach',
|
|
['none', 'reach', 'rel_reach', 'att_util'],
|
|
'Deviation measure (used if compare_penalties=False).')
|
|
flags.DEFINE_enum('dev_fun', 'truncation', ['truncation', 'absolute'],
|
|
'Summary function for the deviation measure ' +
|
|
'(used if compare_penalties=False)')
|
|
flags.DEFINE_float('value_discount', 0.99,
|
|
'Discount factor for deviation measure value function ' +
|
|
'(used if compare_penalties=False)')
|
|
|
|
|
|
def beta_choice(baseline, dev_measure, dev_fun, value_discount, env_name,
|
|
beta_list, seed_list, noops=False, path='', suffix=''):
|
|
"""Choose beta value that gives the highest final performance."""
|
|
if dev_measure == 'none':
|
|
return 0.1
|
|
perf_max = float('-inf')
|
|
best_beta = 0.0
|
|
for beta in beta_list:
|
|
df = load_files(baseline=baseline, dev_measure=dev_measure,
|
|
dev_fun=dev_fun, value_discount=value_discount, beta=beta,
|
|
env_name=env_name, noops=noops, path=path, suffix=suffix,
|
|
seed_list=seed_list)
|
|
if df.empty:
|
|
perf = float('-inf')
|
|
else:
|
|
perf = df['performance_smooth'].mean()
|
|
if perf > perf_max:
|
|
perf_max = perf
|
|
best_beta = beta
|
|
return best_beta
|
|
|
|
|
|
def penalty_label(dev_measure, dev_fun, value_discount):
|
|
"""Penalty label specifying design choices."""
|
|
dev_measure_labels = {
|
|
'none': 'None', 'rel_reach': 'RR', 'att_util': 'AU', 'reach': 'UR'}
|
|
label = dev_measure_labels[dev_measure]
|
|
disc_lab = 'u' if value_discount == 1.0 else 'd'
|
|
dev_lab = ''
|
|
if dev_measure in ['rel_reach', 'att_util']:
|
|
dev_lab = 't' if dev_fun == 'truncation' else 'a'
|
|
if dev_measure != 'none':
|
|
label = label + '(' + disc_lab + dev_lab + ')'
|
|
return label
|
|
|
|
|
|
def make_summary_data_frame(
|
|
env_name, beta_list, seed_list, final=True, baseline=None, dev_measure=None,
|
|
dev_fun=None, value_discount=None, noops=False, compare_penalties=True,
|
|
path='', input_suffix='', output_suffix=''):
|
|
"""Make summary dataframe from multiple csv result files and output to csv."""
|
|
# For each of the penalty parameters (baseline, dev_measure, dev_fun, and
|
|
# value_discount), compare a list of multiple values if the parameter is None,
|
|
# or use the provided parameter value if it is not None
|
|
baseline_list = ['start', 'inaction', 'stepwise', 'step_noroll']
|
|
if dev_measure is not None:
|
|
dev_measure_list = [dev_measure]
|
|
else:
|
|
dev_measure_list = ['none', 'reach', 'rel_reach', 'att_util']
|
|
dataframes = []
|
|
for dev_measure in dev_measure_list:
|
|
# These deviation measures don't have a deviation function:
|
|
if dev_measure in ['reach', 'none']:
|
|
dev_fun_list = ['none']
|
|
elif dev_fun is not None:
|
|
dev_fun_list = [dev_fun]
|
|
else:
|
|
dev_fun_list = ['truncation', 'absolute']
|
|
# These deviation measures must be discounted:
|
|
if dev_measure in ['none', 'att_util']:
|
|
value_discount_list = [0.99]
|
|
elif value_discount is not None:
|
|
value_discount_list = [value_discount]
|
|
else:
|
|
value_discount_list = [0.99, 1.0]
|
|
for baseline in baseline_list:
|
|
for vd in value_discount_list:
|
|
for devf in dev_fun_list:
|
|
# Choose the best beta for this set of penalty parameters if
|
|
# compare_penalties=True, or compare all betas otherwise
|
|
if compare_penalties:
|
|
beta = beta_choice(
|
|
baseline=baseline, dev_measure=dev_measure, dev_fun=devf,
|
|
value_discount=vd, env_name=env_name, noops=noops,
|
|
beta_list=beta_list, seed_list=seed_list, path=path,
|
|
suffix=input_suffix)
|
|
betas = [beta]
|
|
else:
|
|
betas = beta_list
|
|
for beta in betas:
|
|
label = penalty_label(
|
|
dev_measure=dev_measure, dev_fun=devf, value_discount=vd)
|
|
df_part = load_files(
|
|
baseline=baseline, dev_measure=dev_measure, dev_fun=devf,
|
|
value_discount=vd, beta=beta, env_name=env_name,
|
|
noops=noops, path=path, suffix=input_suffix, final=final,
|
|
seed_list=seed_list)
|
|
df_part = df_part.assign(
|
|
baseline=baseline, dev_measure=dev_measure, dev_fun=devf,
|
|
value_discount=vd, beta=beta, env_name=env_name, label=label)
|
|
dataframes.append(df_part)
|
|
df = pd.concat(dataframes, sort=False)
|
|
# Output summary data frame
|
|
final_str = '_final' if final else ''
|
|
if compare_penalties:
|
|
filename = ('df_summary_penalties_' + env_name + final_str +
|
|
output_suffix + '.csv')
|
|
else:
|
|
filename = ('df_summary_betas_' + env_name + '_' + dev_measure + '_' +
|
|
dev_fun + '_' + str(value_discount) + final_str + output_suffix
|
|
+ '.csv')
|
|
f = os.path.join(path, filename)
|
|
df.to_csv(f)
|
|
return df
|
|
|
|
|
|
def main(unused_argv):
|
|
compare_penalties = FLAGS.compare_penalties
|
|
dev_measure = None if compare_penalties else FLAGS.dev_measure
|
|
dev_fun = None if compare_penalties else FLAGS.dev_fun
|
|
value_discount = None if compare_penalties else FLAGS.value_discount
|
|
make_summary_data_frame(
|
|
compare_penalties=compare_penalties, env_name=FLAGS.env_name,
|
|
noops=FLAGS.noops, final=FLAGS.bar_plot, dev_measure=dev_measure,
|
|
value_discount=value_discount, dev_fun=dev_fun, path=FLAGS.path,
|
|
input_suffix=FLAGS.input_suffix, output_suffix=FLAGS.output_suffix,
|
|
beta_list=FLAGS.beta_list, seed_list=FLAGS.seed_list)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
app.run(main)
|