diff --git a/.travis.yml b/.travis.yml index b2fd28e..bc900d5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -25,6 +25,7 @@ env: - PROJECT="tvt" # - PROJECT="unrestricted_advx" # TODO(b/184862249): Fix and enable - PROJECT="unsupervised_adversarial_training" + - PROJECT="meshgraphnets" before_script: - pip install --upgrade pip - pip install --upgrade virtualenv diff --git a/README.md b/README.md index d0086f3..8f6c12a 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ https://deepmind.com/research/publications/ ## Projects +* [Learning Mesh-Based Simulation with Graph Networks](meshgraphnets), ICLR 2021 * [Open Graph Benchmark - Large-Scale Challenge (OGB-LSC)](ogb_lsc) * [Synthetic Returns for Long-Term Credit Assignment](synthetic_returns) * [A Deep Learning Approach for Characterizing Major Galaxy Mergers](galaxy_mergers) diff --git a/meshgraphnets/README.md b/meshgraphnets/README.md new file mode 100644 index 0000000..2806dc4 --- /dev/null +++ b/meshgraphnets/README.md @@ -0,0 +1,67 @@ +# Learning Mesh-Based Simulation with Graph Networks (ICLR 2021) + +Video site: [sites.google.com/view/meshgraphnets](https://sites.google.com/view/meshgraphnets) + +Paper: [arxiv.org/abs/2010.03409](https://arxiv.org/abs/2010.03409) + +If you use the code here please cite this paper: + + @inproceedings{pfaff2021learning, + title={Learning Mesh-Based Simulation with Graph Networks}, + author={Tobias Pfaff and + Meire Fortunato and + Alvaro Sanchez-Gonzalez and + Peter W. Battaglia}, + booktitle={International Conference on Learning Representations}, + year={2021} + } + +## Setup + +Prepare environment, install dependencies: + + virtualenv --python=python3.6 "${ENV}" + ${ENV}/bin/activate + pip install -r meshgraphnets/requirements.txt + +Download a dataset: + + mkdir -p ${DATA} + bash meshgraphnets/download_dataset.sh flag_simple ${DATA} + +## Running the model + +Train a model: + + python -m meshgraphnets.run_model --mode=train --model=cloth \ + --checkpoint_dir=${DATA}/chk --dataset_dir=${DATA}/flag_simple + +Generate some trajectory rollouts: + + python -m meshgraphnets.run_model --mode=eval --model=cloth \ + --checkpoint_dir=${DATA}/chk --dataset_dir=${DATA}/flag_simple \ + --rollout_path=${DATA}/rollout_flag.pkl + +Plot a trajectory: + + python -m meshgraphnets.plot_cloth --rollout_path=${DATA}/rollout_flag.pkl + +## Datasets + +Datasets can be downloaded using the script `download_dataset.sh`. They contain +a metadata file describing the available fields and their shape, and tfrecord +datasets for train, valid and test splits. +Dataset names match the naming in the paper. +The following datasets are available: + + airfoil + cylinder_flow + deforming_plate + flag_minimal + flag_simple + flag_dynamic + sphere_simple + sphere_dynamic + +`flag_minimal` is a truncated version of flag_simple, and is only used for +integration tests. diff --git a/meshgraphnets/cfd_eval.py b/meshgraphnets/cfd_eval.py new file mode 100644 index 0000000..e1449b1 --- /dev/null +++ b/meshgraphnets/cfd_eval.py @@ -0,0 +1,62 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Functions to build evaluation metrics for CFD data.""" + +import tensorflow.compat.v1 as tf + +from meshgraphnets.common import NodeType + + +def _rollout(model, initial_state, num_steps): + """Rolls out a model trajectory.""" + node_type = initial_state['node_type'][:, 0] + mask = tf.logical_or(tf.equal(node_type, NodeType.NORMAL), + tf.equal(node_type, NodeType.OUTFLOW)) + + def step_fn(step, velocity, trajectory): + prediction = model({**initial_state, + 'velocity': velocity}) + # don't update boundary nodes + next_velocity = tf.where(mask, prediction, velocity) + trajectory = trajectory.write(step, velocity) + return step+1, next_velocity, trajectory + + _, _, output = tf.while_loop( + cond=lambda step, cur, traj: tf.less(step, num_steps), + body=step_fn, + loop_vars=(0, initial_state['velocity'], + tf.TensorArray(tf.float32, num_steps)), + parallel_iterations=1) + return output.stack() + + +def evaluate(model, inputs): + """Performs model rollouts and create stats.""" + initial_state = {k: v[0] for k, v in inputs.items()} + num_steps = inputs['cells'].shape[0] + prediction = _rollout(model, initial_state, num_steps) + + error = tf.reduce_mean((prediction - inputs['velocity'])**2, axis=-1) + scalars = {'mse_%d_steps' % horizon: tf.reduce_mean(error[1:horizon+1]) + for horizon in [1, 10, 20, 50, 100, 200]} + traj_ops = { + 'faces': inputs['cells'], + 'mesh_pos': inputs['mesh_pos'], + 'gt_velocity': inputs['velocity'], + 'pred_velocity': prediction + } + return scalars, traj_ops diff --git a/meshgraphnets/cfd_model.py b/meshgraphnets/cfd_model.py new file mode 100644 index 0000000..8df66d7 --- /dev/null +++ b/meshgraphnets/cfd_model.py @@ -0,0 +1,94 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Model for CylinderFlow.""" + +import sonnet as snt +import tensorflow.compat.v1 as tf + +from meshgraphnets import common +from meshgraphnets import core_model +from meshgraphnets import normalization + + +class Model(snt.AbstractModule): + """Model for fluid simulation.""" + + def __init__(self, learned_model, name='Model'): + super(Model, self).__init__(name=name) + with self._enter_variable_scope(): + self._learned_model = learned_model + self._output_normalizer = normalization.Normalizer( + size=2, name='output_normalizer') + self._node_normalizer = normalization.Normalizer( + size=2+common.NodeType.SIZE, name='node_normalizer') + self._edge_normalizer = normalization.Normalizer( + size=3, name='edge_normalizer') # 2D coord + length + + def _build_graph(self, inputs, is_training): + """Builds input graph.""" + # construct graph nodes + node_type = tf.one_hot(inputs['node_type'][:, 0], common.NodeType.SIZE) + node_features = tf.concat([inputs['velocity'], node_type], axis=-1) + + # construct graph edges + senders, receivers = common.triangles_to_edges(inputs['cells']) + relative_mesh_pos = (tf.gather(inputs['mesh_pos'], senders) - + tf.gather(inputs['mesh_pos'], receivers)) + edge_features = tf.concat([ + relative_mesh_pos, + tf.norm(relative_mesh_pos, axis=-1, keepdims=True)], axis=-1) + + mesh_edges = core_model.EdgeSet( + name='mesh_edges', + features=self._edge_normalizer(edge_features, is_training), + receivers=receivers, + senders=senders) + return core_model.MultiGraph( + node_features=self._node_normalizer(node_features, is_training), + edge_sets=[mesh_edges]) + + def _build(self, inputs): + graph = self._build_graph(inputs, is_training=False) + per_node_network_output = self._learned_model(graph) + return self._update(inputs, per_node_network_output) + + @snt.reuse_variables + def loss(self, inputs): + """L2 loss on velocity.""" + graph = self._build_graph(inputs, is_training=True) + network_output = self._learned_model(graph) + + # build target velocity change + cur_velocity = inputs['velocity'] + target_velocity = inputs['target|velocity'] + target_velocity_change = target_velocity - cur_velocity + target_normalized = self._output_normalizer(target_velocity_change) + + # build loss + node_type = inputs['node_type'][:, 0] + loss_mask = tf.logical_or(tf.equal(node_type, common.NodeType.NORMAL), + tf.equal(node_type, common.NodeType.OUTFLOW)) + error = tf.reduce_sum((target_normalized - network_output)**2, axis=1) + loss = tf.reduce_mean(error[loss_mask]) + return loss + + def _update(self, inputs, per_node_network_output): + """Integrate model outputs.""" + velocity_update = self._output_normalizer.inverse(per_node_network_output) + # integrate forward + cur_velocity = inputs['velocity'] + return cur_velocity + velocity_update diff --git a/meshgraphnets/cloth_eval.py b/meshgraphnets/cloth_eval.py new file mode 100644 index 0000000..3fef858 --- /dev/null +++ b/meshgraphnets/cloth_eval.py @@ -0,0 +1,61 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Functions to build evaluation metrics for cloth data.""" + +import tensorflow.compat.v1 as tf + +from meshgraphnets.common import NodeType + + +def _rollout(model, initial_state, num_steps): + """Rolls out a model trajectory.""" + mask = tf.equal(initial_state['node_type'][:, 0], NodeType.NORMAL) + + def step_fn(step, prev_pos, cur_pos, trajectory): + prediction = model({**initial_state, + 'prev|world_pos': prev_pos, + 'world_pos': cur_pos}) + # don't update kinematic nodes + next_pos = tf.where(mask, prediction, cur_pos) + trajectory = trajectory.write(step, cur_pos) + return step+1, cur_pos, next_pos, trajectory + + _, _, _, output = tf.while_loop( + cond=lambda step, last, cur, traj: tf.less(step, num_steps), + body=step_fn, + loop_vars=(0, initial_state['prev|world_pos'], initial_state['world_pos'], + tf.TensorArray(tf.float32, num_steps)), + parallel_iterations=1) + return output.stack() + + +def evaluate(model, inputs): + """Performs model rollouts and create stats.""" + initial_state = {k: v[0] for k, v in inputs.items()} + num_steps = inputs['cells'].shape[0] + prediction = _rollout(model, initial_state, num_steps) + + error = tf.reduce_mean((prediction - inputs['world_pos'])**2, axis=-1) + scalars = {'mse_%d_steps' % horizon: tf.reduce_mean(error[1:horizon+1]) + for horizon in [1, 10, 20, 50, 100, 200]} + traj_ops = { + 'faces': inputs['cells'], + 'mesh_pos': inputs['mesh_pos'], + 'gt_pos': inputs['world_pos'], + 'pred_pos': prediction + } + return scalars, traj_ops diff --git a/meshgraphnets/cloth_model.py b/meshgraphnets/cloth_model.py new file mode 100644 index 0000000..d1f2fe4 --- /dev/null +++ b/meshgraphnets/cloth_model.py @@ -0,0 +1,100 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Model for FlagSimple.""" + +import sonnet as snt +import tensorflow.compat.v1 as tf + +from meshgraphnets import common +from meshgraphnets import core_model +from meshgraphnets import normalization + + +class Model(snt.AbstractModule): + """Model for static cloth simulation.""" + + def __init__(self, learned_model, name='Model'): + super(Model, self).__init__(name=name) + with self._enter_variable_scope(): + self._learned_model = learned_model + self._output_normalizer = normalization.Normalizer( + size=3, name='output_normalizer') + self._node_normalizer = normalization.Normalizer( + size=3+common.NodeType.SIZE, name='node_normalizer') + self._edge_normalizer = normalization.Normalizer( + size=7, name='edge_normalizer') # 2D coord + 3D coord + 2*length = 7 + + def _build_graph(self, inputs, is_training): + """Builds input graph.""" + # construct graph nodes + velocity = inputs['world_pos'] - inputs['prev|world_pos'] + node_type = tf.one_hot(inputs['node_type'][:, 0], common.NodeType.SIZE) + node_features = tf.concat([velocity, node_type], axis=-1) + + # construct graph edges + senders, receivers = common.triangles_to_edges(inputs['cells']) + relative_world_pos = (tf.gather(inputs['world_pos'], senders) - + tf.gather(inputs['world_pos'], receivers)) + relative_mesh_pos = (tf.gather(inputs['mesh_pos'], senders) - + tf.gather(inputs['mesh_pos'], receivers)) + edge_features = tf.concat([ + relative_world_pos, + tf.norm(relative_world_pos, axis=-1, keepdims=True), + relative_mesh_pos, + tf.norm(relative_mesh_pos, axis=-1, keepdims=True)], axis=-1) + + mesh_edges = core_model.EdgeSet( + name='mesh_edges', + features=self._edge_normalizer(edge_features, is_training), + receivers=receivers, + senders=senders) + return core_model.MultiGraph( + node_features=self._node_normalizer(node_features, is_training), + edge_sets=[mesh_edges]) + + def _build(self, inputs): + graph = self._build_graph(inputs, is_training=False) + per_node_network_output = self._learned_model(graph) + return self._update(inputs, per_node_network_output) + + @snt.reuse_variables + def loss(self, inputs): + """L2 loss on position.""" + graph = self._build_graph(inputs, is_training=True) + network_output = self._learned_model(graph) + + # build target acceleration + cur_position = inputs['world_pos'] + prev_position = inputs['prev|world_pos'] + target_position = inputs['target|world_pos'] + target_acceleration = target_position - 2*cur_position + prev_position + target_normalized = self._output_normalizer(target_acceleration) + + # build loss + loss_mask = tf.equal(inputs['node_type'][:, 0], common.NodeType.NORMAL) + error = tf.reduce_sum((target_normalized - network_output)**2, axis=1) + loss = tf.reduce_mean(error[loss_mask]) + return loss + + def _update(self, inputs, per_node_network_output): + """Integrate model outputs.""" + acceleration = self._output_normalizer.inverse(per_node_network_output) + # integrate forward + cur_position = inputs['world_pos'] + prev_position = inputs['prev|world_pos'] + position = 2*cur_position + acceleration - prev_position + return position diff --git a/meshgraphnets/common.py b/meshgraphnets/common.py new file mode 100644 index 0000000..96c9855 --- /dev/null +++ b/meshgraphnets/common.py @@ -0,0 +1,51 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Commonly used data structures and functions.""" + +import enum +import tensorflow.compat.v1 as tf + + +class NodeType(enum.IntEnum): + NORMAL = 0 + OBSTACLE = 1 + AIRFOIL = 2 + HANDLE = 3 + INFLOW = 4 + OUTFLOW = 5 + WALL_BOUNDARY = 6 + SIZE = 9 + + +def triangles_to_edges(faces): + """Computes mesh edges from triangles.""" + # collect edges from triangles + edges = tf.concat([faces[:, 0:2], + faces[:, 1:3], + tf.stack([faces[:, 2], faces[:, 0]], axis=1)], axis=0) + # those edges are sometimes duplicated (within the mesh) and sometimes + # single (at the mesh boundary). + # sort & pack edges as single tf.int64 + receivers = tf.reduce_min(edges, axis=1) + senders = tf.reduce_max(edges, axis=1) + packed_edges = tf.bitcast(tf.stack([senders, receivers], axis=1), tf.int64) + # remove duplicates and unpack + unique_edges = tf.bitcast(tf.unique(packed_edges)[0], tf.int32) + senders, receivers = tf.unstack(unique_edges, axis=1) + # create two-way connectivity + return (tf.concat([senders, receivers], axis=0), + tf.concat([receivers, senders], axis=0)) diff --git a/meshgraphnets/core_model.py b/meshgraphnets/core_model.py new file mode 100644 index 0000000..5d04a66 --- /dev/null +++ b/meshgraphnets/core_model.py @@ -0,0 +1,121 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Core learned graph net model.""" + +import collections +import functools +import sonnet as snt +import tensorflow.compat.v1 as tf + +EdgeSet = collections.namedtuple('EdgeSet', ['name', 'features', 'senders', + 'receivers']) +MultiGraph = collections.namedtuple('Graph', ['node_features', 'edge_sets']) + + +class GraphNetBlock(snt.AbstractModule): + """Multi-Edge Interaction Network with residual connections.""" + + def __init__(self, model_fn, name='GraphNetBlock'): + super(GraphNetBlock, self).__init__(name=name) + self._model_fn = model_fn + + def _update_edge_features(self, node_features, edge_set): + """Aggregrates node features, and applies edge function.""" + sender_features = tf.gather(node_features, edge_set.senders) + receiver_features = tf.gather(node_features, edge_set.receivers) + features = [sender_features, receiver_features, edge_set.features] + with tf.variable_scope(edge_set.name+'_edge_fn'): + return self._model_fn()(tf.concat(features, axis=-1)) + + def _update_node_features(self, node_features, edge_sets): + """Aggregrates edge features, and applies node function.""" + num_nodes = tf.shape(node_features)[0] + features = [node_features] + for edge_set in edge_sets: + features.append(tf.math.unsorted_segment_sum(edge_set.features, + edge_set.receivers, + num_nodes)) + with tf.variable_scope('node_fn'): + return self._model_fn()(tf.concat(features, axis=-1)) + + def _build(self, graph): + """Applies GraphNetBlock and returns updated MultiGraph.""" + + # apply edge functions + new_edge_sets = [] + for edge_set in graph.edge_sets: + updated_features = self._update_edge_features(graph.node_features, + edge_set) + new_edge_sets.append(edge_set._replace(features=updated_features)) + + # apply node function + new_node_features = self._update_node_features(graph.node_features, + new_edge_sets) + + # add residual connections + new_node_features += graph.node_features + new_edge_sets = [es._replace(features=es.features + old_es.features) + for es, old_es in zip(new_edge_sets, graph.edge_sets)] + return MultiGraph(new_node_features, new_edge_sets) + + +class EncodeProcessDecode(snt.AbstractModule): + """Encode-Process-Decode GraphNet model.""" + + def __init__(self, + output_size, + latent_size, + num_layers, + message_passing_steps, + name='EncodeProcessDecode'): + super(EncodeProcessDecode, self).__init__(name=name) + self._latent_size = latent_size + self._output_size = output_size + self._num_layers = num_layers + self._message_passing_steps = message_passing_steps + + def _make_mlp(self, output_size, layer_norm=True): + """Builds an MLP.""" + widths = [self._latent_size] * self._num_layers + [output_size] + network = snt.nets.MLP(widths, activate_final=False) + if layer_norm: + network = snt.Sequential([network, snt.LayerNorm()]) + return network + + def _encoder(self, graph): + """Encodes node and edge features into latent features.""" + with tf.variable_scope('encoder'): + node_latents = self._make_mlp(self._latent_size)(graph.node_features) + new_edges_sets = [] + for edge_set in graph.edge_sets: + latent = self._make_mlp(self._latent_size)(edge_set.features) + new_edges_sets.append(edge_set._replace(features=latent)) + return MultiGraph(node_latents, new_edges_sets) + + def _decoder(self, graph): + """Decodes node features from graph.""" + with tf.variable_scope('decoder'): + decoder = self._make_mlp(self._output_size, layer_norm=False) + return decoder(graph.node_features) + + def _build(self, graph): + """Encodes and processes a multigraph, and returns node features.""" + model_fn = functools.partial(self._make_mlp, output_size=self._latent_size) + latent_graph = self._encoder(graph) + for _ in range(self._message_passing_steps): + latent_graph = GraphNetBlock(model_fn)(latent_graph) + return self._decoder(latent_graph) diff --git a/meshgraphnets/dataset.py b/meshgraphnets/dataset.py new file mode 100644 index 0000000..465d3f8 --- /dev/null +++ b/meshgraphnets/dataset.py @@ -0,0 +1,119 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utility functions for reading the datasets.""" + +import functools +import json +import os + +import tensorflow.compat.v1 as tf + +from meshgraphnets.common import NodeType + + +def _parse(proto, meta): + """Parses a trajectory from tf.Example.""" + feature_lists = {k: tf.io.VarLenFeature(tf.string) + for k in meta['field_names']} + features = tf.io.parse_single_example(proto, feature_lists) + out = {} + for key, field in meta['features'].items(): + data = tf.io.decode_raw(features[key].values, getattr(tf, field['dtype'])) + data = tf.reshape(data, field['shape']) + if field['type'] == 'static': + data = tf.tile(data, [meta['trajectory_length'], 1, 1]) + elif field['type'] == 'dynamic_varlen': + length = tf.io.decode_raw(features['length_'+key].values, tf.int32) + length = tf.reshape(length, [-1]) + data = tf.RaggedTensor.from_row_lengths(data, row_lengths=length) + elif field['type'] != 'dynamic': + raise ValueError('invalid data format') + out[key] = data + return out + + +def load_dataset(path, split): + """Load dataset.""" + with open(os.path.join(path, 'meta.json'), 'r') as fp: + meta = json.loads(fp.read()) + ds = tf.data.TFRecordDataset(os.path.join(path, split+'.tfrecord')) + ds = ds.map(functools.partial(_parse, meta=meta), num_parallel_calls=8) + ds = ds.prefetch(1) + return ds + + +def add_targets(ds, fields, add_history): + """Adds target and optionally history fields to dataframe.""" + def fn(trajectory): + out = {} + for key, val in trajectory.items(): + out[key] = val[1:-1] + if key in fields: + if add_history: + out['prev|'+key] = val[0:-2] + out['target|'+key] = val[2:] + return out + return ds.map(fn, num_parallel_calls=8) + + +def split_and_preprocess(ds, noise_field, noise_scale, noise_gamma): + """Splits trajectories into frames, and adds training noise.""" + def add_noise(frame): + noise = tf.random.normal(tf.shape(frame[noise_field]), + stddev=noise_scale, dtype=tf.float32) + # don't apply noise to boundary nodes + mask = tf.equal(frame['node_type'], NodeType.NORMAL)[:, 0] + noise = tf.where(mask, noise, tf.zeros_like(noise)) + frame[noise_field] += noise + frame['target|'+noise_field] += (1.0 - noise_gamma) * noise + return frame + + ds = ds.flat_map(tf.data.Dataset.from_tensor_slices) + ds = ds.map(add_noise, num_parallel_calls=8) + ds = ds.shuffle(10000) + ds = ds.repeat(None) + return ds.prefetch(10) + + +def batch_dataset(ds, batch_size): + """Batches input datasets.""" + shapes = ds.output_shapes + types = ds.output_types + def renumber(buffer, frame): + nodes, cells = buffer + new_nodes, new_cells = frame + return nodes + new_nodes, tf.concat([cells, new_cells+nodes], axis=0) + + def batch_accumulate(ds_window): + out = {} + for key, ds_val in ds_window.items(): + initial = tf.zeros((0, shapes[key][1]), dtype=types[key]) + if key == 'cells': + # renumber node indices in cells + num_nodes = ds_window['node_type'].map(lambda x: tf.shape(x)[0]) + cells = tf.data.Dataset.zip((num_nodes, ds_val)) + initial = (tf.constant(0, tf.int32), initial) + _, out[key] = cells.reduce(initial, renumber) + else: + merge = lambda prev, cur: tf.concat([prev, cur], axis=0) + out[key] = ds_val.reduce(initial, merge) + return out + + if batch_size > 1: + ds = ds.window(batch_size, drop_remainder=True) + ds = ds.map(batch_accumulate, num_parallel_calls=8) + return ds diff --git a/meshgraphnets/download_dataset.sh b/meshgraphnets/download_dataset.sh new file mode 100755 index 0000000..ca4a826 --- /dev/null +++ b/meshgraphnets/download_dataset.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Usage: +# sh download_dataset.sh ${DATASET_NAME} ${OUTPUT_DIR} +# Example: +# sh download_dataset.sh flag_simple /tmp/ + +set -e + +DATASET_NAME="${1}" +OUTPUT_DIR="${2}/${DATASET_NAME}" + +BASE_URL="https://storage.googleapis.com/dm-meshgraphnets/${DATASET_NAME}/" + +mkdir -p ${OUTPUT_DIR} +for file in meta.json train.tfrecord valid.tfrecord test.tfrecord +do +wget -O "${OUTPUT_DIR}/${file}" "${BASE_URL}${file}" +done diff --git a/meshgraphnets/normalization.py b/meshgraphnets/normalization.py new file mode 100644 index 0000000..a490f4b --- /dev/null +++ b/meshgraphnets/normalization.py @@ -0,0 +1,73 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Online data normalization.""" + +import sonnet as snt +import tensorflow.compat.v1 as tf + + +class Normalizer(snt.AbstractModule): + """Feature normalizer that accumulates statistics online.""" + + def __init__(self, size, max_accumulations=10**6, std_epsilon=1e-8, + name='Normalizer'): + super(Normalizer, self).__init__(name=name) + self._max_accumulations = max_accumulations + self._std_epsilon = std_epsilon + with self._enter_variable_scope(): + self._acc_count = tf.Variable(0, dtype=tf.float32, trainable=False) + self._num_accumulations = tf.Variable(0, dtype=tf.float32, + trainable=False) + self._acc_sum = tf.Variable(tf.zeros(size, tf.float32), trainable=False) + self._acc_sum_squared = tf.Variable(tf.zeros(size, tf.float32), + trainable=False) + + def _build(self, batched_data, accumulate=True): + """Normalizes input data and accumulates statistics.""" + update_op = tf.no_op() + if accumulate: + # stop accumulating after a million updates, to prevent accuracy issues + update_op = tf.cond(self._num_accumulations < self._max_accumulations, + lambda: self._accumulate(batched_data), + tf.no_op) + with tf.control_dependencies([update_op]): + return (batched_data - self._mean()) / self._std_with_epsilon() + + @snt.reuse_variables + def inverse(self, normalized_batch_data): + """Inverse transformation of the normalizer.""" + return normalized_batch_data * self._std_with_epsilon() + self._mean() + + def _accumulate(self, batched_data): + """Function to perform the accumulation of the batch_data statistics.""" + count = tf.cast(tf.shape(batched_data)[0], tf.float32) + data_sum = tf.reduce_sum(batched_data, axis=0) + squared_data_sum = tf.reduce_sum(batched_data**2, axis=0) + return tf.group( + tf.assign_add(self._acc_sum, data_sum), + tf.assign_add(self._acc_sum_squared, squared_data_sum), + tf.assign_add(self._acc_count, count), + tf.assign_add(self._num_accumulations, 1.)) + + def _mean(self): + safe_count = tf.maximum(self._acc_count, 1.) + return self._acc_sum / safe_count + + def _std_with_epsilon(self): + safe_count = tf.maximum(self._acc_count, 1.) + std = tf.sqrt(self._acc_sum_squared / safe_count - self._mean()**2) + return tf.math.maximum(std, self._std_epsilon) diff --git a/meshgraphnets/plot_cfd.py b/meshgraphnets/plot_cfd.py new file mode 100644 index 0000000..43e761a --- /dev/null +++ b/meshgraphnets/plot_cfd.py @@ -0,0 +1,68 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Plots a CFD trajectory rollout.""" + +import pickle + +from absl import app +from absl import flags +from matplotlib import animation +from matplotlib import tri as mtri +import matplotlib.pyplot as plt + +FLAGS = flags.FLAGS +flags.DEFINE_string('rollout_path', None, 'Path to rollout pickle file') + + +def main(unused_argv): + with open(FLAGS.rollout_path, 'rb') as fp: + rollout_data = pickle.load(fp) + + fig, ax = plt.subplots(1, 1, figsize=(12, 8)) + skip = 10 + num_steps = rollout_data[0]['gt_velocity'].shape[0] + num_frames = len(rollout_data) * num_steps // skip + + # compute bounds + bounds = [] + for trajectory in rollout_data: + bb_min = trajectory['gt_velocity'].min(axis=(0, 1)) + bb_max = trajectory['gt_velocity'].max(axis=(0, 1)) + bounds.append((bb_min, bb_max)) + + def animate(num): + step = (num*skip) % num_steps + traj = (num*skip) // num_steps + ax.cla() + ax.set_aspect('equal') + ax.set_axis_off() + vmin, vmax = bounds[traj] + pos = rollout_data[traj]['mesh_pos'][step] + faces = rollout_data[traj]['faces'][step] + velocity = rollout_data[traj]['pred_velocity'][step] + triang = mtri.Triangulation(pos[:, 0], pos[:, 1], faces) + ax.tripcolor(triang, velocity[:, 0], vmin=vmin[0], vmax=vmax[0]) + ax.triplot(triang, 'ko-', ms=0.5, lw=0.3) + ax.set_title('Trajectory %d Step %d' % (traj, step)) + return fig, + + _ = animation.FuncAnimation(fig, animate, frames=num_frames, interval=100) + plt.show(block=True) + + +if __name__ == '__main__': + app.run(main) diff --git a/meshgraphnets/plot_cloth.py b/meshgraphnets/plot_cloth.py new file mode 100644 index 0000000..ed16983 --- /dev/null +++ b/meshgraphnets/plot_cloth.py @@ -0,0 +1,67 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Plots a cloth trajectory rollout.""" + +import pickle + +from absl import app +from absl import flags + +from matplotlib import animation +import matplotlib.pyplot as plt + +FLAGS = flags.FLAGS +flags.DEFINE_string('rollout_path', None, 'Path to rollout pickle file') + + +def main(unused_argv): + with open(FLAGS.rollout_path, 'rb') as fp: + rollout_data = pickle.load(fp) + + fig = plt.figure(figsize=(8, 8)) + ax = fig.add_subplot(111, projection='3d') + skip = 10 + num_steps = rollout_data[0]['gt_pos'].shape[0] + num_frames = len(rollout_data) * num_steps // skip + + # compute bounds + bounds = [] + for trajectory in rollout_data: + bb_min = trajectory['gt_pos'].min(axis=(0, 1)) + bb_max = trajectory['gt_pos'].max(axis=(0, 1)) + bounds.append((bb_min, bb_max)) + + def animate(num): + step = (num*skip) % num_steps + traj = (num*skip) // num_steps + ax.cla() + bound = bounds[traj] + ax.set_xlim([bound[0][0], bound[1][0]]) + ax.set_ylim([bound[0][1], bound[1][1]]) + ax.set_zlim([bound[0][2], bound[1][2]]) + pos = rollout_data[traj]['pred_pos'][step] + faces = rollout_data[traj]['faces'][step] + ax.plot_trisurf(pos[:, 0], pos[:, 1], faces, pos[:, 2], shade=True) + ax.set_title('Trajectory %d Step %d' % (traj, step)) + return fig, + + _ = animation.FuncAnimation(fig, animate, frames=num_frames, interval=100) + plt.show(block=True) + + +if __name__ == '__main__': + app.run(main) diff --git a/meshgraphnets/requirements.txt b/meshgraphnets/requirements.txt new file mode 100644 index 0000000..efec3cb --- /dev/null +++ b/meshgraphnets/requirements.txt @@ -0,0 +1,5 @@ +tensorflow-gpu>=1.15,<2 +dm-sonnet<2 +matplotlib +absl-py +numpy diff --git a/meshgraphnets/run.sh b/meshgraphnets/run.sh new file mode 100644 index 0000000..a520a49 --- /dev/null +++ b/meshgraphnets/run.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Fail on any error. +set -e + +# Display commands being run. +set -x + +TMP_DIR=`mktemp -d` + +virtualenv --python=python3.6 "${TMP_DIR}/env" +source "${TMP_DIR}/env/bin/activate" + +# Install dependencies. +pip install --upgrade -r meshgraphnets/requirements.txt + +# Download minimal dataset +DATA_DIR="${TMP_DIR}/flag_minimal" +bash meshgraphnets/download_dataset.sh flag_minimal ${TMP_DIR} + +# Train for a few steps. +CHK_DIR="${TMP_DIR}/checkpoint" +python -m meshgraphnets.run_model --model=cloth --mode=train --checkpoint_dir=${CHK_DIR} --dataset_dir=${DATA_DIR} --num_training_steps=10 + +# Generate a rollout trajectory +ROLLOUT_PATH="${TMP_DIR}/rollout.pkl" +python -m meshgraphnets.run_model --model=cloth --mode=eval --checkpoint_dir=${CHK_DIR} --dataset_dir=${DATA_DIR} --rollout_path=${ROLLOUT_PATH} --num_rollouts=1 + +# Plot the rollout trajectory +python -m meshgraphnets.plot_cloth --rollout_path=${ROLLOUT_PATH} + +# Clean up. +rm -r ${TMP_DIR} +echo "Test run complete." diff --git a/meshgraphnets/run_model.py b/meshgraphnets/run_model.py new file mode 100644 index 0000000..4d31b19 --- /dev/null +++ b/meshgraphnets/run_model.py @@ -0,0 +1,131 @@ +# Lint as: python3 +# pylint: disable=g-bad-file-header +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Runs the learner/evaluator.""" + +import pickle +from absl import app +from absl import flags +from absl import logging +import numpy as np +import tensorflow.compat.v1 as tf +from meshgraphnets import cfd_eval +from meshgraphnets import cfd_model +from meshgraphnets import cloth_eval +from meshgraphnets import cloth_model +from meshgraphnets import core_model +from meshgraphnets import dataset + + +FLAGS = flags.FLAGS +flags.DEFINE_enum('mode', 'train', ['train', 'eval'], + 'Train model, or run evaluation.') +flags.DEFINE_enum('model', None, ['cfd', 'cloth'], + 'Select model to run.') +flags.DEFINE_string('checkpoint_dir', None, 'Directory to save checkpoint') +flags.DEFINE_string('dataset_dir', None, 'Directory to load dataset from.') +flags.DEFINE_string('rollout_path', None, + 'Pickle file to save eval trajectories') +flags.DEFINE_enum('rollout_split', 'valid', ['train', 'test', 'valid'], + 'Dataset split to use for rollouts.') +flags.DEFINE_integer('num_rollouts', 10, 'No. of rollout trajectories') +flags.DEFINE_integer('num_training_steps', int(10e6), 'No. of training steps') + +PARAMETERS = { + 'cfd': dict(noise=0.02, gamma=1.0, field='velocity', history=False, + size=2, batch=2, model=cfd_model, evaluator=cfd_eval), + 'cloth': dict(noise=0.003, gamma=0.1, field='world_pos', history=True, + size=3, batch=1, model=cloth_model, evaluator=cloth_eval) +} + + +def learner(model, params): + """Run a learner job.""" + ds = dataset.load_dataset(FLAGS.dataset_dir, 'train') + ds = dataset.add_targets(ds, [params['field']], add_history=params['history']) + ds = dataset.split_and_preprocess(ds, noise_field=params['field'], + noise_scale=params['noise'], + noise_gamma=params['gamma']) + inputs = tf.data.make_one_shot_iterator(ds).get_next() + + loss_op = model.loss(inputs) + global_step = tf.train.create_global_step() + lr = tf.train.exponential_decay(learning_rate=1e-4, + global_step=global_step, + decay_steps=int(5e6), + decay_rate=0.1) + 1e-6 + optimizer = tf.train.AdamOptimizer(learning_rate=lr) + train_op = optimizer.minimize(loss_op, global_step=global_step) + # Don't train for the first few steps, just accumulate normalization stats + train_op = tf.cond(tf.less(global_step, 1000), + lambda: tf.group(tf.assign_add(global_step, 1)), + lambda: tf.group(train_op)) + + with tf.train.MonitoredTrainingSession( + hooks=[tf.train.StopAtStepHook(last_step=FLAGS.num_training_steps)], + checkpoint_dir=FLAGS.checkpoint_dir, + save_checkpoint_secs=600) as sess: + + while not sess.should_stop(): + _, step, loss = sess.run([train_op, global_step, loss_op]) + if step % 1000 == 0: + logging.info('Step %d: Loss %g', step, loss) + logging.info('Training complete.') + + +def evaluator(model, params): + """Run a model rollout trajectory.""" + ds = dataset.load_dataset(FLAGS.dataset_dir, FLAGS.rollout_split) + ds = dataset.add_targets(ds, [params['field']], add_history=params['history']) + inputs = tf.data.make_one_shot_iterator(ds).get_next() + scalar_op, traj_ops = params['evaluator'].evaluate(model, inputs) + tf.train.create_global_step() + + with tf.train.MonitoredTrainingSession( + checkpoint_dir=FLAGS.checkpoint_dir, + save_checkpoint_secs=None, + save_checkpoint_steps=None) as sess: + trajectories = [] + scalars = [] + for traj_idx in range(FLAGS.num_rollouts): + logging.info('Rollout trajectory %d', traj_idx) + scalar_data, traj_data = sess.run([scalar_op, traj_ops]) + trajectories.append(traj_data) + scalars.append(scalar_data) + for key in scalars[0]: + logging.info('%s: %g', key, np.mean([x[key] for x in scalars])) + with open(FLAGS.rollout_path, 'wb') as fp: + pickle.dump(trajectories, fp) + + +def main(argv): + del argv + tf.enable_resource_variables() + tf.disable_eager_execution() + params = PARAMETERS[FLAGS.model] + learned_model = core_model.EncodeProcessDecode( + output_size=params['size'], + latent_size=128, + num_layers=2, + message_passing_steps=15) + model = params['model'].Model(learned_model) + if FLAGS.mode == 'train': + learner(model, params) + elif FLAGS.mode == 'eval': + evaluator(model, params) + +if __name__ == '__main__': + app.run(main)