Adding explicit option for adding self edges.

PiperOrigin-RevId: 363872110
This commit is contained in:
Alvaro Sanchez-Gonzalez
2021-03-19 12:39:47 +00:00
committed by Louise Deason
parent 25bd036e91
commit ad49bf36f7
+22 -6
View File
@@ -16,18 +16,21 @@
# ============================================================================
"""Tools to compute the connectivity of the graph."""
import functools
import numpy as np
from sklearn import neighbors
import tensorflow.compat.v1 as tf
def _compute_connectivity(positions, radius):
def _compute_connectivity(positions, radius, add_self_edges):
"""Get the indices of connected edges with radius connectivity.
Args:
positions: Positions of nodes in the graph. Shape:
[num_nodes_in_graph, num_dims].
radius: Radius of connectivity.
add_self_edges: Whether to include self edges or not.
Returns:
senders indices [num_edges_in_graph]
@@ -39,10 +42,18 @@ def _compute_connectivity(positions, radius):
num_nodes = len(positions)
senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list])
receivers = np.concatenate(receivers_list, axis=0)
if not add_self_edges:
# Remove self edges.
mask = senders != receivers
senders = senders[mask]
receivers = receivers[mask]
return senders, receivers
def _compute_connectivity_for_batch(positions, n_node, radius):
def _compute_connectivity_for_batch(
positions, n_node, radius, add_self_edges):
"""`compute_connectivity` for a batch of graphs.
Args:
@@ -51,6 +62,7 @@ def _compute_connectivity_for_batch(positions, n_node, radius):
n_node: Number of nodes for each graph in the batch. Shape:
[num_graphs in batch].
radius: Radius of connectivity.
add_self_edges: Whether to include self edges or not.
Returns:
senders indices [num_edges_in_batch]
@@ -70,7 +82,7 @@ def _compute_connectivity_for_batch(positions, n_node, radius):
# Compute connectivity for each graph in the batch.
for positions_graph_i in positions_per_graph_list:
senders_graph_i, receivers_graph_i = _compute_connectivity(
positions_graph_i, radius)
positions_graph_i, radius, add_self_edges)
num_edges_graph_i = len(senders_graph_i)
n_edge_list.append(num_edges_graph_i)
@@ -92,11 +104,15 @@ def _compute_connectivity_for_batch(positions, n_node, radius):
return senders, receivers, n_edge
def compute_connectivity_for_batch_pyfunc(positions, n_node, radius):
def compute_connectivity_for_batch_pyfunc(
positions, n_node, radius, add_self_edges=True):
"""`_compute_connectivity_for_batch` wrapped in a pyfunc."""
partial_fn = functools.partial(
_compute_connectivity_for_batch, add_self_edges=add_self_edges)
senders, receivers, n_edge = tf.py_function(
_compute_connectivity_for_batch,
[positions, n_node, radius], [tf.int32, tf.int32, tf.int32])
partial_fn,
[positions, n_node, radius],
[tf.int32, tf.int32, tf.int32])
senders.set_shape([None])
receivers.set_shape([None])
n_edge.set_shape(n_node.get_shape())