mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-07 12:42:06 +08:00
94 lines
3.6 KiB
Python
94 lines
3.6 KiB
Python
# pylint: disable=g-bad-file-header
|
|
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Model for CylinderFlow."""
|
|
|
|
import sonnet as snt
|
|
import tensorflow.compat.v1 as tf
|
|
|
|
from meshgraphnets import common
|
|
from meshgraphnets import core_model
|
|
from meshgraphnets import normalization
|
|
|
|
|
|
class Model(snt.AbstractModule):
|
|
"""Model for fluid simulation."""
|
|
|
|
def __init__(self, learned_model, name='Model'):
|
|
super(Model, self).__init__(name=name)
|
|
with self._enter_variable_scope():
|
|
self._learned_model = learned_model
|
|
self._output_normalizer = normalization.Normalizer(
|
|
size=2, name='output_normalizer')
|
|
self._node_normalizer = normalization.Normalizer(
|
|
size=2+common.NodeType.SIZE, name='node_normalizer')
|
|
self._edge_normalizer = normalization.Normalizer(
|
|
size=3, name='edge_normalizer') # 2D coord + length
|
|
|
|
def _build_graph(self, inputs, is_training):
|
|
"""Builds input graph."""
|
|
# construct graph nodes
|
|
node_type = tf.one_hot(inputs['node_type'][:, 0], common.NodeType.SIZE)
|
|
node_features = tf.concat([inputs['velocity'], node_type], axis=-1)
|
|
|
|
# construct graph edges
|
|
senders, receivers = common.triangles_to_edges(inputs['cells'])
|
|
relative_mesh_pos = (tf.gather(inputs['mesh_pos'], senders) -
|
|
tf.gather(inputs['mesh_pos'], receivers))
|
|
edge_features = tf.concat([
|
|
relative_mesh_pos,
|
|
tf.norm(relative_mesh_pos, axis=-1, keepdims=True)], axis=-1)
|
|
|
|
mesh_edges = core_model.EdgeSet(
|
|
name='mesh_edges',
|
|
features=self._edge_normalizer(edge_features, is_training),
|
|
receivers=receivers,
|
|
senders=senders)
|
|
return core_model.MultiGraph(
|
|
node_features=self._node_normalizer(node_features, is_training),
|
|
edge_sets=[mesh_edges])
|
|
|
|
def _build(self, inputs):
|
|
graph = self._build_graph(inputs, is_training=False)
|
|
per_node_network_output = self._learned_model(graph)
|
|
return self._update(inputs, per_node_network_output)
|
|
|
|
@snt.reuse_variables
|
|
def loss(self, inputs):
|
|
"""L2 loss on velocity."""
|
|
graph = self._build_graph(inputs, is_training=True)
|
|
network_output = self._learned_model(graph)
|
|
|
|
# build target velocity change
|
|
cur_velocity = inputs['velocity']
|
|
target_velocity = inputs['target|velocity']
|
|
target_velocity_change = target_velocity - cur_velocity
|
|
target_normalized = self._output_normalizer(target_velocity_change)
|
|
|
|
# build loss
|
|
node_type = inputs['node_type'][:, 0]
|
|
loss_mask = tf.logical_or(tf.equal(node_type, common.NodeType.NORMAL),
|
|
tf.equal(node_type, common.NodeType.OUTFLOW))
|
|
error = tf.reduce_sum((target_normalized - network_output)**2, axis=1)
|
|
loss = tf.reduce_mean(error[loss_mask])
|
|
return loss
|
|
|
|
def _update(self, inputs, per_node_network_output):
|
|
"""Integrate model outputs."""
|
|
velocity_update = self._output_normalizer.inverse(per_node_network_output)
|
|
# integrate forward
|
|
cur_velocity = inputs['velocity']
|
|
return cur_velocity + velocity_update
|