mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-19 02:24:20 +08:00
Initial release of Learning Mesh-Based Simulation with Graph Networks
PiperOrigin-RevId: 380185977
This commit is contained in:
@@ -25,6 +25,7 @@ env:
|
||||
- PROJECT="tvt"
|
||||
# - PROJECT="unrestricted_advx" # TODO(b/184862249): Fix and enable
|
||||
- PROJECT="unsupervised_adversarial_training"
|
||||
- PROJECT="meshgraphnets"
|
||||
before_script:
|
||||
- pip install --upgrade pip
|
||||
- pip install --upgrade virtualenv
|
||||
|
||||
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
|
||||
|
||||
## Projects
|
||||
|
||||
* [Learning Mesh-Based Simulation with Graph Networks](meshgraphnets), ICLR 2021
|
||||
* [Open Graph Benchmark - Large-Scale Challenge (OGB-LSC)](ogb_lsc)
|
||||
* [Synthetic Returns for Long-Term Credit Assignment](synthetic_returns)
|
||||
* [A Deep Learning Approach for Characterizing Major Galaxy Mergers](galaxy_mergers)
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
# Learning Mesh-Based Simulation with Graph Networks (ICLR 2021)
|
||||
|
||||
Video site: [sites.google.com/view/meshgraphnets](https://sites.google.com/view/meshgraphnets)
|
||||
|
||||
Paper: [arxiv.org/abs/2010.03409](https://arxiv.org/abs/2010.03409)
|
||||
|
||||
If you use the code here please cite this paper:
|
||||
|
||||
@inproceedings{pfaff2021learning,
|
||||
title={Learning Mesh-Based Simulation with Graph Networks},
|
||||
author={Tobias Pfaff and
|
||||
Meire Fortunato and
|
||||
Alvaro Sanchez-Gonzalez and
|
||||
Peter W. Battaglia},
|
||||
booktitle={International Conference on Learning Representations},
|
||||
year={2021}
|
||||
}
|
||||
|
||||
## Setup
|
||||
|
||||
Prepare environment, install dependencies:
|
||||
|
||||
virtualenv --python=python3.6 "${ENV}"
|
||||
${ENV}/bin/activate
|
||||
pip install -r meshgraphnets/requirements.txt
|
||||
|
||||
Download a dataset:
|
||||
|
||||
mkdir -p ${DATA}
|
||||
bash meshgraphnets/download_dataset.sh flag_simple ${DATA}
|
||||
|
||||
## Running the model
|
||||
|
||||
Train a model:
|
||||
|
||||
python -m meshgraphnets.run_model --mode=train --model=cloth \
|
||||
--checkpoint_dir=${DATA}/chk --dataset_dir=${DATA}/flag_simple
|
||||
|
||||
Generate some trajectory rollouts:
|
||||
|
||||
python -m meshgraphnets.run_model --mode=eval --model=cloth \
|
||||
--checkpoint_dir=${DATA}/chk --dataset_dir=${DATA}/flag_simple \
|
||||
--rollout_path=${DATA}/rollout_flag.pkl
|
||||
|
||||
Plot a trajectory:
|
||||
|
||||
python -m meshgraphnets.plot_cloth --rollout_path=${DATA}/rollout_flag.pkl
|
||||
|
||||
## Datasets
|
||||
|
||||
Datasets can be downloaded using the script `download_dataset.sh`. They contain
|
||||
a metadata file describing the available fields and their shape, and tfrecord
|
||||
datasets for train, valid and test splits.
|
||||
Dataset names match the naming in the paper.
|
||||
The following datasets are available:
|
||||
|
||||
airfoil
|
||||
cylinder_flow
|
||||
deforming_plate
|
||||
flag_minimal
|
||||
flag_simple
|
||||
flag_dynamic
|
||||
sphere_simple
|
||||
sphere_dynamic
|
||||
|
||||
`flag_minimal` is a truncated version of flag_simple, and is only used for
|
||||
integration tests.
|
||||
@@ -0,0 +1,62 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Functions to build evaluation metrics for CFD data."""
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from meshgraphnets.common import NodeType
|
||||
|
||||
|
||||
def _rollout(model, initial_state, num_steps):
|
||||
"""Rolls out a model trajectory."""
|
||||
node_type = initial_state['node_type'][:, 0]
|
||||
mask = tf.logical_or(tf.equal(node_type, NodeType.NORMAL),
|
||||
tf.equal(node_type, NodeType.OUTFLOW))
|
||||
|
||||
def step_fn(step, velocity, trajectory):
|
||||
prediction = model({**initial_state,
|
||||
'velocity': velocity})
|
||||
# don't update boundary nodes
|
||||
next_velocity = tf.where(mask, prediction, velocity)
|
||||
trajectory = trajectory.write(step, velocity)
|
||||
return step+1, next_velocity, trajectory
|
||||
|
||||
_, _, output = tf.while_loop(
|
||||
cond=lambda step, cur, traj: tf.less(step, num_steps),
|
||||
body=step_fn,
|
||||
loop_vars=(0, initial_state['velocity'],
|
||||
tf.TensorArray(tf.float32, num_steps)),
|
||||
parallel_iterations=1)
|
||||
return output.stack()
|
||||
|
||||
|
||||
def evaluate(model, inputs):
|
||||
"""Performs model rollouts and create stats."""
|
||||
initial_state = {k: v[0] for k, v in inputs.items()}
|
||||
num_steps = inputs['cells'].shape[0]
|
||||
prediction = _rollout(model, initial_state, num_steps)
|
||||
|
||||
error = tf.reduce_mean((prediction - inputs['velocity'])**2, axis=-1)
|
||||
scalars = {'mse_%d_steps' % horizon: tf.reduce_mean(error[1:horizon+1])
|
||||
for horizon in [1, 10, 20, 50, 100, 200]}
|
||||
traj_ops = {
|
||||
'faces': inputs['cells'],
|
||||
'mesh_pos': inputs['mesh_pos'],
|
||||
'gt_velocity': inputs['velocity'],
|
||||
'pred_velocity': prediction
|
||||
}
|
||||
return scalars, traj_ops
|
||||
@@ -0,0 +1,94 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Model for CylinderFlow."""
|
||||
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from meshgraphnets import common
|
||||
from meshgraphnets import core_model
|
||||
from meshgraphnets import normalization
|
||||
|
||||
|
||||
class Model(snt.AbstractModule):
|
||||
"""Model for fluid simulation."""
|
||||
|
||||
def __init__(self, learned_model, name='Model'):
|
||||
super(Model, self).__init__(name=name)
|
||||
with self._enter_variable_scope():
|
||||
self._learned_model = learned_model
|
||||
self._output_normalizer = normalization.Normalizer(
|
||||
size=2, name='output_normalizer')
|
||||
self._node_normalizer = normalization.Normalizer(
|
||||
size=2+common.NodeType.SIZE, name='node_normalizer')
|
||||
self._edge_normalizer = normalization.Normalizer(
|
||||
size=3, name='edge_normalizer') # 2D coord + length
|
||||
|
||||
def _build_graph(self, inputs, is_training):
|
||||
"""Builds input graph."""
|
||||
# construct graph nodes
|
||||
node_type = tf.one_hot(inputs['node_type'][:, 0], common.NodeType.SIZE)
|
||||
node_features = tf.concat([inputs['velocity'], node_type], axis=-1)
|
||||
|
||||
# construct graph edges
|
||||
senders, receivers = common.triangles_to_edges(inputs['cells'])
|
||||
relative_mesh_pos = (tf.gather(inputs['mesh_pos'], senders) -
|
||||
tf.gather(inputs['mesh_pos'], receivers))
|
||||
edge_features = tf.concat([
|
||||
relative_mesh_pos,
|
||||
tf.norm(relative_mesh_pos, axis=-1, keepdims=True)], axis=-1)
|
||||
|
||||
mesh_edges = core_model.EdgeSet(
|
||||
name='mesh_edges',
|
||||
features=self._edge_normalizer(edge_features, is_training),
|
||||
receivers=receivers,
|
||||
senders=senders)
|
||||
return core_model.MultiGraph(
|
||||
node_features=self._node_normalizer(node_features, is_training),
|
||||
edge_sets=[mesh_edges])
|
||||
|
||||
def _build(self, inputs):
|
||||
graph = self._build_graph(inputs, is_training=False)
|
||||
per_node_network_output = self._learned_model(graph)
|
||||
return self._update(inputs, per_node_network_output)
|
||||
|
||||
@snt.reuse_variables
|
||||
def loss(self, inputs):
|
||||
"""L2 loss on velocity."""
|
||||
graph = self._build_graph(inputs, is_training=True)
|
||||
network_output = self._learned_model(graph)
|
||||
|
||||
# build target velocity change
|
||||
cur_velocity = inputs['velocity']
|
||||
target_velocity = inputs['target|velocity']
|
||||
target_velocity_change = target_velocity - cur_velocity
|
||||
target_normalized = self._output_normalizer(target_velocity_change)
|
||||
|
||||
# build loss
|
||||
node_type = inputs['node_type'][:, 0]
|
||||
loss_mask = tf.logical_or(tf.equal(node_type, common.NodeType.NORMAL),
|
||||
tf.equal(node_type, common.NodeType.OUTFLOW))
|
||||
error = tf.reduce_sum((target_normalized - network_output)**2, axis=1)
|
||||
loss = tf.reduce_mean(error[loss_mask])
|
||||
return loss
|
||||
|
||||
def _update(self, inputs, per_node_network_output):
|
||||
"""Integrate model outputs."""
|
||||
velocity_update = self._output_normalizer.inverse(per_node_network_output)
|
||||
# integrate forward
|
||||
cur_velocity = inputs['velocity']
|
||||
return cur_velocity + velocity_update
|
||||
@@ -0,0 +1,61 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Functions to build evaluation metrics for cloth data."""
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from meshgraphnets.common import NodeType
|
||||
|
||||
|
||||
def _rollout(model, initial_state, num_steps):
|
||||
"""Rolls out a model trajectory."""
|
||||
mask = tf.equal(initial_state['node_type'][:, 0], NodeType.NORMAL)
|
||||
|
||||
def step_fn(step, prev_pos, cur_pos, trajectory):
|
||||
prediction = model({**initial_state,
|
||||
'prev|world_pos': prev_pos,
|
||||
'world_pos': cur_pos})
|
||||
# don't update kinematic nodes
|
||||
next_pos = tf.where(mask, prediction, cur_pos)
|
||||
trajectory = trajectory.write(step, cur_pos)
|
||||
return step+1, cur_pos, next_pos, trajectory
|
||||
|
||||
_, _, _, output = tf.while_loop(
|
||||
cond=lambda step, last, cur, traj: tf.less(step, num_steps),
|
||||
body=step_fn,
|
||||
loop_vars=(0, initial_state['prev|world_pos'], initial_state['world_pos'],
|
||||
tf.TensorArray(tf.float32, num_steps)),
|
||||
parallel_iterations=1)
|
||||
return output.stack()
|
||||
|
||||
|
||||
def evaluate(model, inputs):
|
||||
"""Performs model rollouts and create stats."""
|
||||
initial_state = {k: v[0] for k, v in inputs.items()}
|
||||
num_steps = inputs['cells'].shape[0]
|
||||
prediction = _rollout(model, initial_state, num_steps)
|
||||
|
||||
error = tf.reduce_mean((prediction - inputs['world_pos'])**2, axis=-1)
|
||||
scalars = {'mse_%d_steps' % horizon: tf.reduce_mean(error[1:horizon+1])
|
||||
for horizon in [1, 10, 20, 50, 100, 200]}
|
||||
traj_ops = {
|
||||
'faces': inputs['cells'],
|
||||
'mesh_pos': inputs['mesh_pos'],
|
||||
'gt_pos': inputs['world_pos'],
|
||||
'pred_pos': prediction
|
||||
}
|
||||
return scalars, traj_ops
|
||||
@@ -0,0 +1,100 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Model for FlagSimple."""
|
||||
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from meshgraphnets import common
|
||||
from meshgraphnets import core_model
|
||||
from meshgraphnets import normalization
|
||||
|
||||
|
||||
class Model(snt.AbstractModule):
|
||||
"""Model for static cloth simulation."""
|
||||
|
||||
def __init__(self, learned_model, name='Model'):
|
||||
super(Model, self).__init__(name=name)
|
||||
with self._enter_variable_scope():
|
||||
self._learned_model = learned_model
|
||||
self._output_normalizer = normalization.Normalizer(
|
||||
size=3, name='output_normalizer')
|
||||
self._node_normalizer = normalization.Normalizer(
|
||||
size=3+common.NodeType.SIZE, name='node_normalizer')
|
||||
self._edge_normalizer = normalization.Normalizer(
|
||||
size=7, name='edge_normalizer') # 2D coord + 3D coord + 2*length = 7
|
||||
|
||||
def _build_graph(self, inputs, is_training):
|
||||
"""Builds input graph."""
|
||||
# construct graph nodes
|
||||
velocity = inputs['world_pos'] - inputs['prev|world_pos']
|
||||
node_type = tf.one_hot(inputs['node_type'][:, 0], common.NodeType.SIZE)
|
||||
node_features = tf.concat([velocity, node_type], axis=-1)
|
||||
|
||||
# construct graph edges
|
||||
senders, receivers = common.triangles_to_edges(inputs['cells'])
|
||||
relative_world_pos = (tf.gather(inputs['world_pos'], senders) -
|
||||
tf.gather(inputs['world_pos'], receivers))
|
||||
relative_mesh_pos = (tf.gather(inputs['mesh_pos'], senders) -
|
||||
tf.gather(inputs['mesh_pos'], receivers))
|
||||
edge_features = tf.concat([
|
||||
relative_world_pos,
|
||||
tf.norm(relative_world_pos, axis=-1, keepdims=True),
|
||||
relative_mesh_pos,
|
||||
tf.norm(relative_mesh_pos, axis=-1, keepdims=True)], axis=-1)
|
||||
|
||||
mesh_edges = core_model.EdgeSet(
|
||||
name='mesh_edges',
|
||||
features=self._edge_normalizer(edge_features, is_training),
|
||||
receivers=receivers,
|
||||
senders=senders)
|
||||
return core_model.MultiGraph(
|
||||
node_features=self._node_normalizer(node_features, is_training),
|
||||
edge_sets=[mesh_edges])
|
||||
|
||||
def _build(self, inputs):
|
||||
graph = self._build_graph(inputs, is_training=False)
|
||||
per_node_network_output = self._learned_model(graph)
|
||||
return self._update(inputs, per_node_network_output)
|
||||
|
||||
@snt.reuse_variables
|
||||
def loss(self, inputs):
|
||||
"""L2 loss on position."""
|
||||
graph = self._build_graph(inputs, is_training=True)
|
||||
network_output = self._learned_model(graph)
|
||||
|
||||
# build target acceleration
|
||||
cur_position = inputs['world_pos']
|
||||
prev_position = inputs['prev|world_pos']
|
||||
target_position = inputs['target|world_pos']
|
||||
target_acceleration = target_position - 2*cur_position + prev_position
|
||||
target_normalized = self._output_normalizer(target_acceleration)
|
||||
|
||||
# build loss
|
||||
loss_mask = tf.equal(inputs['node_type'][:, 0], common.NodeType.NORMAL)
|
||||
error = tf.reduce_sum((target_normalized - network_output)**2, axis=1)
|
||||
loss = tf.reduce_mean(error[loss_mask])
|
||||
return loss
|
||||
|
||||
def _update(self, inputs, per_node_network_output):
|
||||
"""Integrate model outputs."""
|
||||
acceleration = self._output_normalizer.inverse(per_node_network_output)
|
||||
# integrate forward
|
||||
cur_position = inputs['world_pos']
|
||||
prev_position = inputs['prev|world_pos']
|
||||
position = 2*cur_position + acceleration - prev_position
|
||||
return position
|
||||
@@ -0,0 +1,51 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Commonly used data structures and functions."""
|
||||
|
||||
import enum
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class NodeType(enum.IntEnum):
|
||||
NORMAL = 0
|
||||
OBSTACLE = 1
|
||||
AIRFOIL = 2
|
||||
HANDLE = 3
|
||||
INFLOW = 4
|
||||
OUTFLOW = 5
|
||||
WALL_BOUNDARY = 6
|
||||
SIZE = 9
|
||||
|
||||
|
||||
def triangles_to_edges(faces):
|
||||
"""Computes mesh edges from triangles."""
|
||||
# collect edges from triangles
|
||||
edges = tf.concat([faces[:, 0:2],
|
||||
faces[:, 1:3],
|
||||
tf.stack([faces[:, 2], faces[:, 0]], axis=1)], axis=0)
|
||||
# those edges are sometimes duplicated (within the mesh) and sometimes
|
||||
# single (at the mesh boundary).
|
||||
# sort & pack edges as single tf.int64
|
||||
receivers = tf.reduce_min(edges, axis=1)
|
||||
senders = tf.reduce_max(edges, axis=1)
|
||||
packed_edges = tf.bitcast(tf.stack([senders, receivers], axis=1), tf.int64)
|
||||
# remove duplicates and unpack
|
||||
unique_edges = tf.bitcast(tf.unique(packed_edges)[0], tf.int32)
|
||||
senders, receivers = tf.unstack(unique_edges, axis=1)
|
||||
# create two-way connectivity
|
||||
return (tf.concat([senders, receivers], axis=0),
|
||||
tf.concat([receivers, senders], axis=0))
|
||||
@@ -0,0 +1,121 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Core learned graph net model."""
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
EdgeSet = collections.namedtuple('EdgeSet', ['name', 'features', 'senders',
|
||||
'receivers'])
|
||||
MultiGraph = collections.namedtuple('Graph', ['node_features', 'edge_sets'])
|
||||
|
||||
|
||||
class GraphNetBlock(snt.AbstractModule):
|
||||
"""Multi-Edge Interaction Network with residual connections."""
|
||||
|
||||
def __init__(self, model_fn, name='GraphNetBlock'):
|
||||
super(GraphNetBlock, self).__init__(name=name)
|
||||
self._model_fn = model_fn
|
||||
|
||||
def _update_edge_features(self, node_features, edge_set):
|
||||
"""Aggregrates node features, and applies edge function."""
|
||||
sender_features = tf.gather(node_features, edge_set.senders)
|
||||
receiver_features = tf.gather(node_features, edge_set.receivers)
|
||||
features = [sender_features, receiver_features, edge_set.features]
|
||||
with tf.variable_scope(edge_set.name+'_edge_fn'):
|
||||
return self._model_fn()(tf.concat(features, axis=-1))
|
||||
|
||||
def _update_node_features(self, node_features, edge_sets):
|
||||
"""Aggregrates edge features, and applies node function."""
|
||||
num_nodes = tf.shape(node_features)[0]
|
||||
features = [node_features]
|
||||
for edge_set in edge_sets:
|
||||
features.append(tf.math.unsorted_segment_sum(edge_set.features,
|
||||
edge_set.receivers,
|
||||
num_nodes))
|
||||
with tf.variable_scope('node_fn'):
|
||||
return self._model_fn()(tf.concat(features, axis=-1))
|
||||
|
||||
def _build(self, graph):
|
||||
"""Applies GraphNetBlock and returns updated MultiGraph."""
|
||||
|
||||
# apply edge functions
|
||||
new_edge_sets = []
|
||||
for edge_set in graph.edge_sets:
|
||||
updated_features = self._update_edge_features(graph.node_features,
|
||||
edge_set)
|
||||
new_edge_sets.append(edge_set._replace(features=updated_features))
|
||||
|
||||
# apply node function
|
||||
new_node_features = self._update_node_features(graph.node_features,
|
||||
new_edge_sets)
|
||||
|
||||
# add residual connections
|
||||
new_node_features += graph.node_features
|
||||
new_edge_sets = [es._replace(features=es.features + old_es.features)
|
||||
for es, old_es in zip(new_edge_sets, graph.edge_sets)]
|
||||
return MultiGraph(new_node_features, new_edge_sets)
|
||||
|
||||
|
||||
class EncodeProcessDecode(snt.AbstractModule):
|
||||
"""Encode-Process-Decode GraphNet model."""
|
||||
|
||||
def __init__(self,
|
||||
output_size,
|
||||
latent_size,
|
||||
num_layers,
|
||||
message_passing_steps,
|
||||
name='EncodeProcessDecode'):
|
||||
super(EncodeProcessDecode, self).__init__(name=name)
|
||||
self._latent_size = latent_size
|
||||
self._output_size = output_size
|
||||
self._num_layers = num_layers
|
||||
self._message_passing_steps = message_passing_steps
|
||||
|
||||
def _make_mlp(self, output_size, layer_norm=True):
|
||||
"""Builds an MLP."""
|
||||
widths = [self._latent_size] * self._num_layers + [output_size]
|
||||
network = snt.nets.MLP(widths, activate_final=False)
|
||||
if layer_norm:
|
||||
network = snt.Sequential([network, snt.LayerNorm()])
|
||||
return network
|
||||
|
||||
def _encoder(self, graph):
|
||||
"""Encodes node and edge features into latent features."""
|
||||
with tf.variable_scope('encoder'):
|
||||
node_latents = self._make_mlp(self._latent_size)(graph.node_features)
|
||||
new_edges_sets = []
|
||||
for edge_set in graph.edge_sets:
|
||||
latent = self._make_mlp(self._latent_size)(edge_set.features)
|
||||
new_edges_sets.append(edge_set._replace(features=latent))
|
||||
return MultiGraph(node_latents, new_edges_sets)
|
||||
|
||||
def _decoder(self, graph):
|
||||
"""Decodes node features from graph."""
|
||||
with tf.variable_scope('decoder'):
|
||||
decoder = self._make_mlp(self._output_size, layer_norm=False)
|
||||
return decoder(graph.node_features)
|
||||
|
||||
def _build(self, graph):
|
||||
"""Encodes and processes a multigraph, and returns node features."""
|
||||
model_fn = functools.partial(self._make_mlp, output_size=self._latent_size)
|
||||
latent_graph = self._encoder(graph)
|
||||
for _ in range(self._message_passing_steps):
|
||||
latent_graph = GraphNetBlock(model_fn)(latent_graph)
|
||||
return self._decoder(latent_graph)
|
||||
@@ -0,0 +1,119 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utility functions for reading the datasets."""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from meshgraphnets.common import NodeType
|
||||
|
||||
|
||||
def _parse(proto, meta):
|
||||
"""Parses a trajectory from tf.Example."""
|
||||
feature_lists = {k: tf.io.VarLenFeature(tf.string)
|
||||
for k in meta['field_names']}
|
||||
features = tf.io.parse_single_example(proto, feature_lists)
|
||||
out = {}
|
||||
for key, field in meta['features'].items():
|
||||
data = tf.io.decode_raw(features[key].values, getattr(tf, field['dtype']))
|
||||
data = tf.reshape(data, field['shape'])
|
||||
if field['type'] == 'static':
|
||||
data = tf.tile(data, [meta['trajectory_length'], 1, 1])
|
||||
elif field['type'] == 'dynamic_varlen':
|
||||
length = tf.io.decode_raw(features['length_'+key].values, tf.int32)
|
||||
length = tf.reshape(length, [-1])
|
||||
data = tf.RaggedTensor.from_row_lengths(data, row_lengths=length)
|
||||
elif field['type'] != 'dynamic':
|
||||
raise ValueError('invalid data format')
|
||||
out[key] = data
|
||||
return out
|
||||
|
||||
|
||||
def load_dataset(path, split):
|
||||
"""Load dataset."""
|
||||
with open(os.path.join(path, 'meta.json'), 'r') as fp:
|
||||
meta = json.loads(fp.read())
|
||||
ds = tf.data.TFRecordDataset(os.path.join(path, split+'.tfrecord'))
|
||||
ds = ds.map(functools.partial(_parse, meta=meta), num_parallel_calls=8)
|
||||
ds = ds.prefetch(1)
|
||||
return ds
|
||||
|
||||
|
||||
def add_targets(ds, fields, add_history):
|
||||
"""Adds target and optionally history fields to dataframe."""
|
||||
def fn(trajectory):
|
||||
out = {}
|
||||
for key, val in trajectory.items():
|
||||
out[key] = val[1:-1]
|
||||
if key in fields:
|
||||
if add_history:
|
||||
out['prev|'+key] = val[0:-2]
|
||||
out['target|'+key] = val[2:]
|
||||
return out
|
||||
return ds.map(fn, num_parallel_calls=8)
|
||||
|
||||
|
||||
def split_and_preprocess(ds, noise_field, noise_scale, noise_gamma):
|
||||
"""Splits trajectories into frames, and adds training noise."""
|
||||
def add_noise(frame):
|
||||
noise = tf.random.normal(tf.shape(frame[noise_field]),
|
||||
stddev=noise_scale, dtype=tf.float32)
|
||||
# don't apply noise to boundary nodes
|
||||
mask = tf.equal(frame['node_type'], NodeType.NORMAL)[:, 0]
|
||||
noise = tf.where(mask, noise, tf.zeros_like(noise))
|
||||
frame[noise_field] += noise
|
||||
frame['target|'+noise_field] += (1.0 - noise_gamma) * noise
|
||||
return frame
|
||||
|
||||
ds = ds.flat_map(tf.data.Dataset.from_tensor_slices)
|
||||
ds = ds.map(add_noise, num_parallel_calls=8)
|
||||
ds = ds.shuffle(10000)
|
||||
ds = ds.repeat(None)
|
||||
return ds.prefetch(10)
|
||||
|
||||
|
||||
def batch_dataset(ds, batch_size):
|
||||
"""Batches input datasets."""
|
||||
shapes = ds.output_shapes
|
||||
types = ds.output_types
|
||||
def renumber(buffer, frame):
|
||||
nodes, cells = buffer
|
||||
new_nodes, new_cells = frame
|
||||
return nodes + new_nodes, tf.concat([cells, new_cells+nodes], axis=0)
|
||||
|
||||
def batch_accumulate(ds_window):
|
||||
out = {}
|
||||
for key, ds_val in ds_window.items():
|
||||
initial = tf.zeros((0, shapes[key][1]), dtype=types[key])
|
||||
if key == 'cells':
|
||||
# renumber node indices in cells
|
||||
num_nodes = ds_window['node_type'].map(lambda x: tf.shape(x)[0])
|
||||
cells = tf.data.Dataset.zip((num_nodes, ds_val))
|
||||
initial = (tf.constant(0, tf.int32), initial)
|
||||
_, out[key] = cells.reduce(initial, renumber)
|
||||
else:
|
||||
merge = lambda prev, cur: tf.concat([prev, cur], axis=0)
|
||||
out[key] = ds_val.reduce(initial, merge)
|
||||
return out
|
||||
|
||||
if batch_size > 1:
|
||||
ds = ds.window(batch_size, drop_remainder=True)
|
||||
ds = ds.map(batch_accumulate, num_parallel_calls=8)
|
||||
return ds
|
||||
Executable
+32
@@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Usage:
|
||||
# sh download_dataset.sh ${DATASET_NAME} ${OUTPUT_DIR}
|
||||
# Example:
|
||||
# sh download_dataset.sh flag_simple /tmp/
|
||||
|
||||
set -e
|
||||
|
||||
DATASET_NAME="${1}"
|
||||
OUTPUT_DIR="${2}/${DATASET_NAME}"
|
||||
|
||||
BASE_URL="https://storage.googleapis.com/dm-meshgraphnets/${DATASET_NAME}/"
|
||||
|
||||
mkdir -p ${OUTPUT_DIR}
|
||||
for file in meta.json train.tfrecord valid.tfrecord test.tfrecord
|
||||
do
|
||||
wget -O "${OUTPUT_DIR}/${file}" "${BASE_URL}${file}"
|
||||
done
|
||||
@@ -0,0 +1,73 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Online data normalization."""
|
||||
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class Normalizer(snt.AbstractModule):
|
||||
"""Feature normalizer that accumulates statistics online."""
|
||||
|
||||
def __init__(self, size, max_accumulations=10**6, std_epsilon=1e-8,
|
||||
name='Normalizer'):
|
||||
super(Normalizer, self).__init__(name=name)
|
||||
self._max_accumulations = max_accumulations
|
||||
self._std_epsilon = std_epsilon
|
||||
with self._enter_variable_scope():
|
||||
self._acc_count = tf.Variable(0, dtype=tf.float32, trainable=False)
|
||||
self._num_accumulations = tf.Variable(0, dtype=tf.float32,
|
||||
trainable=False)
|
||||
self._acc_sum = tf.Variable(tf.zeros(size, tf.float32), trainable=False)
|
||||
self._acc_sum_squared = tf.Variable(tf.zeros(size, tf.float32),
|
||||
trainable=False)
|
||||
|
||||
def _build(self, batched_data, accumulate=True):
|
||||
"""Normalizes input data and accumulates statistics."""
|
||||
update_op = tf.no_op()
|
||||
if accumulate:
|
||||
# stop accumulating after a million updates, to prevent accuracy issues
|
||||
update_op = tf.cond(self._num_accumulations < self._max_accumulations,
|
||||
lambda: self._accumulate(batched_data),
|
||||
tf.no_op)
|
||||
with tf.control_dependencies([update_op]):
|
||||
return (batched_data - self._mean()) / self._std_with_epsilon()
|
||||
|
||||
@snt.reuse_variables
|
||||
def inverse(self, normalized_batch_data):
|
||||
"""Inverse transformation of the normalizer."""
|
||||
return normalized_batch_data * self._std_with_epsilon() + self._mean()
|
||||
|
||||
def _accumulate(self, batched_data):
|
||||
"""Function to perform the accumulation of the batch_data statistics."""
|
||||
count = tf.cast(tf.shape(batched_data)[0], tf.float32)
|
||||
data_sum = tf.reduce_sum(batched_data, axis=0)
|
||||
squared_data_sum = tf.reduce_sum(batched_data**2, axis=0)
|
||||
return tf.group(
|
||||
tf.assign_add(self._acc_sum, data_sum),
|
||||
tf.assign_add(self._acc_sum_squared, squared_data_sum),
|
||||
tf.assign_add(self._acc_count, count),
|
||||
tf.assign_add(self._num_accumulations, 1.))
|
||||
|
||||
def _mean(self):
|
||||
safe_count = tf.maximum(self._acc_count, 1.)
|
||||
return self._acc_sum / safe_count
|
||||
|
||||
def _std_with_epsilon(self):
|
||||
safe_count = tf.maximum(self._acc_count, 1.)
|
||||
std = tf.sqrt(self._acc_sum_squared / safe_count - self._mean()**2)
|
||||
return tf.math.maximum(std, self._std_epsilon)
|
||||
@@ -0,0 +1,68 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Plots a CFD trajectory rollout."""
|
||||
|
||||
import pickle
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from matplotlib import animation
|
||||
from matplotlib import tri as mtri
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('rollout_path', None, 'Path to rollout pickle file')
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
with open(FLAGS.rollout_path, 'rb') as fp:
|
||||
rollout_data = pickle.load(fp)
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
|
||||
skip = 10
|
||||
num_steps = rollout_data[0]['gt_velocity'].shape[0]
|
||||
num_frames = len(rollout_data) * num_steps // skip
|
||||
|
||||
# compute bounds
|
||||
bounds = []
|
||||
for trajectory in rollout_data:
|
||||
bb_min = trajectory['gt_velocity'].min(axis=(0, 1))
|
||||
bb_max = trajectory['gt_velocity'].max(axis=(0, 1))
|
||||
bounds.append((bb_min, bb_max))
|
||||
|
||||
def animate(num):
|
||||
step = (num*skip) % num_steps
|
||||
traj = (num*skip) // num_steps
|
||||
ax.cla()
|
||||
ax.set_aspect('equal')
|
||||
ax.set_axis_off()
|
||||
vmin, vmax = bounds[traj]
|
||||
pos = rollout_data[traj]['mesh_pos'][step]
|
||||
faces = rollout_data[traj]['faces'][step]
|
||||
velocity = rollout_data[traj]['pred_velocity'][step]
|
||||
triang = mtri.Triangulation(pos[:, 0], pos[:, 1], faces)
|
||||
ax.tripcolor(triang, velocity[:, 0], vmin=vmin[0], vmax=vmax[0])
|
||||
ax.triplot(triang, 'ko-', ms=0.5, lw=0.3)
|
||||
ax.set_title('Trajectory %d Step %d' % (traj, step))
|
||||
return fig,
|
||||
|
||||
_ = animation.FuncAnimation(fig, animate, frames=num_frames, interval=100)
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,67 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Plots a cloth trajectory rollout."""
|
||||
|
||||
import pickle
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from matplotlib import animation
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('rollout_path', None, 'Path to rollout pickle file')
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
with open(FLAGS.rollout_path, 'rb') as fp:
|
||||
rollout_data = pickle.load(fp)
|
||||
|
||||
fig = plt.figure(figsize=(8, 8))
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
skip = 10
|
||||
num_steps = rollout_data[0]['gt_pos'].shape[0]
|
||||
num_frames = len(rollout_data) * num_steps // skip
|
||||
|
||||
# compute bounds
|
||||
bounds = []
|
||||
for trajectory in rollout_data:
|
||||
bb_min = trajectory['gt_pos'].min(axis=(0, 1))
|
||||
bb_max = trajectory['gt_pos'].max(axis=(0, 1))
|
||||
bounds.append((bb_min, bb_max))
|
||||
|
||||
def animate(num):
|
||||
step = (num*skip) % num_steps
|
||||
traj = (num*skip) // num_steps
|
||||
ax.cla()
|
||||
bound = bounds[traj]
|
||||
ax.set_xlim([bound[0][0], bound[1][0]])
|
||||
ax.set_ylim([bound[0][1], bound[1][1]])
|
||||
ax.set_zlim([bound[0][2], bound[1][2]])
|
||||
pos = rollout_data[traj]['pred_pos'][step]
|
||||
faces = rollout_data[traj]['faces'][step]
|
||||
ax.plot_trisurf(pos[:, 0], pos[:, 1], faces, pos[:, 2], shade=True)
|
||||
ax.set_title('Trajectory %d Step %d' % (traj, step))
|
||||
return fig,
|
||||
|
||||
_ = animation.FuncAnimation(fig, animate, frames=num_frames, interval=100)
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,5 @@
|
||||
tensorflow-gpu>=1.15,<2
|
||||
dm-sonnet<2
|
||||
matplotlib
|
||||
absl-py
|
||||
numpy
|
||||
@@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Fail on any error.
|
||||
set -e
|
||||
|
||||
# Display commands being run.
|
||||
set -x
|
||||
|
||||
TMP_DIR=`mktemp -d`
|
||||
|
||||
virtualenv --python=python3.6 "${TMP_DIR}/env"
|
||||
source "${TMP_DIR}/env/bin/activate"
|
||||
|
||||
# Install dependencies.
|
||||
pip install --upgrade -r meshgraphnets/requirements.txt
|
||||
|
||||
# Download minimal dataset
|
||||
DATA_DIR="${TMP_DIR}/flag_minimal"
|
||||
bash meshgraphnets/download_dataset.sh flag_minimal ${TMP_DIR}
|
||||
|
||||
# Train for a few steps.
|
||||
CHK_DIR="${TMP_DIR}/checkpoint"
|
||||
python -m meshgraphnets.run_model --model=cloth --mode=train --checkpoint_dir=${CHK_DIR} --dataset_dir=${DATA_DIR} --num_training_steps=10
|
||||
|
||||
# Generate a rollout trajectory
|
||||
ROLLOUT_PATH="${TMP_DIR}/rollout.pkl"
|
||||
python -m meshgraphnets.run_model --model=cloth --mode=eval --checkpoint_dir=${CHK_DIR} --dataset_dir=${DATA_DIR} --rollout_path=${ROLLOUT_PATH} --num_rollouts=1
|
||||
|
||||
# Plot the rollout trajectory
|
||||
python -m meshgraphnets.plot_cloth --rollout_path=${ROLLOUT_PATH}
|
||||
|
||||
# Clean up.
|
||||
rm -r ${TMP_DIR}
|
||||
echo "Test run complete."
|
||||
@@ -0,0 +1,131 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Runs the learner/evaluator."""
|
||||
|
||||
import pickle
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
from meshgraphnets import cfd_eval
|
||||
from meshgraphnets import cfd_model
|
||||
from meshgraphnets import cloth_eval
|
||||
from meshgraphnets import cloth_model
|
||||
from meshgraphnets import core_model
|
||||
from meshgraphnets import dataset
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_enum('mode', 'train', ['train', 'eval'],
|
||||
'Train model, or run evaluation.')
|
||||
flags.DEFINE_enum('model', None, ['cfd', 'cloth'],
|
||||
'Select model to run.')
|
||||
flags.DEFINE_string('checkpoint_dir', None, 'Directory to save checkpoint')
|
||||
flags.DEFINE_string('dataset_dir', None, 'Directory to load dataset from.')
|
||||
flags.DEFINE_string('rollout_path', None,
|
||||
'Pickle file to save eval trajectories')
|
||||
flags.DEFINE_enum('rollout_split', 'valid', ['train', 'test', 'valid'],
|
||||
'Dataset split to use for rollouts.')
|
||||
flags.DEFINE_integer('num_rollouts', 10, 'No. of rollout trajectories')
|
||||
flags.DEFINE_integer('num_training_steps', int(10e6), 'No. of training steps')
|
||||
|
||||
PARAMETERS = {
|
||||
'cfd': dict(noise=0.02, gamma=1.0, field='velocity', history=False,
|
||||
size=2, batch=2, model=cfd_model, evaluator=cfd_eval),
|
||||
'cloth': dict(noise=0.003, gamma=0.1, field='world_pos', history=True,
|
||||
size=3, batch=1, model=cloth_model, evaluator=cloth_eval)
|
||||
}
|
||||
|
||||
|
||||
def learner(model, params):
|
||||
"""Run a learner job."""
|
||||
ds = dataset.load_dataset(FLAGS.dataset_dir, 'train')
|
||||
ds = dataset.add_targets(ds, [params['field']], add_history=params['history'])
|
||||
ds = dataset.split_and_preprocess(ds, noise_field=params['field'],
|
||||
noise_scale=params['noise'],
|
||||
noise_gamma=params['gamma'])
|
||||
inputs = tf.data.make_one_shot_iterator(ds).get_next()
|
||||
|
||||
loss_op = model.loss(inputs)
|
||||
global_step = tf.train.create_global_step()
|
||||
lr = tf.train.exponential_decay(learning_rate=1e-4,
|
||||
global_step=global_step,
|
||||
decay_steps=int(5e6),
|
||||
decay_rate=0.1) + 1e-6
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
|
||||
train_op = optimizer.minimize(loss_op, global_step=global_step)
|
||||
# Don't train for the first few steps, just accumulate normalization stats
|
||||
train_op = tf.cond(tf.less(global_step, 1000),
|
||||
lambda: tf.group(tf.assign_add(global_step, 1)),
|
||||
lambda: tf.group(train_op))
|
||||
|
||||
with tf.train.MonitoredTrainingSession(
|
||||
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.num_training_steps)],
|
||||
checkpoint_dir=FLAGS.checkpoint_dir,
|
||||
save_checkpoint_secs=600) as sess:
|
||||
|
||||
while not sess.should_stop():
|
||||
_, step, loss = sess.run([train_op, global_step, loss_op])
|
||||
if step % 1000 == 0:
|
||||
logging.info('Step %d: Loss %g', step, loss)
|
||||
logging.info('Training complete.')
|
||||
|
||||
|
||||
def evaluator(model, params):
|
||||
"""Run a model rollout trajectory."""
|
||||
ds = dataset.load_dataset(FLAGS.dataset_dir, FLAGS.rollout_split)
|
||||
ds = dataset.add_targets(ds, [params['field']], add_history=params['history'])
|
||||
inputs = tf.data.make_one_shot_iterator(ds).get_next()
|
||||
scalar_op, traj_ops = params['evaluator'].evaluate(model, inputs)
|
||||
tf.train.create_global_step()
|
||||
|
||||
with tf.train.MonitoredTrainingSession(
|
||||
checkpoint_dir=FLAGS.checkpoint_dir,
|
||||
save_checkpoint_secs=None,
|
||||
save_checkpoint_steps=None) as sess:
|
||||
trajectories = []
|
||||
scalars = []
|
||||
for traj_idx in range(FLAGS.num_rollouts):
|
||||
logging.info('Rollout trajectory %d', traj_idx)
|
||||
scalar_data, traj_data = sess.run([scalar_op, traj_ops])
|
||||
trajectories.append(traj_data)
|
||||
scalars.append(scalar_data)
|
||||
for key in scalars[0]:
|
||||
logging.info('%s: %g', key, np.mean([x[key] for x in scalars]))
|
||||
with open(FLAGS.rollout_path, 'wb') as fp:
|
||||
pickle.dump(trajectories, fp)
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
tf.enable_resource_variables()
|
||||
tf.disable_eager_execution()
|
||||
params = PARAMETERS[FLAGS.model]
|
||||
learned_model = core_model.EncodeProcessDecode(
|
||||
output_size=params['size'],
|
||||
latent_size=128,
|
||||
num_layers=2,
|
||||
message_passing_steps=15)
|
||||
model = params['model'].Model(learned_model)
|
||||
if FLAGS.mode == 'train':
|
||||
learner(model, params)
|
||||
elif FLAGS.mode == 'eval':
|
||||
evaluator(model, params)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
Reference in New Issue
Block a user