Files
deepmind-research/learning_to_simulate/graph_network.py
Yilei Yang 152ac280f2 Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 440896914
2022-05-26 17:43:58 +01:00

183 lines
6.8 KiB
Python

# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Graph network implementation accompanying ICML 2020 submission.
"Learning to Simulate Complex Physics with Graph Networks"
Alvaro Sanchez-Gonzalez*, Jonathan Godwin*, Tobias Pfaff*, Rex Ying,
Jure Leskovec, Peter W. Battaglia
https://arxiv.org/abs/2002.09405
The Sonnet `EncodeProcessDecode` module provided here implements the learnable
parts of the model.
It assumes an encoder preprocessor has already built a graph with
connectivity and features as described in the paper, with features normalized
to zero-mean unit-variance.
Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library.
"""
from typing import Callable
import graph_nets as gn
import sonnet as snt
import tensorflow as tf
Reducer = Callable[[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor]
def build_mlp(
hidden_size: int, num_hidden_layers: int, output_size: int) -> snt.Module:
"""Builds an MLP."""
return snt.nets.MLP(
output_sizes=[hidden_size] * num_hidden_layers + [output_size]) # pytype: disable=bad-return-type # gen-stub-imports
class EncodeProcessDecode(snt.AbstractModule):
"""Encode-Process-Decode function approximator for learnable simulator."""
def __init__(
self,
latent_size: int,
mlp_hidden_size: int,
mlp_num_hidden_layers: int,
num_message_passing_steps: int,
output_size: int,
reducer: Reducer = tf.math.unsorted_segment_sum,
name: str = "EncodeProcessDecode"):
"""Inits the model.
Args:
latent_size: Size of the node and edge latent representations.
mlp_hidden_size: Hidden layer size for all MLPs.
mlp_num_hidden_layers: Number of hidden layers in all MLPs.
num_message_passing_steps: Number of message passing steps.
output_size: Output size of the decode node representations as required
by the downstream update function.
reducer: Reduction to be used when aggregating the edges in the nodes in
the interaction network. This should be a callable whose signature
matches tf.math.unsorted_segment_sum.
name: Name of the model.
"""
super().__init__(name=name)
self._latent_size = latent_size
self._mlp_hidden_size = mlp_hidden_size
self._mlp_num_hidden_layers = mlp_num_hidden_layers
self._num_message_passing_steps = num_message_passing_steps
self._output_size = output_size
self._reducer = reducer
with self._enter_variable_scope():
self._networks_builder()
def _build(self, input_graph: gn.graphs.GraphsTuple) -> tf.Tensor:
"""Forward pass of the learnable dynamics model."""
# Encode the input_graph.
latent_graph_0 = self._encode(input_graph)
# Do `m` message passing steps in the latent graphs.
latent_graph_m = self._process(latent_graph_0)
# Decode from the last latent graph.
return self._decode(latent_graph_m)
def _networks_builder(self):
"""Builds the networks."""
def build_mlp_with_layer_norm():
mlp = build_mlp(
hidden_size=self._mlp_hidden_size,
num_hidden_layers=self._mlp_num_hidden_layers,
output_size=self._latent_size)
return snt.Sequential([mlp, snt.LayerNorm()])
# The encoder graph network independently encodes edge and node features.
encoder_kwargs = dict(
edge_model_fn=build_mlp_with_layer_norm,
node_model_fn=build_mlp_with_layer_norm)
self._encoder_network = gn.modules.GraphIndependent(**encoder_kwargs)
# Create `num_message_passing_steps` graph networks with unshared parameters
# that update the node and edge latent features.
# Note that we can use `modules.InteractionNetwork` because
# it also outputs the messages as updated edge latent features.
self._processor_networks = []
for _ in range(self._num_message_passing_steps):
self._processor_networks.append(
gn.modules.InteractionNetwork(
edge_model_fn=build_mlp_with_layer_norm,
node_model_fn=build_mlp_with_layer_norm,
reducer=self._reducer))
# The decoder MLP decodes node latent features into the output size.
self._decoder_network = build_mlp(
hidden_size=self._mlp_hidden_size,
num_hidden_layers=self._mlp_num_hidden_layers,
output_size=self._output_size)
def _encode(
self, input_graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
"""Encodes the input graph features into a latent graph."""
# Copy the globals to all of the nodes, if applicable.
if input_graph.globals is not None:
broadcasted_globals = gn.blocks.broadcast_globals_to_nodes(input_graph)
input_graph = input_graph.replace(
nodes=tf.concat([input_graph.nodes, broadcasted_globals], axis=-1),
globals=None)
# Encode the node and edge features.
latent_graph_0 = self._encoder_network(input_graph)
return latent_graph_0
def _process(
self, latent_graph_0: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
"""Processes the latent graph with several steps of message passing."""
# Do `m` message passing steps in the latent graphs.
# (In the shared parameters case, just reuse the same `processor_network`)
latent_graph_prev_k = latent_graph_0
latent_graph_k = latent_graph_0
for processor_network_k in self._processor_networks:
latent_graph_k = self._process_step(
processor_network_k, latent_graph_prev_k)
latent_graph_prev_k = latent_graph_k
latent_graph_m = latent_graph_k
return latent_graph_m
def _process_step(
self, processor_network_k: snt.Module,
latent_graph_prev_k: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
"""Single step of message passing with node/edge residual connections."""
# One step of message passing.
latent_graph_k = processor_network_k(latent_graph_prev_k)
# Add residuals.
latent_graph_k = latent_graph_k.replace(
nodes=latent_graph_k.nodes+latent_graph_prev_k.nodes,
edges=latent_graph_k.edges+latent_graph_prev_k.edges)
return latent_graph_k
def _decode(self, latent_graph: gn.graphs.GraphsTuple) -> tf.Tensor:
"""Decodes from the latent graph."""
return self._decoder_network(latent_graph.nodes)