mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
9e3d04c867
PiperOrigin-RevId: 272089371
242 lines
6.6 KiB
Plaintext
242 lines
6.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "rINYEKJlYpQU"
|
|
},
|
|
"source": [
|
|
"Copyright 2019 DeepMind Technologies Limited.\n",
|
|
"\n",
|
|
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
|
"you may not use this file except in compliance with the License.\n",
|
|
"You may obtain a copy of the License at\n",
|
|
"\n",
|
|
"https://www.apache.org/licenses/LICENSE-2.0\n",
|
|
"\n",
|
|
"Unless required by applicable law or agreed to in writing, software\n",
|
|
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
|
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
|
"See the License for the specific language governing permissions and\n",
|
|
"limitations under the License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "KbCarv91XChI"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from __future__ import absolute_import\n",
|
|
"from __future__ import division\n",
|
|
"from __future__ import print_function\n",
|
|
"\n",
|
|
"from google.colab import files\n",
|
|
"import io\n",
|
|
"import pandas as pd\n",
|
|
"import seaborn as sns"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "irahOycBZM1E"
|
|
},
|
|
"source": [
|
|
"### Plot parameters (edit as needed)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "VeaB-Y9TYWCP"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Make a bar plot for average results from the final 100 episodes (True),\n",
|
|
"# or make a learning curve plot (False)\n",
|
|
"bar_plot = False\n",
|
|
"\n",
|
|
"# Compare different penalties using the best beta value for each penalty (True),\n",
|
|
"# or compare different beta values for the same penalty (False):\n",
|
|
"compare_penalties = False\n",
|
|
"\n",
|
|
"# If compare_penalties is False, specify the penalty parameters:\n",
|
|
"dev_measure = 'rel_reach'\n",
|
|
"dev_fun = 'truncation'\n",
|
|
"value_discount = 0.99\n",
|
|
"\n",
|
|
"# Environment name\n",
|
|
"env_name = 'box'\n",
|
|
"\n",
|
|
"# Filename suffix\n",
|
|
"suffix = '' "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "t8hYO3f2ZaHq"
|
|
},
|
|
"source": [
|
|
"### Plot settings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "foyP_qrUeTsx"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"final_str = '_final' if bar_plot else ''\n",
|
|
"if compare_penalties:\n",
|
|
" var = 'label'\n",
|
|
" x_label = 'deviation_measure'\n",
|
|
" legend_title = 'penalty'\n",
|
|
" palette = sns.color_palette()\n",
|
|
" filename = ('df_summary_penalties_' + env_name + final_str + suffix\n",
|
|
" + '.csv')\n",
|
|
"else:\n",
|
|
" var = 'beta'\n",
|
|
" x_label = 'beta'\n",
|
|
" legend_title = 'beta'\n",
|
|
" palette = sns.cubehelix_palette()\n",
|
|
" filename = ('df_summary_betas_' + env_name + '_' + dev_measure + '_' + dev_fun \n",
|
|
" + '_' + str(value_discount) + final_str + suffix + '.csv')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "5RakwQsFZc0V"
|
|
},
|
|
"source": [
|
|
"### Load summary data output by results_summary.py"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "PQ615uuGYLF3"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"uploaded = files.upload()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "-FoI7u8BtYol"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"df = pd.read_csv(io.BytesIO(uploaded[filename]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "A7BGierpZi6t"
|
|
},
|
|
"source": [
|
|
"### Make bar plots"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "ch06DzzeXlGK"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"plot = sns.catplot(data=df, col='baseline', x=var, y='performance_smooth',\n",
|
|
" kind='bar', height=4, aspect=1.3)\n",
|
|
"axes = plot.axes.flatten()\n",
|
|
"for ax in axes:\n",
|
|
" title = ax.get_title().split()\n",
|
|
" ax.set_title(title[2] + ' baseline')\n",
|
|
" ax.set_ylabel('performance')\n",
|
|
" ax.set_xlabel(x_label)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "BU9PYyzOZlRu"
|
|
},
|
|
"source": [
|
|
"### Make learning curve plots"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "RRfU_2iIX-jo"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"plot = sns.FacetGrid(df, col='baseline', size=5, aspect=1.3,\n",
|
|
" sharey=False, sharex=False)\n",
|
|
"plot.map_dataframe(sns.tsplot, time='episode', unit='seed', condition=var,\n",
|
|
" value='performance_smooth', n_boot=100, color=palette,\n",
|
|
" alpha=1.0, linewidth=1)\n",
|
|
"plot.add_legend(title=legend_title)\n",
|
|
"axes = plot.axes.flatten()\n",
|
|
"for ax in axes:\n",
|
|
" title = ax.get_title().split()\n",
|
|
" ax.set_title(title[2] + ' baseline')\n",
|
|
" ax.set_ylabel('performance')\n",
|
|
" ax.set_xlabel('episode')"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"name": "plot_results.ipynb",
|
|
"provenance": [
|
|
{
|
|
"file_id": "1a8ub19XYD4M-r5mGm0lKYTrNwTo1zF7Z",
|
|
"timestamp": 1569850224175
|
|
}
|
|
]
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 2",
|
|
"name": "python2"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|