mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 12:02:08 +08:00
62 lines
2.4 KiB
Python
62 lines
2.4 KiB
Python
# Copyright 2019 DeepMind Technologies Limited.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Helper functions for loading files."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os.path
|
|
import pandas as pd
|
|
|
|
|
|
def filename(env_name, noops, dev_measure, dev_fun, baseline, beta,
|
|
value_discount, seed, path='', suffix=''):
|
|
"""Generate filename for the given set of parameters."""
|
|
noop_str = 'noops' if noops else 'nonoops'
|
|
seed_str = '_' + str(seed) if seed else ''
|
|
filename_template = ('{env_name}_{noop_str}_{dev_measure}_{dev_fun}' +
|
|
'_{baseline}_beta_{beta}_vd_{value_discount}' +
|
|
'{suffix}{seed_str}.csv')
|
|
full_path = os.path.join(path, filename_template.format(
|
|
env_name=env_name, noop_str=noop_str, dev_measure=dev_measure,
|
|
dev_fun=dev_fun, baseline=baseline, beta=beta,
|
|
value_discount=value_discount, suffix=suffix, seed_str=seed_str))
|
|
return full_path
|
|
|
|
|
|
def load_files(baseline, dev_measure, dev_fun, value_discount, beta, env_name,
|
|
noops, path, suffix, seed_list, final=True):
|
|
"""Load result files generated by run_experiment with the given parameters."""
|
|
def try_loading(f, final):
|
|
if os.path.isfile(f):
|
|
df = pd.read_csv(f, index_col=0)
|
|
if final:
|
|
last_episode = max(df['episode'])
|
|
return df[df.episode == last_episode]
|
|
else:
|
|
return df
|
|
else:
|
|
return pd.DataFrame()
|
|
dataframes = []
|
|
for seed in seed_list:
|
|
f = filename(baseline=baseline, dev_measure=dev_measure, dev_fun=dev_fun,
|
|
value_discount=value_discount, beta=beta, env_name=env_name,
|
|
noops=noops, path=path, suffix=suffix, seed=int(seed))
|
|
df_part = try_loading(f, final)
|
|
dataframes.append(df_part)
|
|
df = pd.concat(dataframes)
|
|
return df
|