Making integration test pull an actual dataset and do a few steps of training and evaluation.

PiperOrigin-RevId: 331524333
This commit is contained in:
Alvaro Sanchez-Gonzalez
2020-09-14 13:00:48 +01:00
committed by Diego de Las Casas
parent ef672a0db9
commit 94f98470b9
12 changed files with 1594 additions and 0 deletions
+33
View File
@@ -0,0 +1,33 @@
# Learning to Simulate Complex Physics with Graph Networks
This is a model implementation for the ICML 2020 submission (also available in
arXiv [arxiv.org/abs/2002.09405](https://arxiv.org/abs/2002.09405). If you use
the code here please cite this paper:
@article{sanchezgonzalez2020learning,
title={Learning to Simulate Complex Physics with Graph Networks},
author={Alvaro Sanchez-Gonzalez and
Jonathan Godwin and
Tobias Pfaff and
Rex Ying and
Jure Leskovec and
Peter W. Battaglia},
url={https://arxiv.org/abs/2002.09405},
year={2020},
eprint={2002.09405},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
## Contents
* `model.py`: implementation of the graph network use as the learnable part of
the model.
* `model_demo.py`: example connecting the model to input dummy data.
## Running demo
(From one directory above)
pip install -r learning_to_simulate/requirements.txt
python -m learning_to_simulate.model_demo
+104
View File
@@ -0,0 +1,104 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tools to compute the connectivity of the graph."""
import numpy as np
from sklearn import neighbors
import tensorflow.compat.v1 as tf
def _compute_connectivity(positions, radius):
"""Get the indices of connected edges with radius connectivity.
Args:
positions: Positions of nodes in the graph. Shape:
[num_nodes_in_graph, num_dims].
radius: Radius of connectivity.
Returns:
senders indices [num_edges_in_graph]
receiver indices [num_edges_in_graph]
"""
tree = neighbors.KDTree(positions)
receivers_list = tree.query_radius(positions, r=radius)
num_nodes = len(positions)
senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list])
receivers = np.concatenate(receivers_list, axis=0)
return senders, receivers
def _compute_connectivity_for_batch(positions, n_node, radius):
"""`compute_connectivity` for a batch of graphs.
Args:
positions: Positions of nodes in the batch of graphs. Shape:
[num_nodes_in_batch, num_dims].
n_node: Number of nodes for each graph in the batch. Shape:
[num_graphs in batch].
radius: Radius of connectivity.
Returns:
senders indices [num_edges_in_batch]
receiver indices [num_edges_in_batch]
number of edges per graph [num_graphs_in_batch]
"""
# TODO(alvarosg): Consider if we want to support batches here or not.
# Separate the positions corresponding to particles in different graphs.
positions_per_graph_list = np.split(positions, np.cumsum(n_node[:-1]), axis=0)
receivers_list = []
senders_list = []
n_edge_list = []
num_nodes_in_previous_graphs = 0
# Compute connectivity for each graph in the batch.
for positions_graph_i in positions_per_graph_list:
senders_graph_i, receivers_graph_i = _compute_connectivity(
positions_graph_i, radius)
num_edges_graph_i = len(senders_graph_i)
n_edge_list.append(num_edges_graph_i)
# Because the inputs will be concatenated, we need to add offsets to the
# sender and receiver indices according to the number of nodes in previous
# graphs in the same batch.
receivers_list.append(receivers_graph_i + num_nodes_in_previous_graphs)
senders_list.append(senders_graph_i + num_nodes_in_previous_graphs)
num_nodes_graph_i = len(positions_graph_i)
num_nodes_in_previous_graphs += num_nodes_graph_i
# Concatenate all of the results.
senders = np.concatenate(senders_list, axis=0).astype(np.int32)
receivers = np.concatenate(receivers_list, axis=0).astype(np.int32)
n_edge = np.stack(n_edge_list).astype(np.int32)
return senders, receivers, n_edge
def compute_connectivity_for_batch_pyfunc(positions, n_node, radius):
"""`_compute_connectivity_for_batch` wrapped in a pyfunc."""
senders, receivers, n_edge = tf.py_function(
_compute_connectivity_for_batch,
[positions, n_node, radius], [tf.int32, tf.int32, tf.int32])
senders.set_shape([None])
receivers.set_shape([None])
n_edge.set_shape(n_node.get_shape())
return senders, receivers, n_edge
+32
View File
@@ -0,0 +1,32 @@
#!/bin/bash
# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Usage:
# bash download_dataset.sh ${DATASET_NAME} ${OUTPUT_DIR}
# Example:
# bash download_dataset.sh WaterDrop /tmp/
set -e
DATASET_NAME="${1}"
OUTPUT_DIR="${2}/${DATASET_NAME}"
BASE_URL="https://storage.googleapis.com/learning-to-simulate-complex-physics/Datasets/${DATASET_NAME}/"
mkdir -p ${OUTPUT_DIR}
for file in metadata.json train.tfrecord valid.tfrecord test.tfrecord
do
wget -O "${OUTPUT_DIR}/${file}" "${BASE_URL}${file}"
done
+173
View File
@@ -0,0 +1,173 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Graph network implementation accompanying ICML 2020 submission.
"Learning to Simulate Complex Physics with Graph Networks"
Alvaro Sanchez-Gonzalez*, Jonathan Godwin*, Tobias Pfaff*, Rex Ying,
Jure Leskovec, Peter W. Battaglia
https://arxiv.org/abs/2002.09405
The Sonnet `EncodeProcessDecode` module provided here implements the learnable
parts of the model.
It assumes an encoder preprocessor has already built a graph with
connectivity and features as described in the paper, with features normalized
to zero-mean unit-variance.
Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library.
"""
import graph_nets as gn
import sonnet as snt
import tensorflow as tf
def build_mlp(
hidden_size: int, num_hidden_layers: int, output_size: int) -> snt.Module:
"""Builds an MLP."""
return snt.nets.MLP(
output_sizes=[hidden_size] * num_hidden_layers + [output_size])
class EncodeProcessDecode(snt.AbstractModule):
"""Encode-Process-Decode function approximator for learnable simulator."""
def __init__(
self,
latent_size: int,
mlp_hidden_size: int,
mlp_num_hidden_layers: int,
num_message_passing_steps: int,
output_size: int,
name: str = "EncodeProcessDecode"):
"""Inits the model.
Args:
latent_size: Size of the node and edge latent representations.
mlp_hidden_size: Hidden layer size for all MLPs.
mlp_num_hidden_layers: Number of hidden layers in all MLPs.
num_message_passing_steps: Number of message passing steps.
output_size: Output size of the decode node representations as required
by the downstream update function.
name: Name of the model.
"""
super().__init__(name=name)
self._latent_size = latent_size
self._mlp_hidden_size = mlp_hidden_size
self._mlp_num_hidden_layers = mlp_num_hidden_layers
self._num_message_passing_steps = num_message_passing_steps
self._output_size = output_size
with self._enter_variable_scope():
self._networks_builder()
def _build(self, input_graph: gn.graphs.GraphsTuple) -> tf.Tensor:
"""Forward pass of the learnable dynamics model."""
# Encode the input_graph.
latent_graph_0 = self._encode(input_graph)
# Do `m` message passing steps in the latent graphs.
latent_graph_m = self._process(latent_graph_0)
# Decode from the last latent graph.
return self._decode(latent_graph_m)
def _networks_builder(self):
"""Builds the networks."""
def build_mlp_with_layer_norm():
mlp = build_mlp(
hidden_size=self._mlp_hidden_size,
num_hidden_layers=self._mlp_num_hidden_layers,
output_size=self._latent_size)
return snt.Sequential([mlp, snt.LayerNorm()])
# The encoder graph network independently encodes edge and node features.
encoder_kwargs = dict(
edge_model_fn=build_mlp_with_layer_norm,
node_model_fn=build_mlp_with_layer_norm)
self._encoder_network = gn.modules.GraphIndependent(**encoder_kwargs)
# Create `num_message_passing_steps` graph networks with unshared parameters
# that update the node and edge latent features.
# Note that we can use `modules.InteractionNetwork` because
# it also outputs the messages as updated edge latent features.
self._processor_networks = []
for _ in range(self._num_message_passing_steps):
self._processor_networks.append(
gn.modules.InteractionNetwork(
edge_model_fn=build_mlp_with_layer_norm,
node_model_fn=build_mlp_with_layer_norm))
# The decoder MLP decodes node latent features into the output size.
self._decoder_network = build_mlp(
hidden_size=self._mlp_hidden_size,
num_hidden_layers=self._mlp_num_hidden_layers,
output_size=self._output_size)
def _encode(
self, input_graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
"""Encodes the input graph features into a latent graph."""
# Copy the globals to all of the nodes, if applicable.
if input_graph.globals is not None:
broadcasted_globals = gn.blocks.broadcast_globals_to_nodes(input_graph)
input_graph = input_graph.replace(
nodes=tf.concat([input_graph.nodes, broadcasted_globals], axis=-1),
globals=None)
# Encode the node and edge features.
latent_graph_0 = self._encoder_network(input_graph)
return latent_graph_0
def _process(
self, latent_graph_0: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
"""Processes the latent graph with several steps of message passing."""
# Do `m` message passing steps in the latent graphs.
# (In the shared parameters case, just reuse the same `processor_network`)
latent_graph_prev_k = latent_graph_0
for processor_network_k in self._processor_networks:
latent_graph_k = self._process_step(
processor_network_k, latent_graph_prev_k)
latent_graph_prev_k = latent_graph_k
latent_graph_m = latent_graph_k
return latent_graph_m
def _process_step(
self, processor_network_k: snt.Module,
latent_graph_prev_k: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
"""Single step of message passing with node/edge residual connections."""
# One step of message passing.
latent_graph_k = processor_network_k(latent_graph_prev_k)
# Add residuals.
latent_graph_k = latent_graph_k.replace(
nodes=latent_graph_k.nodes+latent_graph_prev_k.nodes,
edges=latent_graph_k.edges+latent_graph_prev_k.edges)
return latent_graph_k
def _decode(self, latent_graph: gn.graphs.GraphsTuple) -> tf.Tensor:
"""Decodes from the latent graph."""
return self._decoder_network(latent_graph.nodes)
+266
View File
@@ -0,0 +1,266 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Full model implementation accompanying ICML 2020 submission.
"Learning to Simulate Complex Physics with Graph Networks"
Alvaro Sanchez-Gonzalez*, Jonathan Godwin*, Tobias Pfaff*, Rex Ying,
Jure Leskovec, Peter W. Battaglia
https://arxiv.org/abs/2002.09405
"""
import graph_nets as gn
import sonnet as snt
import tensorflow.compat.v1 as tf
from learning_to_simulate import connectivity_utils
from learning_to_simulate import graph_network
STD_EPSILON = 1e-8
class LearnedSimulator(snt.AbstractModule):
"""Learned simulator from https://arxiv.org/pdf/2002.09405.pdf."""
def __init__(
self,
num_dimensions,
connectivity_radius,
graph_network_kwargs,
boundaries,
normalization_stats,
num_particle_types,
particle_type_embedding_size,
name="LearnedSimulator"):
"""Inits the model.
Args:
num_dimensions: Dimensionality of the problem.
connectivity_radius: Scalar with the radius of connectivity.
graph_network_kwargs: Keyword arguments to pass to the learned part
of the graph network `model.EncodeProcessDecode`.
boundaries: List of 2-tuples, containing the lower and upper boundaries of
the cuboid containing the particles along each dimensions, matching
the dimensionality of the problem.
normalization_stats: Dictionary with statistics with keys "acceleration"
and "velocity", containing a named tuple for each with mean and std
fields, matching the dimensionality of the problem.
num_particle_types: Number of different particle types.
particle_type_embedding_size: Embedding size for the particle type.
name: Name of the Sonnet module.
"""
super().__init__(name=name)
self._connectivity_radius = connectivity_radius
self._num_particle_types = num_particle_types
self._boundaries = boundaries
self._normalization_stats = normalization_stats
with self._enter_variable_scope():
self._graph_network = graph_network.EncodeProcessDecode(
output_size=num_dimensions, **graph_network_kwargs)
if self._num_particle_types > 1:
self._particle_type_embedding = tf.get_variable(
"particle_embedding",
[self._num_particle_types, particle_type_embedding_size],
trainable=True, use_resource=True)
def _build(self, position_sequence, n_particles_per_example,
global_context=None, particle_types=None):
"""Produces a model step, outputting the next position for each particle.
Args:
position_sequence: Sequence of positions for each node in the batch,
with shape [num_particles_in_batch, sequence_length, num_dimensions]
n_particles_per_example: Number of particles for each graph in the batch
with shape [batch_size]
global_context: Tensor of shape [batch_size, context_size], with global
context.
particle_types: Integer tensor of shape [num_particles_in_batch] with
the integer types of the particles, from 0 to `num_particle_types - 1`.
If None, we assume all particles are the same type.
Returns:
updated_position_sequence with shape [num_particles_in_batch,
num_dimensions] for one step into the future from the input sequence.
"""
input_graphs_tuple = self._encoder_preprocessor(
position_sequence, n_particles_per_example, global_context,
particle_types)
normalized_acceleration = self._graph_network(input_graphs_tuple)
next_position = self._decoder_postprocessor(
normalized_acceleration, position_sequence)
return next_position
def _encoder_preprocessor(
self, position_sequence, n_node, global_context, particle_types):
# Extract important features from the position_sequence.
most_recent_position = position_sequence[:, -1]
velocity_sequence = time_diff(position_sequence) # Finite-difference.
# Get connectivity of the graph.
(senders, receivers, n_edge
) = connectivity_utils.compute_connectivity_for_batch_pyfunc(
most_recent_position, n_node, self._connectivity_radius)
# Collect node features.
node_features = []
# Normalized velocity sequence, merging spatial an time axis.
velocity_stats = self._normalization_stats["velocity"]
normalized_velocity_sequence = (
velocity_sequence - velocity_stats.mean) / velocity_stats.std
flat_velocity_sequence = snt.MergeDims(start=1, size=2)(
normalized_velocity_sequence)
node_features.append(flat_velocity_sequence)
# Normalized clipped distances to lower and upper boundaries.
# boundaries are an array of shape [num_dimensions, 2], where the second
# axis, provides the lower/upper boundaries.
boundaries = tf.constant(self._boundaries, dtype=tf.float32)
distance_to_lower_boundary = (
most_recent_position - tf.expand_dims(boundaries[:, 0], 0))
distance_to_upper_boundary = (
tf.expand_dims(boundaries[:, 1], 0) - most_recent_position)
distance_to_boundaries = tf.concat(
[distance_to_lower_boundary, distance_to_upper_boundary], axis=1)
normalized_clipped_distance_to_boundaries = tf.clip_by_value(
distance_to_boundaries / self._connectivity_radius, -1., 1.)
node_features.append(normalized_clipped_distance_to_boundaries)
# Particle type.
if self._num_particle_types > 1:
particle_type_embeddings = tf.nn.embedding_lookup(
self._particle_type_embedding, particle_types)
node_features.append(particle_type_embeddings)
# Collect edge features.
edge_features = []
# Relative displacement and distances normalized to radius
normalized_relative_displacements = (
tf.gather(most_recent_position, senders) -
tf.gather(most_recent_position, receivers)) / self._connectivity_radius
edge_features.append(normalized_relative_displacements)
normalized_relative_distances = tf.norm(
normalized_relative_displacements, axis=-1, keepdims=True)
edge_features.append(normalized_relative_distances)
# Normalize the global context.
if global_context is not None:
context_stats = self._normalization_stats["context"]
# Context in some datasets are all zero, so add an epsilon for numerical
# stability.
global_context = (global_context - context_stats.mean) / tf.math.maximum(
context_stats.std, STD_EPSILON)
return gn.graphs.GraphsTuple(
nodes=tf.concat(node_features, axis=-1),
edges=tf.concat(edge_features, axis=-1),
globals=global_context, # self._graph_net will appending this to nodes.
n_node=n_node,
n_edge=n_edge,
senders=senders,
receivers=receivers,
)
def _decoder_postprocessor(self, normalized_acceleration, position_sequence):
# The model produces the output in normalized space so we apply inverse
# normalization.
acceleration_stats = self._normalization_stats["acceleration"]
acceleration = (
normalized_acceleration * acceleration_stats.std
) + acceleration_stats.mean
# Use an Euler integrator to go from acceleration to position, assuming
# a dt=1 corresponding to the size of the finite difference.
most_recent_position = position_sequence[:, -1]
most_recent_velocity = most_recent_position - position_sequence[:, -2]
new_velocity = most_recent_velocity + acceleration # * dt = 1
new_position = most_recent_position + new_velocity # * dt = 1
return new_position
def get_predicted_and_target_normalized_accelerations(
self, next_position, position_sequence_noise, position_sequence,
n_particles_per_example, global_context=None, particle_types=None): # pylint: disable=g-doc-args
"""Produces a model step, outputting the next position for each particle.
Args:
next_position: Tensor of shape [num_particles_in_batch, num_dimensions]
with the positions the model should output given the inputs.
position_sequence_noise: Tensor of the same shape as `position_sequence`
with the noise to apply to each particle.
position_sequence, n_node, global_context, particle_types: Inputs to the
model as defined by `_build`.
Returns:
Tensors of shape [num_particles_in_batch, num_dimensions] with the
predicted and target normalized accelerations.
"""
# Add noise to the input position sequence.
noisy_position_sequence = position_sequence + position_sequence_noise
# Perform the forward pass with the noisy position sequence.
input_graphs_tuple = self._encoder_preprocessor(
noisy_position_sequence, n_particles_per_example, global_context,
particle_types)
predicted_normalized_acceleration = self._graph_network(input_graphs_tuple)
# Calculate the target acceleration, using an `adjusted_next_position `that
# is shifted by the noise in the last input position.
next_position_adjusted = next_position + position_sequence_noise[:, -1]
target_normalized_acceleration = self._inverse_decoder_postprocessor(
next_position_adjusted, noisy_position_sequence)
# As a result the inverted Euler update in the `_inverse_decoder` produces:
# * A target acceleration that does not explicitly correct for the noise in
# the input positions, as the `next_position_adjusted` is different
# from the true `next_position`.
# * A target acceleration that exactly corrects noise in the input velocity
# since the target next velocity calculated by the inverse Euler update
# as `next_position_adjusted - noisy_position_sequence[:,-1]`
# matches the ground truth next velocity (noise cancels out).
return predicted_normalized_acceleration, target_normalized_acceleration
def _inverse_decoder_postprocessor(self, next_position, position_sequence):
"""Inverse of `_decoder_postprocessor`."""
previous_position = position_sequence[:, -1]
previous_velocity = previous_position - position_sequence[:, -2]
next_velocity = next_position - previous_position
acceleration = next_velocity - previous_velocity
acceleration_stats = self._normalization_stats["acceleration"]
normalized_acceleration = (
acceleration - acceleration_stats.mean) / acceleration_stats.std
return normalized_acceleration
def time_diff(input_sequence):
return input_sequence[:, 1:] - input_sequence[:, :-1]
+145
View File
@@ -0,0 +1,145 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Example script accompanying ICML 2020 submission.
"Learning to Simulate Complex Physics with Graph Networks"
Alvaro Sanchez-Gonzalez*, Jonathan Godwin*, Tobias Pfaff*, Rex Ying,
Jure Leskovec, Peter W. Battaglia
https://arxiv.org/abs/2002.09405
Here we provide the utility function `sample_random_position_sequence()` which
returns a sequence of positions for a variable number of particles, similar to
what a real dataset would provide, and connect the model to it, in both,
single step inference and training mode.
Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library.
"""
import collections
from learning_to_simulate import learned_simulator
from learning_to_simulate import noise_utils
import numpy as np
import tensorflow.compat.v1 as tf
INPUT_SEQUENCE_LENGTH = 6
SEQUENCE_LENGTH = INPUT_SEQUENCE_LENGTH + 1 # add one target position.
NUM_DIMENSIONS = 3
NUM_PARTICLE_TYPES = 6
BATCH_SIZE = 5
GLOBAL_CONTEXT_SIZE = 6
Stats = collections.namedtuple("Stats", ["mean", "std"])
DUMMY_STATS = Stats(
mean=np.zeros([NUM_DIMENSIONS], dtype=np.float32),
std=np.ones([NUM_DIMENSIONS], dtype=np.float32))
DUMMY_CONTEXT_STATS = Stats(
mean=np.zeros([GLOBAL_CONTEXT_SIZE], dtype=np.float32),
std=np.ones([GLOBAL_CONTEXT_SIZE], dtype=np.float32))
DUMMY_BOUNDARIES = [(-1., 1.)] * NUM_DIMENSIONS
def sample_random_position_sequence():
"""Returns mock data mimicking the input features collected by the encoder."""
num_particles = tf.random_uniform(
shape=(), minval=50, maxval=1000, dtype=tf.int32)
position_sequence = tf.random.normal(
shape=[num_particles, SEQUENCE_LENGTH, NUM_DIMENSIONS])
return position_sequence
def main():
# Build the model.
learnable_model = learned_simulator.LearnedSimulator(
num_dimensions=NUM_DIMENSIONS,
connectivity_radius=0.05,
graph_network_kwargs=dict(
latent_size=128,
mlp_hidden_size=128,
mlp_num_hidden_layers=2,
num_message_passing_steps=10,
),
boundaries=DUMMY_BOUNDARIES,
normalization_stats={"acceleration": DUMMY_STATS,
"velocity": DUMMY_STATS,
"context": DUMMY_CONTEXT_STATS,},
num_particle_types=NUM_PARTICLE_TYPES,
particle_type_embedding_size=16,
)
# Sample a batch of particle sequences with shape:
# [TOTAL_NUM_PARTICLES, SEQUENCE_LENGTH, NUM_DIMENSIONS]
sampled_position_sequences = [
sample_random_position_sequence() for _ in range(BATCH_SIZE)]
position_sequence_batch = tf.concat(sampled_position_sequences, axis=0)
# Count how many particles are present in each element in the batch.
# [BATCH_SIZE]
n_particles_per_example = tf.stack(
[tf.shape(seq)[0] for seq in sampled_position_sequences], axis=0)
# Sample particle types.
# [TOTAL_NUM_PARTICLES]
particle_types = tf.random_uniform(
[tf.shape(position_sequence_batch)[0]],
0, NUM_PARTICLE_TYPES, dtype=tf.int32)
# Sample global context.
global_context = tf.random_uniform(
[BATCH_SIZE, GLOBAL_CONTEXT_SIZE], -1., 1., dtype=tf.float32)
# Separate input sequence from target sequence.
# [TOTAL_NUM_PARTICLES, INPUT_SEQUENCE_LENGTH, NUM_DIMENSIONS]
input_position_sequence = position_sequence_batch[:, :-1]
# [TOTAL_NUM_PARTICLES, NUM_DIMENSIONS]
target_next_position = position_sequence_batch[:, -1]
# Single step of inference with the model to predict next position for each
# particle [TOTAL_NUM_PARTICLES, NUM_DIMENSIONS].
predicted_next_position = learnable_model(
input_position_sequence, n_particles_per_example, global_context,
particle_types)
print(f"Per-particle output tensor: {predicted_next_position}")
# Obtaining predicted and target normalized accelerations for training.
position_sequence_noise = (
noise_utils.get_random_walk_noise_for_position_sequence(
input_position_sequence, noise_std_last_step=6.7e-4))
# Both with shape [TOTAL_NUM_PARTICLES, NUM_DIMENSIONS]
predicted_normalized_acceleration, target_normalized_acceleration = (
learnable_model.get_predicted_and_target_normalized_accelerations(
target_next_position, position_sequence_noise,
input_position_sequence, n_particles_per_example, global_context,
particle_types))
print(f"Predicted norm. acceleration: {predicted_normalized_acceleration}")
print(f"Target norm. acceleration: {target_normalized_acceleration}")
with tf.train.SingularMonitoredSession() as sess:
sess.run([predicted_next_position,
predicted_normalized_acceleration,
target_normalized_acceleration])
if __name__ == "__main__":
tf.disable_v2_behavior()
main()
+54
View File
@@ -0,0 +1,54 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Methods to calculate input noise."""
import tensorflow.compat.v1 as tf
from learning_to_simulate import learned_simulator
def get_random_walk_noise_for_position_sequence(
position_sequence, noise_std_last_step):
"""Returns random-walk noise in the velocity applied to the position."""
velocity_sequence = learned_simulator.time_diff(position_sequence)
# We want the noise scale in the velocity at the last step to be fixed.
# Because we are going to compose noise at each step using a random_walk:
# std_last_step**2 = num_velocities * std_each_step**2
# so to keep `std_last_step` fixed, we apply at each step:
# std_each_step `std_last_step / np.sqrt(num_input_velocities)`
# TODO(alvarosg): Make sure this is consistent with the value and
# description provided in the paper.
num_velocities = velocity_sequence.shape.as_list()[1]
velocity_sequence_noise = tf.random.normal(
tf.shape(velocity_sequence),
stddev=noise_std_last_step / num_velocities ** 0.5,
dtype=position_sequence.dtype)
# Apply the random walk.
velocity_sequence_noise = tf.cumsum(velocity_sequence_noise, axis=1)
# Integrate the noise in the velocity to the positions, assuming
# an Euler intergrator and a dt = 1, and adding no noise to the very first
# position (since that will only be used to calculate the first position#
# change.
position_sequence_noise = tf.concat([
tf.zeros_like(velocity_sequence_noise[:, 0:1]),
tf.cumsum(velocity_sequence_noise, axis=1)], axis=1)
return position_sequence_noise
+140
View File
@@ -0,0 +1,140 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utilities for reading open sourced Learning Complex Physics data."""
import functools
import numpy as np
import tensorflow.compat.v1 as tf
# Create a description of the features.
_FEATURE_DESCRIPTION = {
'position': tf.io.VarLenFeature(tf.string),
}
_FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT = _FEATURE_DESCRIPTION.copy()
_FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT['step_context'] = tf.io.VarLenFeature(
tf.string)
_FEATURE_DTYPES = {
'position': {
'in': np.float32,
'out': tf.float32
},
'step_context': {
'in': np.float32,
'out': tf.float32
}
}
_CONTEXT_FEATURES = {
'key': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'particle_type': tf.io.VarLenFeature(tf.string)
}
def convert_to_tensor(x, encoded_dtype):
if len(x) == 1:
out = np.frombuffer(x[0].numpy(), dtype=encoded_dtype)
else:
out = []
for el in x:
out.append(np.frombuffer(el.numpy(), dtype=encoded_dtype))
out = tf.convert_to_tensor(np.array(out))
return out
def parse_serialized_simulation_example(example_proto, metadata):
"""Parses a serialized simulation tf.SequenceExample.
Args:
example_proto: A string encoding of the tf.SequenceExample proto.
metadata: A dict of metadata for the dataset.
Returns:
context: A dict, with features that do not vary over the trajectory.
parsed_features: A dict of tf.Tensors representing the parsed examples
across time, where axis zero is the time axis.
"""
if 'context_mean' in metadata:
feature_description = _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT
else:
feature_description = _FEATURE_DESCRIPTION
context, parsed_features = tf.io.parse_single_sequence_example(
example_proto,
context_features=_CONTEXT_FEATURES,
sequence_features=feature_description)
for feature_key, item in parsed_features.items():
convert_fn = functools.partial(
convert_to_tensor, encoded_dtype=_FEATURE_DTYPES[feature_key]['in'])
parsed_features[feature_key] = tf.py_function(
convert_fn, inp=[item.values], Tout=_FEATURE_DTYPES[feature_key]['out'])
# There is an extra frame at the beginning so we can calculate pos change
# for all frames used in the paper.
position_shape = [metadata['sequence_length'] + 1, -1, metadata['dim']]
# Reshape positions to correct dim:
parsed_features['position'] = tf.reshape(parsed_features['position'],
position_shape)
# Set correct shapes of the remaining tensors.
sequence_length = metadata['sequence_length'] + 1
if 'context_mean' in metadata:
context_feat_len = len(metadata['context_mean'])
parsed_features['step_context'] = tf.reshape(
parsed_features['step_context'],
[sequence_length, context_feat_len])
# Decode particle type explicitly
context['particle_type'] = tf.py_function(
functools.partial(convert_fn, encoded_dtype=np.int64),
inp=[context['particle_type'].values],
Tout=[tf.int64])
context['particle_type'] = tf.reshape(context['particle_type'], [-1])
return context, parsed_features
def split_trajectory(context, features, window_length=7):
"""Splits trajectory into sliding windows."""
# Our strategy is to make sure all the leading dimensions are the same size,
# then we can use from_tensor_slices.
trajectory_length = features['position'].get_shape().as_list()[0]
# We then stack window_length position changes so the final
# trajectory length will be - window_length +1 (the 1 to make sure we get
# the last split).
input_trajectory_length = trajectory_length - window_length + 1
model_input_features = {}
# Prepare the context features per step.
model_input_features['particle_type'] = tf.tile(
tf.expand_dims(context['particle_type'], axis=0),
[input_trajectory_length, 1])
if 'step_context' in features:
global_stack = []
for idx in range(input_trajectory_length):
global_stack.append(features['step_context'][idx:idx + window_length])
model_input_features['step_context'] = tf.stack(global_stack)
pos_stack = []
for idx in range(input_trajectory_length):
pos_stack.append(features['position'][idx:idx + window_length])
# Get the corresponding positions
model_input_features['position'] = tf.stack(pos_stack)
return tf.data.Dataset.from_tensor_slices(model_input_features)
+104
View File
@@ -0,0 +1,104 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Simple matplotlib rendering of a rollout prediction against ground truth.
Usage (from parent directory):
`python -m learning_to_simulate.render_rollout --rollout_path={OUTPUT_PATH}/rollout_test_1.pkl`
Where {OUTPUT_PATH} is the output path passed to `train.py` in "eval_rollout"
mode.
It may require installing Tkinter with `sudo apt-get install python3.7-tk`.
""" # pylint: disable=line-too-long
import pickle
from absl import app
from absl import flags
from matplotlib import animation
import matplotlib.pyplot as plt
import numpy as np
flags.DEFINE_string("rollout_path", None, help="Path to rollout pickle file")
flags.DEFINE_integer("step_stride", 3, help="Stride of steps to skip.")
flags.DEFINE_boolean("block_on_show", True, help="For test purposes.")
FLAGS = flags.FLAGS
TYPE_TO_COLOR = {
3: "black", # Boundary particles.
0: "green", # Rigid solids.
7: "magenta", # Goop.
6: "gold", # Sand.
5: "blue", # Water.
}
def main(unused_argv):
if not FLAGS.rollout_path:
raise ValueError("A `rollout_path` must be passed.")
with open(FLAGS.rollout_path, "rb") as file:
rollout_data = pickle.load(file)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
plot_info = []
for ax_i, (label, rollout_field) in enumerate(
[("Ground truth", "ground_truth_rollout"),
("Prediction", "predicted_rollout")]):
# Append the initial positions to get the full trajectory.
trajectory = np.concatenate([
rollout_data["initial_positions"],
rollout_data[rollout_field]], axis=0)
ax = axes[ax_i]
ax.set_title(label)
bounds = rollout_data["metadata"]["bounds"]
ax.set_xlim(bounds[0][0], bounds[0][1])
ax.set_ylim(bounds[1][0], bounds[1][1])
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect(1.)
points = {
particle_type: ax.plot([], [], "o", ms=2, color=color)[0]
for particle_type, color in TYPE_TO_COLOR.items()}
plot_info.append((ax, trajectory, points))
num_steps = trajectory.shape[0]
def update(step_i):
outputs = []
for _, trajectory, points in plot_info:
for particle_type, line in points.items():
mask = rollout_data["particle_types"] == particle_type
line.set_data(trajectory[step_i, mask, 0],
trajectory[step_i, mask, 1])
outputs.append(line)
return outputs
unused_animation = animation.FuncAnimation(
fig, update,
frames=np.arange(0, num_steps, FLAGS.step_stride),
blit=True, interval=10)
plt.show(block=FLAGS.block_on_show)
if __name__ == "__main__":
app.run(main)
+9
View File
@@ -0,0 +1,9 @@
absl-py
graph-nets>=1.1
tensorflow>=1.15,<2
numpy
dm-sonnet<2
tensorflow_probability<0.9
sklearn
dm-tree
matplotlib
+57
View File
@@ -0,0 +1,57 @@
#!/bin/bash
# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Fail on any error.
set -e
# Display commands being run.
set -x
TMP_DIR=`mktemp -d`
virtualenv --python=python3.7 "${TMP_DIR}/learning_to_simulate"
source "${TMP_DIR}/learning_to_simulate/bin/activate"
# Install dependencies.
pip install --upgrade -r learning_to_simulate/requirements.txt
# Run the simple demo with dummy inputs.
python -m learning_to_simulate.model_demo
# Run some training and evaluation in one of the dataset samples.
# Download a sample of a dataset.
DATASET_NAME="WaterDropSample"
bash ./learning_to_simulate/download_dataset.sh ${DATASET_NAME} "${TMP_DIR}/datasets"
# Train for a few steps.
DATA_PATH="${TMP_DIR}/datasets/${DATASET_NAME}"
MODEL_PATH="${TMP_DIR}/models/${DATASET_NAME}"
python -m learning_to_simulate.train --data_path=${DATA_PATH} --model_path=${MODEL_PATH} --num_steps=10
# Evaluate on validation split.
python -m learning_to_simulate.train --data_path=${DATA_PATH} --model_path=${MODEL_PATH} --mode="eval" --eval_split="valid"
# Generate test rollouts.
ROLLOUT_PATH="${TMP_DIR}/rollouts/${DATASET_NAME}"
mkdir -p ${ROLLOUT_PATH}
python -m learning_to_simulate.train --data_path=${DATA_PATH} --model_path=${MODEL_PATH} --mode="eval_rollout" --output_path=${ROLLOUT_PATH}
# Plot the first rollout.
python -m learning_to_simulate.render_rollout --rollout_path="${ROLLOUT_PATH}/rollout_test_0.pkl" --block_on_show=False
# Clean up.
rm -r ${TMP_DIR}
+477
View File
@@ -0,0 +1,477 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# pylint: disable=line-too-long
"""Training script for https://arxiv.org/pdf/2002.09405.pdf.
Example usage (from parent directory):
`python -m learning_to_simulate.train --data_path={DATA_PATH} --model_path={MODEL_PATH}`
Evaluate model from checkpoint (from parent directory):
`python -m learning_to_simulate.train --data_path={DATA_PATH} --model_path={MODEL_PATH} --mode=eval`
Produce rollouts (from parent directory):
`python -m learning_to_simulate.train --data_path={DATA_PATH} --model_path={MODEL_PATH} --output_path={OUTPUT_PATH} --mode=eval_rollout`
"""
# pylint: enable=line-too-long
import collections
import functools
import json
import os
import pickle
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
import tree
from learning_to_simulate import learned_simulator
from learning_to_simulate import noise_utils
from learning_to_simulate import reading_utils
flags.DEFINE_enum(
'mode', 'train', ['train', 'eval', 'eval_rollout'],
help='Train model, one step evaluation or rollout evaluation.')
flags.DEFINE_enum('eval_split', 'test', ['train', 'valid', 'test'],
help='Split to use when running evaluation.')
flags.DEFINE_string('data_path', None, help='The dataset directory.')
flags.DEFINE_integer('batch_size', 2, help='The batch size.')
flags.DEFINE_integer('num_steps', int(2e7), help='Number of steps of training.')
flags.DEFINE_float('noise_std', 6.7e-4, help='The std deviation of the noise.')
flags.DEFINE_string('model_path', None,
help=('The path for saving checkpoints of the model. '
'Defaults to a temporary directory.'))
flags.DEFINE_string('output_path', None,
help='The path for saving outputs (e.g. rollouts).')
FLAGS = flags.FLAGS
Stats = collections.namedtuple('Stats', ['mean', 'std'])
INPUT_SEQUENCE_LENGTH = 6 # So we can calculate the last 5 velocities.
NUM_PARTICLE_TYPES = 9
KINEMATIC_PARTICLE_ID = 3
def get_kinematic_mask(particle_types):
"""Returns a boolean mask, set to true for kinematic (obstacle) particles."""
return tf.equal(particle_types, KINEMATIC_PARTICLE_ID)
def prepare_inputs(tensor_dict):
"""Prepares a single stack of inputs by calculating inputs and targets.
Computes n_particles_per_example, which is a tensor that contains information
about how to partition the axis - i.e. which nodes belong to which graph.
Adds a batch axis to `n_particles_per_example` and `step_context` so they can
later be batched using `batch_concat`. This batch will be the same as if the
elements had been batched via stacking.
Note that all other tensors have a variable size particle axis,
and in this case they will simply be concatenated along that
axis.
Args:
tensor_dict: A dict of tensors containing positions, and step context (
if available).
Returns:
A tuple of input features and target positions.
"""
# Position is encoded as [sequence_length, num_particles, dim] but the model
# expects [num_particles, sequence_length, dim].
pos = tensor_dict['position']
pos = tf.transpose(pos, perm=[1, 0, 2])
# The target position is the final step of the stack of positions.
target_position = pos[:, -1]
# Remove the target from the input.
tensor_dict['position'] = pos[:, :-1]
# Compute the number of particles per example.
num_particles = tf.shape(pos)[0]
# Add an extra dimension for stacking via concat.
tensor_dict['n_particles_per_example'] = num_particles[tf.newaxis]
if 'step_context' in tensor_dict:
# Take the input global context. We have a stack of global contexts,
# and we take the penultimate since the final is the target.
tensor_dict['step_context'] = tensor_dict['step_context'][-2]
# Add an extra dimension for stacking via concat.
tensor_dict['step_context'] = tensor_dict['step_context'][tf.newaxis]
return tensor_dict, target_position
def prepare_rollout_inputs(context, features):
"""Prepares an inputs trajectory for rollout."""
out_dict = {**context}
# Position is encoded as [sequence_length, num_particles, dim] but the model
# expects [num_particles, sequence_length, dim].
pos = tf.transpose(features['position'], [1, 0, 2])
# The target position is the final step of the stack of positions.
target_position = pos[:, -1]
# Remove the target from the input.
out_dict['position'] = pos[:, :-1]
# Compute the number of nodes
out_dict['n_particles_per_example'] = [tf.shape(pos)[0]]
if 'step_context' in features:
out_dict['step_context'] = features['step_context']
out_dict['is_trajectory'] = tf.constant([True], tf.bool)
return out_dict, target_position
def batch_concat(dataset, batch_size):
"""We implement batching as concatenating on the leading axis."""
# We create a dataset of datasets of length batch_size.
windowed_ds = dataset.window(batch_size)
# The plan is then to reduce every nested dataset by concatenating. We can
# do this using tf.data.Dataset.reduce. This requires an initial state, and
# then incrementally reduces by running through the dataset
# Get initial state. In this case this will be empty tensors of the
# correct shape.
initial_state = tree.map_structure(
lambda spec: tf.zeros( # pylint: disable=g-long-lambda
shape=[0] + spec.shape.as_list()[1:], dtype=spec.dtype),
dataset.element_spec)
# We run through the nest and concatenate each entry with the previous state.
def reduce_window(initial_state, ds):
return ds.reduce(initial_state, lambda x, y: tf.concat([x, y], axis=0))
return windowed_ds.map(
lambda *x: tree.map_structure(reduce_window, initial_state, x))
def get_input_fn(data_path, batch_size, mode, split):
"""Gets the learning simulation input function for tf.estimator.Estimator.
Args:
data_path: the path to the dataset directory.
batch_size: the number of graphs in a batch.
mode: either 'one_step_train', 'one_step' or 'rollout'
split: either 'train', 'valid' or 'test.
Returns:
The input function for the learning simulation model.
"""
def input_fn():
"""Input function for learning simulation."""
# Loads the metadata of the dataset.
metadata = _read_metadata(data_path)
# Create a tf.data.Dataset from the TFRecord.
ds = tf.data.TFRecordDataset([os.path.join(data_path, f'{split}.tfrecord')])
ds = ds.map(functools.partial(
reading_utils.parse_serialized_simulation_example, metadata=metadata))
if mode.startswith('one_step'):
# Splits an entire trajectory into chunks of 7 steps.
# Previous 5 velocities, current velocity and target.
split_with_window = functools.partial(
reading_utils.split_trajectory,
window_length=INPUT_SEQUENCE_LENGTH + 1)
ds = ds.flat_map(split_with_window)
# Splits a chunk into input steps and target steps
ds = ds.map(prepare_inputs)
# If in train mode, repeat dataset forever and shuffle.
if mode == 'one_step_train':
ds = ds.repeat()
ds = ds.shuffle(512)
# Custom batching on the leading axis.
ds = batch_concat(ds, batch_size)
elif mode == 'rollout':
# Rollout evaluation only available for batch size 1
assert batch_size == 1
ds = ds.map(prepare_rollout_inputs)
else:
raise ValueError(f'mode: {mode} not recognized')
return ds
return input_fn
def rollout(simulator, features, num_steps):
"""Rolls out a trajectory by applying the model in sequence."""
initial_positions = features['position'][:, 0:INPUT_SEQUENCE_LENGTH]
ground_truth_positions = features['position'][:, INPUT_SEQUENCE_LENGTH:]
global_context = features.get('step_context')
def step_fn(step, current_positions, predictions):
if global_context is None:
global_context_step = None
else:
global_context_step = global_context[
step + INPUT_SEQUENCE_LENGTH - 1][tf.newaxis]
next_position = simulator(
current_positions,
n_particles_per_example=features['n_particles_per_example'],
particle_types=features['particle_type'],
global_context=global_context_step)
# Update kinematic particles from prescribed trajectory.
kinematic_mask = get_kinematic_mask(features['particle_type'])
next_position_ground_truth = ground_truth_positions[:, step]
next_position = tf.where(kinematic_mask, next_position_ground_truth,
next_position)
updated_predictions = predictions.write(step, next_position)
# Shift `current_positions`, removing the oldest position in the sequence
# and appending the next position at the end.
next_positions = tf.concat([current_positions[:, 1:],
next_position[:, tf.newaxis]], axis=1)
return (step + 1, next_positions, updated_predictions)
predictions = tf.TensorArray(size=num_steps, dtype=tf.float32)
_, _, predictions = tf.while_loop(
cond=lambda step, state, prediction: tf.less(step, num_steps),
body=step_fn,
loop_vars=(0, initial_positions, predictions),
back_prop=False,
parallel_iterations=1)
output_dict = {
'initial_positions': tf.transpose(initial_positions, [1, 0, 2]),
'predicted_rollout': predictions.stack(),
'ground_truth_rollout': tf.transpose(ground_truth_positions, [1, 0, 2]),
'particle_types': features['particle_type'],
}
if global_context is not None:
output_dict['global_context'] = global_context
return output_dict
def _combine_std(std_x, std_y):
return np.sqrt(std_x**2 + std_y**2)
def _get_simulator(model_kwargs, metadata, acc_noise_std, vel_noise_std):
"""Instantiates the simulator."""
# Cast statistics to numpy so they are arrays when entering the model.
cast = lambda v: np.array(v, dtype=np.float32)
acceleration_stats = Stats(
cast(metadata['acc_mean']),
_combine_std(cast(metadata['acc_std']), acc_noise_std))
velocity_stats = Stats(
cast(metadata['vel_mean']),
_combine_std(cast(metadata['vel_std']), vel_noise_std))
normalization_stats = {'acceleration': acceleration_stats,
'velocity': velocity_stats}
if 'context_mean' in metadata:
context_stats = Stats(
cast(metadata['context_mean']), cast(metadata['context_std']))
normalization_stats['context'] = context_stats
simulator = learned_simulator.LearnedSimulator(
num_dimensions=metadata['dim'],
connectivity_radius=metadata['default_connectivity_radius'],
graph_network_kwargs=model_kwargs,
boundaries=metadata['bounds'],
num_particle_types=NUM_PARTICLE_TYPES,
normalization_stats=normalization_stats,
particle_type_embedding_size=16)
return simulator
def get_one_step_estimator_fn(data_path,
noise_std,
latent_size=128,
hidden_size=128,
hidden_layers=2,
message_passing_steps=10):
"""Gets one step model for training simulation."""
metadata = _read_metadata(data_path)
model_kwargs = dict(
latent_size=latent_size,
mlp_hidden_size=hidden_size,
mlp_num_hidden_layers=hidden_layers,
num_message_passing_steps=message_passing_steps)
def estimator_fn(features, labels, mode):
target_next_position = labels
simulator = _get_simulator(model_kwargs, metadata,
vel_noise_std=noise_std,
acc_noise_std=noise_std)
# Sample the noise to add to the inputs to the model during training.
sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence(
features['position'], noise_std_last_step=noise_std)
non_kinematic_mask = tf.logical_not(
get_kinematic_mask(features['particle_type']))
noise_mask = tf.cast(
non_kinematic_mask, sampled_noise.dtype)[:, tf.newaxis, tf.newaxis]
sampled_noise *= noise_mask
# Get the predictions and target accelerations.
pred_target = simulator.get_predicted_and_target_normalized_accelerations(
next_position=target_next_position,
position_sequence=features['position'],
position_sequence_noise=sampled_noise,
n_particles_per_example=features['n_particles_per_example'],
particle_types=features['particle_type'],
global_context=features.get('step_context'))
pred_acceleration, target_acceleration = pred_target
# Calculate the loss and mask out loss on kinematic particles/
loss = (pred_acceleration - target_acceleration)**2
num_non_kinematic = tf.reduce_sum(
tf.cast(non_kinematic_mask, tf.float32))
loss = tf.where(non_kinematic_mask, loss, tf.zeros_like(loss))
loss = tf.reduce_sum(loss) / tf.reduce_sum(num_non_kinematic)
global_step = tf.train.get_global_step()
# Set learning rate to decay from 1e-4 to 1e-6 exponentially.
min_lr = 1e-6
lr = tf.train.exponential_decay(learning_rate=1e-4 - min_lr,
global_step=global_step,
decay_steps=int(5e6),
decay_rate=0.1) + min_lr
opt = tf.train.AdamOptimizer(learning_rate=lr)
train_op = opt.minimize(loss, global_step)
# Calculate next position and add some additional eval metrics (only eval).
predicted_next_position = simulator(
position_sequence=features['position'],
n_particles_per_example=features['n_particles_per_example'],
particle_types=features['particle_type'],
global_context=features.get('step_context'))
predictions = {'predicted_next_position': predicted_next_position}
eval_metrics_ops = {
'loss_mse': tf.metrics.mean_squared_error(
pred_acceleration, target_acceleration),
'one_step_position_mse': tf.metrics.mean_squared_error(
predicted_next_position, target_next_position)
}
return tf.estimator.EstimatorSpec(
mode=mode,
train_op=train_op,
loss=loss,
predictions=predictions,
eval_metric_ops=eval_metrics_ops)
return estimator_fn
def get_rollout_estimator_fn(data_path,
noise_std,
latent_size=128,
hidden_size=128,
hidden_layers=2,
message_passing_steps=10):
"""Gets the model function for tf.estimator.Estimator."""
metadata = _read_metadata(data_path)
model_kwargs = dict(
latent_size=latent_size,
mlp_hidden_size=hidden_size,
mlp_num_hidden_layers=hidden_layers,
num_message_passing_steps=message_passing_steps)
def estimator_fn(features, labels, mode):
del labels # Labels to conform to estimator spec.
simulator = _get_simulator(model_kwargs, metadata,
acc_noise_std=noise_std,
vel_noise_std=noise_std)
num_steps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH
rollout_op = rollout(simulator, features, num_steps=num_steps)
squared_error = (rollout_op['predicted_rollout'] -
rollout_op['ground_truth_rollout']) ** 2
loss = tf.reduce_mean(squared_error)
eval_ops = {'rollout_error_mse': tf.metrics.mean_squared_error(
rollout_op['predicted_rollout'], rollout_op['ground_truth_rollout'])}
# Add a leading axis, since Estimator's predict method insists that all
# tensors have a shared leading batch axis fo the same dims.
rollout_op = tree.map_structure(lambda x: x[tf.newaxis], rollout_op)
return tf.estimator.EstimatorSpec(
mode=mode,
train_op=None,
loss=loss,
predictions=rollout_op,
eval_metric_ops=eval_ops)
return estimator_fn
def _read_metadata(data_path):
with open(os.path.join(data_path, 'metadata.json'), 'rt') as fp:
return json.loads(fp.read())
def main(_):
"""Train or evaluates the model."""
if FLAGS.mode in ['train', 'eval']:
estimator = tf.estimator.Estimator(
get_one_step_estimator_fn(FLAGS.data_path, FLAGS.noise_std),
model_dir=FLAGS.model_path)
if FLAGS.mode == 'train':
# Train all the way through.
estimator.train(
input_fn=get_input_fn(FLAGS.data_path, FLAGS.batch_size,
mode='one_step_train', split='train'),
max_steps=FLAGS.num_steps)
else:
# One-step evaluation from checkpoint.
eval_metrics = estimator.evaluate(input_fn=get_input_fn(
FLAGS.data_path, FLAGS.batch_size,
mode='one_step', split=FLAGS.eval_split))
logging.info('Evaluation metrics:')
logging.info(eval_metrics)
elif FLAGS.mode == 'eval_rollout':
if not FLAGS.output_path:
raise ValueError('A rollout path must be provided.')
rollout_estimator = tf.estimator.Estimator(
get_rollout_estimator_fn(FLAGS.data_path, FLAGS.noise_std),
model_dir=FLAGS.model_path)
# Iterate through rollouts saving them one by one.
metadata = _read_metadata(FLAGS.data_path)
rollout_iterator = rollout_estimator.predict(
input_fn=get_input_fn(FLAGS.data_path, batch_size=1,
mode='rollout', split=FLAGS.eval_split))
for example_index, example_rollout in enumerate(rollout_iterator):
example_rollout['metadata'] = metadata
filename = f'rollout_{FLAGS.eval_split}_{example_index}.pkl'
filename = os.path.join(FLAGS.output_path, filename)
logging.info('Saving: %s.', filename)
if not os.path.exists(FLAGS.output_path):
os.mkdir(FLAGS.output_path)
with open(filename, 'wb') as file:
pickle.dump(example_rollout, file)
if __name__ == '__main__':
tf.disable_v2_behavior()
app.run(main)