mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-10 05:17:46 +08:00
486d8caf21
PiperOrigin-RevId: 328363319
144 lines
6.0 KiB
Python
144 lines
6.0 KiB
Python
# Copyright 2019 Deepmind Technologies Limited.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for graph_model."""
|
|
|
|
import itertools
|
|
from absl.testing import parameterized
|
|
|
|
from graph_nets import graphs
|
|
import numpy as np
|
|
import tensorflow.compat.v1 as tf
|
|
|
|
from glassy_dynamics import graph_model
|
|
|
|
|
|
class GraphModelTest(tf.test.TestCase, parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
"""Initializes a small tractable test (particle) system."""
|
|
super(GraphModelTest, self).setUp()
|
|
# Fixes random seed to ensure deterministic outputs.
|
|
tf.random.set_random_seed(1234)
|
|
|
|
# In this test we use a small tractable set of particles covering all corner
|
|
# cases:
|
|
# a) eight particles with different types,
|
|
# b) periodic box is not cubic,
|
|
# c) three disjoint cluster of particles separated by a threshold > 2,
|
|
# d) first two clusters overlap with the periodic boundary,
|
|
# e) first cluster is not fully connected,
|
|
# f) second cluster is fully connected,
|
|
# g) and third cluster is a single isolated particle.
|
|
#
|
|
# The formatting of the code below separates the three clusters by
|
|
# adding linebreaks after each cluster.
|
|
self._positions = np.array(
|
|
[[0.0, 0.0, 0.0], [2.5, 0.0, 0.0], [0.0, 1.5, 0.0], [0.0, 0.0, 9.0],
|
|
[0.0, 5.0, 0.0], [0.0, 5.0, 1.0], [3.0, 5.0, 0.0],
|
|
[2.0, 3.0, 3.0]])
|
|
self._types = np.array([0.0, 0.0, 1.0, 0.0,
|
|
0.0, 1.0, 0.0,
|
|
0.0])
|
|
self._box = np.array([4.0, 10.0, 10.0])
|
|
|
|
# Creates the corresponding graph elements, assuming a threshold of 2 and
|
|
# the conventions described in `graph_nets.graphs`.
|
|
self._edge_threshold = 2
|
|
self._nodes = np.array(
|
|
[[0.0], [0.0], [1.0], [0.0],
|
|
[0.0], [1.0], [0.0],
|
|
[0.0]])
|
|
self._edges = np.array(
|
|
[[0.0, 0.0, 0.0], [-1.5, 0.0, 0.0], [0.0, 1.5, 0.0], [0.0, 0.0, -1.0],
|
|
[1.5, 0.0, 0.0], [0.0, 0.0, 0.0], [1.5, 0.0, -1.0],
|
|
[0.0, -1.5, 0.0], [0.0, 0.0, 0.0], [0.0, -1.5, -1.0],
|
|
[0.0, 0.0, 1.0], [-1.5, 0.0, 1.0], [0.0, 1.5, 1.0], [0.0, 0.0, 0.0],
|
|
[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [-1.0, 0.0, 0.0],
|
|
[0.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, -1.0],
|
|
[1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 0.0, 0.0],
|
|
[0.0, 0.0, 0.0]])
|
|
|
|
self._receivers = np.array(
|
|
[0, 1, 2, 3, 0, 1, 3, 0, 2, 3, 0, 1, 2, 3,
|
|
4, 5, 6, 4, 5, 6, 4, 5, 6,
|
|
7])
|
|
self._senders = np.array(
|
|
[0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3,
|
|
4, 4, 4, 5, 5, 5, 6, 6, 6,
|
|
7])
|
|
|
|
def _get_graphs_tuple(self):
|
|
"""Returns a GraphsTuple containing a graph based on the test system."""
|
|
return graphs.GraphsTuple(
|
|
nodes=tf.constant(self._nodes, dtype=tf.float32),
|
|
edges=tf.constant(self._edges, dtype=tf.float32),
|
|
globals=tf.constant(np.array([[0.0]]), dtype=tf.float32),
|
|
receivers=tf.constant(self._receivers, dtype=tf.int32),
|
|
senders=tf.constant(self._senders, dtype=tf.int32),
|
|
n_node=tf.constant([len(self._nodes)], dtype=tf.int32),
|
|
n_edge=tf.constant([len(self._edges)], dtype=tf.int32))
|
|
|
|
def test_make_graph_from_static_structure(self):
|
|
graphs_tuple_op = graph_model.make_graph_from_static_structure(
|
|
tf.constant(self._positions, dtype=tf.float32),
|
|
tf.constant(self._types, dtype=tf.int32),
|
|
tf.constant(self._box, dtype=tf.float32),
|
|
self._edge_threshold)
|
|
graphs_tuple = self.evaluate(graphs_tuple_op)
|
|
self.assertLen(self._nodes, graphs_tuple.n_node)
|
|
self.assertLen(self._edges, graphs_tuple.n_edge)
|
|
np.testing.assert_almost_equal(graphs_tuple.nodes, self._nodes)
|
|
np.testing.assert_equal(graphs_tuple.senders, self._senders)
|
|
np.testing.assert_equal(graphs_tuple.receivers, self._receivers)
|
|
np.testing.assert_almost_equal(graphs_tuple.globals, np.array([[0.0]]))
|
|
np.testing.assert_almost_equal(graphs_tuple.edges, self._edges)
|
|
|
|
def _is_equal_up_to_rotation(self, x, y):
|
|
for axes in itertools.permutations([0, 1, 2]):
|
|
for mirrors in itertools.product([1, -1], repeat=3):
|
|
if np.allclose(x, y[:, axes] * mirrors):
|
|
return True
|
|
return False
|
|
|
|
def test_apply_random_rotation(self):
|
|
graphs_tuple = self._get_graphs_tuple()
|
|
rotated_graphs_tuple_op = graph_model.apply_random_rotation(graphs_tuple)
|
|
rotated_graphs_tuple = self.evaluate(rotated_graphs_tuple_op)
|
|
np.testing.assert_almost_equal(rotated_graphs_tuple.nodes, self._nodes)
|
|
np.testing.assert_almost_equal(rotated_graphs_tuple.senders, self._senders)
|
|
np.testing.assert_almost_equal(
|
|
rotated_graphs_tuple.receivers, self._receivers)
|
|
np.testing.assert_almost_equal(
|
|
rotated_graphs_tuple.globals, np.array([[0.0]]))
|
|
self.assertTrue(self._is_equal_up_to_rotation(rotated_graphs_tuple.edges,
|
|
self._edges))
|
|
|
|
@parameterized.named_parameters(('no_propagation', 0, (30,)),
|
|
('multi_propagation', 5, (15,)),
|
|
('multi_layer', 1, (20, 30)))
|
|
def test_GraphModel(self, n_recurrences, mlp_sizes):
|
|
graphs_tuple = self._get_graphs_tuple()
|
|
output_op = graph_model.GraphBasedModel(n_recurrences=n_recurrences,
|
|
mlp_sizes=mlp_sizes)(graphs_tuple)
|
|
self.assertListEqual(output_op.shape.as_list(), [len(self._types)])
|
|
# Tests if the model runs without crashing.
|
|
with self.session():
|
|
tf.global_variables_initializer().run()
|
|
output_op.eval()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
tf.test.main()
|