mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 12:37:43 +08:00
Adding explicit option for adding self edges.
PiperOrigin-RevId: 363872110
This commit is contained in:
committed by
Louise Deason
parent
25bd036e91
commit
ad49bf36f7
@@ -16,18 +16,21 @@
|
||||
# ============================================================================
|
||||
"""Tools to compute the connectivity of the graph."""
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
from sklearn import neighbors
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
def _compute_connectivity(positions, radius):
|
||||
def _compute_connectivity(positions, radius, add_self_edges):
|
||||
"""Get the indices of connected edges with radius connectivity.
|
||||
|
||||
Args:
|
||||
positions: Positions of nodes in the graph. Shape:
|
||||
[num_nodes_in_graph, num_dims].
|
||||
radius: Radius of connectivity.
|
||||
add_self_edges: Whether to include self edges or not.
|
||||
|
||||
Returns:
|
||||
senders indices [num_edges_in_graph]
|
||||
@@ -39,10 +42,18 @@ def _compute_connectivity(positions, radius):
|
||||
num_nodes = len(positions)
|
||||
senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list])
|
||||
receivers = np.concatenate(receivers_list, axis=0)
|
||||
|
||||
if not add_self_edges:
|
||||
# Remove self edges.
|
||||
mask = senders != receivers
|
||||
senders = senders[mask]
|
||||
receivers = receivers[mask]
|
||||
|
||||
return senders, receivers
|
||||
|
||||
|
||||
def _compute_connectivity_for_batch(positions, n_node, radius):
|
||||
def _compute_connectivity_for_batch(
|
||||
positions, n_node, radius, add_self_edges):
|
||||
"""`compute_connectivity` for a batch of graphs.
|
||||
|
||||
Args:
|
||||
@@ -51,6 +62,7 @@ def _compute_connectivity_for_batch(positions, n_node, radius):
|
||||
n_node: Number of nodes for each graph in the batch. Shape:
|
||||
[num_graphs in batch].
|
||||
radius: Radius of connectivity.
|
||||
add_self_edges: Whether to include self edges or not.
|
||||
|
||||
Returns:
|
||||
senders indices [num_edges_in_batch]
|
||||
@@ -70,7 +82,7 @@ def _compute_connectivity_for_batch(positions, n_node, radius):
|
||||
# Compute connectivity for each graph in the batch.
|
||||
for positions_graph_i in positions_per_graph_list:
|
||||
senders_graph_i, receivers_graph_i = _compute_connectivity(
|
||||
positions_graph_i, radius)
|
||||
positions_graph_i, radius, add_self_edges)
|
||||
|
||||
num_edges_graph_i = len(senders_graph_i)
|
||||
n_edge_list.append(num_edges_graph_i)
|
||||
@@ -92,11 +104,15 @@ def _compute_connectivity_for_batch(positions, n_node, radius):
|
||||
return senders, receivers, n_edge
|
||||
|
||||
|
||||
def compute_connectivity_for_batch_pyfunc(positions, n_node, radius):
|
||||
def compute_connectivity_for_batch_pyfunc(
|
||||
positions, n_node, radius, add_self_edges=True):
|
||||
"""`_compute_connectivity_for_batch` wrapped in a pyfunc."""
|
||||
partial_fn = functools.partial(
|
||||
_compute_connectivity_for_batch, add_self_edges=add_self_edges)
|
||||
senders, receivers, n_edge = tf.py_function(
|
||||
_compute_connectivity_for_batch,
|
||||
[positions, n_node, radius], [tf.int32, tf.int32, tf.int32])
|
||||
partial_fn,
|
||||
[positions, n_node, radius],
|
||||
[tf.int32, tf.int32, tf.int32])
|
||||
senders.set_shape([None])
|
||||
receivers.set_shape([None])
|
||||
n_edge.set_shape(n_node.get_shape())
|
||||
|
||||
Reference in New Issue
Block a user