mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-22 07:11:25 +08:00
Internal.
PiperOrigin-RevId: 373093293
This commit is contained in:
committed by
Louise Deason
parent
f0bd8651b9
commit
33076aa480
@@ -32,11 +32,14 @@ to zero-mean unit-variance.
|
||||
|
||||
Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library.
|
||||
"""
|
||||
from typing import Callable
|
||||
|
||||
import graph_nets as gn
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
|
||||
Reducer = Callable[[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor]
|
||||
|
||||
|
||||
def build_mlp(
|
||||
hidden_size: int, num_hidden_layers: int, output_size: int) -> snt.Module:
|
||||
@@ -55,6 +58,7 @@ class EncodeProcessDecode(snt.AbstractModule):
|
||||
mlp_num_hidden_layers: int,
|
||||
num_message_passing_steps: int,
|
||||
output_size: int,
|
||||
reducer: Reducer = tf.math.unsorted_segment_sum,
|
||||
name: str = "EncodeProcessDecode"):
|
||||
"""Inits the model.
|
||||
|
||||
@@ -65,6 +69,9 @@ class EncodeProcessDecode(snt.AbstractModule):
|
||||
num_message_passing_steps: Number of message passing steps.
|
||||
output_size: Output size of the decode node representations as required
|
||||
by the downstream update function.
|
||||
reducer: Reduction to be used when aggregating the edges in the nodes in
|
||||
the interaction network. This should be a callable whose signature
|
||||
matches tf.math.unsorted_segment_sum.
|
||||
name: Name of the model.
|
||||
"""
|
||||
|
||||
@@ -75,6 +82,7 @@ class EncodeProcessDecode(snt.AbstractModule):
|
||||
self._mlp_num_hidden_layers = mlp_num_hidden_layers
|
||||
self._num_message_passing_steps = num_message_passing_steps
|
||||
self._output_size = output_size
|
||||
self._reducer = reducer
|
||||
|
||||
with self._enter_variable_scope():
|
||||
self._networks_builder()
|
||||
@@ -116,7 +124,8 @@ class EncodeProcessDecode(snt.AbstractModule):
|
||||
self._processor_networks.append(
|
||||
gn.modules.InteractionNetwork(
|
||||
edge_model_fn=build_mlp_with_layer_norm,
|
||||
node_model_fn=build_mlp_with_layer_norm))
|
||||
node_model_fn=build_mlp_with_layer_norm,
|
||||
reducer=self._reducer))
|
||||
|
||||
# The decoder MLP decodes node latent features into the output size.
|
||||
self._decoder_network = build_mlp(
|
||||
|
||||
Reference in New Issue
Block a user