mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-30 12:25:25 +08:00
PAC Bayes Quadratic bound open sourcing.
PiperOrigin-RevId: 302629851
This commit is contained in:
committed by
Louise Deason
parent
afcdc77239
commit
f6395d709e
@@ -0,0 +1,80 @@
|
|||||||
|
# Unveiling the predictive power of static structure in glassy systems
|
||||||
|
|
||||||
|
This repository contains an open source implementation of the graph neural
|
||||||
|
network model described in our paper.
|
||||||
|
The model can be trained using the training binary included in this repository,
|
||||||
|
and the dataset published with our paper.
|
||||||
|
|
||||||
|
|
||||||
|
## Abstract
|
||||||
|
|
||||||
|
Despite decades of theoretical studies, the nature of the glass transition
|
||||||
|
remains elusive and debated, while the existence of structural predictors of the
|
||||||
|
dynamics is a major open question. Recent approaches propose inferring
|
||||||
|
predictors from a variety of human-defined features using machine learning.
|
||||||
|
We learn the long time evolution of a glassy system solely from the initial
|
||||||
|
particle positions and without any hand-crafted features, using a powerful
|
||||||
|
model: graph neural networks. We show that this method strongly outperforms
|
||||||
|
state-of-the-art methods, generalizing over a wide range of temperatures,
|
||||||
|
pressures, and densities. In shear experiments, it predicts the location of
|
||||||
|
rearranging particles. The structural predictors learned by our network unveil a
|
||||||
|
correlation length which increases with larger timescales to reach the size of
|
||||||
|
our system. Beyond glasses, our method could apply to many other physical
|
||||||
|
systems that map to a graph of local interactions.
|
||||||
|
|
||||||
|
|
||||||
|
## Dataset
|
||||||
|
|
||||||
|
### System description
|
||||||
|
|
||||||
|
The dataset was generated with the LAMMPS molecular dynamics package.
|
||||||
|
The simulated system has periodic boundaries and is a binary mixture of 4096
|
||||||
|
large (A) and small (B) particles that interact via a 6-12 Lennard-Jones
|
||||||
|
potential.
|
||||||
|
The interaction coefficients are set for a typical Kob-Andersen configuration.
|
||||||
|
|
||||||
|
### Data format
|
||||||
|
|
||||||
|
The data is stored in Python's pickle format protocol version 3.
|
||||||
|
Each file contains the data for one of the equilibrated systems in a Python
|
||||||
|
dictionary. The dictionary contains the following entries:
|
||||||
|
|
||||||
|
- `positions` the particle positions of the equilibrated system.
|
||||||
|
- `types` the particle types (0 == type A and 1 == type B) of the equilibrated
|
||||||
|
system.
|
||||||
|
- `box` the dimensions of the periodic cubic simulation box.
|
||||||
|
- `time` the logarithmically sampled time points.
|
||||||
|
- `time_indices` the indices of the time points for which the sampled
|
||||||
|
trajectories on average reach a certain value of the intermediate
|
||||||
|
scattering function.
|
||||||
|
- `is_values` the values of the intermediate scattering function associated
|
||||||
|
with each time index.
|
||||||
|
- `trajectory_start_velocities` the velocities drawn from a Boltzmann
|
||||||
|
distribution at the start of each trajectory.
|
||||||
|
- `trajectory_target_positions` the positions of the particles for each of
|
||||||
|
the trajectories at selected time points (as defined by the `time_indices`
|
||||||
|
array and the corresponding values of the intermediate scattering function
|
||||||
|
stored in `is_values`).
|
||||||
|
- `metadata` a dictionary containing additional metadata:
|
||||||
|
- `temperature` the temperature at which the system was equilibrated.
|
||||||
|
- `pressure` the pressure at which the system was equilibrated.
|
||||||
|
- `fluid` the type of fluid which was simulated (Kob-Andersen).
|
||||||
|
|
||||||
|
All units are in Lennard-Jones units. The positions are stored in the absolute
|
||||||
|
coordinate system i.e. they are outside of the simulation box if the particle
|
||||||
|
crossed a periodic boundary during the simulation.
|
||||||
|
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
|
||||||
|
If this repository is helpful for your research please cite the following
|
||||||
|
publication:
|
||||||
|
|
||||||
|
Unveiling the predictive power of static structure in glassysystems
|
||||||
|
V. Bapst, T. Keck, A. Grabska-Barwinska, C. Donner, E. D. Cubuk,
|
||||||
|
S. S. Schoenholz, A.Obika, A. W. R. Nelson, T. Back, D. Hassabis and P. Kohli
|
||||||
|
|
||||||
|
|
||||||
|
## Disclaimer
|
||||||
|
This is not an official Google product.
|
||||||
|
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Applies a graph-based network to predict particle mobilities in glasses."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
from glassy_dynamics import train
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'data_directory',
|
||||||
|
'',
|
||||||
|
'Directory which contains the train or test datasets.')
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
'time_index',
|
||||||
|
9,
|
||||||
|
'The time index of the target mobilities.')
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
'max_files_to_load',
|
||||||
|
None,
|
||||||
|
'The maximum number of files to load.')
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'checkpoint_path',
|
||||||
|
'checkpoints/t044_s09.ckpt',
|
||||||
|
'Path used to load the model.')
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
if len(argv) > 1:
|
||||||
|
raise app.UsageError('Too many command-line arguments.')
|
||||||
|
|
||||||
|
file_pattern = os.path.join(FLAGS.data_directory, 'aggregated*')
|
||||||
|
train.apply_model(
|
||||||
|
checkpoint_path=FLAGS.checkpoint_path,
|
||||||
|
file_pattern=file_pattern,
|
||||||
|
max_files_to_load=FLAGS.max_files_to_load,
|
||||||
|
time_index=FLAGS.time_index)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(main)
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,190 @@
|
|||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""A graph neural network based model to predict particle mobilities.
|
||||||
|
|
||||||
|
The architecture and performance of this model is described in our publication:
|
||||||
|
"Unveiling the predictive power of static structure in glassy systems".
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from graph_nets import graphs
|
||||||
|
from graph_nets import modules as gn_modules
|
||||||
|
from graph_nets import utils_tf
|
||||||
|
|
||||||
|
import sonnet as snt
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
from typing import Any, Dict, Text, Tuple, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def make_graph_from_static_structure(
|
||||||
|
positions,
|
||||||
|
types,
|
||||||
|
box,
|
||||||
|
edge_threshold):
|
||||||
|
"""Returns graph representing the static structure of the glass.
|
||||||
|
|
||||||
|
Each particle is represented by a node in the graph. The particle type is
|
||||||
|
stored as a node feature.
|
||||||
|
Two particles at a distance less than the threshold are connected by an edge.
|
||||||
|
The relative distance vector is stored as an edge feature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positions: particle positions with shape [n_particles, 3].
|
||||||
|
types: particle types with shape [n_particles].
|
||||||
|
box: dimensions of the cubic box that contains the particles with shape [3].
|
||||||
|
edge_threshold: particles at distance less than threshold are connected by
|
||||||
|
an edge.
|
||||||
|
"""
|
||||||
|
# Calculate pairwise relative distances between particles: shape [n, n, 3].
|
||||||
|
cross_positions = positions[tf.newaxis, :, :] - positions[:, tf.newaxis, :]
|
||||||
|
# Enforces periodic boundary conditions.
|
||||||
|
box_ = box[tf.newaxis, tf.newaxis, :]
|
||||||
|
cross_positions += tf.cast(cross_positions < -box_ / 2., tf.float32) * box_
|
||||||
|
cross_positions -= tf.cast(cross_positions > box_ / 2., tf.float32) * box_
|
||||||
|
# Calculates adjacency matrix in a sparse format (indices), based on the given
|
||||||
|
# distances and threshold.
|
||||||
|
distances = tf.norm(cross_positions, axis=-1)
|
||||||
|
indices = tf.where(distances < edge_threshold)
|
||||||
|
|
||||||
|
# Defines graph.
|
||||||
|
nodes = types[:, tf.newaxis]
|
||||||
|
senders = indices[:, 0]
|
||||||
|
receivers = indices[:, 1]
|
||||||
|
edges = tf.gather_nd(cross_positions, indices)
|
||||||
|
|
||||||
|
return graphs.GraphsTuple(
|
||||||
|
nodes=tf.cast(nodes, tf.float32),
|
||||||
|
n_node=tf.reshape(tf.shape(nodes)[0], [1]),
|
||||||
|
edges=tf.cast(edges, tf.float32),
|
||||||
|
n_edge=tf.reshape(tf.shape(edges)[0], [1]),
|
||||||
|
globals=tf.zeros((1, 1), dtype=tf.float32),
|
||||||
|
receivers=tf.cast(receivers, tf.int32),
|
||||||
|
senders=tf.cast(senders, tf.int32)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_random_rotation(graph):
|
||||||
|
"""Returns randomly rotated graph representation.
|
||||||
|
|
||||||
|
The rotation is an element of O(3) with rotation angles multiple of pi/2.
|
||||||
|
This function assumes that the relative particle distances are stored in
|
||||||
|
the edge features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: The graphs tuple as defined in `graph_nets.graphs`.
|
||||||
|
"""
|
||||||
|
# Transposes edge features, so that the axes are in the first dimension.
|
||||||
|
# Outputs a tensor of shape [3, n_particles].
|
||||||
|
xyz = tf.transpose(graph.edges)
|
||||||
|
# Random pi/2 rotation(s)
|
||||||
|
permutation = tf.random.shuffle(tf.constant([0, 1, 2], dtype=tf.int32))
|
||||||
|
xyz = tf.gather(xyz, permutation)
|
||||||
|
# Random reflections.
|
||||||
|
symmetry = tf.random_uniform([3], minval=0, maxval=2, dtype=tf.int32)
|
||||||
|
symmetry = 1 - 2 * tf.cast(tf.reshape(symmetry, [3, 1]), tf.float32)
|
||||||
|
xyz = xyz * symmetry
|
||||||
|
edges = tf.transpose(xyz)
|
||||||
|
return graph.replace(edges=edges)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphBasedModel(snt.AbstractModule):
|
||||||
|
"""Graph based model which predicts particle mobilities from their positions.
|
||||||
|
|
||||||
|
This network encodes the nodes and edges of the input graph independently, and
|
||||||
|
then performs message-passing on this graph, updating its edges based on their
|
||||||
|
associated nodes, then updating the nodes based on the input nodes' features
|
||||||
|
and their associated updated edge features.
|
||||||
|
This update is repeated several times.
|
||||||
|
Afterwards the resulting node embeddings are decoded to predict the particle
|
||||||
|
mobility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n_recurrences,
|
||||||
|
mlp_sizes,
|
||||||
|
mlp_kwargs = None,
|
||||||
|
name='Graph'):
|
||||||
|
"""Creates a new GraphBasedModel object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_recurrences: the number of message passing steps in the graph network.
|
||||||
|
mlp_sizes: the number of neurons in each layer of the MLP.
|
||||||
|
mlp_kwargs: additional keyword aguments passed to the MLP.
|
||||||
|
name: the name of the Sonnet module.
|
||||||
|
"""
|
||||||
|
super(GraphBasedModel, self).__init__(name=name)
|
||||||
|
self._n_recurrences = n_recurrences
|
||||||
|
|
||||||
|
if mlp_kwargs is None:
|
||||||
|
mlp_kwargs = {}
|
||||||
|
|
||||||
|
model_fn = functools.partial(
|
||||||
|
snt.nets.MLP,
|
||||||
|
output_sizes=mlp_sizes,
|
||||||
|
activate_final=True,
|
||||||
|
**mlp_kwargs)
|
||||||
|
|
||||||
|
final_model_fn = functools.partial(
|
||||||
|
snt.nets.MLP,
|
||||||
|
output_sizes=mlp_sizes + (1,),
|
||||||
|
activate_final=False,
|
||||||
|
**mlp_kwargs)
|
||||||
|
|
||||||
|
with self._enter_variable_scope():
|
||||||
|
self._encoder = gn_modules.GraphIndependent(
|
||||||
|
node_model_fn=model_fn,
|
||||||
|
edge_model_fn=model_fn)
|
||||||
|
|
||||||
|
if self._n_recurrences > 0:
|
||||||
|
self._propagation_network = gn_modules.GraphNetwork(
|
||||||
|
node_model_fn=model_fn,
|
||||||
|
edge_model_fn=model_fn,
|
||||||
|
# We do not use globals, hence we just pass the identity function.
|
||||||
|
global_model_fn=lambda: lambda x: x,
|
||||||
|
reducer=tf.unsorted_segment_sum,
|
||||||
|
edge_block_opt=dict(use_globals=False),
|
||||||
|
node_block_opt=dict(use_globals=False),
|
||||||
|
global_block_opt=dict(use_globals=False))
|
||||||
|
|
||||||
|
self._decoder = gn_modules.GraphIndependent(
|
||||||
|
node_model_fn=final_model_fn,
|
||||||
|
edge_model_fn=model_fn)
|
||||||
|
|
||||||
|
def _build(self, graphs_tuple):
|
||||||
|
"""Connects the model into the tensorflow graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graphs_tuple: input graph tensor as defined in `graphs_tuple.graphs`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor with shape [n_particles] containing the predicted particle
|
||||||
|
mobilities.
|
||||||
|
"""
|
||||||
|
encoded = self._encoder(graphs_tuple)
|
||||||
|
outputs = encoded
|
||||||
|
|
||||||
|
for _ in range(self._n_recurrences):
|
||||||
|
# Adds skip connections.
|
||||||
|
inputs = utils_tf.concat([outputs, encoded], axis=-1)
|
||||||
|
outputs = self._propagation_network(inputs)
|
||||||
|
|
||||||
|
decoded = self._decoder(outputs)
|
||||||
|
return tf.squeeze(decoded.nodes, axis=-1)
|
||||||
@@ -0,0 +1,147 @@
|
|||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for graph_model."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from graph_nets import graphs
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from glassy_dynamics import graph_model
|
||||||
|
|
||||||
|
|
||||||
|
class GraphModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Initializes a small tractable test (particle) system."""
|
||||||
|
super(GraphModelTest, self).setUp()
|
||||||
|
# Fixes random seed to ensure deterministic outputs.
|
||||||
|
tf.random.set_random_seed(1234)
|
||||||
|
|
||||||
|
# In this test we use a small tractable set of particles covering all corner
|
||||||
|
# cases:
|
||||||
|
# a) eight particles with different types,
|
||||||
|
# b) periodic box is not cubic,
|
||||||
|
# c) three disjoint cluster of particles separated by a threshold > 2,
|
||||||
|
# d) first two clusters overlap with the periodic boundary,
|
||||||
|
# e) first cluster is not fully connected,
|
||||||
|
# f) second cluster is fully connected,
|
||||||
|
# g) and third cluster is a single isolated particle.
|
||||||
|
#
|
||||||
|
# The formatting of the code below separates the three clusters by
|
||||||
|
# adding linebreaks after each cluster.
|
||||||
|
self._positions = np.array(
|
||||||
|
[[0.0, 0.0, 0.0], [2.5, 0.0, 0.0], [0.0, 1.5, 0.0], [0.0, 0.0, 9.0],
|
||||||
|
[0.0, 5.0, 0.0], [0.0, 5.0, 1.0], [3.0, 5.0, 0.0],
|
||||||
|
[2.0, 3.0, 3.0]])
|
||||||
|
self._types = np.array([0.0, 0.0, 1.0, 0.0,
|
||||||
|
0.0, 1.0, 0.0,
|
||||||
|
0.0])
|
||||||
|
self._box = np.array([4.0, 10.0, 10.0])
|
||||||
|
|
||||||
|
# Creates the corresponding graph elements, assuming a threshold of 2 and
|
||||||
|
# the conventions described in `graph_nets.graphs`.
|
||||||
|
self._edge_threshold = 2
|
||||||
|
self._nodes = np.array(
|
||||||
|
[[0.0], [0.0], [1.0], [0.0],
|
||||||
|
[0.0], [1.0], [0.0],
|
||||||
|
[0.0]])
|
||||||
|
self._edges = np.array(
|
||||||
|
[[0.0, 0.0, 0.0], [-1.5, 0.0, 0.0], [0.0, 1.5, 0.0], [0.0, 0.0, -1.0],
|
||||||
|
[1.5, 0.0, 0.0], [0.0, 0.0, 0.0], [1.5, 0.0, -1.0],
|
||||||
|
[0.0, -1.5, 0.0], [0.0, 0.0, 0.0], [0.0, -1.5, -1.0],
|
||||||
|
[0.0, 0.0, 1.0], [-1.5, 0.0, 1.0], [0.0, 1.5, 1.0], [0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [-1.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, -1.0],
|
||||||
|
[1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0]])
|
||||||
|
|
||||||
|
self._receivers = np.array(
|
||||||
|
[0, 1, 2, 3, 0, 1, 3, 0, 2, 3, 0, 1, 2, 3,
|
||||||
|
4, 5, 6, 4, 5, 6, 4, 5, 6,
|
||||||
|
7])
|
||||||
|
self._senders = np.array(
|
||||||
|
[0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3,
|
||||||
|
4, 4, 4, 5, 5, 5, 6, 6, 6,
|
||||||
|
7])
|
||||||
|
|
||||||
|
def _get_graphs_tuple(self):
|
||||||
|
"""Returns a GraphsTuple containing a graph based on the test system."""
|
||||||
|
return graphs.GraphsTuple(
|
||||||
|
nodes=tf.constant(self._nodes, dtype=tf.float32),
|
||||||
|
edges=tf.constant(self._edges, dtype=tf.float32),
|
||||||
|
globals=tf.constant(np.array([[0.0]]), dtype=tf.float32),
|
||||||
|
receivers=tf.constant(self._receivers, dtype=tf.int32),
|
||||||
|
senders=tf.constant(self._senders, dtype=tf.int32),
|
||||||
|
n_node=tf.constant([len(self._nodes)], dtype=tf.int32),
|
||||||
|
n_edge=tf.constant([len(self._edges)], dtype=tf.int32))
|
||||||
|
|
||||||
|
def test_make_graph_from_static_structure(self):
|
||||||
|
graphs_tuple_op = graph_model.make_graph_from_static_structure(
|
||||||
|
tf.constant(self._positions, dtype=tf.float32),
|
||||||
|
tf.constant(self._types, dtype=tf.int32),
|
||||||
|
tf.constant(self._box, dtype=tf.float32),
|
||||||
|
self._edge_threshold)
|
||||||
|
graphs_tuple = self.evaluate(graphs_tuple_op)
|
||||||
|
self.assertLen(self._nodes, graphs_tuple.n_node)
|
||||||
|
self.assertLen(self._edges, graphs_tuple.n_edge)
|
||||||
|
np.testing.assert_almost_equal(graphs_tuple.nodes, self._nodes)
|
||||||
|
np.testing.assert_equal(graphs_tuple.senders, self._senders)
|
||||||
|
np.testing.assert_equal(graphs_tuple.receivers, self._receivers)
|
||||||
|
np.testing.assert_almost_equal(graphs_tuple.globals, np.array([[0.0]]))
|
||||||
|
np.testing.assert_almost_equal(graphs_tuple.edges, self._edges)
|
||||||
|
|
||||||
|
def _is_equal_up_to_rotation(self, x, y):
|
||||||
|
for axes in itertools.permutations([0, 1, 2]):
|
||||||
|
for mirrors in itertools.product([1, -1], repeat=3):
|
||||||
|
if np.allclose(x, y[:, axes] * mirrors):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_apply_random_rotation(self):
|
||||||
|
graphs_tuple = self._get_graphs_tuple()
|
||||||
|
rotated_graphs_tuple_op = graph_model.apply_random_rotation(graphs_tuple)
|
||||||
|
rotated_graphs_tuple = self.evaluate(rotated_graphs_tuple_op)
|
||||||
|
np.testing.assert_almost_equal(rotated_graphs_tuple.nodes, self._nodes)
|
||||||
|
np.testing.assert_almost_equal(rotated_graphs_tuple.senders, self._senders)
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
rotated_graphs_tuple.receivers, self._receivers)
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
rotated_graphs_tuple.globals, np.array([[0.0]]))
|
||||||
|
self.assertTrue(self._is_equal_up_to_rotation(rotated_graphs_tuple.edges,
|
||||||
|
self._edges))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('no_propagation', 0, (30,)),
|
||||||
|
('multi_propagation', 5, (15,)),
|
||||||
|
('multi_layer', 1, (20, 30)))
|
||||||
|
def test_GraphModel(self, n_recurrences, mlp_sizes):
|
||||||
|
graphs_tuple = self._get_graphs_tuple()
|
||||||
|
output_op = graph_model.GraphBasedModel(n_recurrences=n_recurrences,
|
||||||
|
mlp_sizes=mlp_sizes)(graphs_tuple)
|
||||||
|
self.assertListEqual(output_op.shape.as_list(), [len(self._types)])
|
||||||
|
# Tests if the model runs without crashing.
|
||||||
|
with self.session():
|
||||||
|
tf.global_variables_initializer().run()
|
||||||
|
output_op.eval()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
# GlassyDynamics
|
||||||
|
absl-py==0.8.1
|
||||||
|
astor==0.8.0
|
||||||
|
cloudpickle==1.1.1
|
||||||
|
contextlib2==0.6.0.post1
|
||||||
|
decorator==4.4.0
|
||||||
|
dm-sonnet==1.35
|
||||||
|
future==0.18.0
|
||||||
|
gast==0.3.2
|
||||||
|
google-pasta==0.1.7
|
||||||
|
graph-nets==1.0.4
|
||||||
|
grpcio==1.24.1
|
||||||
|
h5py==2.10.0
|
||||||
|
Keras-Applications==1.0.8
|
||||||
|
Keras-Preprocessing==1.1.0
|
||||||
|
Markdown==3.1.1
|
||||||
|
mock==3.0.5
|
||||||
|
networkx==2.3
|
||||||
|
numpy==1.17.2
|
||||||
|
pkg-resources==0.0.0
|
||||||
|
protobuf==3.10.0
|
||||||
|
semantic-version==2.8.2
|
||||||
|
six==1.12.0
|
||||||
|
tensorboard==1.13.1
|
||||||
|
tensorflow==1.13.1
|
||||||
|
tensorflow-estimator==1.13.0
|
||||||
|
tensorflow-probability==0.6.0
|
||||||
|
termcolor==1.1.0
|
||||||
|
Werkzeug==0.16.0
|
||||||
|
wrapt==1.11.2
|
||||||
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,381 @@
|
|||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Training pipeline for the prediction of particle mobilities in glasses."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from absl import logging
|
||||||
|
import enum
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tensorflow_probability as tfp
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Text, Tuple, Sequence
|
||||||
|
|
||||||
|
from glassy_dynamics import graph_model
|
||||||
|
|
||||||
|
tf.enable_resource_variables()
|
||||||
|
|
||||||
|
LossCollection = collections.namedtuple('LossCollection',
|
||||||
|
'l1_loss, l2_loss, correlation')
|
||||||
|
GlassSimulationData = collections.namedtuple('GlassSimulationData',
|
||||||
|
'positions, targets, types, box')
|
||||||
|
|
||||||
|
|
||||||
|
class ParticleType(enum.IntEnum):
|
||||||
|
"""The simulation contains two particle types, identified as type A and B.
|
||||||
|
|
||||||
|
The dataset encodes the particle type in an integer.
|
||||||
|
- 0 corresponds to particle type A.
|
||||||
|
- 1 corresponds to particle type B.
|
||||||
|
"""
|
||||||
|
A = 0
|
||||||
|
B = 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_targets(
|
||||||
|
initial_positions,
|
||||||
|
trajectory_target_positions):
|
||||||
|
"""Returns the averaged particle mobilities from the sampled trajectories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_positions: the initial positions of the particles with shape
|
||||||
|
[n_particles, 3].
|
||||||
|
trajectory_target_positions: the absolute positions of the particles at the
|
||||||
|
target time for all sampled trajectories, each with shape
|
||||||
|
[n_particles, 3].
|
||||||
|
"""
|
||||||
|
targets = np.mean([np.linalg.norm(t - initial_positions, axis=-1)
|
||||||
|
for t in trajectory_target_positions], axis=0)
|
||||||
|
return targets.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
file_pattern,
|
||||||
|
time_index,
|
||||||
|
max_files_to_load = None):
|
||||||
|
"""Returns a dictionary containing the training or test dataset.
|
||||||
|
|
||||||
|
The dictionary contains:
|
||||||
|
`positions`: `np.ndarray` containing the particle positions with shape
|
||||||
|
[n_particles, 3].
|
||||||
|
`targets`: `np.ndarray` containing particle mobilities with shape
|
||||||
|
[n_particles].
|
||||||
|
`types`: `np.ndarray` containing the particle types with shape with shape
|
||||||
|
[n_particles].
|
||||||
|
`box`: `np.ndarray` containing the dimensions of the periodic box with shape
|
||||||
|
[3].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_pattern: pattern matching the files with the simulation data.
|
||||||
|
time_index: the time index of the targets.
|
||||||
|
max_files_to_load: the maximum number of files to load.
|
||||||
|
"""
|
||||||
|
filenames = tf.io.gfile.glob(file_pattern)
|
||||||
|
if max_files_to_load:
|
||||||
|
filenames = filenames[:max_files_to_load]
|
||||||
|
|
||||||
|
static_structures = []
|
||||||
|
for filename in filenames:
|
||||||
|
with tf.io.gfile.GFile(filename, 'rb') as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
static_structures.append(GlassSimulationData(
|
||||||
|
positions=data['positions'].astype(np.float32),
|
||||||
|
targets=get_targets(
|
||||||
|
data['positions'], data['trajectory_target_positions'][time_index]),
|
||||||
|
types=data['types'].astype(np.int32),
|
||||||
|
box=data['box'].astype(np.float32)))
|
||||||
|
return static_structures
|
||||||
|
|
||||||
|
|
||||||
|
def get_loss_ops(
|
||||||
|
prediction,
|
||||||
|
target,
|
||||||
|
types):
|
||||||
|
"""Returns L1/L2 loss and correlation for type A particles.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prediction: tensor with shape [n_particles] containing the predicted
|
||||||
|
particle mobilities.
|
||||||
|
target: tensor with shape [n_particles] containing the true particle
|
||||||
|
mobilities.
|
||||||
|
types: tensor with shape [n_particles] containing the particle types.
|
||||||
|
"""
|
||||||
|
# Considers only type A particles.
|
||||||
|
mask = tf.equal(types, ParticleType.A)
|
||||||
|
prediction = tf.boolean_mask(prediction, mask)
|
||||||
|
target = tf.boolean_mask(target, mask)
|
||||||
|
return LossCollection(
|
||||||
|
l1_loss=tf.reduce_mean(tf.abs(prediction - target)),
|
||||||
|
l2_loss=tf.reduce_mean((prediction - target)**2),
|
||||||
|
correlation=tf.squeeze(tfp.stats.correlation(
|
||||||
|
prediction[:, tf.newaxis], target[:, tf.newaxis])))
|
||||||
|
|
||||||
|
|
||||||
|
def get_minimize_op(
|
||||||
|
loss,
|
||||||
|
learning_rate,
|
||||||
|
grad_clip = None):
|
||||||
|
"""Returns minimization operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss: the loss tensor which is minimized.
|
||||||
|
learning_rate: the learning rate used by the optimizer.
|
||||||
|
grad_clip: all gradients are clipped to the given value if not None or 0.
|
||||||
|
"""
|
||||||
|
optimizer = tf.train.AdamOptimizer(learning_rate)
|
||||||
|
grads_and_vars = optimizer.compute_gradients(loss)
|
||||||
|
if grad_clip:
|
||||||
|
grads, _ = tf.clip_by_global_norm([g for g, _ in grads_and_vars], grad_clip)
|
||||||
|
grads_and_vars = [(g, pair[1]) for g, pair in zip(grads, grads_and_vars)]
|
||||||
|
minimize = optimizer.apply_gradients(grads_and_vars)
|
||||||
|
return minimize
|
||||||
|
|
||||||
|
|
||||||
|
def _log_stats_and_return_mean_correlation(
|
||||||
|
label,
|
||||||
|
stats):
|
||||||
|
"""Logs performance statistics and returns mean correlation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
label: label printed before the combined statistics e.g. train or test.
|
||||||
|
stats: statistics calculated for each batch in a dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mean correlation
|
||||||
|
"""
|
||||||
|
for key in LossCollection._fields:
|
||||||
|
values = [getattr(s, key) for s in stats]
|
||||||
|
mean = np.mean(values)
|
||||||
|
std = np.std(values)
|
||||||
|
logging.info('%s: %s: %.4f +/- %.4f', label, key, mean, std)
|
||||||
|
return np.mean([s.correlation for s in stats])
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(train_file_pattern,
|
||||||
|
test_file_pattern,
|
||||||
|
max_files_to_load = None,
|
||||||
|
n_epochs = 1000,
|
||||||
|
time_index = 9,
|
||||||
|
augment_data_using_rotations = True,
|
||||||
|
learning_rate = 1e-4,
|
||||||
|
grad_clip = 1.0,
|
||||||
|
n_recurrences = 7,
|
||||||
|
mlp_sizes = (64, 64),
|
||||||
|
mlp_kwargs = None,
|
||||||
|
edge_threshold = 2.0,
|
||||||
|
measurement_store_interval = 1000,
|
||||||
|
checkpoint_path = None):
|
||||||
|
"""Trains GraphModel using tensorflow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_file_pattern: pattern matching the files with the training data.
|
||||||
|
test_file_pattern: pattern matching the files with the test data.
|
||||||
|
max_files_to_load: the maximum number of train and test files to load.
|
||||||
|
If None, all files will be loaded.
|
||||||
|
n_epochs: the number of passes through the training dataset (epochs).
|
||||||
|
time_index: the time index (0-9) of the target mobilities.
|
||||||
|
augment_data_using_rotations: data is augemented by using random rotations.
|
||||||
|
learning_rate: the learning rate used by the optimizer.
|
||||||
|
grad_clip: all gradients are clipped to the given value.
|
||||||
|
n_recurrences: the number of message passing steps in the graphnet.
|
||||||
|
mlp_sizes: the number of neurons in each layer of the MLP.
|
||||||
|
mlp_kwargs: additional keyword aguments passed to the MLP.
|
||||||
|
edge_threshold: particles at distance less than threshold are connected by
|
||||||
|
an edge.
|
||||||
|
measurement_store_interval: number of steps between storing objective values
|
||||||
|
(loss and correlation).
|
||||||
|
checkpoint_path: path used to store the checkpoint with the highest
|
||||||
|
correlation on the test set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Correlation on the test dataset of best model encountered during training.
|
||||||
|
"""
|
||||||
|
if mlp_kwargs is None:
|
||||||
|
mlp_kwargs = dict(initializers=dict(w=tf.variance_scaling_initializer(1.0),
|
||||||
|
b=tf.variance_scaling_initializer(0.1)))
|
||||||
|
# Loads train and test dataset.
|
||||||
|
dataset_kwargs = dict(
|
||||||
|
time_index=time_index,
|
||||||
|
max_files_to_load=max_files_to_load)
|
||||||
|
training_data = load_data(train_file_pattern, **dataset_kwargs)
|
||||||
|
test_data = load_data(test_file_pattern, **dataset_kwargs)
|
||||||
|
|
||||||
|
# Defines wrapper functions, which can directly be passed to the
|
||||||
|
# tf.data.Dataset.map function.
|
||||||
|
def _make_graph_from_static_structure(static_structure):
|
||||||
|
"""Converts static structure to graph, targets and types."""
|
||||||
|
return (graph_model.make_graph_from_static_structure(
|
||||||
|
static_structure.positions,
|
||||||
|
static_structure.types,
|
||||||
|
static_structure.box,
|
||||||
|
edge_threshold),
|
||||||
|
static_structure.targets,
|
||||||
|
static_structure.types)
|
||||||
|
|
||||||
|
def _apply_random_rotation(graph, targets, types):
|
||||||
|
"""Applies random rotations to the graph and forwards targets and types."""
|
||||||
|
return graph_model.apply_random_rotation(graph), targets, types
|
||||||
|
|
||||||
|
# Defines data-pipeline based on tf.data.Dataset following the official
|
||||||
|
# guideline: https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays.
|
||||||
|
# We use initializable iterators to avoid embedding the training and test data
|
||||||
|
# directly into the graph.
|
||||||
|
# Instead we feed the data to the iterators during the initalization of the
|
||||||
|
# iterators before the main training loop.
|
||||||
|
placeholders = GlassSimulationData._make(
|
||||||
|
tf.placeholder(s.dtype, (None,) + s.shape) for s in training_data[0])
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices(placeholders)
|
||||||
|
dataset = dataset.map(_make_graph_from_static_structure)
|
||||||
|
dataset = dataset.cache()
|
||||||
|
dataset = dataset.shuffle(400)
|
||||||
|
# Augments data. This has to be done after calling dataset.cache!
|
||||||
|
if augment_data_using_rotations:
|
||||||
|
dataset = dataset.map(_apply_random_rotation)
|
||||||
|
dataset = dataset.repeat()
|
||||||
|
train_iterator = dataset.make_initializable_iterator()
|
||||||
|
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices(placeholders)
|
||||||
|
dataset = dataset.map(_make_graph_from_static_structure)
|
||||||
|
dataset = dataset.cache()
|
||||||
|
dataset = dataset.repeat()
|
||||||
|
test_iterator = dataset.make_initializable_iterator()
|
||||||
|
|
||||||
|
# Creates tensorflow graph.
|
||||||
|
# Note: We decouple the training and test datasets from the input pipeline
|
||||||
|
# by creating a new iterator from a string-handle placeholder with the same
|
||||||
|
# output types and shapes as the training dataset.
|
||||||
|
dataset_handle = tf.placeholder(tf.string, shape=[])
|
||||||
|
iterator = tf.data.Iterator.from_string_handle(
|
||||||
|
dataset_handle, train_iterator.output_types, train_iterator.output_shapes)
|
||||||
|
graph, targets, types = iterator.get_next()
|
||||||
|
|
||||||
|
model = graph_model.GraphBasedModel(
|
||||||
|
n_recurrences, mlp_sizes, mlp_kwargs)
|
||||||
|
prediction = model(graph)
|
||||||
|
|
||||||
|
# Defines loss and minimization operations.
|
||||||
|
loss_ops = get_loss_ops(prediction, targets, types)
|
||||||
|
minimize_op = get_minimize_op(loss_ops.l2_loss, learning_rate, grad_clip)
|
||||||
|
|
||||||
|
best_so_far = -1
|
||||||
|
train_stats = []
|
||||||
|
test_stats = []
|
||||||
|
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
|
||||||
|
with tf.train.SingularMonitoredSession() as session:
|
||||||
|
# Initializes train and test iterators with the training and test datasets.
|
||||||
|
# The obtained training and test string-handles can be passed to the
|
||||||
|
# dataset_handle placeholder to select the dataset.
|
||||||
|
train_handle = session.run(train_iterator.string_handle())
|
||||||
|
test_handle = session.run(test_iterator.string_handle())
|
||||||
|
feed_dict = {p: [x[i] for x in training_data]
|
||||||
|
for i, p in enumerate(placeholders)}
|
||||||
|
session.run(train_iterator.initializer, feed_dict=feed_dict)
|
||||||
|
feed_dict = {p: [x[i] for x in test_data]
|
||||||
|
for i, p in enumerate(placeholders)}
|
||||||
|
session.run(test_iterator.initializer, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
# Trains model using stochatic gradient descent on the training dataset.
|
||||||
|
n_training_steps = len(training_data) * n_epochs
|
||||||
|
for i in range(n_training_steps):
|
||||||
|
feed_dict = {dataset_handle: train_handle}
|
||||||
|
train_loss, _ = session.run((loss_ops, minimize_op), feed_dict=feed_dict)
|
||||||
|
train_stats.append(train_loss)
|
||||||
|
|
||||||
|
if (i+1) % measurement_store_interval == 0:
|
||||||
|
# Evaluates model on test dataset.
|
||||||
|
for _ in range(len(test_data)):
|
||||||
|
feed_dict = {dataset_handle: test_handle}
|
||||||
|
test_stats.append(session.run(loss_ops, feed_dict=feed_dict))
|
||||||
|
|
||||||
|
# Outputs performance statistics on training and test dataset.
|
||||||
|
_log_stats_and_return_mean_correlation('Train', train_stats)
|
||||||
|
correlation = _log_stats_and_return_mean_correlation('Test', test_stats)
|
||||||
|
train_stats = []
|
||||||
|
test_stats = []
|
||||||
|
|
||||||
|
# Updates best model based on the observed correlation on the test
|
||||||
|
# dataset.
|
||||||
|
if correlation > best_so_far:
|
||||||
|
best_so_far = correlation
|
||||||
|
if checkpoint_path:
|
||||||
|
saver.save(session.raw_session(), checkpoint_path)
|
||||||
|
|
||||||
|
return best_so_far
|
||||||
|
|
||||||
|
|
||||||
|
def apply_model(checkpoint_path,
|
||||||
|
file_pattern,
|
||||||
|
max_files_to_load = None,
|
||||||
|
time_index = 9):
|
||||||
|
"""Applies trained GraphModel using tensorflow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: path from which the model is loaded.
|
||||||
|
file_pattern: pattern matching the files with the data.
|
||||||
|
max_files_to_load: the maximum number of files to load.
|
||||||
|
If None, all files will be loaded.
|
||||||
|
time_index: the time index (0-9) of the target mobilities.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Predictions of the model for all files.
|
||||||
|
"""
|
||||||
|
dataset_kwargs = dict(
|
||||||
|
time_index=time_index,
|
||||||
|
max_files_to_load=max_files_to_load)
|
||||||
|
data = load_data(file_pattern, **dataset_kwargs)
|
||||||
|
|
||||||
|
tf.reset_default_graph()
|
||||||
|
saver = tf.train.import_meta_graph(checkpoint_path + '.meta')
|
||||||
|
graph = tf.get_default_graph()
|
||||||
|
|
||||||
|
placeholders = GlassSimulationData(
|
||||||
|
positions=graph.get_tensor_by_name('Placeholder:0'),
|
||||||
|
targets=graph.get_tensor_by_name('Placeholder_1:0'),
|
||||||
|
types=graph.get_tensor_by_name('Placeholder_2:0'),
|
||||||
|
box=graph.get_tensor_by_name('Placeholder_3:0'))
|
||||||
|
prediction_tensor = graph.get_tensor_by_name('Graph_1/Squeeze:0')
|
||||||
|
correlation_tensor = graph.get_tensor_by_name('Squeeze:0')
|
||||||
|
|
||||||
|
dataset_handle = graph.get_tensor_by_name('Placeholder_4:0')
|
||||||
|
test_initalizer = graph.get_operation_by_name('MakeIterator_1')
|
||||||
|
test_string_handle = graph.get_tensor_by_name('IteratorToStringHandle_1:0')
|
||||||
|
|
||||||
|
with tf.Session() as session:
|
||||||
|
saver.restore(session, checkpoint_path)
|
||||||
|
handle = session.run(test_string_handle)
|
||||||
|
feed_dict = {p: [x[i] for x in data] for i, p in enumerate(placeholders)}
|
||||||
|
session.run(test_initalizer, feed_dict=feed_dict)
|
||||||
|
predictions = []
|
||||||
|
correlations = []
|
||||||
|
for _ in range(len(data)):
|
||||||
|
p, c = session.run((prediction_tensor, correlation_tensor),
|
||||||
|
feed_dict={dataset_handle: handle})
|
||||||
|
predictions.append(p)
|
||||||
|
correlations.append(c)
|
||||||
|
|
||||||
|
logging.info('Correlation: %.4f +/- %.4f',
|
||||||
|
np.mean(correlations),
|
||||||
|
np.std(correlations))
|
||||||
|
return predictions
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Trains a graph-based network to predict particle mobilities in glasses."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
from glassy_dynamics import train
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'data_directory',
|
||||||
|
'',
|
||||||
|
'Directory which contains the train and test datasets.')
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
'time_index',
|
||||||
|
9,
|
||||||
|
'The time index of the target mobilities.')
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
'max_files_to_load',
|
||||||
|
None,
|
||||||
|
'The maximum number of files to load from the train and test datasets.')
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'checkpoint_path',
|
||||||
|
None,
|
||||||
|
'Path used to store a checkpoint of the best model.')
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
if len(argv) > 1:
|
||||||
|
raise app.UsageError('Too many command-line arguments.')
|
||||||
|
|
||||||
|
train_file_pattern = os.path.join(FLAGS.data_directory, 'train/aggregated*')
|
||||||
|
test_file_pattern = os.path.join(FLAGS.data_directory, 'test/aggregated*')
|
||||||
|
train.train_model(
|
||||||
|
train_file_pattern=train_file_pattern,
|
||||||
|
test_file_pattern=test_file_pattern,
|
||||||
|
max_files_to_load=FLAGS.max_files_to_load,
|
||||||
|
time_index=FLAGS.time_index,
|
||||||
|
checkpoint_path=FLAGS.checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(main)
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for train."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from glassy_dynamics import train
|
||||||
|
|
||||||
|
|
||||||
|
class TrainTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_get_targets(self):
|
||||||
|
initial_positions = np.array([[0, 0, 0], [1, 2, 3]])
|
||||||
|
trajectory_target_positions = [
|
||||||
|
np.array([[1, 0, 0], [1, 2, 4]]),
|
||||||
|
np.array([[0, 1, 0], [1, 0, 3]]),
|
||||||
|
np.array([[0, 0, 5], [1, 2, 3]]),
|
||||||
|
]
|
||||||
|
expected_targets = np.array([7.0 / 3.0, 1.0])
|
||||||
|
targets = train.get_targets(initial_positions, trajectory_target_positions)
|
||||||
|
np.testing.assert_almost_equal(expected_targets, targets)
|
||||||
|
|
||||||
|
def test_load_data(self):
|
||||||
|
file_pattern = os.path.join(os.path.dirname(__file__), 'testdata',
|
||||||
|
'test_small.pickle')
|
||||||
|
|
||||||
|
with self.subTest('ContentAndShapesAreAsExpected'):
|
||||||
|
data = train.load_data(file_pattern, 0)
|
||||||
|
self.assertEqual(len(data), 1)
|
||||||
|
element = data[0]
|
||||||
|
self.assertTupleEqual(element.positions.shape, (20, 3))
|
||||||
|
self.assertTupleEqual(element.box.shape, (3,))
|
||||||
|
self.assertTupleEqual(element.targets.shape, (20,))
|
||||||
|
self.assertTupleEqual(element.types.shape, (20,))
|
||||||
|
|
||||||
|
with self.subTest('TargetsGrowAsAFunctionOfTime'):
|
||||||
|
previous_mean_target = 0.0
|
||||||
|
# Time index 9 refers to 1/e = 0.36 in the IS, and therefore it is between
|
||||||
|
# Time index 5 (0.4) and time index 6 (0.3).
|
||||||
|
for time_index in [0, 1, 2, 3, 4, 5, 9, 6, 7, 8]:
|
||||||
|
data = train.load_data(file_pattern, time_index)[0]
|
||||||
|
current_mean_target = data.targets.mean()
|
||||||
|
self.assertGreater(current_mean_target, previous_mean_target)
|
||||||
|
previous_mean_target = current_mean_target
|
||||||
|
|
||||||
|
|
||||||
|
class TensorflowTrainTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_get_loss_op(self):
|
||||||
|
"""Tests the correct calculation of the loss operations."""
|
||||||
|
prediction = tf.constant([0.0, 1.0, 2.0, 1.0, 2.0], dtype=tf.float32)
|
||||||
|
target = tf.constant([1.0, 25.0, 0.0, 4.0, 2.0], dtype=tf.float32)
|
||||||
|
types = tf.constant([0, 1, 0, 0, 0], dtype=tf.int32)
|
||||||
|
loss_ops = train.get_loss_ops(prediction, target, types)
|
||||||
|
loss = self.evaluate(loss_ops)
|
||||||
|
self.assertAlmostEqual(loss.l1_loss, 1.5)
|
||||||
|
self.assertAlmostEqual(loss.l2_loss, 14.0 / 4.0)
|
||||||
|
self.assertAlmostEqual(loss.correlation, -0.15289416)
|
||||||
|
|
||||||
|
def test_get_minimize_op(self):
|
||||||
|
"""Tests the minimize operation by minimizing a single variable."""
|
||||||
|
var = tf.Variable([1.0], name='test')
|
||||||
|
loss = var**2
|
||||||
|
minimize = train.get_minimize_op(loss, 1e-1)
|
||||||
|
with self.session():
|
||||||
|
tf.global_variables_initializer().run()
|
||||||
|
for _ in range(100):
|
||||||
|
minimize.run()
|
||||||
|
value = var.eval()
|
||||||
|
self.assertLess(abs(value[0]), 0.01)
|
||||||
|
|
||||||
|
def test_train_model(self):
|
||||||
|
"""Tests if we can overfit to a small test dataset."""
|
||||||
|
file_pattern = os.path.join(os.path.dirname(__file__), 'testdata',
|
||||||
|
'test_small.pickle')
|
||||||
|
best_correlation_value = train.train_model(
|
||||||
|
train_file_pattern=file_pattern,
|
||||||
|
test_file_pattern=file_pattern,
|
||||||
|
n_epochs=1000,
|
||||||
|
augment_data_using_rotations=False,
|
||||||
|
learning_rate=1e-4,
|
||||||
|
n_recurrences=2,
|
||||||
|
edge_threshold=5,
|
||||||
|
mlp_sizes=(32, 32),
|
||||||
|
measurement_store_interval=1000)
|
||||||
|
# The test dataset contains only a single sample with 20 particles.
|
||||||
|
# Therefore we expect the model to be able to memorize the targets perfectly
|
||||||
|
# if the model works correctly.
|
||||||
|
self.assertGreater(best_correlation_value, 0.99)
|
||||||
|
|
||||||
|
def test_apply_model(self):
|
||||||
|
"""Tests if we can apply a model to a small test dataset."""
|
||||||
|
checkpoint_path = os.path.join(os.path.dirname(__file__), 'checkpoints',
|
||||||
|
't044_s09.ckpt')
|
||||||
|
file_pattern = os.path.join(os.path.dirname(__file__), 'testdata',
|
||||||
|
'test_large.pickle')
|
||||||
|
predictions = train.apply_model(checkpoint_path=checkpoint_path,
|
||||||
|
file_pattern=file_pattern,
|
||||||
|
time_index=0)
|
||||||
|
data = train.load_data(file_pattern, 0)
|
||||||
|
targets = data[0].targets
|
||||||
|
correlation_value = np.corrcoef(predictions[0], targets)[0, 1]
|
||||||
|
self.assertGreater(correlation_value, 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
||||||
Reference in New Issue
Block a user