mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-01 21:56:38 +08:00
Internal.
PiperOrigin-RevId: 373093293
This commit is contained in:
committed by
Louise Deason
parent
f0bd8651b9
commit
33076aa480
@@ -32,11 +32,14 @@ to zero-mean unit-variance.
|
|||||||
|
|
||||||
Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library.
|
Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library.
|
||||||
"""
|
"""
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import graph_nets as gn
|
import graph_nets as gn
|
||||||
import sonnet as snt
|
import sonnet as snt
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
Reducer = Callable[[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor]
|
||||||
|
|
||||||
|
|
||||||
def build_mlp(
|
def build_mlp(
|
||||||
hidden_size: int, num_hidden_layers: int, output_size: int) -> snt.Module:
|
hidden_size: int, num_hidden_layers: int, output_size: int) -> snt.Module:
|
||||||
@@ -55,6 +58,7 @@ class EncodeProcessDecode(snt.AbstractModule):
|
|||||||
mlp_num_hidden_layers: int,
|
mlp_num_hidden_layers: int,
|
||||||
num_message_passing_steps: int,
|
num_message_passing_steps: int,
|
||||||
output_size: int,
|
output_size: int,
|
||||||
|
reducer: Reducer = tf.math.unsorted_segment_sum,
|
||||||
name: str = "EncodeProcessDecode"):
|
name: str = "EncodeProcessDecode"):
|
||||||
"""Inits the model.
|
"""Inits the model.
|
||||||
|
|
||||||
@@ -65,6 +69,9 @@ class EncodeProcessDecode(snt.AbstractModule):
|
|||||||
num_message_passing_steps: Number of message passing steps.
|
num_message_passing_steps: Number of message passing steps.
|
||||||
output_size: Output size of the decode node representations as required
|
output_size: Output size of the decode node representations as required
|
||||||
by the downstream update function.
|
by the downstream update function.
|
||||||
|
reducer: Reduction to be used when aggregating the edges in the nodes in
|
||||||
|
the interaction network. This should be a callable whose signature
|
||||||
|
matches tf.math.unsorted_segment_sum.
|
||||||
name: Name of the model.
|
name: Name of the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -75,6 +82,7 @@ class EncodeProcessDecode(snt.AbstractModule):
|
|||||||
self._mlp_num_hidden_layers = mlp_num_hidden_layers
|
self._mlp_num_hidden_layers = mlp_num_hidden_layers
|
||||||
self._num_message_passing_steps = num_message_passing_steps
|
self._num_message_passing_steps = num_message_passing_steps
|
||||||
self._output_size = output_size
|
self._output_size = output_size
|
||||||
|
self._reducer = reducer
|
||||||
|
|
||||||
with self._enter_variable_scope():
|
with self._enter_variable_scope():
|
||||||
self._networks_builder()
|
self._networks_builder()
|
||||||
@@ -116,7 +124,8 @@ class EncodeProcessDecode(snt.AbstractModule):
|
|||||||
self._processor_networks.append(
|
self._processor_networks.append(
|
||||||
gn.modules.InteractionNetwork(
|
gn.modules.InteractionNetwork(
|
||||||
edge_model_fn=build_mlp_with_layer_norm,
|
edge_model_fn=build_mlp_with_layer_norm,
|
||||||
node_model_fn=build_mlp_with_layer_norm))
|
node_model_fn=build_mlp_with_layer_norm,
|
||||||
|
reducer=self._reducer))
|
||||||
|
|
||||||
# The decoder MLP decodes node latent features into the output size.
|
# The decoder MLP decodes node latent features into the output size.
|
||||||
self._decoder_network = build_mlp(
|
self._decoder_network = build_mlp(
|
||||||
|
|||||||
Reference in New Issue
Block a user