Files
deepmind-research/glassy_dynamics/graph_model_test.py
T
Florent Altché 486d8caf21 Update glass_dynamics to Python 3 imports.
PiperOrigin-RevId: 328363319
2020-08-26 16:55:54 +01:00

144 lines
6.0 KiB
Python

# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for graph_model."""
import itertools
from absl.testing import parameterized
from graph_nets import graphs
import numpy as np
import tensorflow.compat.v1 as tf
from glassy_dynamics import graph_model
class GraphModelTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
"""Initializes a small tractable test (particle) system."""
super(GraphModelTest, self).setUp()
# Fixes random seed to ensure deterministic outputs.
tf.random.set_random_seed(1234)
# In this test we use a small tractable set of particles covering all corner
# cases:
# a) eight particles with different types,
# b) periodic box is not cubic,
# c) three disjoint cluster of particles separated by a threshold > 2,
# d) first two clusters overlap with the periodic boundary,
# e) first cluster is not fully connected,
# f) second cluster is fully connected,
# g) and third cluster is a single isolated particle.
#
# The formatting of the code below separates the three clusters by
# adding linebreaks after each cluster.
self._positions = np.array(
[[0.0, 0.0, 0.0], [2.5, 0.0, 0.0], [0.0, 1.5, 0.0], [0.0, 0.0, 9.0],
[0.0, 5.0, 0.0], [0.0, 5.0, 1.0], [3.0, 5.0, 0.0],
[2.0, 3.0, 3.0]])
self._types = np.array([0.0, 0.0, 1.0, 0.0,
0.0, 1.0, 0.0,
0.0])
self._box = np.array([4.0, 10.0, 10.0])
# Creates the corresponding graph elements, assuming a threshold of 2 and
# the conventions described in `graph_nets.graphs`.
self._edge_threshold = 2
self._nodes = np.array(
[[0.0], [0.0], [1.0], [0.0],
[0.0], [1.0], [0.0],
[0.0]])
self._edges = np.array(
[[0.0, 0.0, 0.0], [-1.5, 0.0, 0.0], [0.0, 1.5, 0.0], [0.0, 0.0, -1.0],
[1.5, 0.0, 0.0], [0.0, 0.0, 0.0], [1.5, 0.0, -1.0],
[0.0, -1.5, 0.0], [0.0, 0.0, 0.0], [0.0, -1.5, -1.0],
[0.0, 0.0, 1.0], [-1.5, 0.0, 1.0], [0.0, 1.5, 1.0], [0.0, 0.0, 0.0],
[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [-1.0, 0.0, 0.0],
[0.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, -1.0],
[1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]])
self._receivers = np.array(
[0, 1, 2, 3, 0, 1, 3, 0, 2, 3, 0, 1, 2, 3,
4, 5, 6, 4, 5, 6, 4, 5, 6,
7])
self._senders = np.array(
[0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3,
4, 4, 4, 5, 5, 5, 6, 6, 6,
7])
def _get_graphs_tuple(self):
"""Returns a GraphsTuple containing a graph based on the test system."""
return graphs.GraphsTuple(
nodes=tf.constant(self._nodes, dtype=tf.float32),
edges=tf.constant(self._edges, dtype=tf.float32),
globals=tf.constant(np.array([[0.0]]), dtype=tf.float32),
receivers=tf.constant(self._receivers, dtype=tf.int32),
senders=tf.constant(self._senders, dtype=tf.int32),
n_node=tf.constant([len(self._nodes)], dtype=tf.int32),
n_edge=tf.constant([len(self._edges)], dtype=tf.int32))
def test_make_graph_from_static_structure(self):
graphs_tuple_op = graph_model.make_graph_from_static_structure(
tf.constant(self._positions, dtype=tf.float32),
tf.constant(self._types, dtype=tf.int32),
tf.constant(self._box, dtype=tf.float32),
self._edge_threshold)
graphs_tuple = self.evaluate(graphs_tuple_op)
self.assertLen(self._nodes, graphs_tuple.n_node)
self.assertLen(self._edges, graphs_tuple.n_edge)
np.testing.assert_almost_equal(graphs_tuple.nodes, self._nodes)
np.testing.assert_equal(graphs_tuple.senders, self._senders)
np.testing.assert_equal(graphs_tuple.receivers, self._receivers)
np.testing.assert_almost_equal(graphs_tuple.globals, np.array([[0.0]]))
np.testing.assert_almost_equal(graphs_tuple.edges, self._edges)
def _is_equal_up_to_rotation(self, x, y):
for axes in itertools.permutations([0, 1, 2]):
for mirrors in itertools.product([1, -1], repeat=3):
if np.allclose(x, y[:, axes] * mirrors):
return True
return False
def test_apply_random_rotation(self):
graphs_tuple = self._get_graphs_tuple()
rotated_graphs_tuple_op = graph_model.apply_random_rotation(graphs_tuple)
rotated_graphs_tuple = self.evaluate(rotated_graphs_tuple_op)
np.testing.assert_almost_equal(rotated_graphs_tuple.nodes, self._nodes)
np.testing.assert_almost_equal(rotated_graphs_tuple.senders, self._senders)
np.testing.assert_almost_equal(
rotated_graphs_tuple.receivers, self._receivers)
np.testing.assert_almost_equal(
rotated_graphs_tuple.globals, np.array([[0.0]]))
self.assertTrue(self._is_equal_up_to_rotation(rotated_graphs_tuple.edges,
self._edges))
@parameterized.named_parameters(('no_propagation', 0, (30,)),
('multi_propagation', 5, (15,)),
('multi_layer', 1, (20, 30)))
def test_GraphModel(self, n_recurrences, mlp_sizes):
graphs_tuple = self._get_graphs_tuple()
output_op = graph_model.GraphBasedModel(n_recurrences=n_recurrences,
mlp_sizes=mlp_sizes)(graphs_tuple)
self.assertListEqual(output_op.shape.as_list(), [len(self._types)])
# Tests if the model runs without crashing.
with self.session():
tf.global_variables_initializer().run()
output_op.eval()
if __name__ == '__main__':
tf.test.main()