From d78eee81f9835c7944c41050bc7a6a02779be1e7 Mon Sep 17 00:00:00 2001 From: Thomas Keck Date: Wed, 23 Dec 2020 15:16:42 +0000 Subject: [PATCH] Adds JAX version of glassy dynamics training pipeline. PiperOrigin-RevId: 348791013 --- glassy_dynamics/train_binary.py | 8 +- glassy_dynamics/train_using_jax.py | 287 +++++++++++++++++++++++++++++ 2 files changed, 294 insertions(+), 1 deletion(-) create mode 100644 glassy_dynamics/train_using_jax.py diff --git a/glassy_dynamics/train_binary.py b/glassy_dynamics/train_binary.py index d33d0b6..e8b5f2c 100644 --- a/glassy_dynamics/train_binary.py +++ b/glassy_dynamics/train_binary.py @@ -19,7 +19,8 @@ import os from absl import app from absl import flags -from glassy_dynamics import train +from glassy_dynamics import train as train_using_tf +from glassy_dynamics import train_using_jax FLAGS = flags.FLAGS @@ -39,6 +40,10 @@ flags.DEFINE_string( 'checkpoint_path', None, 'Path used to store a checkpoint of the best model.') +flags.DEFINE_boolean( + 'use_jax', + False, + 'Uses jax to train model.') def main(argv): @@ -47,6 +52,7 @@ def main(argv): train_file_pattern = os.path.join(FLAGS.data_directory, 'train/aggregated*') test_file_pattern = os.path.join(FLAGS.data_directory, 'test/aggregated*') + train = train_using_jax if FLAGS.use_jax else train_using_tf train.train_model( train_file_pattern=train_file_pattern, test_file_pattern=test_file_pattern, diff --git a/glassy_dynamics/train_using_jax.py b/glassy_dynamics/train_using_jax.py new file mode 100644 index 0000000..6a984bc --- /dev/null +++ b/glassy_dynamics/train_using_jax.py @@ -0,0 +1,287 @@ +# Lint as: python3 +# Copyright 2019 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training pipeline for the prediction of particle mobilities in glasses.""" + +import enum +import functools +import logging +import pickle +import random +import haiku as hk +import jax +import jax.numpy as jnp +import jraph +import numpy as np +import optax + +# Only used for file operations. +# You can use glob.glob and python's open function to replace the tf usage below +# on most platforms. +import tensorflow.compat.v1 as tf + + +class ParticleType(enum.IntEnum): + """The simulation contains two particle types, identified as type A and B. + + The dataset encodes the particle type in an integer. + - 0 corresponds to particle type A. + - 1 corresponds to particle type B. + """ + A = 0 + B = 1 + + +def make_graph_from_static_structure(positions, types, box, edge_threshold): + """Returns graph representing the static structure of the glass. + + Each particle is represented by a node in the graph. The particle type is + stored as a node feature. + Two particles at a distance less than the threshold are connected by an edge. + The relative distance vector is stored as an edge feature. + + Args: + positions: particle positions with shape [n_particles, 3]. + types: particle types with shape [n_particles]. + box: dimensions of the cubic box that contains the particles with shape [3]. + edge_threshold: particles at distance less than threshold are connected by + an edge. + """ + # Calculate pairwise relative distances between particles: shape [n, n, 3]. + cross_positions = positions[None, :, :] - positions[:, None, :] + # Enforces periodic boundary conditions. + box_ = box[None, None, :] + cross_positions += (cross_positions < -box_ / 2.).astype(np.float32) * box_ + cross_positions -= (cross_positions > box_ / 2.).astype(np.float32) * box_ + # Calculates adjacency matrix in a sparse format (indices), based on the given + # distances and threshold. + distances = np.linalg.norm(cross_positions, axis=-1) + indices = np.where(distances < edge_threshold) + # Defines graph. + nodes = types[:, None] + senders = indices[0] + receivers = indices[1] + edges = cross_positions[indices] + + return jraph.pad_with_graphs(jraph.GraphsTuple( + nodes=nodes.astype(np.float32), + n_node=np.reshape(nodes.shape[0], [1]), + edges=edges.astype(np.float32), + n_edge=np.reshape(edges.shape[0], [1]), + globals=np.zeros((1, 1), dtype=np.float32), + receivers=receivers.astype(np.int32), + senders=senders.astype(np.int32) + ), n_node=4097, n_edge=200000) + + +def get_targets(initial_positions, trajectory_target_positions): + """Returns the averaged particle mobilities from the sampled trajectories. + + Args: + initial_positions: the initial positions of the particles with shape + [n_particles, 3]. + trajectory_target_positions: the absolute positions of the particles at the + target time for all sampled trajectories, each with shape + [n_particles, 3]. + """ + targets = np.mean([np.linalg.norm(t - initial_positions, axis=-1) + for t in trajectory_target_positions], axis=0) + return targets.astype(np.float32) + + +def load_data(file_pattern, time_index, max_files_to_load=None): + """Returns a graphs and targets of the training or test dataset. + + Args: + file_pattern: pattern matching the files with the simulation data. + time_index: the time index of the targets. + max_files_to_load: the maximum number of files to load. + """ + filenames = tf.io.gfile.glob(file_pattern) + if max_files_to_load: + filenames = filenames[:max_files_to_load] + + graphs_and_targets = [] + for filename in filenames: + with tf.io.gfile.GFile(filename, 'rb') as f: + data = pickle.load(f) + mask = (data['types'] == ParticleType.A).astype(np.int32) + # Mask dummy node due to padding + mask = np.concatenate([mask, np.zeros((1,), dtype=np.int32)], axis=-1) + targets = get_targets( + data['positions'], data['trajectory_target_positions'][time_index]) + targets = np.concatenate( + [targets, np.zeros((1,), dtype=np.float32)], axis=-1) + graphs_and_targets.append( + (make_graph_from_static_structure( + data['positions'].astype(np.float32), + data['types'].astype(np.int32), + data['box'].astype(np.float32), + edge_threshold=2.0), + targets, + mask)) + return graphs_and_targets + + +def apply_random_rotation(graph): + """Returns randomly rotated graph representation. + + The rotation is an element of O(3) with rotation angles multiple of pi/2. + This function assumes that the relative particle distances are stored in + the edge features. + + Args: + graph: The graphs tuple as defined in `graph_nets.graphs`. + """ + # Transposes edge features, so that the axes are in the first dimension. + # Outputs a tensor of shape [3, n_particles]. + xyz = np.transpose(graph.edges) + # Random pi/2 rotation(s) + permutation = np.array([0, 1, 2], dtype=np.int32) + np.random.shuffle(permutation) + xyz = xyz[permutation] + # Random reflections. + symmetry = np.random.randint(0, 2, [3]) + symmetry = 1 - 2 * np.reshape(symmetry, [3, 1]).astype(np.float32) + xyz = xyz * symmetry + edges = np.transpose(xyz) + return graph._replace(edges=edges) + + +def network_definition(graph): + """Defines a graph neural network. + + Args: + graph: Graphstuple the network processes. + + Returns: + Decoded nodes. + """ + model_fn = functools.partial( + hk.nets.MLP, + w_init=hk.initializers.VarianceScaling(1.0), + b_init=hk.initializers.VarianceScaling(1.0)) + mlp_sizes = (64, 64) + num_message_passing_steps = 7 + + node_encoder = model_fn(output_sizes=mlp_sizes, activate_final=True) + edge_encoder = model_fn(output_sizes=mlp_sizes, activate_final=True) + node_decoder = model_fn(output_sizes=mlp_sizes + (1,), activate_final=False) + + node_encoding = node_encoder(graph.nodes) + edge_encoding = edge_encoder(graph.edges) + graph = graph._replace(nodes=node_encoding, edges=edge_encoding) + + update_edge_fn = jraph.concatenated_args( + model_fn(output_sizes=mlp_sizes, activate_final=True)) + update_node_fn = jraph.concatenated_args( + model_fn(output_sizes=mlp_sizes, activate_final=True)) + gn = jraph.InteractionNetwork( + update_edge_fn=update_edge_fn, + update_node_fn=update_node_fn, + include_sent_messages_in_node_update=True) + + for _ in range(num_message_passing_steps): + graph = graph._replace( + nodes=jnp.concatenate([graph.nodes, node_encoding], axis=-1), + edges=jnp.concatenate([graph.edges, edge_encoding], axis=-1)) + graph = gn(graph) + + return jnp.squeeze(node_decoder(graph.nodes), axis=-1) + + +def train_model(train_file_pattern, + test_file_pattern, + max_files_to_load=None, + n_epochs=1000, + time_index=9, + learning_rate=1e-4, + grad_clip=1.0, + measurement_store_interval=1000, + checkpoint_path=None): + """Trains GraphModel using tensorflow. + + Args: + train_file_pattern: pattern matching the files with the training data. + test_file_pattern: pattern matching the files with the test data. + max_files_to_load: the maximum number of train and test files to load. + If None, all files will be loaded. + n_epochs: the number of passes through the training dataset (epochs). + time_index: the time index (0-9) of the target mobilities. + learning_rate: the learning rate used by the optimizer. + grad_clip: all gradients are clipped to the given value. + measurement_store_interval: number of steps between storing objective values + (loss and correlation). + checkpoint_path: ignored by this implementation. + """ + if checkpoint_path: + logging.warning('The checkpoint_path argument is ignored.') + random.seed(42) + np.random.seed(42) + # Loads train and test dataset. + dataset_kwargs = dict( + time_index=time_index, + max_files_to_load=max_files_to_load) + logging.info('Load training data') + training_data = load_data(train_file_pattern, **dataset_kwargs) + logging.info('Load test data') + test_data = load_data(test_file_pattern, **dataset_kwargs) + logging.info('Finished loading data') + + network = hk.without_apply_rng(hk.transform(network_definition)) + params = network.init(jax.random.PRNGKey(42), training_data[0][0]) + + opt_init, opt_update = optax.chain( + optax.clip_by_global_norm(grad_clip), + optax.scale_by_adam(0.9, 0.999, 1e-8), + optax.scale(-learning_rate)) + opt_state = opt_init(params) + + network_apply = jax.jit(network.apply) + + @jax.jit + def loss_fn(params, graph, targets, mask): + decoded_nodes = network_apply(params, graph) * mask + return (jnp.sum((decoded_nodes - targets)**2 * mask) / + jnp.sum(mask)) + + @jax.jit + def update(params, opt_state, graph, targets, mask): + loss, grads = jax.value_and_grad(loss_fn)(params, graph, targets, mask) + updates, opt_state = opt_update(grads, opt_state) + return optax.apply_updates(params, updates), opt_state, loss + + train_stats = [] + i = 0 + logging.info('Start training') + for epoch in range(n_epochs): + logging.info('Start epoch %r', epoch) + random.shuffle(training_data) + for graph, targets, mask in training_data: + graph = apply_random_rotation(graph) + params, opt_state, loss = update(params, opt_state, graph, targets, mask) + train_stats.append(loss) + + if (i+1) % measurement_store_interval == 0: + logging.info('Start evaluation run') + test_stats = [] + for test_graph, test_targets, test_mask in test_data: + predictions = network_apply(params, test_graph) + test_stats.append(np.corrcoef( + predictions[test_mask == 1], test_targets[test_mask == 1])[0, 1]) + logging.info('Train loss %r', np.mean(train_stats)) + logging.info('Test correlation %r', np.mean(test_stats)) + train_stats = [] + i += 1