Internal.

PiperOrigin-RevId: 373093293
This commit is contained in:
Ravichandra Addanki
2021-05-11 07:23:38 +00:00
committed by Louise Deason
parent f0bd8651b9
commit 33076aa480
+10 -1
View File
@@ -32,11 +32,14 @@ to zero-mean unit-variance.
Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library. Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library.
""" """
from typing import Callable
import graph_nets as gn import graph_nets as gn
import sonnet as snt import sonnet as snt
import tensorflow as tf import tensorflow as tf
Reducer = Callable[[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor]
def build_mlp( def build_mlp(
hidden_size: int, num_hidden_layers: int, output_size: int) -> snt.Module: hidden_size: int, num_hidden_layers: int, output_size: int) -> snt.Module:
@@ -55,6 +58,7 @@ class EncodeProcessDecode(snt.AbstractModule):
mlp_num_hidden_layers: int, mlp_num_hidden_layers: int,
num_message_passing_steps: int, num_message_passing_steps: int,
output_size: int, output_size: int,
reducer: Reducer = tf.math.unsorted_segment_sum,
name: str = "EncodeProcessDecode"): name: str = "EncodeProcessDecode"):
"""Inits the model. """Inits the model.
@@ -65,6 +69,9 @@ class EncodeProcessDecode(snt.AbstractModule):
num_message_passing_steps: Number of message passing steps. num_message_passing_steps: Number of message passing steps.
output_size: Output size of the decode node representations as required output_size: Output size of the decode node representations as required
by the downstream update function. by the downstream update function.
reducer: Reduction to be used when aggregating the edges in the nodes in
the interaction network. This should be a callable whose signature
matches tf.math.unsorted_segment_sum.
name: Name of the model. name: Name of the model.
""" """
@@ -75,6 +82,7 @@ class EncodeProcessDecode(snt.AbstractModule):
self._mlp_num_hidden_layers = mlp_num_hidden_layers self._mlp_num_hidden_layers = mlp_num_hidden_layers
self._num_message_passing_steps = num_message_passing_steps self._num_message_passing_steps = num_message_passing_steps
self._output_size = output_size self._output_size = output_size
self._reducer = reducer
with self._enter_variable_scope(): with self._enter_variable_scope():
self._networks_builder() self._networks_builder()
@@ -116,7 +124,8 @@ class EncodeProcessDecode(snt.AbstractModule):
self._processor_networks.append( self._processor_networks.append(
gn.modules.InteractionNetwork( gn.modules.InteractionNetwork(
edge_model_fn=build_mlp_with_layer_norm, edge_model_fn=build_mlp_with_layer_norm,
node_model_fn=build_mlp_with_layer_norm)) node_model_fn=build_mlp_with_layer_norm,
reducer=self._reducer))
# The decoder MLP decodes node latent features into the output size. # The decoder MLP decodes node latent features into the output size.
self._decoder_network = build_mlp( self._decoder_network = build_mlp(