mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-07 04:22:09 +08:00
103 lines
3.2 KiB
Python
103 lines
3.2 KiB
Python
# pylint: disable=g-bad-file-header
|
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Simple matplotlib rendering of a rollout prediction against ground truth.
|
|
|
|
Usage (from parent directory):
|
|
|
|
`python -m learning_to_simulate.render_rollout --rollout_path={OUTPUT_PATH}/rollout_test_1.pkl`
|
|
|
|
Where {OUTPUT_PATH} is the output path passed to `train.py` in "eval_rollout"
|
|
mode.
|
|
|
|
It may require installing Tkinter with `sudo apt-get install python3.7-tk`.
|
|
|
|
""" # pylint: disable=line-too-long
|
|
|
|
import pickle
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
|
|
from matplotlib import animation
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
flags.DEFINE_string("rollout_path", None, help="Path to rollout pickle file")
|
|
flags.DEFINE_integer("step_stride", 3, help="Stride of steps to skip.")
|
|
flags.DEFINE_boolean("block_on_show", True, help="For test purposes.")
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
TYPE_TO_COLOR = {
|
|
3: "black", # Boundary particles.
|
|
0: "green", # Rigid solids.
|
|
7: "magenta", # Goop.
|
|
6: "gold", # Sand.
|
|
5: "blue", # Water.
|
|
}
|
|
|
|
|
|
def main(unused_argv):
|
|
|
|
if not FLAGS.rollout_path:
|
|
raise ValueError("A `rollout_path` must be passed.")
|
|
with open(FLAGS.rollout_path, "rb") as file:
|
|
rollout_data = pickle.load(file)
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
|
|
|
plot_info = []
|
|
for ax_i, (label, rollout_field) in enumerate(
|
|
[("Ground truth", "ground_truth_rollout"),
|
|
("Prediction", "predicted_rollout")]):
|
|
# Append the initial positions to get the full trajectory.
|
|
trajectory = np.concatenate([
|
|
rollout_data["initial_positions"],
|
|
rollout_data[rollout_field]], axis=0)
|
|
ax = axes[ax_i]
|
|
ax.set_title(label)
|
|
bounds = rollout_data["metadata"]["bounds"]
|
|
ax.set_xlim(bounds[0][0], bounds[0][1])
|
|
ax.set_ylim(bounds[1][0], bounds[1][1])
|
|
ax.set_xticks([])
|
|
ax.set_yticks([])
|
|
ax.set_aspect(1.)
|
|
points = {
|
|
particle_type: ax.plot([], [], "o", ms=2, color=color)[0]
|
|
for particle_type, color in TYPE_TO_COLOR.items()}
|
|
plot_info.append((ax, trajectory, points))
|
|
|
|
num_steps = trajectory.shape[0]
|
|
|
|
def update(step_i):
|
|
outputs = []
|
|
for _, trajectory, points in plot_info:
|
|
for particle_type, line in points.items():
|
|
mask = rollout_data["particle_types"] == particle_type
|
|
line.set_data(trajectory[step_i, mask, 0],
|
|
trajectory[step_i, mask, 1])
|
|
outputs.append(line)
|
|
return outputs
|
|
|
|
unused_animation = animation.FuncAnimation(
|
|
fig, update,
|
|
frames=np.arange(0, num_steps, FLAGS.step_stride), interval=10)
|
|
plt.show(block=FLAGS.block_on_show)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(main)
|