diff --git a/learning_to_simulate/connectivity_utils.py b/learning_to_simulate/connectivity_utils.py index e7c6334..f36e2c2 100644 --- a/learning_to_simulate/connectivity_utils.py +++ b/learning_to_simulate/connectivity_utils.py @@ -16,18 +16,21 @@ # ============================================================================ """Tools to compute the connectivity of the graph.""" +import functools + import numpy as np from sklearn import neighbors import tensorflow.compat.v1 as tf -def _compute_connectivity(positions, radius): +def _compute_connectivity(positions, radius, add_self_edges): """Get the indices of connected edges with radius connectivity. Args: positions: Positions of nodes in the graph. Shape: [num_nodes_in_graph, num_dims]. radius: Radius of connectivity. + add_self_edges: Whether to include self edges or not. Returns: senders indices [num_edges_in_graph] @@ -39,10 +42,18 @@ def _compute_connectivity(positions, radius): num_nodes = len(positions) senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list]) receivers = np.concatenate(receivers_list, axis=0) + + if not add_self_edges: + # Remove self edges. + mask = senders != receivers + senders = senders[mask] + receivers = receivers[mask] + return senders, receivers -def _compute_connectivity_for_batch(positions, n_node, radius): +def _compute_connectivity_for_batch( + positions, n_node, radius, add_self_edges): """`compute_connectivity` for a batch of graphs. Args: @@ -51,6 +62,7 @@ def _compute_connectivity_for_batch(positions, n_node, radius): n_node: Number of nodes for each graph in the batch. Shape: [num_graphs in batch]. radius: Radius of connectivity. + add_self_edges: Whether to include self edges or not. Returns: senders indices [num_edges_in_batch] @@ -70,7 +82,7 @@ def _compute_connectivity_for_batch(positions, n_node, radius): # Compute connectivity for each graph in the batch. for positions_graph_i in positions_per_graph_list: senders_graph_i, receivers_graph_i = _compute_connectivity( - positions_graph_i, radius) + positions_graph_i, radius, add_self_edges) num_edges_graph_i = len(senders_graph_i) n_edge_list.append(num_edges_graph_i) @@ -92,11 +104,15 @@ def _compute_connectivity_for_batch(positions, n_node, radius): return senders, receivers, n_edge -def compute_connectivity_for_batch_pyfunc(positions, n_node, radius): +def compute_connectivity_for_batch_pyfunc( + positions, n_node, radius, add_self_edges=True): """`_compute_connectivity_for_batch` wrapped in a pyfunc.""" + partial_fn = functools.partial( + _compute_connectivity_for_batch, add_self_edges=add_self_edges) senders, receivers, n_edge = tf.py_function( - _compute_connectivity_for_batch, - [positions, n_node, radius], [tf.int32, tf.int32, tf.int32]) + partial_fn, + [positions, n_node, radius], + [tf.int32, tf.int32, tf.int32]) senders.set_shape([None]) receivers.set_shape([None]) n_edge.set_shape(n_node.get_shape())