Add checkpoints from the ablation study.

PiperOrigin-RevId: 328023346
This commit is contained in:
Florent Altché
2020-08-23 14:26:26 +01:00
committed by Diego de Las Casas
parent 22c3daff19
commit 8457046b2c
33 changed files with 397 additions and 363 deletions
+1 -1
View File
@@ -16,7 +16,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function
import os
+10 -10
View File
@@ -20,7 +20,7 @@ The architecture and performance of this model is described in our publication:
from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function
import functools
@@ -35,10 +35,10 @@ from typing import Any, Dict, Text, Tuple, Optional
def make_graph_from_static_structure(
positions,
types,
box,
edge_threshold):
positions: tf.Tensor,
types: tf.Tensor,
box: tf.Tensor,
edge_threshold: float) -> graphs.GraphsTuple:
"""Returns graph representing the static structure of the glass.
Each particle is represented by a node in the graph. The particle type is
@@ -81,7 +81,7 @@ def make_graph_from_static_structure(
)
def apply_random_rotation(graph):
def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple:
"""Returns randomly rotated graph representation.
The rotation is an element of O(3) with rotation angles multiple of pi/2.
@@ -118,9 +118,9 @@ class GraphBasedModel(snt.AbstractModule):
"""
def __init__(self,
n_recurrences,
mlp_sizes,
mlp_kwargs = None,
n_recurrences: int,
mlp_sizes: Tuple[int],
mlp_kwargs: Optional[Dict[Text, Any]] = None,
name='Graph'):
"""Creates a new GraphBasedModel object.
@@ -168,7 +168,7 @@ class GraphBasedModel(snt.AbstractModule):
node_model_fn=final_model_fn,
edge_model_fn=model_fn)
def _build(self, graphs_tuple):
def _build(self, graphs_tuple: graphs.GraphsTuple) -> tf.Tensor:
"""Connects the model into the tensorflow graph.
Args:
+32 -32
View File
@@ -16,7 +16,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function
import collections
@@ -53,8 +53,8 @@ class ParticleType(enum.IntEnum):
def get_targets(
initial_positions,
trajectory_target_positions):
initial_positions: np.ndarray,
trajectory_target_positions: Sequence[np.ndarray]) -> np.ndarray:
"""Returns the averaged particle mobilities from the sampled trajectories.
Args:
@@ -70,9 +70,9 @@ def get_targets(
def load_data(
file_pattern,
time_index,
max_files_to_load = None):
file_pattern: Text,
time_index: int,
max_files_to_load: Optional[int] = None) -> List[GlassSimulationData]:
"""Returns a dictionary containing the training or test dataset.
The dictionary contains:
@@ -108,9 +108,9 @@ def load_data(
def get_loss_ops(
prediction,
target,
types):
prediction: tf.Tensor,
target: tf.Tensor,
types: tf.Tensor) -> LossCollection:
"""Returns L1/L2 loss and correlation for type A particles.
Args:
@@ -132,9 +132,9 @@ def get_loss_ops(
def get_minimize_op(
loss,
learning_rate,
grad_clip = None):
loss: tf.Tensor,
learning_rate: float,
grad_clip: Optional[float] = None) -> tf.Tensor:
"""Returns minimization operation.
Args:
@@ -152,8 +152,8 @@ def get_minimize_op(
def _log_stats_and_return_mean_correlation(
label,
stats):
label: Text,
stats: Sequence[LossCollection]) -> float:
"""Logs performance statistics and returns mean correlation.
Args:
@@ -171,20 +171,20 @@ def _log_stats_and_return_mean_correlation(
return np.mean([s.correlation for s in stats])
def train_model(train_file_pattern,
test_file_pattern,
max_files_to_load = None,
n_epochs = 1000,
time_index = 9,
augment_data_using_rotations = True,
learning_rate = 1e-4,
grad_clip = 1.0,
n_recurrences = 7,
mlp_sizes = (64, 64),
mlp_kwargs = None,
edge_threshold = 2.0,
measurement_store_interval = 1000,
checkpoint_path = None):
def train_model(train_file_pattern: Text,
test_file_pattern: Text,
max_files_to_load: Optional[int] = None,
n_epochs: int = 1000,
time_index: int = 9,
augment_data_using_rotations: bool = True,
learning_rate: float = 1e-4,
grad_clip: Optional[float] = 1.0,
n_recurrences: int = 7,
mlp_sizes: Tuple[int] = (64, 64),
mlp_kwargs: Optional[Dict[Text, Any]] = None,
edge_threshold: float = 2.0,
measurement_store_interval: int = 1000,
checkpoint_path: Optional[Text] = None) -> float:
"""Trains GraphModel using tensorflow.
Args:
@@ -325,10 +325,10 @@ def train_model(train_file_pattern,
return best_so_far
def apply_model(checkpoint_path,
file_pattern,
max_files_to_load = None,
time_index = 9):
def apply_model(checkpoint_path: Text,
file_pattern: Text,
max_files_to_load: Optional[int] = None,
time_index: int = 9) -> List[np.ndarray]:
"""Applies trained GraphModel using tensorflow.
Args:
+1 -1
View File
@@ -16,7 +16,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function
import os