mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Add checkpoints from the ablation study.
PiperOrigin-RevId: 328023346
This commit is contained in:
committed by
Diego de Las Casas
parent
22c3daff19
commit
8457046b2c
@@ -16,7 +16,7 @@
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
@@ -20,7 +20,7 @@ The architecture and performance of this model is described in our publication:
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
@@ -35,10 +35,10 @@ from typing import Any, Dict, Text, Tuple, Optional
|
||||
|
||||
|
||||
def make_graph_from_static_structure(
|
||||
positions,
|
||||
types,
|
||||
box,
|
||||
edge_threshold):
|
||||
positions: tf.Tensor,
|
||||
types: tf.Tensor,
|
||||
box: tf.Tensor,
|
||||
edge_threshold: float) -> graphs.GraphsTuple:
|
||||
"""Returns graph representing the static structure of the glass.
|
||||
|
||||
Each particle is represented by a node in the graph. The particle type is
|
||||
@@ -81,7 +81,7 @@ def make_graph_from_static_structure(
|
||||
)
|
||||
|
||||
|
||||
def apply_random_rotation(graph):
|
||||
def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple:
|
||||
"""Returns randomly rotated graph representation.
|
||||
|
||||
The rotation is an element of O(3) with rotation angles multiple of pi/2.
|
||||
@@ -118,9 +118,9 @@ class GraphBasedModel(snt.AbstractModule):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_recurrences,
|
||||
mlp_sizes,
|
||||
mlp_kwargs = None,
|
||||
n_recurrences: int,
|
||||
mlp_sizes: Tuple[int],
|
||||
mlp_kwargs: Optional[Dict[Text, Any]] = None,
|
||||
name='Graph'):
|
||||
"""Creates a new GraphBasedModel object.
|
||||
|
||||
@@ -168,7 +168,7 @@ class GraphBasedModel(snt.AbstractModule):
|
||||
node_model_fn=final_model_fn,
|
||||
edge_model_fn=model_fn)
|
||||
|
||||
def _build(self, graphs_tuple):
|
||||
def _build(self, graphs_tuple: graphs.GraphsTuple) -> tf.Tensor:
|
||||
"""Connects the model into the tensorflow graph.
|
||||
|
||||
Args:
|
||||
|
||||
+32
-32
@@ -16,7 +16,7 @@
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
@@ -53,8 +53,8 @@ class ParticleType(enum.IntEnum):
|
||||
|
||||
|
||||
def get_targets(
|
||||
initial_positions,
|
||||
trajectory_target_positions):
|
||||
initial_positions: np.ndarray,
|
||||
trajectory_target_positions: Sequence[np.ndarray]) -> np.ndarray:
|
||||
"""Returns the averaged particle mobilities from the sampled trajectories.
|
||||
|
||||
Args:
|
||||
@@ -70,9 +70,9 @@ def get_targets(
|
||||
|
||||
|
||||
def load_data(
|
||||
file_pattern,
|
||||
time_index,
|
||||
max_files_to_load = None):
|
||||
file_pattern: Text,
|
||||
time_index: int,
|
||||
max_files_to_load: Optional[int] = None) -> List[GlassSimulationData]:
|
||||
"""Returns a dictionary containing the training or test dataset.
|
||||
|
||||
The dictionary contains:
|
||||
@@ -108,9 +108,9 @@ def load_data(
|
||||
|
||||
|
||||
def get_loss_ops(
|
||||
prediction,
|
||||
target,
|
||||
types):
|
||||
prediction: tf.Tensor,
|
||||
target: tf.Tensor,
|
||||
types: tf.Tensor) -> LossCollection:
|
||||
"""Returns L1/L2 loss and correlation for type A particles.
|
||||
|
||||
Args:
|
||||
@@ -132,9 +132,9 @@ def get_loss_ops(
|
||||
|
||||
|
||||
def get_minimize_op(
|
||||
loss,
|
||||
learning_rate,
|
||||
grad_clip = None):
|
||||
loss: tf.Tensor,
|
||||
learning_rate: float,
|
||||
grad_clip: Optional[float] = None) -> tf.Tensor:
|
||||
"""Returns minimization operation.
|
||||
|
||||
Args:
|
||||
@@ -152,8 +152,8 @@ def get_minimize_op(
|
||||
|
||||
|
||||
def _log_stats_and_return_mean_correlation(
|
||||
label,
|
||||
stats):
|
||||
label: Text,
|
||||
stats: Sequence[LossCollection]) -> float:
|
||||
"""Logs performance statistics and returns mean correlation.
|
||||
|
||||
Args:
|
||||
@@ -171,20 +171,20 @@ def _log_stats_and_return_mean_correlation(
|
||||
return np.mean([s.correlation for s in stats])
|
||||
|
||||
|
||||
def train_model(train_file_pattern,
|
||||
test_file_pattern,
|
||||
max_files_to_load = None,
|
||||
n_epochs = 1000,
|
||||
time_index = 9,
|
||||
augment_data_using_rotations = True,
|
||||
learning_rate = 1e-4,
|
||||
grad_clip = 1.0,
|
||||
n_recurrences = 7,
|
||||
mlp_sizes = (64, 64),
|
||||
mlp_kwargs = None,
|
||||
edge_threshold = 2.0,
|
||||
measurement_store_interval = 1000,
|
||||
checkpoint_path = None):
|
||||
def train_model(train_file_pattern: Text,
|
||||
test_file_pattern: Text,
|
||||
max_files_to_load: Optional[int] = None,
|
||||
n_epochs: int = 1000,
|
||||
time_index: int = 9,
|
||||
augment_data_using_rotations: bool = True,
|
||||
learning_rate: float = 1e-4,
|
||||
grad_clip: Optional[float] = 1.0,
|
||||
n_recurrences: int = 7,
|
||||
mlp_sizes: Tuple[int] = (64, 64),
|
||||
mlp_kwargs: Optional[Dict[Text, Any]] = None,
|
||||
edge_threshold: float = 2.0,
|
||||
measurement_store_interval: int = 1000,
|
||||
checkpoint_path: Optional[Text] = None) -> float:
|
||||
"""Trains GraphModel using tensorflow.
|
||||
|
||||
Args:
|
||||
@@ -325,10 +325,10 @@ def train_model(train_file_pattern,
|
||||
return best_so_far
|
||||
|
||||
|
||||
def apply_model(checkpoint_path,
|
||||
file_pattern,
|
||||
max_files_to_load = None,
|
||||
time_index = 9):
|
||||
def apply_model(checkpoint_path: Text,
|
||||
file_pattern: Text,
|
||||
max_files_to_load: Optional[int] = None,
|
||||
time_index: int = 9) -> List[np.ndarray]:
|
||||
"""Applies trained GraphModel using tensorflow.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
Reference in New Issue
Block a user