Internal.

PiperOrigin-RevId: 373093293
This commit is contained in:
Ravichandra Addanki
2021-05-11 07:23:38 +00:00
committed by Louise Deason
parent f0bd8651b9
commit 33076aa480
+10 -1
View File
@@ -32,11 +32,14 @@ to zero-mean unit-variance.
Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library.
"""
from typing import Callable
import graph_nets as gn
import sonnet as snt
import tensorflow as tf
Reducer = Callable[[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor]
def build_mlp(
hidden_size: int, num_hidden_layers: int, output_size: int) -> snt.Module:
@@ -55,6 +58,7 @@ class EncodeProcessDecode(snt.AbstractModule):
mlp_num_hidden_layers: int,
num_message_passing_steps: int,
output_size: int,
reducer: Reducer = tf.math.unsorted_segment_sum,
name: str = "EncodeProcessDecode"):
"""Inits the model.
@@ -65,6 +69,9 @@ class EncodeProcessDecode(snt.AbstractModule):
num_message_passing_steps: Number of message passing steps.
output_size: Output size of the decode node representations as required
by the downstream update function.
reducer: Reduction to be used when aggregating the edges in the nodes in
the interaction network. This should be a callable whose signature
matches tf.math.unsorted_segment_sum.
name: Name of the model.
"""
@@ -75,6 +82,7 @@ class EncodeProcessDecode(snt.AbstractModule):
self._mlp_num_hidden_layers = mlp_num_hidden_layers
self._num_message_passing_steps = num_message_passing_steps
self._output_size = output_size
self._reducer = reducer
with self._enter_variable_scope():
self._networks_builder()
@@ -116,7 +124,8 @@ class EncodeProcessDecode(snt.AbstractModule):
self._processor_networks.append(
gn.modules.InteractionNetwork(
edge_model_fn=build_mlp_with_layer_norm,
node_model_fn=build_mlp_with_layer_norm))
node_model_fn=build_mlp_with_layer_norm,
reducer=self._reducer))
# The decoder MLP decodes node latent features into the output size.
self._decoder_network = build_mlp(