Adding explicit option for adding self edges.

PiperOrigin-RevId: 363872110
This commit is contained in:
Alvaro Sanchez-Gonzalez
2021-03-19 12:39:47 +00:00
committed by Louise Deason
parent 25bd036e91
commit ad49bf36f7
+22 -6
View File
@@ -16,18 +16,21 @@
# ============================================================================ # ============================================================================
"""Tools to compute the connectivity of the graph.""" """Tools to compute the connectivity of the graph."""
import functools
import numpy as np import numpy as np
from sklearn import neighbors from sklearn import neighbors
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
def _compute_connectivity(positions, radius): def _compute_connectivity(positions, radius, add_self_edges):
"""Get the indices of connected edges with radius connectivity. """Get the indices of connected edges with radius connectivity.
Args: Args:
positions: Positions of nodes in the graph. Shape: positions: Positions of nodes in the graph. Shape:
[num_nodes_in_graph, num_dims]. [num_nodes_in_graph, num_dims].
radius: Radius of connectivity. radius: Radius of connectivity.
add_self_edges: Whether to include self edges or not.
Returns: Returns:
senders indices [num_edges_in_graph] senders indices [num_edges_in_graph]
@@ -39,10 +42,18 @@ def _compute_connectivity(positions, radius):
num_nodes = len(positions) num_nodes = len(positions)
senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list]) senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list])
receivers = np.concatenate(receivers_list, axis=0) receivers = np.concatenate(receivers_list, axis=0)
if not add_self_edges:
# Remove self edges.
mask = senders != receivers
senders = senders[mask]
receivers = receivers[mask]
return senders, receivers return senders, receivers
def _compute_connectivity_for_batch(positions, n_node, radius): def _compute_connectivity_for_batch(
positions, n_node, radius, add_self_edges):
"""`compute_connectivity` for a batch of graphs. """`compute_connectivity` for a batch of graphs.
Args: Args:
@@ -51,6 +62,7 @@ def _compute_connectivity_for_batch(positions, n_node, radius):
n_node: Number of nodes for each graph in the batch. Shape: n_node: Number of nodes for each graph in the batch. Shape:
[num_graphs in batch]. [num_graphs in batch].
radius: Radius of connectivity. radius: Radius of connectivity.
add_self_edges: Whether to include self edges or not.
Returns: Returns:
senders indices [num_edges_in_batch] senders indices [num_edges_in_batch]
@@ -70,7 +82,7 @@ def _compute_connectivity_for_batch(positions, n_node, radius):
# Compute connectivity for each graph in the batch. # Compute connectivity for each graph in the batch.
for positions_graph_i in positions_per_graph_list: for positions_graph_i in positions_per_graph_list:
senders_graph_i, receivers_graph_i = _compute_connectivity( senders_graph_i, receivers_graph_i = _compute_connectivity(
positions_graph_i, radius) positions_graph_i, radius, add_self_edges)
num_edges_graph_i = len(senders_graph_i) num_edges_graph_i = len(senders_graph_i)
n_edge_list.append(num_edges_graph_i) n_edge_list.append(num_edges_graph_i)
@@ -92,11 +104,15 @@ def _compute_connectivity_for_batch(positions, n_node, radius):
return senders, receivers, n_edge return senders, receivers, n_edge
def compute_connectivity_for_batch_pyfunc(positions, n_node, radius): def compute_connectivity_for_batch_pyfunc(
positions, n_node, radius, add_self_edges=True):
"""`_compute_connectivity_for_batch` wrapped in a pyfunc.""" """`_compute_connectivity_for_batch` wrapped in a pyfunc."""
partial_fn = functools.partial(
_compute_connectivity_for_batch, add_self_edges=add_self_edges)
senders, receivers, n_edge = tf.py_function( senders, receivers, n_edge = tf.py_function(
_compute_connectivity_for_batch, partial_fn,
[positions, n_node, radius], [tf.int32, tf.int32, tf.int32]) [positions, n_node, radius],
[tf.int32, tf.int32, tf.int32])
senders.set_shape([None]) senders.set_shape([None])
receivers.set_shape([None]) receivers.set_shape([None])
n_edge.set_shape(n_node.get_shape()) n_edge.set_shape(n_node.get_shape())