Initial release of Learning Mesh-Based Simulation with Graph Networks

PiperOrigin-RevId: 380185977
This commit is contained in:
Louise Deason
2021-06-18 14:56:42 +00:00
parent db9ef1d94e
commit fa8c9be4bb
17 changed files with 1100 additions and 0 deletions
+1
View File
@@ -25,6 +25,7 @@ env:
- PROJECT="tvt"
# - PROJECT="unrestricted_advx" # TODO(b/184862249): Fix and enable
- PROJECT="unsupervised_adversarial_training"
- PROJECT="meshgraphnets"
before_script:
- pip install --upgrade pip
- pip install --upgrade virtualenv
+1
View File
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
## Projects
* [Learning Mesh-Based Simulation with Graph Networks](meshgraphnets), ICLR 2021
* [Open Graph Benchmark - Large-Scale Challenge (OGB-LSC)](ogb_lsc)
* [Synthetic Returns for Long-Term Credit Assignment](synthetic_returns)
* [A Deep Learning Approach for Characterizing Major Galaxy Mergers](galaxy_mergers)
+67
View File
@@ -0,0 +1,67 @@
# Learning Mesh-Based Simulation with Graph Networks (ICLR 2021)
Video site: [sites.google.com/view/meshgraphnets](https://sites.google.com/view/meshgraphnets)
Paper: [arxiv.org/abs/2010.03409](https://arxiv.org/abs/2010.03409)
If you use the code here please cite this paper:
@inproceedings{pfaff2021learning,
title={Learning Mesh-Based Simulation with Graph Networks},
author={Tobias Pfaff and
Meire Fortunato and
Alvaro Sanchez-Gonzalez and
Peter W. Battaglia},
booktitle={International Conference on Learning Representations},
year={2021}
}
## Setup
Prepare environment, install dependencies:
virtualenv --python=python3.6 "${ENV}"
${ENV}/bin/activate
pip install -r meshgraphnets/requirements.txt
Download a dataset:
mkdir -p ${DATA}
bash meshgraphnets/download_dataset.sh flag_simple ${DATA}
## Running the model
Train a model:
python -m meshgraphnets.run_model --mode=train --model=cloth \
--checkpoint_dir=${DATA}/chk --dataset_dir=${DATA}/flag_simple
Generate some trajectory rollouts:
python -m meshgraphnets.run_model --mode=eval --model=cloth \
--checkpoint_dir=${DATA}/chk --dataset_dir=${DATA}/flag_simple \
--rollout_path=${DATA}/rollout_flag.pkl
Plot a trajectory:
python -m meshgraphnets.plot_cloth --rollout_path=${DATA}/rollout_flag.pkl
## Datasets
Datasets can be downloaded using the script `download_dataset.sh`. They contain
a metadata file describing the available fields and their shape, and tfrecord
datasets for train, valid and test splits.
Dataset names match the naming in the paper.
The following datasets are available:
airfoil
cylinder_flow
deforming_plate
flag_minimal
flag_simple
flag_dynamic
sphere_simple
sphere_dynamic
`flag_minimal` is a truncated version of flag_simple, and is only used for
integration tests.
+62
View File
@@ -0,0 +1,62 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Functions to build evaluation metrics for CFD data."""
import tensorflow.compat.v1 as tf
from meshgraphnets.common import NodeType
def _rollout(model, initial_state, num_steps):
"""Rolls out a model trajectory."""
node_type = initial_state['node_type'][:, 0]
mask = tf.logical_or(tf.equal(node_type, NodeType.NORMAL),
tf.equal(node_type, NodeType.OUTFLOW))
def step_fn(step, velocity, trajectory):
prediction = model({**initial_state,
'velocity': velocity})
# don't update boundary nodes
next_velocity = tf.where(mask, prediction, velocity)
trajectory = trajectory.write(step, velocity)
return step+1, next_velocity, trajectory
_, _, output = tf.while_loop(
cond=lambda step, cur, traj: tf.less(step, num_steps),
body=step_fn,
loop_vars=(0, initial_state['velocity'],
tf.TensorArray(tf.float32, num_steps)),
parallel_iterations=1)
return output.stack()
def evaluate(model, inputs):
"""Performs model rollouts and create stats."""
initial_state = {k: v[0] for k, v in inputs.items()}
num_steps = inputs['cells'].shape[0]
prediction = _rollout(model, initial_state, num_steps)
error = tf.reduce_mean((prediction - inputs['velocity'])**2, axis=-1)
scalars = {'mse_%d_steps' % horizon: tf.reduce_mean(error[1:horizon+1])
for horizon in [1, 10, 20, 50, 100, 200]}
traj_ops = {
'faces': inputs['cells'],
'mesh_pos': inputs['mesh_pos'],
'gt_velocity': inputs['velocity'],
'pred_velocity': prediction
}
return scalars, traj_ops
+94
View File
@@ -0,0 +1,94 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Model for CylinderFlow."""
import sonnet as snt
import tensorflow.compat.v1 as tf
from meshgraphnets import common
from meshgraphnets import core_model
from meshgraphnets import normalization
class Model(snt.AbstractModule):
"""Model for fluid simulation."""
def __init__(self, learned_model, name='Model'):
super(Model, self).__init__(name=name)
with self._enter_variable_scope():
self._learned_model = learned_model
self._output_normalizer = normalization.Normalizer(
size=2, name='output_normalizer')
self._node_normalizer = normalization.Normalizer(
size=2+common.NodeType.SIZE, name='node_normalizer')
self._edge_normalizer = normalization.Normalizer(
size=3, name='edge_normalizer') # 2D coord + length
def _build_graph(self, inputs, is_training):
"""Builds input graph."""
# construct graph nodes
node_type = tf.one_hot(inputs['node_type'][:, 0], common.NodeType.SIZE)
node_features = tf.concat([inputs['velocity'], node_type], axis=-1)
# construct graph edges
senders, receivers = common.triangles_to_edges(inputs['cells'])
relative_mesh_pos = (tf.gather(inputs['mesh_pos'], senders) -
tf.gather(inputs['mesh_pos'], receivers))
edge_features = tf.concat([
relative_mesh_pos,
tf.norm(relative_mesh_pos, axis=-1, keepdims=True)], axis=-1)
mesh_edges = core_model.EdgeSet(
name='mesh_edges',
features=self._edge_normalizer(edge_features, is_training),
receivers=receivers,
senders=senders)
return core_model.MultiGraph(
node_features=self._node_normalizer(node_features, is_training),
edge_sets=[mesh_edges])
def _build(self, inputs):
graph = self._build_graph(inputs, is_training=False)
per_node_network_output = self._learned_model(graph)
return self._update(inputs, per_node_network_output)
@snt.reuse_variables
def loss(self, inputs):
"""L2 loss on velocity."""
graph = self._build_graph(inputs, is_training=True)
network_output = self._learned_model(graph)
# build target velocity change
cur_velocity = inputs['velocity']
target_velocity = inputs['target|velocity']
target_velocity_change = target_velocity - cur_velocity
target_normalized = self._output_normalizer(target_velocity_change)
# build loss
node_type = inputs['node_type'][:, 0]
loss_mask = tf.logical_or(tf.equal(node_type, common.NodeType.NORMAL),
tf.equal(node_type, common.NodeType.OUTFLOW))
error = tf.reduce_sum((target_normalized - network_output)**2, axis=1)
loss = tf.reduce_mean(error[loss_mask])
return loss
def _update(self, inputs, per_node_network_output):
"""Integrate model outputs."""
velocity_update = self._output_normalizer.inverse(per_node_network_output)
# integrate forward
cur_velocity = inputs['velocity']
return cur_velocity + velocity_update
+61
View File
@@ -0,0 +1,61 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Functions to build evaluation metrics for cloth data."""
import tensorflow.compat.v1 as tf
from meshgraphnets.common import NodeType
def _rollout(model, initial_state, num_steps):
"""Rolls out a model trajectory."""
mask = tf.equal(initial_state['node_type'][:, 0], NodeType.NORMAL)
def step_fn(step, prev_pos, cur_pos, trajectory):
prediction = model({**initial_state,
'prev|world_pos': prev_pos,
'world_pos': cur_pos})
# don't update kinematic nodes
next_pos = tf.where(mask, prediction, cur_pos)
trajectory = trajectory.write(step, cur_pos)
return step+1, cur_pos, next_pos, trajectory
_, _, _, output = tf.while_loop(
cond=lambda step, last, cur, traj: tf.less(step, num_steps),
body=step_fn,
loop_vars=(0, initial_state['prev|world_pos'], initial_state['world_pos'],
tf.TensorArray(tf.float32, num_steps)),
parallel_iterations=1)
return output.stack()
def evaluate(model, inputs):
"""Performs model rollouts and create stats."""
initial_state = {k: v[0] for k, v in inputs.items()}
num_steps = inputs['cells'].shape[0]
prediction = _rollout(model, initial_state, num_steps)
error = tf.reduce_mean((prediction - inputs['world_pos'])**2, axis=-1)
scalars = {'mse_%d_steps' % horizon: tf.reduce_mean(error[1:horizon+1])
for horizon in [1, 10, 20, 50, 100, 200]}
traj_ops = {
'faces': inputs['cells'],
'mesh_pos': inputs['mesh_pos'],
'gt_pos': inputs['world_pos'],
'pred_pos': prediction
}
return scalars, traj_ops
+100
View File
@@ -0,0 +1,100 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Model for FlagSimple."""
import sonnet as snt
import tensorflow.compat.v1 as tf
from meshgraphnets import common
from meshgraphnets import core_model
from meshgraphnets import normalization
class Model(snt.AbstractModule):
"""Model for static cloth simulation."""
def __init__(self, learned_model, name='Model'):
super(Model, self).__init__(name=name)
with self._enter_variable_scope():
self._learned_model = learned_model
self._output_normalizer = normalization.Normalizer(
size=3, name='output_normalizer')
self._node_normalizer = normalization.Normalizer(
size=3+common.NodeType.SIZE, name='node_normalizer')
self._edge_normalizer = normalization.Normalizer(
size=7, name='edge_normalizer') # 2D coord + 3D coord + 2*length = 7
def _build_graph(self, inputs, is_training):
"""Builds input graph."""
# construct graph nodes
velocity = inputs['world_pos'] - inputs['prev|world_pos']
node_type = tf.one_hot(inputs['node_type'][:, 0], common.NodeType.SIZE)
node_features = tf.concat([velocity, node_type], axis=-1)
# construct graph edges
senders, receivers = common.triangles_to_edges(inputs['cells'])
relative_world_pos = (tf.gather(inputs['world_pos'], senders) -
tf.gather(inputs['world_pos'], receivers))
relative_mesh_pos = (tf.gather(inputs['mesh_pos'], senders) -
tf.gather(inputs['mesh_pos'], receivers))
edge_features = tf.concat([
relative_world_pos,
tf.norm(relative_world_pos, axis=-1, keepdims=True),
relative_mesh_pos,
tf.norm(relative_mesh_pos, axis=-1, keepdims=True)], axis=-1)
mesh_edges = core_model.EdgeSet(
name='mesh_edges',
features=self._edge_normalizer(edge_features, is_training),
receivers=receivers,
senders=senders)
return core_model.MultiGraph(
node_features=self._node_normalizer(node_features, is_training),
edge_sets=[mesh_edges])
def _build(self, inputs):
graph = self._build_graph(inputs, is_training=False)
per_node_network_output = self._learned_model(graph)
return self._update(inputs, per_node_network_output)
@snt.reuse_variables
def loss(self, inputs):
"""L2 loss on position."""
graph = self._build_graph(inputs, is_training=True)
network_output = self._learned_model(graph)
# build target acceleration
cur_position = inputs['world_pos']
prev_position = inputs['prev|world_pos']
target_position = inputs['target|world_pos']
target_acceleration = target_position - 2*cur_position + prev_position
target_normalized = self._output_normalizer(target_acceleration)
# build loss
loss_mask = tf.equal(inputs['node_type'][:, 0], common.NodeType.NORMAL)
error = tf.reduce_sum((target_normalized - network_output)**2, axis=1)
loss = tf.reduce_mean(error[loss_mask])
return loss
def _update(self, inputs, per_node_network_output):
"""Integrate model outputs."""
acceleration = self._output_normalizer.inverse(per_node_network_output)
# integrate forward
cur_position = inputs['world_pos']
prev_position = inputs['prev|world_pos']
position = 2*cur_position + acceleration - prev_position
return position
+51
View File
@@ -0,0 +1,51 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Commonly used data structures and functions."""
import enum
import tensorflow.compat.v1 as tf
class NodeType(enum.IntEnum):
NORMAL = 0
OBSTACLE = 1
AIRFOIL = 2
HANDLE = 3
INFLOW = 4
OUTFLOW = 5
WALL_BOUNDARY = 6
SIZE = 9
def triangles_to_edges(faces):
"""Computes mesh edges from triangles."""
# collect edges from triangles
edges = tf.concat([faces[:, 0:2],
faces[:, 1:3],
tf.stack([faces[:, 2], faces[:, 0]], axis=1)], axis=0)
# those edges are sometimes duplicated (within the mesh) and sometimes
# single (at the mesh boundary).
# sort & pack edges as single tf.int64
receivers = tf.reduce_min(edges, axis=1)
senders = tf.reduce_max(edges, axis=1)
packed_edges = tf.bitcast(tf.stack([senders, receivers], axis=1), tf.int64)
# remove duplicates and unpack
unique_edges = tf.bitcast(tf.unique(packed_edges)[0], tf.int32)
senders, receivers = tf.unstack(unique_edges, axis=1)
# create two-way connectivity
return (tf.concat([senders, receivers], axis=0),
tf.concat([receivers, senders], axis=0))
+121
View File
@@ -0,0 +1,121 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Core learned graph net model."""
import collections
import functools
import sonnet as snt
import tensorflow.compat.v1 as tf
EdgeSet = collections.namedtuple('EdgeSet', ['name', 'features', 'senders',
'receivers'])
MultiGraph = collections.namedtuple('Graph', ['node_features', 'edge_sets'])
class GraphNetBlock(snt.AbstractModule):
"""Multi-Edge Interaction Network with residual connections."""
def __init__(self, model_fn, name='GraphNetBlock'):
super(GraphNetBlock, self).__init__(name=name)
self._model_fn = model_fn
def _update_edge_features(self, node_features, edge_set):
"""Aggregrates node features, and applies edge function."""
sender_features = tf.gather(node_features, edge_set.senders)
receiver_features = tf.gather(node_features, edge_set.receivers)
features = [sender_features, receiver_features, edge_set.features]
with tf.variable_scope(edge_set.name+'_edge_fn'):
return self._model_fn()(tf.concat(features, axis=-1))
def _update_node_features(self, node_features, edge_sets):
"""Aggregrates edge features, and applies node function."""
num_nodes = tf.shape(node_features)[0]
features = [node_features]
for edge_set in edge_sets:
features.append(tf.math.unsorted_segment_sum(edge_set.features,
edge_set.receivers,
num_nodes))
with tf.variable_scope('node_fn'):
return self._model_fn()(tf.concat(features, axis=-1))
def _build(self, graph):
"""Applies GraphNetBlock and returns updated MultiGraph."""
# apply edge functions
new_edge_sets = []
for edge_set in graph.edge_sets:
updated_features = self._update_edge_features(graph.node_features,
edge_set)
new_edge_sets.append(edge_set._replace(features=updated_features))
# apply node function
new_node_features = self._update_node_features(graph.node_features,
new_edge_sets)
# add residual connections
new_node_features += graph.node_features
new_edge_sets = [es._replace(features=es.features + old_es.features)
for es, old_es in zip(new_edge_sets, graph.edge_sets)]
return MultiGraph(new_node_features, new_edge_sets)
class EncodeProcessDecode(snt.AbstractModule):
"""Encode-Process-Decode GraphNet model."""
def __init__(self,
output_size,
latent_size,
num_layers,
message_passing_steps,
name='EncodeProcessDecode'):
super(EncodeProcessDecode, self).__init__(name=name)
self._latent_size = latent_size
self._output_size = output_size
self._num_layers = num_layers
self._message_passing_steps = message_passing_steps
def _make_mlp(self, output_size, layer_norm=True):
"""Builds an MLP."""
widths = [self._latent_size] * self._num_layers + [output_size]
network = snt.nets.MLP(widths, activate_final=False)
if layer_norm:
network = snt.Sequential([network, snt.LayerNorm()])
return network
def _encoder(self, graph):
"""Encodes node and edge features into latent features."""
with tf.variable_scope('encoder'):
node_latents = self._make_mlp(self._latent_size)(graph.node_features)
new_edges_sets = []
for edge_set in graph.edge_sets:
latent = self._make_mlp(self._latent_size)(edge_set.features)
new_edges_sets.append(edge_set._replace(features=latent))
return MultiGraph(node_latents, new_edges_sets)
def _decoder(self, graph):
"""Decodes node features from graph."""
with tf.variable_scope('decoder'):
decoder = self._make_mlp(self._output_size, layer_norm=False)
return decoder(graph.node_features)
def _build(self, graph):
"""Encodes and processes a multigraph, and returns node features."""
model_fn = functools.partial(self._make_mlp, output_size=self._latent_size)
latent_graph = self._encoder(graph)
for _ in range(self._message_passing_steps):
latent_graph = GraphNetBlock(model_fn)(latent_graph)
return self._decoder(latent_graph)
+119
View File
@@ -0,0 +1,119 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utility functions for reading the datasets."""
import functools
import json
import os
import tensorflow.compat.v1 as tf
from meshgraphnets.common import NodeType
def _parse(proto, meta):
"""Parses a trajectory from tf.Example."""
feature_lists = {k: tf.io.VarLenFeature(tf.string)
for k in meta['field_names']}
features = tf.io.parse_single_example(proto, feature_lists)
out = {}
for key, field in meta['features'].items():
data = tf.io.decode_raw(features[key].values, getattr(tf, field['dtype']))
data = tf.reshape(data, field['shape'])
if field['type'] == 'static':
data = tf.tile(data, [meta['trajectory_length'], 1, 1])
elif field['type'] == 'dynamic_varlen':
length = tf.io.decode_raw(features['length_'+key].values, tf.int32)
length = tf.reshape(length, [-1])
data = tf.RaggedTensor.from_row_lengths(data, row_lengths=length)
elif field['type'] != 'dynamic':
raise ValueError('invalid data format')
out[key] = data
return out
def load_dataset(path, split):
"""Load dataset."""
with open(os.path.join(path, 'meta.json'), 'r') as fp:
meta = json.loads(fp.read())
ds = tf.data.TFRecordDataset(os.path.join(path, split+'.tfrecord'))
ds = ds.map(functools.partial(_parse, meta=meta), num_parallel_calls=8)
ds = ds.prefetch(1)
return ds
def add_targets(ds, fields, add_history):
"""Adds target and optionally history fields to dataframe."""
def fn(trajectory):
out = {}
for key, val in trajectory.items():
out[key] = val[1:-1]
if key in fields:
if add_history:
out['prev|'+key] = val[0:-2]
out['target|'+key] = val[2:]
return out
return ds.map(fn, num_parallel_calls=8)
def split_and_preprocess(ds, noise_field, noise_scale, noise_gamma):
"""Splits trajectories into frames, and adds training noise."""
def add_noise(frame):
noise = tf.random.normal(tf.shape(frame[noise_field]),
stddev=noise_scale, dtype=tf.float32)
# don't apply noise to boundary nodes
mask = tf.equal(frame['node_type'], NodeType.NORMAL)[:, 0]
noise = tf.where(mask, noise, tf.zeros_like(noise))
frame[noise_field] += noise
frame['target|'+noise_field] += (1.0 - noise_gamma) * noise
return frame
ds = ds.flat_map(tf.data.Dataset.from_tensor_slices)
ds = ds.map(add_noise, num_parallel_calls=8)
ds = ds.shuffle(10000)
ds = ds.repeat(None)
return ds.prefetch(10)
def batch_dataset(ds, batch_size):
"""Batches input datasets."""
shapes = ds.output_shapes
types = ds.output_types
def renumber(buffer, frame):
nodes, cells = buffer
new_nodes, new_cells = frame
return nodes + new_nodes, tf.concat([cells, new_cells+nodes], axis=0)
def batch_accumulate(ds_window):
out = {}
for key, ds_val in ds_window.items():
initial = tf.zeros((0, shapes[key][1]), dtype=types[key])
if key == 'cells':
# renumber node indices in cells
num_nodes = ds_window['node_type'].map(lambda x: tf.shape(x)[0])
cells = tf.data.Dataset.zip((num_nodes, ds_val))
initial = (tf.constant(0, tf.int32), initial)
_, out[key] = cells.reduce(initial, renumber)
else:
merge = lambda prev, cur: tf.concat([prev, cur], axis=0)
out[key] = ds_val.reduce(initial, merge)
return out
if batch_size > 1:
ds = ds.window(batch_size, drop_remainder=True)
ds = ds.map(batch_accumulate, num_parallel_calls=8)
return ds
+32
View File
@@ -0,0 +1,32 @@
#!/bin/bash
# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Usage:
# sh download_dataset.sh ${DATASET_NAME} ${OUTPUT_DIR}
# Example:
# sh download_dataset.sh flag_simple /tmp/
set -e
DATASET_NAME="${1}"
OUTPUT_DIR="${2}/${DATASET_NAME}"
BASE_URL="https://storage.googleapis.com/dm-meshgraphnets/${DATASET_NAME}/"
mkdir -p ${OUTPUT_DIR}
for file in meta.json train.tfrecord valid.tfrecord test.tfrecord
do
wget -O "${OUTPUT_DIR}/${file}" "${BASE_URL}${file}"
done
+73
View File
@@ -0,0 +1,73 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Online data normalization."""
import sonnet as snt
import tensorflow.compat.v1 as tf
class Normalizer(snt.AbstractModule):
"""Feature normalizer that accumulates statistics online."""
def __init__(self, size, max_accumulations=10**6, std_epsilon=1e-8,
name='Normalizer'):
super(Normalizer, self).__init__(name=name)
self._max_accumulations = max_accumulations
self._std_epsilon = std_epsilon
with self._enter_variable_scope():
self._acc_count = tf.Variable(0, dtype=tf.float32, trainable=False)
self._num_accumulations = tf.Variable(0, dtype=tf.float32,
trainable=False)
self._acc_sum = tf.Variable(tf.zeros(size, tf.float32), trainable=False)
self._acc_sum_squared = tf.Variable(tf.zeros(size, tf.float32),
trainable=False)
def _build(self, batched_data, accumulate=True):
"""Normalizes input data and accumulates statistics."""
update_op = tf.no_op()
if accumulate:
# stop accumulating after a million updates, to prevent accuracy issues
update_op = tf.cond(self._num_accumulations < self._max_accumulations,
lambda: self._accumulate(batched_data),
tf.no_op)
with tf.control_dependencies([update_op]):
return (batched_data - self._mean()) / self._std_with_epsilon()
@snt.reuse_variables
def inverse(self, normalized_batch_data):
"""Inverse transformation of the normalizer."""
return normalized_batch_data * self._std_with_epsilon() + self._mean()
def _accumulate(self, batched_data):
"""Function to perform the accumulation of the batch_data statistics."""
count = tf.cast(tf.shape(batched_data)[0], tf.float32)
data_sum = tf.reduce_sum(batched_data, axis=0)
squared_data_sum = tf.reduce_sum(batched_data**2, axis=0)
return tf.group(
tf.assign_add(self._acc_sum, data_sum),
tf.assign_add(self._acc_sum_squared, squared_data_sum),
tf.assign_add(self._acc_count, count),
tf.assign_add(self._num_accumulations, 1.))
def _mean(self):
safe_count = tf.maximum(self._acc_count, 1.)
return self._acc_sum / safe_count
def _std_with_epsilon(self):
safe_count = tf.maximum(self._acc_count, 1.)
std = tf.sqrt(self._acc_sum_squared / safe_count - self._mean()**2)
return tf.math.maximum(std, self._std_epsilon)
+68
View File
@@ -0,0 +1,68 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Plots a CFD trajectory rollout."""
import pickle
from absl import app
from absl import flags
from matplotlib import animation
from matplotlib import tri as mtri
import matplotlib.pyplot as plt
FLAGS = flags.FLAGS
flags.DEFINE_string('rollout_path', None, 'Path to rollout pickle file')
def main(unused_argv):
with open(FLAGS.rollout_path, 'rb') as fp:
rollout_data = pickle.load(fp)
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
skip = 10
num_steps = rollout_data[0]['gt_velocity'].shape[0]
num_frames = len(rollout_data) * num_steps // skip
# compute bounds
bounds = []
for trajectory in rollout_data:
bb_min = trajectory['gt_velocity'].min(axis=(0, 1))
bb_max = trajectory['gt_velocity'].max(axis=(0, 1))
bounds.append((bb_min, bb_max))
def animate(num):
step = (num*skip) % num_steps
traj = (num*skip) // num_steps
ax.cla()
ax.set_aspect('equal')
ax.set_axis_off()
vmin, vmax = bounds[traj]
pos = rollout_data[traj]['mesh_pos'][step]
faces = rollout_data[traj]['faces'][step]
velocity = rollout_data[traj]['pred_velocity'][step]
triang = mtri.Triangulation(pos[:, 0], pos[:, 1], faces)
ax.tripcolor(triang, velocity[:, 0], vmin=vmin[0], vmax=vmax[0])
ax.triplot(triang, 'ko-', ms=0.5, lw=0.3)
ax.set_title('Trajectory %d Step %d' % (traj, step))
return fig,
_ = animation.FuncAnimation(fig, animate, frames=num_frames, interval=100)
plt.show(block=True)
if __name__ == '__main__':
app.run(main)
+67
View File
@@ -0,0 +1,67 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Plots a cloth trajectory rollout."""
import pickle
from absl import app
from absl import flags
from matplotlib import animation
import matplotlib.pyplot as plt
FLAGS = flags.FLAGS
flags.DEFINE_string('rollout_path', None, 'Path to rollout pickle file')
def main(unused_argv):
with open(FLAGS.rollout_path, 'rb') as fp:
rollout_data = pickle.load(fp)
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
skip = 10
num_steps = rollout_data[0]['gt_pos'].shape[0]
num_frames = len(rollout_data) * num_steps // skip
# compute bounds
bounds = []
for trajectory in rollout_data:
bb_min = trajectory['gt_pos'].min(axis=(0, 1))
bb_max = trajectory['gt_pos'].max(axis=(0, 1))
bounds.append((bb_min, bb_max))
def animate(num):
step = (num*skip) % num_steps
traj = (num*skip) // num_steps
ax.cla()
bound = bounds[traj]
ax.set_xlim([bound[0][0], bound[1][0]])
ax.set_ylim([bound[0][1], bound[1][1]])
ax.set_zlim([bound[0][2], bound[1][2]])
pos = rollout_data[traj]['pred_pos'][step]
faces = rollout_data[traj]['faces'][step]
ax.plot_trisurf(pos[:, 0], pos[:, 1], faces, pos[:, 2], shade=True)
ax.set_title('Trajectory %d Step %d' % (traj, step))
return fig,
_ = animation.FuncAnimation(fig, animate, frames=num_frames, interval=100)
plt.show(block=True)
if __name__ == '__main__':
app.run(main)
+5
View File
@@ -0,0 +1,5 @@
tensorflow-gpu>=1.15,<2
dm-sonnet<2
matplotlib
absl-py
numpy
+47
View File
@@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Fail on any error.
set -e
# Display commands being run.
set -x
TMP_DIR=`mktemp -d`
virtualenv --python=python3.6 "${TMP_DIR}/env"
source "${TMP_DIR}/env/bin/activate"
# Install dependencies.
pip install --upgrade -r meshgraphnets/requirements.txt
# Download minimal dataset
DATA_DIR="${TMP_DIR}/flag_minimal"
bash meshgraphnets/download_dataset.sh flag_minimal ${TMP_DIR}
# Train for a few steps.
CHK_DIR="${TMP_DIR}/checkpoint"
python -m meshgraphnets.run_model --model=cloth --mode=train --checkpoint_dir=${CHK_DIR} --dataset_dir=${DATA_DIR} --num_training_steps=10
# Generate a rollout trajectory
ROLLOUT_PATH="${TMP_DIR}/rollout.pkl"
python -m meshgraphnets.run_model --model=cloth --mode=eval --checkpoint_dir=${CHK_DIR} --dataset_dir=${DATA_DIR} --rollout_path=${ROLLOUT_PATH} --num_rollouts=1
# Plot the rollout trajectory
python -m meshgraphnets.plot_cloth --rollout_path=${ROLLOUT_PATH}
# Clean up.
rm -r ${TMP_DIR}
echo "Test run complete."
+131
View File
@@ -0,0 +1,131 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Runs the learner/evaluator."""
import pickle
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
from meshgraphnets import cfd_eval
from meshgraphnets import cfd_model
from meshgraphnets import cloth_eval
from meshgraphnets import cloth_model
from meshgraphnets import core_model
from meshgraphnets import dataset
FLAGS = flags.FLAGS
flags.DEFINE_enum('mode', 'train', ['train', 'eval'],
'Train model, or run evaluation.')
flags.DEFINE_enum('model', None, ['cfd', 'cloth'],
'Select model to run.')
flags.DEFINE_string('checkpoint_dir', None, 'Directory to save checkpoint')
flags.DEFINE_string('dataset_dir', None, 'Directory to load dataset from.')
flags.DEFINE_string('rollout_path', None,
'Pickle file to save eval trajectories')
flags.DEFINE_enum('rollout_split', 'valid', ['train', 'test', 'valid'],
'Dataset split to use for rollouts.')
flags.DEFINE_integer('num_rollouts', 10, 'No. of rollout trajectories')
flags.DEFINE_integer('num_training_steps', int(10e6), 'No. of training steps')
PARAMETERS = {
'cfd': dict(noise=0.02, gamma=1.0, field='velocity', history=False,
size=2, batch=2, model=cfd_model, evaluator=cfd_eval),
'cloth': dict(noise=0.003, gamma=0.1, field='world_pos', history=True,
size=3, batch=1, model=cloth_model, evaluator=cloth_eval)
}
def learner(model, params):
"""Run a learner job."""
ds = dataset.load_dataset(FLAGS.dataset_dir, 'train')
ds = dataset.add_targets(ds, [params['field']], add_history=params['history'])
ds = dataset.split_and_preprocess(ds, noise_field=params['field'],
noise_scale=params['noise'],
noise_gamma=params['gamma'])
inputs = tf.data.make_one_shot_iterator(ds).get_next()
loss_op = model.loss(inputs)
global_step = tf.train.create_global_step()
lr = tf.train.exponential_decay(learning_rate=1e-4,
global_step=global_step,
decay_steps=int(5e6),
decay_rate=0.1) + 1e-6
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
train_op = optimizer.minimize(loss_op, global_step=global_step)
# Don't train for the first few steps, just accumulate normalization stats
train_op = tf.cond(tf.less(global_step, 1000),
lambda: tf.group(tf.assign_add(global_step, 1)),
lambda: tf.group(train_op))
with tf.train.MonitoredTrainingSession(
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.num_training_steps)],
checkpoint_dir=FLAGS.checkpoint_dir,
save_checkpoint_secs=600) as sess:
while not sess.should_stop():
_, step, loss = sess.run([train_op, global_step, loss_op])
if step % 1000 == 0:
logging.info('Step %d: Loss %g', step, loss)
logging.info('Training complete.')
def evaluator(model, params):
"""Run a model rollout trajectory."""
ds = dataset.load_dataset(FLAGS.dataset_dir, FLAGS.rollout_split)
ds = dataset.add_targets(ds, [params['field']], add_history=params['history'])
inputs = tf.data.make_one_shot_iterator(ds).get_next()
scalar_op, traj_ops = params['evaluator'].evaluate(model, inputs)
tf.train.create_global_step()
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.checkpoint_dir,
save_checkpoint_secs=None,
save_checkpoint_steps=None) as sess:
trajectories = []
scalars = []
for traj_idx in range(FLAGS.num_rollouts):
logging.info('Rollout trajectory %d', traj_idx)
scalar_data, traj_data = sess.run([scalar_op, traj_ops])
trajectories.append(traj_data)
scalars.append(scalar_data)
for key in scalars[0]:
logging.info('%s: %g', key, np.mean([x[key] for x in scalars]))
with open(FLAGS.rollout_path, 'wb') as fp:
pickle.dump(trajectories, fp)
def main(argv):
del argv
tf.enable_resource_variables()
tf.disable_eager_execution()
params = PARAMETERS[FLAGS.model]
learned_model = core_model.EncodeProcessDecode(
output_size=params['size'],
latent_size=128,
num_layers=2,
message_passing_steps=15)
model = params['model'].Model(learned_model)
if FLAGS.mode == 'train':
learner(model, params)
elif FLAGS.mode == 'eval':
evaluator(model, params)
if __name__ == '__main__':
app.run(main)