mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-28 02:35:47 +08:00
Adding explicit option for adding self edges.
PiperOrigin-RevId: 363872110
This commit is contained in:
committed by
Louise Deason
parent
25bd036e91
commit
ad49bf36f7
@@ -16,18 +16,21 @@
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Tools to compute the connectivity of the graph."""
|
"""Tools to compute the connectivity of the graph."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn import neighbors
|
from sklearn import neighbors
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
|
||||||
def _compute_connectivity(positions, radius):
|
def _compute_connectivity(positions, radius, add_self_edges):
|
||||||
"""Get the indices of connected edges with radius connectivity.
|
"""Get the indices of connected edges with radius connectivity.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
positions: Positions of nodes in the graph. Shape:
|
positions: Positions of nodes in the graph. Shape:
|
||||||
[num_nodes_in_graph, num_dims].
|
[num_nodes_in_graph, num_dims].
|
||||||
radius: Radius of connectivity.
|
radius: Radius of connectivity.
|
||||||
|
add_self_edges: Whether to include self edges or not.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
senders indices [num_edges_in_graph]
|
senders indices [num_edges_in_graph]
|
||||||
@@ -39,10 +42,18 @@ def _compute_connectivity(positions, radius):
|
|||||||
num_nodes = len(positions)
|
num_nodes = len(positions)
|
||||||
senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list])
|
senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list])
|
||||||
receivers = np.concatenate(receivers_list, axis=0)
|
receivers = np.concatenate(receivers_list, axis=0)
|
||||||
|
|
||||||
|
if not add_self_edges:
|
||||||
|
# Remove self edges.
|
||||||
|
mask = senders != receivers
|
||||||
|
senders = senders[mask]
|
||||||
|
receivers = receivers[mask]
|
||||||
|
|
||||||
return senders, receivers
|
return senders, receivers
|
||||||
|
|
||||||
|
|
||||||
def _compute_connectivity_for_batch(positions, n_node, radius):
|
def _compute_connectivity_for_batch(
|
||||||
|
positions, n_node, radius, add_self_edges):
|
||||||
"""`compute_connectivity` for a batch of graphs.
|
"""`compute_connectivity` for a batch of graphs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -51,6 +62,7 @@ def _compute_connectivity_for_batch(positions, n_node, radius):
|
|||||||
n_node: Number of nodes for each graph in the batch. Shape:
|
n_node: Number of nodes for each graph in the batch. Shape:
|
||||||
[num_graphs in batch].
|
[num_graphs in batch].
|
||||||
radius: Radius of connectivity.
|
radius: Radius of connectivity.
|
||||||
|
add_self_edges: Whether to include self edges or not.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
senders indices [num_edges_in_batch]
|
senders indices [num_edges_in_batch]
|
||||||
@@ -70,7 +82,7 @@ def _compute_connectivity_for_batch(positions, n_node, radius):
|
|||||||
# Compute connectivity for each graph in the batch.
|
# Compute connectivity for each graph in the batch.
|
||||||
for positions_graph_i in positions_per_graph_list:
|
for positions_graph_i in positions_per_graph_list:
|
||||||
senders_graph_i, receivers_graph_i = _compute_connectivity(
|
senders_graph_i, receivers_graph_i = _compute_connectivity(
|
||||||
positions_graph_i, radius)
|
positions_graph_i, radius, add_self_edges)
|
||||||
|
|
||||||
num_edges_graph_i = len(senders_graph_i)
|
num_edges_graph_i = len(senders_graph_i)
|
||||||
n_edge_list.append(num_edges_graph_i)
|
n_edge_list.append(num_edges_graph_i)
|
||||||
@@ -92,11 +104,15 @@ def _compute_connectivity_for_batch(positions, n_node, radius):
|
|||||||
return senders, receivers, n_edge
|
return senders, receivers, n_edge
|
||||||
|
|
||||||
|
|
||||||
def compute_connectivity_for_batch_pyfunc(positions, n_node, radius):
|
def compute_connectivity_for_batch_pyfunc(
|
||||||
|
positions, n_node, radius, add_self_edges=True):
|
||||||
"""`_compute_connectivity_for_batch` wrapped in a pyfunc."""
|
"""`_compute_connectivity_for_batch` wrapped in a pyfunc."""
|
||||||
|
partial_fn = functools.partial(
|
||||||
|
_compute_connectivity_for_batch, add_self_edges=add_self_edges)
|
||||||
senders, receivers, n_edge = tf.py_function(
|
senders, receivers, n_edge = tf.py_function(
|
||||||
_compute_connectivity_for_batch,
|
partial_fn,
|
||||||
[positions, n_node, radius], [tf.int32, tf.int32, tf.int32])
|
[positions, n_node, radius],
|
||||||
|
[tf.int32, tf.int32, tf.int32])
|
||||||
senders.set_shape([None])
|
senders.set_shape([None])
|
||||||
receivers.set_shape([None])
|
receivers.set_shape([None])
|
||||||
n_edge.set_shape(n_node.get_shape())
|
n_edge.set_shape(n_node.get_shape())
|
||||||
|
|||||||
Reference in New Issue
Block a user