Open sourcing the physics inspired models code.

PiperOrigin-RevId: 408640606
This commit is contained in:
Alex Botev
2021-10-27 00:57:04 +01:00
committed by Saran Tunyasuvunakool
parent 9b751b7d20
commit 2c7c401024
20 changed files with 5902 additions and 0 deletions
+59
View File
@@ -0,0 +1,59 @@
# Implementation of multiple physics inspired models for modelling dynamics
This repository contains an implementation of different physics inspired models
used in the papers: **SyMetric: Measuring the Quality of Learnt Hamiltonian
Dynamics Inferred from Vision** and **Which priors matter? Benchmarking models
for learning latent dynamics**.
## Contributing
This is purely research code, provided with no further intentions of support or
any guarantees of backward compatibility.
## Installation
All package requirements are listed in `requirements.txt`.
You will still need to download and setup the datasets from the
[DeepMind Hamiltonian Dynamics Suite] manually.
```shell
git clone git@github.com:deepmind/deepmind-research.git
pip install -r ./deepmind_research/physics_inspired_models/requirements.txt
pip install ./deepmind_research/physics_inspired_models
pip install --upgrade "jax[XXX]"
```
where `XXX` is the correct type of accelerator that you have on your machine.
Note that if you are using a GPU you might need `XXX` to also include the
correct version of CUDA and cuDNN installed on your machine.
For more details please read [here](https://github.com/google/jax#installation).
## Usage
The file `jaxline_configs.py` contains all the configurations specifications for
the experiments in the two papers. To run an experiment, in addition to passing
the location of the configs file, you must provide extra arguments in the
following manner:
`${name_of_configuration},${index_in_sweep},${dataset_name}`
For example to run the second hyper-parameter configuration of the improved
Hamiltonian Generative Network (HGN++) on the mass-spring dataset you should
run in the command line (assuming that you are in the folder of the project):
```shell
python3 jaxline_train.py \
--config="jaxline_configs.py:sym_metric_hgn_plus_plus_sweep,1,toy_physics/mass_spring" \
--jaxline_mode="train" \
--logtostderr
```
## Reference
**SyMetric: Measuring the Quality of Learnt Hamiltonian Dynamics Inferred from Vision**
**Which priors matter? Benchmarking models for learning latent dynamics**
[DeepMind Hamiltonian Dynamics Suite]: https://github.com/deepmind/dm_hamiltonian_dynamics_suite
+14
View File
@@ -0,0 +1,14 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+353
View File
@@ -0,0 +1,353 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing model evaluation metric."""
import _thread as thread
import sys
import threading
import time
import warnings
from absl import logging
import distrax
import numpy as np
from sklearn import linear_model
from sklearn import model_selection
from sklearn import preprocessing
def quit_function(fn_name):
logging.error('%s took too long', fn_name)
sys.stderr.flush()
thread.interrupt_main()
def exit_after(s):
"""Use as decorator to exit function after s seconds."""
def outer(fn):
def inner(*args, **kwargs):
timer = threading.Timer(s, quit_function, args=[fn.__name__])
timer.start()
try:
result = fn(*args, **kwargs)
finally:
timer.cancel()
return result
return inner
return outer
@exit_after(400)
def do_grid_search(data_x_exp, data_y, clf, parameters, cv):
scoring_choice = 'explained_variance'
regressor = model_selection.GridSearchCV(
clf, parameters, cv=cv, refit=True, scoring=scoring_choice)
regressor.fit(data_x_exp, data_y)
return regressor
def symplectic_matrix(dim):
"""Return anti-symmetric identity matrix of given dimensionality."""
half_dims = int(dim/2)
eye = np.eye(half_dims)
zeros = np.zeros([half_dims, half_dims])
top_rows = np.concatenate([zeros, - eye], axis=1)
bottom_rows = np.concatenate([eye, zeros], axis=1)
return np.concatenate([top_rows, bottom_rows], axis=0)
def create_latent_mask(z0, dist_std_threshold=0.5):
"""Create mask based on informativeness of each latent dimension.
For stochastic models those latent dimensions that are too close to the prior
are likely to be uninformative and can be ignored.
Args:
z0: distribution or array of phase space
dist_std_threshold: informative latents have average inferred stds <
dist_std_threshold
Returns:
latent_mask_final: boolean mask of the same dimensionality as z0
"""
if isinstance(z0, distrax.Normal):
std_vals = np.mean(z0.variance(), axis=0)
elif isinstance(z0, distrax.Distribution):
raise NotImplementedError()
else:
# If the latent is deterministic, pass through all dimensions
return np.array([True]*z0.shape[-1])
tensor_shape = std_vals.shape
half_dims = int(tensor_shape[-1] / 2)
std_vals_q = std_vals[:half_dims]
std_vals_p = std_vals[half_dims:]
# Keep both q and corresponding p as either one is informative
informative_latents_inds = np.array([
x for x in range(len(std_vals_q)) if
std_vals_q[x] < dist_std_threshold or std_vals_p[x] < dist_std_threshold
])
if informative_latents_inds.shape[0] > 0:
latent_mask_final = np.zeros_like(std_vals_q)
latent_mask_final[informative_latents_inds] = 1
latent_mask_final = np.concatenate([latent_mask_final, latent_mask_final])
latent_mask_final = latent_mask_final == 1
return latent_mask_final
else:
return np.array([True]*tensor_shape[-1])
def standardize_data(data):
"""Applies the sklearn standardization to the data."""
scaler = preprocessing.StandardScaler()
scaler.fit(data)
return scaler.transform(data)
def find_best_polynomial(data_x, data_y, max_poly_order, rsq_threshold,
max_dim_n=32,
alpha_sweep=None,
max_iter=1000, cv=2):
"""Find minimal polynomial expansion that is sufficient to explain data using Lasso regression."""
rsq = 0
poly_order = 1
if not np.any(alpha_sweep):
alpha_sweep = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2]
# Avoid a large polynomial expansion for large latent sizes
if data_x.shape[-1] > max_dim_n:
print(f'>WARNING! Data is too high dimensional at {data_x.shape[-1]}')
print('>WARNING! Setting max_poly_order = 1')
max_poly_order = 1
while rsq < rsq_threshold and poly_order <= max_poly_order:
time_start = time.perf_counter()
poly = preprocessing.PolynomialFeatures(poly_order, include_bias=False)
data_x_exp = poly.fit_transform(data_x)
time_end = time.perf_counter()
print(
f'Took {time_end-time_start}s to create polynomial features of order '
f'{poly_order} and size {data_x_exp.shape[1]}.')
with warnings.catch_warnings():
warnings.simplefilter('ignore')
time_start = time.perf_counter()
clf = linear_model.Lasso(
random_state=0, max_iter=max_iter, normalize=False, warm_start=False)
parameters = {'alpha': alpha_sweep}
try:
regressor = do_grid_search(data_x_exp, data_y, clf, parameters, cv)
time_end = time.perf_counter()
print(f'Took {time_end-time_start}s to do regression grid search.')
# Get rsq results
time_start = time.perf_counter()
clf = linear_model.Lasso(
random_state=0,
alpha=regressor.best_params_['alpha'],
max_iter=max_iter,
normalize=False,
warm_start=False)
clf.fit(data_x_exp, data_y)
rsq = clf.score(data_x_exp, data_y)
time_end = time.perf_counter()
print(f'Took {time_end-time_start}s to get rsq results.')
old_regressor = regressor
old_poly_order = poly_order
old_poly = poly
old_data_x_exp = data_x_exp
old_rsq = rsq
old_clf = clf
print(f'Polynomial of order {poly_order} with '
f' alpha={regressor.best_params_} RSQ: {rsq}')
poly_order += 1
except KeyboardInterrupt:
time_end = time.perf_counter()
print(f'Timed out after {time_end-time_start}s of doing grid search.')
print(f'Continuing with previous poly_order={old_poly_order}...')
regressor = old_regressor
poly_order = old_poly_order
poly = old_poly
data_x_exp = old_data_x_exp
rsq = old_rsq
clf = old_clf
print(f'Polynomial of order {poly_order} with '
f' alpha={regressor.best_params_} RSQ: {rsq}')
break
return clf, poly, data_x_exp, rsq
def eval_monomial_grad(feature, x, w, grad_acc):
"""Accumulates gradient from polynomial features and their weights."""
features = feature.split(' ')
variable_indices = []
grads = np.ones(len(features)) * w
for i, feature in enumerate(features):
name_and_power = feature.split('^')
if len(name_and_power) == 1:
name, power = name_and_power[0], 1
else:
name, power = name_and_power
power = int(power)
var_index = int(name[1:])
variable_indices.append(var_index)
new_prod = np.ones_like(grads) * (x[var_index] ** power)
# This needs a special case, for situation where x[index] = 0.0
if power == 1:
new_prod[i] = 1.0
else:
new_prod[i] = power * (x[var_index] ** (power - 1))
grads = grads * new_prod
grad_acc[variable_indices] += grads
return grad_acc
def compute_jacobian_manual(x, polynomial_features, weight_matrix, tolerance):
"""Computes the jacobian manually."""
# Put together the equation for each output var
# polynomial_features = np.array(polynomial_obj.get_feature_names())
weight_mask = np.abs(weight_matrix) > tolerance
weight_matrix = weight_mask * weight_matrix
jacobians = list()
for i in range(weight_matrix.shape[0]):
grad_accumulator = np.zeros_like(x)
for j, feature in enumerate(polynomial_features):
eval_monomial_grad(feature, x, weight_matrix[i, j], grad_accumulator)
jacobians.append(grad_accumulator)
return np.stack(jacobians)
def calculate_jacobian_prod(jacobian, noise_eps=1e-6):
"""Calculates AA*, where A=JEJ^T and A*=JE^TJ^T, which should be I."""
# Add noise as 0 in jacobian creates issues in calculations later
jacobian = jacobian + noise_eps
sym_matrix = symplectic_matrix(jacobian.shape[1])
pred = np.matmul(jacobian, sym_matrix)
pred = np.matmul(pred, np.transpose(jacobian))
pred_t = np.matmul(jacobian, np.transpose(sym_matrix))
pred_t = np.matmul(pred_t, np.transpose(jacobian))
pred_id = np.matmul(pred, pred_t)
return pred_id
def normalise_jacobian_prods(jacobian_preds):
"""Normalises Jacobians evaluated at various points by a constant."""
stacked_preds = np.stack(jacobian_preds)
# For each attempt at estimating E, get the max term, and take their average
normalisation_factor = np.mean(np.max(np.abs(stacked_preds), axis=(1, 2)))
if normalisation_factor != 0:
stacked_preds = stacked_preds/normalisation_factor
return stacked_preds
def calculate_symetric_score(
gt_data,
model_data,
max_poly_order,
max_sym_score,
rsq_threshold,
sym_threshold,
evaluation_point_n,
trajectory_n=1,
weight_tolerance=1e-5,
alpha_sweep=None,
max_iter=1000,
cv=2):
"""Finds minimal polynomial expansion to explain data using Lasso regression, gets the Jacobian of the mapping and calculates how symplectic the map is."""
model_data = model_data[..., :gt_data.shape[0], :]
# Fing polynomial expansion that explains enough variance in the gt data
print('Finding best polynomial expansion...')
time_start = time.perf_counter()
# Clean up model data to ensure it doesn't contain NaN, infinity
# or values too large for dtype('float32')
model_data = np.nan_to_num(model_data)
model_data = np.clip(model_data, -999999, 999999)
clf, poly, model_data_exp, best_rsq = find_best_polynomial(
model_data, gt_data, max_poly_order, rsq_threshold,
32, alpha_sweep, max_iter, cv)
time_end = time.perf_counter()
print(f'Took {time_end - time_start}s to find best polynomial.')
# Calculate Symplecticity score
all_raw_scores = []
features = np.array(poly.get_feature_names())
points_per_trajectory = int(len(gt_data) / trajectory_n)
for trajectory in range(trajectory_n):
random_data_inds = np.random.permutation(
range(points_per_trajectory))[:evaluation_point_n]
jacobian_preds = []
for point_ind in random_data_inds:
input_data_point = model_data[points_per_trajectory * trajectory +
point_ind]
time_start = time.perf_counter()
jacobian = compute_jacobian_manual(input_data_point, features,
clf.coef_, weight_tolerance)
pred = calculate_jacobian_prod(jacobian)
jacobian_preds.append(pred)
time_end = time.perf_counter()
print(f'Took {time_end - time_start}s to evaluate jacobian '
f'around point {point_ind}.')
# Normalise
normalised_jacobian_preds = normalise_jacobian_prods(jacobian_preds)
# The score is measured as the deviation from I
identity = np.eye(normalised_jacobian_preds.shape[-1])
scores = np.mean(np.power(normalised_jacobian_preds - identity, 2),
axis=(1, 2))
all_raw_scores.append(scores)
sym_score = np.min([np.mean(all_raw_scores), max_sym_score])
# Calculate final SyMetric score
if best_rsq > rsq_threshold and sym_score < sym_threshold:
sy_metric = 1.0
else:
sy_metric = 0.0
results = {
'poly_exp_order': poly.get_params()['degree'],
'rsq': best_rsq,
'sym': sym_score,
'SyMetric': sy_metric,
}
with np.printoptions(precision=4, suppress=True):
print(f'----------------FINAL RESULTS FOR {trajectory_n} '
'TRAJECTORIES------------------')
print(f'BEST POLYNOMIAL EXPANSION ORDER: {results["poly_exp_order"]}')
print(f'BEST RSQ (1-best): {results["rsq"]}')
print(f'SYMPLECTICITY SCORE AROUND ALL POINTS AND ALL '
f'TRAJECTORIES (0-best): {sym_score}')
print(f'SyMETRIC SCORE: {sy_metric}')
print(f'----------------FINAL RESULTS FOR {trajectory_n} '
f'TRAJECTORIES------------------')
return results, clf, poly, model_data_exp
File diff suppressed because it is too large Load Diff
+397
View File
@@ -0,0 +1,397 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing all of the configurations for various models."""
import copy
import os
from jaxline import base_config
import ml_collections as collections
_DATASETS_PATH_VAR_NAME = "DM_HAMILTONIAN_DYNAMICS_SUITE_DATASETS"
def get_config(arg_string):
"""Return config object for training."""
args = arg_string.split(",")
if len(args) != 3:
raise ValueError("You must provide exactly three arguments separated by a "
"comma - model_config_name,sweep_index,dataset_name.")
model_config_name, sweep_index, dataset_name = args
sweep_index = int(sweep_index)
config = base_config.get_base_config()
config.random_seed = 123109801
config.eval_modes = ("eval", "eval_metric")
# Get the model config and the sweeps
if model_config_name not in globals():
raise ValueError(f"The config name {model_config_name} does not exist in "
f"jaxline_configs.py")
config_and_sweep_fn = globals()[model_config_name]
model_config, sweeps = config_and_sweep_fn()
if not os.environ.get(_DATASETS_PATH_VAR_NAME, None):
raise ValueError(f"You need to set the {_DATASETS_PATH_VAR_NAME}")
dm_hamiltonian_suite_path = os.environ[_DATASETS_PATH_VAR_NAME]
dataset_folder = os.path.join(dm_hamiltonian_suite_path, dataset_name)
# Experiment config. Note that batch_size is per device.
# In the experiments we run on 4 GPUs, so the effective batch size was 128.
config.experiment_kwargs = collections.ConfigDict(
dict(
config=dict(
dataset_folder=dataset_folder,
model_kwargs=model_config,
num_extrapolation_steps=60,
drop_stats_containing=("neg_log_p_x", "l2_over_time", "neg_elbo"),
optimizer=dict(
name="adam",
kwargs=dict(
learning_rate=1.5e-4,
b1=0.9,
b2=0.999,
)
),
training=dict(
batch_size=32,
burnin_steps=5,
num_epochs=None,
lagging_vae=False
),
evaluation=dict(
batch_size=64,
),
evaluation_metric=dict(
batch_size=5,
batch_n=20,
num_eval_metric_steps=60,
max_poly_order=5,
max_jacobian_score=1000,
rsq_threshold=0.9,
sym_threshold=0.05,
evaluation_point_n=10,
weight_tolerance=1e-03,
max_iter=1000,
cv=2,
alpha_min_logspace=-4,
alpha_max_logspace=-0.5,
alpha_step_n=10,
calculate_fully_after_steps=40000,
),
evaluation_metric_mlp=dict(
batch_size=64,
batch_n=10000,
datapoint_param_multiplier=1000,
num_eval_metric_steps=60,
evaluation_point_n=10,
evaluation_trajectory_n=50,
rsq_threshold=0.9,
sym_threshold=0.05,
ridge_lambda=0.01,
model=dict(
num_units=4,
num_layers=4,
activation="tanh",
),
optimizer=dict(
name="adam",
kwargs=dict(
learning_rate=1.5e-3,
)
),
),
evaluation_vpt=dict(
batch_size=5,
batch_n=2,
vpt_threshold=0.025,
)
)
)
)
# Training loop config.
config.training_steps = int(500000)
config.interval_type = "steps"
config.log_tensors_interval = 50
config.log_train_data_interval = 50
config.log_all_train_data = False
config.save_checkpoint_interval = 100
config.checkpoint_dir = "/tmp/physics_inspired_models/"
config.train_checkpoint_all_hosts = False
config.eval_specific_checkpoint_dir = ""
config.update_from_flattened_dict(sweeps[sweep_index])
return config
config_prefix = "experiment_kwargs.config."
model_prefix = config_prefix + "model_kwargs."
default_encoder_kwargs = collections.ConfigDict(dict(
conv_channels=64,
num_blocks=3,
blocks_depth=2,
activation="leaky_relu",
))
default_decoder_kwargs = collections.ConfigDict(dict(
conv_channels=64,
num_blocks=3,
blocks_depth=2,
activation="leaky_relu",
))
default_latent_system_net_kwargs = collections.ConfigDict(dict(
conv_channels=64,
num_units=250,
num_layers=5,
activation="swish",
))
default_latent_system_kwargs = collections.ConfigDict(dict(
# Physics model arguments
input_space=collections.config_dict.placeholder(str),
simulation_space=collections.config_dict.placeholder(str),
potential_func_form="separable_net",
kinetic_func_form=collections.config_dict.placeholder(str),
hgn_kinetic_func_form="separable_net",
lgn_kinetic_func_form="matrix_dep_quad",
parametrize_mass_matrix=collections.config_dict.placeholder(bool),
hgn_parametrize_mass_matrix=False,
lgn_parametrize_mass_matrix=True,
mass_eps=1.0,
# ODE model arguments
integrator_method=collections.config_dict.placeholder(str),
# RGN model arguments
residual=collections.config_dict.placeholder(bool),
# General arguments
net_kwargs=default_latent_system_net_kwargs
))
default_config_dict = collections.ConfigDict(dict(
name=collections.config_dict.placeholder(str),
latent_system_dim=32,
latent_system_net_type="mlp",
latent_system_kwargs=default_latent_system_kwargs,
encoder_aggregation_type="linear_projection",
decoder_de_aggregation_type=collections.config_dict.placeholder(str),
encoder_kwargs=default_encoder_kwargs,
decoder_kwargs=default_decoder_kwargs,
has_latent_transform=False,
num_inference_steps=5,
num_target_steps=60,
latent_training_type="forward",
# Choices: overlap_by_one, no_overlap, include_inference
training_data_split="overlap_by_one",
objective_type="ELBO",
elbo_beta_delay=0,
elbo_beta_final=1.0,
geco_kappa=0.001,
geco_alpha=0.0,
dt=0.125,
))
hgn_paper_encoder_kwargs = collections.ConfigDict(dict(
conv_channels=[[32, 64], [64, 64], [64]],
num_blocks=3,
blocks_depth=2,
activation="relu",
kernel_shapes=[2, 4],
padding=["VALID", "SAME"],
))
hgn_paper_decoder_kwargs = collections.ConfigDict(dict(
conv_channels=64,
num_blocks=3,
blocks_depth=2,
activation="tf_leaky_relu",
))
hgn_paper_latent_net_kwargs = collections.ConfigDict(dict(
conv_channels=[32, 64, 64, 64],
num_units=250,
num_layers=5,
activation="softplus",
kernel_shapes=[3, 2, 2, 2, 2],
strides=[1, 2, 1, 2, 1],
padding=["SAME", "VALID", "SAME", "VALID", "SAME"]
))
hgn_paper_latent_system_kwargs = collections.ConfigDict(dict(
potential_func_form="separable_net",
kinetic_func_form="separable_net",
parametrize_mass_matrix=False,
net_kwargs=hgn_paper_latent_net_kwargs
))
hgn_paper_latent_transform_kwargs = collections.ConfigDict(dict(
num_layers=5,
conv_channels=64,
num_units=64,
activation="relu",
))
hgn_paper_config = copy.deepcopy(default_config_dict)
hgn_paper_config.training_data_split = "include_inference"
hgn_paper_config.latent_system_net_type = "conv"
hgn_paper_config.encoder_aggregation_type = (collections.config_dict.
placeholder(str))
hgn_paper_config.decoder_de_aggregation_type = (collections.config_dict.
placeholder(str))
hgn_paper_config.latent_system_kwargs = hgn_paper_latent_system_kwargs
hgn_paper_config.encoder_kwargs = hgn_paper_encoder_kwargs
hgn_paper_config.decoder_kwargs = hgn_paper_decoder_kwargs
hgn_paper_config.has_latent_transform = True
hgn_paper_config.latent_transform_kwargs = hgn_paper_latent_transform_kwargs
hgn_paper_config.num_inference_steps = 31
hgn_paper_config.num_target_steps = 0
hgn_paper_config.objective_type = "GECO"
forward_overlap_by_one = {
model_prefix + "latent_training_type": "forward",
model_prefix + "training_data_split": "overlap_by_one",
}
forward_backward_include_inference = {
model_prefix + "latent_training_type": "forward_backward",
model_prefix + "training_data_split": "include_inference",
}
latent_training_sweep = [
forward_overlap_by_one,
forward_backward_include_inference,
]
def sym_metric_hgn_plus_plus_sweep():
"""HGN++ experimental sweep for the SyMetric paper."""
model_config = copy.deepcopy(default_config_dict)
model_config.name = "HGN"
sweeps = list()
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
sweeps.append({
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
model_prefix + "latent_training_type": "forward",
model_prefix + "training_data_split": "overlap_by_one",
model_prefix + "elbo_beta_final": elbo_beta_final,
})
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
sweeps.append({
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
model_prefix + "latent_training_type": "forward_backward",
model_prefix + "training_data_split": "include_inference",
model_prefix + "elbo_beta_final": elbo_beta_final,
})
return model_config, sweeps
def sym_metric_hgn_sweep():
"""HGN experimental sweep for the SyMetric paper."""
model_config = copy.deepcopy(hgn_paper_config)
model_config.name = "HGN"
return model_config, list(dict())
def benchmark_hgn_overlap_sweep():
"""HGN++ sweep for the benchmark paper."""
model_config = copy.deepcopy(default_config_dict)
model_config.name = "HGN"
sweeps = list()
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
for train_dict in latent_training_sweep:
sweeps.append({
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
model_prefix + "elbo_beta_final": elbo_beta_final,
})
sweeps[-1].update(train_dict)
return model_config, sweeps
def benchmark_lgn_sweep():
"""LGN sweep for the benchmark paper."""
model_config = copy.deepcopy(default_config_dict)
model_config.name = "LGN"
sweeps = list()
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
for train_dict in latent_training_sweep:
sweeps.append({
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
model_prefix + "latent_system_kwargs.kinetic_func_form":
"matrix_dep_pure_quad",
model_prefix + "elbo_beta_final": elbo_beta_final,
})
sweeps[-1].update(train_dict)
return model_config, sweeps
def benchmark_ode_sweep():
"""Neural ODE sweep for the benchmark paper."""
model_config = copy.deepcopy(default_config_dict)
model_config.name = "ODE"
sweeps = list()
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
for integrator in ("adaptive", "rk2"):
for train_dict in latent_training_sweep:
sweeps.append({
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
model_prefix + "integrator_method": integrator,
model_prefix + "elbo_beta_final": elbo_beta_final,
})
sweeps[-1].update(train_dict)
return model_config, sweeps
def benchmark_rgn_sweep():
"""RGN sweep for the benchmark paper."""
model_config = copy.deepcopy(default_config_dict)
model_config.name = "RGN"
sweeps = list()
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
for residual in (True, False):
sweeps.append({
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
model_prefix + "latent_system_kwargs.residual": residual,
model_prefix + "elbo_beta_final": elbo_beta_final,
})
return model_config, sweeps
def benchmark_ar_sweep():
"""AR sweep for the benchmark paper."""
model_config = copy.deepcopy(default_config_dict)
model_config.name = "AR"
model_config.latent_dynamics_type = "vanilla"
sweeps = list()
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
for ar_type in ("vanilla", "lstm", "gru"):
sweeps.append({
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
model_prefix + "latent_dynamics_type": ar_type,
model_prefix + "elbo_beta_final": elbo_beta_final,
})
return model_config, sweeps
File diff suppressed because it is too large Load Diff
+52
View File
@@ -0,0 +1,52 @@
#!/bin/bash
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Script to execute a single configuration on all datasets.
if [[ "$#" -eq 2 ]]; then
readonly CONFIG_NAME="$1"
readonly NUM_SWEEPS="$2"
else
echo "You must provide exactly two arguments - the configuration name and " \
"how many sweeps it contains. For example:"
echo "./launch_all.sh sym_metric_hgn_plus_plus_sweep 1"
exit 2
fi
DATASETS=(
"toy_physics/mass_spring"
"toy_physics/mass_spring_colors"
"toy_physics/mass_spring_colors_friction"
"toy_physics/pendulum"
"toy_physics/pendulum_colors"
"toy_physics/pendulum_colors_friction"
"toy_physics/two_body"
"toy_physics/two_body_colors"
"toy_physics/double_pendulum"
"toy_physics/double_pendulum_colors"
"toy_physics/double_pendulum_colors_friction"
"molecular_dynamics/lj_4"
"molecular_dynamics/lj_16"
"multi_agent/rock_paper_scissors"
"multi_agent/matching_pennies"
"mujoco_room/circle"
"mujoco_room/spiral"
)
readonly DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
for dataset in "${DATASETS[@]}"; do
"${DIR}/launch_local.sh" "${CONFIG_NAME}" "${NUM_SWEEPS}" "${dataset}"
done
+40
View File
@@ -0,0 +1,40 @@
#!/bin/bash
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to execute a single configuration name on a given dataset.
if [[ "$#" -eq 3 ]]; then
readonly CONFIG_NAME="$1"
readonly NUM_SWEEPS="$2"
readonly DATASET="$3"
else
echo "You must provide exactly three arguments - the configuration name, " \
"the number of sweeps it contains and the dataset name. For example:"
echo "./launch_local.sh sym_metric_hgn_plus_plus_sweep 1 " \
"toy_physics/mass_spring"
exit 2
fi
echo "Running with config ${CONFIG_NAME} on ${DATASET}."
readonly EXPERIMENT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
readonly TRAIN_FILE="${EXPERIMENT_DIR}/jaxline_train.py"
readonly CONFIG_FILE="${EXPERIMENT_DIR}/jaxline_configs.py"
for sweep_id in $(seq 0 $((NUM_SWEEPS - 1))); do
python3 "${TRAIN_FILE}" \
--config="${CONFIG_FILE}:${CONFIG_NAME},${sweep_id},${DATASET}" \
--jaxline_mode="train" \
--logtostderr
done
+247
View File
@@ -0,0 +1,247 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing code for computing various metrics for training and evaluation."""
from typing import Callable, Dict, Optional
import distrax
import haiku as hk
import jax
import jax.nn as nn
import jax.numpy as jnp
import numpy as np
import physics_inspired_models.utils as utils
_ReconstructFunc = Callable[[utils.Params, jnp.ndarray, jnp.ndarray, bool],
distrax.Distribution]
def calculate_small_latents(dist, threshold=0.5):
"""Calculates the number of active latents by thresholding the variance of their distribution."""
if not isinstance(dist, distrax.Normal):
raise NotImplementedError()
latent_means = dist.mean()
latent_stddevs = dist.variance()
small_latents = jnp.sum(
(latent_stddevs < threshold) & (jnp.abs(latent_means) > 0.1), axis=1)
return jnp.mean(small_latents)
def compute_scale(
targets: jnp.ndarray,
rescale_by: str
) -> jnp.ndarray:
"""Compute a scaling factor based on targets shape and the rescale_by argument."""
if rescale_by == "pixels_and_time":
return jnp.asarray(np.prod(targets.shape[-4:]))
elif rescale_by is not None:
raise ValueError(f"Unrecognized rescale_by={rescale_by}.")
else:
return jnp.ones([])
def compute_data_domain_stats(
p_x: distrax.Distribution,
targets: jnp.ndarray
) -> Dict[str, jnp.ndarray]:
"""Compute several statistics in the data domain, such as L2 and negative log likelihood."""
axis = tuple(range(2, targets.ndim))
l2_over_time = jnp.sum((p_x.mean() - targets) ** 2, axis=axis)
l2 = jnp.sum(l2_over_time, axis=1)
# Calculate relative L2 normalised by image "length"
norm_factor = jnp.sum(targets**2, axis=(2, 3, 4))
l2_over_time_norm = l2_over_time / norm_factor
l2_norm = jnp.sum(l2_over_time_norm, axis=1)
# Compute negative log-likelihood under p(x)
neg_log_p_x_over_time = - np.sum(p_x.log_prob(targets), axis=axis)
neg_log_p_x = jnp.sum(neg_log_p_x_over_time, axis=1)
return dict(
neg_log_p_x_over_time=neg_log_p_x_over_time,
neg_log_p_x=neg_log_p_x,
l2_over_time=l2_over_time,
l2=l2,
l2_over_time_norm=l2_over_time_norm,
l2_norm=l2_norm,
)
def compute_vae_stats(
neg_log_p_x: jnp.ndarray,
rng: jnp.ndarray,
q_z: distrax.Distribution,
prior: distrax.Distribution
) -> Dict[str, jnp.ndarray]:
"""Compute the KL(q(z|x)||p(z)) and the negative ELBO, which are used for VAE models."""
# Compute the KL
kl = distrax.estimate_kl_best_effort(q_z, prior, rng_key=rng, num_samples=1)
kl = np.sum(kl, axis=list(range(1, kl.ndim)))
# Sanity check
assert kl.shape == neg_log_p_x.shape
return dict(
kl=kl,
neg_elbo=neg_log_p_x + kl,
)
def training_statistics(
p_x: distrax.Distribution,
targets: jnp.ndarray,
rescale_by: Optional[str],
rng: Optional[jnp.ndarray] = None,
q_z: Optional[distrax.Distribution] = None,
prior: Optional[distrax.Distribution] = None,
p_x_learned_sigma: bool = False
) -> Dict[str, jnp.ndarray]:
"""Computes various statistics we track during training."""
stats = compute_data_domain_stats(p_x, targets)
if rng is not None and q_z is not None and prior is not None:
stats.update(compute_vae_stats(stats["neg_log_p_x"], rng, q_z, prior))
else:
assert rng is None and q_z is None and prior is None
# Rescale these stats accordingly
scale = compute_scale(targets, rescale_by)
# Note that "_over_time" stats are getting normalised by time here
stats = jax.tree_map(lambda x: x / scale, stats)
if p_x_learned_sigma:
stats["p_x_sigma"] = p_x.variance().reshape([-1])[0]
if q_z is not None:
stats["small_latents"] = calculate_small_latents(q_z)
return stats
def evaluation_only_statistics(
reconstruct_func: _ReconstructFunc,
params: hk.Params,
inputs: jnp.ndarray,
rng: jnp.ndarray,
rescale_by: str,
can_run_backwards: bool,
train_sequence_length: int,
reconstruction_skip: int,
p_x_learned_sigma: bool = False,
) -> Dict[str, jnp.ndarray]:
"""Computes various statistics we track only during evaluation."""
full_trajectory = utils.extract_image(inputs)
prefixes = ("forward", "backward") if can_run_backwards else ("forward",)
full_forward_targets = jax.tree_map(
lambda x: x[:, reconstruction_skip:], full_trajectory)
full_backward_targets = jax.tree_map(
lambda x: x[:, :x.shape[1]-reconstruction_skip], full_trajectory)
train_targets_length = train_sequence_length - reconstruction_skip
full_targets_length = full_forward_targets.shape[1]
stats = dict()
keys = ()
for prefix in prefixes:
# Fully unroll the model and reconstruct the whole sequence
full_prediction = reconstruct_func(params, full_trajectory, rng,
prefix == "forward")
assert isinstance(full_prediction, distrax.Normal)
full_targets = (full_forward_targets if prefix == "forward" else
full_backward_targets)
# In cases where the model can run backwards it is possible to reconstruct
# parts which were indented to be skipped, so here we take care of that.
if full_prediction.mean().shape[1] > full_targets_length:
if prefix == "forward":
full_prediction = jax.tree_map(lambda x: x[:, -full_targets_length:],
full_prediction)
else:
full_prediction = jax.tree_map(lambda x: x[:, :full_targets_length],
full_prediction)
# Based on the prefix and suffix fetch correct predictions and targets
for suffix in ("train", "extrapolation", "full"):
if prefix == "forward" and suffix == "train":
predict, targets = jax.tree_map(lambda x: x[:, :train_targets_length],
(full_prediction, full_targets))
elif prefix == "forward" and suffix == "extrapolation":
predict, targets = jax.tree_map(lambda x: x[:, train_targets_length:],
(full_prediction, full_targets))
elif prefix == "backward" and suffix == "train":
predict, targets = jax.tree_map(lambda x: x[:, -train_targets_length:],
(full_prediction, full_targets))
elif prefix == "backward" and suffix == "extrapolation":
predict, targets = jax.tree_map(lambda x: x[:, :-train_targets_length],
(full_prediction, full_targets))
else:
predict, targets = full_prediction, full_targets
# Compute train statistics
train_stats = training_statistics(predict, targets, rescale_by,
p_x_learned_sigma=p_x_learned_sigma)
for key, value in train_stats.items():
stats[prefix + "_" + suffix + "_" + key] = value
# Copy all stats keys
keys = tuple(train_stats.keys())
# Make a combined metric summing forward and backward
if can_run_backwards:
# Also compute
for suffix in ("train", "extrapolation", "full"):
for key in keys:
forward = stats["forward_" + suffix + "_" + key]
backward = stats["backward_" + suffix + "_" + key]
combined = (forward + backward) / 2
stats["combined_" + suffix + "_" + key] = combined
return stats
def geco_objective(
l2_loss,
kl,
alpha,
kappa,
constraint_ema,
lambda_var,
is_training
) -> Dict[str, jnp.ndarray]:
"""Computes the objective for GECO and some of it statistics used ofr updates."""
# C_t
constraint_t = l2_loss - kappa
if is_training:
# We update C_ma only during training
constraint_ema = alpha * constraint_ema + (1 - alpha) * constraint_t
lagrange = nn.softplus(lambda_var)
lagrange = jnp.broadcast_to(lagrange, constraint_ema.shape)
# Add this special op for getting all gradients correct
loss = utils.geco_lagrange_product(lagrange, constraint_ema, constraint_t)
return dict(
loss=loss + kl,
geco_multiplier=lagrange,
geco_constraint=constraint_t,
geco_constraint_ema=constraint_ema
)
def elbo_objective(neg_log_p_x, kl, final_beta, beta_delay, step):
"""Computes objective for optimizing the Evidence Lower Bound (ELBO)."""
if beta_delay == 0:
beta = final_beta
else:
delayed_beta = jnp.minimum(float(step) / float(beta_delay), 1.0)
beta = delayed_beta * final_beta
return dict(
loss=neg_log_p_x + beta * kl,
elbo_beta=beta
)
@@ -0,0 +1,14 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,345 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for all autoregressive models."""
import functools
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union
import distrax
import haiku as hk
from jax import lax
import jax.numpy as jnp
import jax.random as jnr
import physics_inspired_models.metrics as metrics
import physics_inspired_models.models.base as base
import physics_inspired_models.models.networks as nets
import physics_inspired_models.utils as utils
class TeacherForcingAutoregressiveModel(base.SequenceModel):
"""A standard autoregressive model trained via teacher forcing."""
def __init__(
self,
latent_system_dim: int,
latent_system_net_type: str,
latent_system_kwargs: Dict[str, Any],
latent_dynamics_type: str,
encoder_aggregation_type: Optional[str],
decoder_de_aggregation_type: Optional[str],
encoder_kwargs: Dict[str, Any],
decoder_kwargs: Dict[str, Any],
num_inference_steps: int,
num_target_steps: int,
name: Optional[str] = None,
**kwargs
):
# Remove any parameters from vae models
encoder_kwargs = dict(**encoder_kwargs)
encoder_kwargs["distribution_name"] = None
if kwargs.get("has_latent_transform", False):
raise ValueError("We do not support AR models with latent transform.")
super().__init__(
can_run_backwards=False,
latent_system_dim=latent_system_dim,
latent_system_net_type=latent_system_net_type,
latent_system_kwargs=latent_system_kwargs,
encoder_aggregation_type=encoder_aggregation_type,
decoder_de_aggregation_type=decoder_de_aggregation_type,
encoder_kwargs=encoder_kwargs,
decoder_kwargs=decoder_kwargs,
num_inference_steps=num_inference_steps,
num_target_steps=num_target_steps,
name=name,
**kwargs
)
self.latent_dynamics_type = latent_dynamics_type
# Arguments checks
if self.latent_system_net_type != "mlp":
raise ValueError("Currently we do not support non-mlp AR models.")
def recurrence_function(sequence, initial_state=None):
core = nets.make_flexible_recurrent_net(
core_type=latent_dynamics_type,
net_type=latent_system_net_type,
output_dims=self.latent_system_dim,
**self.latent_system_kwargs["net_kwargs"])
initial_state = initial_state or core.initial_state(sequence.shape[1])
core(sequence[0], initial_state)
return hk.dynamic_unroll(core, sequence, initial_state)
self.recurrence = hk.transform(recurrence_function)
def process_inputs_for_encoder(self, x: jnp.ndarray) -> jnp.ndarray:
return x
def process_latents_for_dynamics(self, z: jnp.ndarray) -> jnp.ndarray:
return z
def process_latents_for_decoder(self, z: jnp.ndarray) -> jnp.ndarray:
return z
@property
def inferred_index(self) -> int:
return self.num_inference_steps - 1
@property
def train_sequence_length(self) -> int:
return self.num_target_steps
def train_data_split(
self,
images: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray, Mapping[str, Any]]:
images = images[:, :self.train_sequence_length]
inference_data = images[:, :-1]
target_data = images[:, 1:]
return inference_data, target_data, dict(
num_steps_forward=1,
num_steps_backward=0,
include_z0=False)
def unroll_without_inputs(
self,
params: utils.Params,
rng: jnp.ndarray,
x_init: jnp.ndarray,
h_init: jnp.ndarray,
num_steps: int,
is_training: bool
) -> Tuple[Tuple[distrax.Distribution, jnp.ndarray], Any]:
if num_steps < 1:
raise ValueError("`num_steps` must be at least 1.")
def step_fn(carry, key):
x_last, h_last = carry
enc_key, dec_key = jnr.split(key)
z_in_next = self.encoder.apply(params, enc_key, x_last,
is_training=is_training)
z_next, h_next = self.recurrence.apply(params, None, z_in_next[None],
h_last)
p_x_next = self.decode_latents(params, dec_key, z_next[0],
is_training=is_training)
return (p_x_next.mean(), h_next), (p_x_next, z_next[0])
return lax.scan(
step_fn,
init=(x_init, h_init),
xs=jnr.split(rng, num_steps)
)
def unroll_latent_dynamics(
self,
z: jnp.ndarray,
params: utils.Params,
key: jnp.ndarray,
num_steps_forward: int,
num_steps_backward: int,
include_z0: bool,
is_training: bool,
**kwargs: Any
) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
init_key, unroll_key, dec_key = jnr.split(key, 3)
if num_steps_backward != 0:
raise ValueError("This model can not run backwards.")
# Change 'z' time dimension to be first
z = jnp.swapaxes(z, 0, 1)
# Run recurrent model on inputs
z_0, h_0 = self.recurrence.apply(params, init_key, z)
if num_steps_forward == 1:
z_t = z_0
elif num_steps_forward > 1:
p_x_0 = self.decode_latents(params, dec_key, z_0[-1], is_training=False)
_, (_, z_t) = self.unroll_without_inputs(
params=params,
rng=unroll_key,
x_init=p_x_0.mean(),
h_init=h_0,
num_steps=num_steps_forward-1,
is_training=is_training
)
z_t = jnp.concatenate([z_0, z_t], axis=0)
else:
raise ValueError("num_steps_forward should be at least 1.")
# Make time dimension second
return jnp.swapaxes(z_t, 0, 1), dict()
def _models_core(
self,
params: utils.Params,
keys: jnp.ndarray,
image_data: jnp.ndarray,
is_training: bool,
**unroll_kwargs: Any
) -> Tuple[distrax.Distribution, jnp.ndarray, jnp.ndarray]:
enc_key, _, transform_key, unroll_key, dec_key, _ = keys
# Calculate latent input representation
inference_data = self.process_inputs_for_encoder(image_data)
z_raw = self.encoder.apply(params, enc_key, inference_data,
is_training=is_training)
# Apply latent transformation (should be identity)
z0 = self.apply_latent_transform(params, transform_key, z_raw,
is_training=is_training)
z0 = self.process_latents_for_dynamics(z0)
# Calculate latent output representation
decoder_z, _ = self.unroll_latent_dynamics(
z=z0,
params=params,
key=unroll_key,
is_training=is_training,
**unroll_kwargs
)
decoder_z = self.process_latents_for_decoder(decoder_z)
# Compute p(x|z)
p_x = self.decode_latents(params, dec_key, decoder_z,
is_training=is_training)
return p_x, z0, decoder_z
def training_objectives(
self,
params: hk.Params,
state: hk.State,
rng: jnp.ndarray,
inputs: jnp.ndarray,
step: jnp.ndarray,
is_training: bool = True,
use_mean_for_eval_stats: bool = True
) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]:
"""Computes the training objective and any supporting stats."""
# Split all rng keys
keys = jnr.split(rng, 6)
# Process training data
images = utils.extract_image(inputs)
image_data, target_data, unroll_kwargs = self.train_data_split(images)
p_x, _, _ = self._models_core(
params=params,
keys=keys,
image_data=image_data,
is_training=is_training,
**unroll_kwargs
)
# Compute training statistics
stats = metrics.training_statistics(
p_x=p_x,
targets=target_data,
rescale_by=self.rescale_by,
p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False)
)
# The loss is just the negative log-likelihood (e.g. the L2 loss)
stats["loss"] = stats["neg_log_p_x"]
if not is_training:
# Optionally add the evaluation stats when not training
# Add also the evaluation statistics
# We need to be able to set `use_mean = False` for some of the tests
stats.update(metrics.evaluation_only_statistics(
reconstruct_func=functools.partial(
self.reconstruct, use_mean=use_mean_for_eval_stats),
params=params,
inputs=inputs,
rng=rng,
rescale_by=self.rescale_by,
can_run_backwards=self.can_run_backwards,
train_sequence_length=self.train_sequence_length,
reconstruction_skip=1,
p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False)
))
return stats["loss"], (dict(), stats, dict())
def reconstruct(
self,
params: utils.Params,
inputs: jnp.ndarray,
rng: jnp.ndarray,
forward: bool,
use_mean: bool = True,
) -> distrax.Distribution:
"""Reconstructs the input sequence."""
if not forward:
raise ValueError("This model can not run backwards.")
images = utils.extract_image(inputs)
image_data = images[:, :self.num_inference_steps]
return self._models_core(
params=params,
keys=jnr.split(rng, 6),
image_data=image_data,
is_training=False,
num_steps_forward=images.shape[1] - self.num_inference_steps,
num_steps_backward=0,
include_z0=False,
)[0]
def gt_state_and_latents(
self,
params: hk.Params,
rng: jnp.ndarray,
inputs: Dict[str, jnp.ndarray],
seq_length: int,
is_training: bool = False,
unroll_direction: str = "forward",
**kwargs: Dict[str, Any]
) -> Tuple[jnp.ndarray, jnp.ndarray,
Union[distrax.Distribution, jnp.ndarray]]:
"""Computes the ground state and matching latents."""
assert unroll_direction == "forward"
images = utils.extract_image(inputs)
gt_state = utils.extract_gt_state(inputs)
image_data = images[:, :self.num_inference_steps]
gt_state = gt_state[:, 1:seq_length + 1]
_, z_in, z_out = self._models_core(
params=params,
keys=jnr.split(rng, 6),
image_data=image_data,
is_training=False,
num_steps_forward=images.shape[1] - self.num_inference_steps,
num_steps_backward=0,
include_z0=False,
)
return gt_state, z_out, z_in
def _init_non_model_params_and_state(
self,
rng: jnp.ndarray
) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]:
return dict(), dict()
def _init_latent_system(
self,
rng: jnp.ndarray,
z: jnp.ndarray,
**kwargs: Any
) -> utils.Params:
return self.recurrence.init(rng, z)
+360
View File
@@ -0,0 +1,360 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing the base abstract classes for sequence models."""
import abc
from typing import Any, Dict, Generic, Mapping, Optional, Sequence, Tuple, TypeVar, Union
from absl import logging
import distrax
import haiku as hk
import jax
import jax.numpy as jnp
import jax.random as jnr
from physics_inspired_models import utils
from physics_inspired_models.models import networks
T = TypeVar("T")
class SequenceModel(abc.ABC, Generic[T]):
"""An abstract class for sequence models."""
def __init__(
self,
can_run_backwards: bool,
latent_system_dim: int,
latent_system_net_type: str,
latent_system_kwargs: Dict[str, Any],
encoder_aggregation_type: Optional[str],
decoder_de_aggregation_type: Optional[str],
encoder_kwargs: Dict[str, Any],
decoder_kwargs: Dict[str, Any],
num_inference_steps: int,
num_target_steps: int,
name: str,
latent_spatial_shape: Optional[Tuple[int, int]] = (4, 4),
has_latent_transform: bool = False,
latent_transform_kwargs: Optional[Dict[str, Any]] = None,
rescale_by: Optional[str] = "pixels_and_time",
data_format: str = "NHWC",
**unused_kwargs
):
# Arguments checks
encoder_kwargs = encoder_kwargs or dict()
decoder_kwargs = decoder_kwargs or dict()
# Set the decoder de-aggregation type the "same" type as the encoder if not
# provided
if (decoder_de_aggregation_type is None and
encoder_aggregation_type is not None):
if encoder_aggregation_type == "linear_projection":
decoder_de_aggregation_type = "linear_projection"
elif encoder_aggregation_type in ("mean", "max"):
decoder_de_aggregation_type = "tile"
else:
raise ValueError(f"Unrecognized encoder_aggregation_type="
f"{encoder_aggregation_type}")
if latent_system_net_type == "conv":
if encoder_aggregation_type is not None:
raise ValueError("When the latent system is convolutional, the encoder "
"aggregation type should be None.")
if decoder_de_aggregation_type is not None:
raise ValueError("When the latent system is convolutional, the decoder "
"aggregation type should be None.")
else:
if encoder_aggregation_type is None:
raise ValueError("When the latent system is not convolutional, the "
"you must provide an encoder aggregation type.")
if decoder_de_aggregation_type is None:
raise ValueError("When the latent system is not convolutional, the "
"you must provide an decoder aggregation type.")
if has_latent_transform and latent_transform_kwargs is None:
raise ValueError("When using latent transformation you have to provide "
"the latent_transform_kwargs argument.")
if unused_kwargs:
logging.warning("Unused kwargs: %s", str(unused_kwargs))
super().__init__(**unused_kwargs)
self.can_run_backwards = can_run_backwards
self.latent_system_dim = latent_system_dim
self.latent_system_kwargs = latent_system_kwargs
self.latent_system_net_type = latent_system_net_type
self.latent_spatial_shape = latent_spatial_shape
self.num_inference_steps = num_inference_steps
self.num_target_steps = num_target_steps
self.rescale_by = rescale_by
self.data_format = data_format
self.name = name
# Encoder
self.encoder_kwargs = encoder_kwargs
self.encoder = hk.transform(
lambda *args, **kwargs: networks.SpatialConvEncoder( # pylint: disable=unnecessary-lambda,g-long-lambda
latent_dim=latent_system_dim,
aggregation_type=encoder_aggregation_type,
data_format=data_format,
name="Encoder",
**encoder_kwargs
)(*args, **kwargs))
# Decoder
self.decoder_kwargs = decoder_kwargs
self.decoder = hk.transform(
lambda *args, **kwargs: networks.SpatialConvDecoder( # pylint: disable=unnecessary-lambda,g-long-lambda
initial_spatial_shape=self.latent_spatial_shape,
de_aggregation_type=decoder_de_aggregation_type,
data_format=data_format,
max_de_aggregation_dims=self.latent_system_dim // 2,
name="Decoder",
**decoder_kwargs,
)(*args, **kwargs))
self.has_latent_transform = has_latent_transform
if has_latent_transform:
self.latent_transform = hk.transform(
lambda *args, **kwargs: networks.make_flexible_net( # pylint: disable=unnecessary-lambda,g-long-lambda
net_type=latent_system_net_type,
output_dims=latent_system_dim,
name="LatentTransform",
**latent_transform_kwargs
)(*args, **kwargs))
else:
self.latent_transform = None
self._jit_init = None
@property
@abc.abstractmethod
def train_sequence_length(self) -> int:
"""Computes the total length of a sequence needed for training or evaluation."""
pass
@abc.abstractmethod
def train_data_split(
self,
images: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, Mapping[str, Any]]:
"""Extracts from the inputs the data splits for training."""
pass
def decode_latents(
self,
params: hk.Params,
rng: jnp.ndarray,
z: jnp.ndarray,
**kwargs: Any
) -> distrax.Distribution:
"""Decodes the latent variable given the parameters of the model."""
# Allow to run with both the full parameters and only the decoders
if self.latent_system_net_type == "mlp":
fixed_dims = 1
elif self.latent_system_net_type == "conv":
fixed_dims = 1 + len(self.latent_spatial_shape)
else:
raise NotImplementedError()
n_shape = z.shape[:-fixed_dims]
z = z.reshape((-1,) + z.shape[-fixed_dims:])
x = self.decoder.apply(params, rng, z, **kwargs)
return jax.tree_map(lambda a: a.reshape(n_shape + a.shape[1:]), x)
def apply_latent_transform(
self,
params: hk.Params,
key: jnp.ndarray,
z: jnp.ndarray,
**kwargs: Any
) -> jnp.ndarray:
if self.latent_transform is not None:
return self.latent_transform.apply(params, key, z, **kwargs)
else:
return z
@abc.abstractmethod
def process_inputs_for_encoder(self, x: jnp.ndarray) -> jnp.ndarray:
pass
@abc.abstractmethod
def process_latents_for_dynamics(self, z: jnp.ndarray) -> T:
pass
@abc.abstractmethod
def process_latents_for_decoder(self, z: T) -> jnp.ndarray:
pass
@abc.abstractmethod
def unroll_latent_dynamics(
self,
z: T,
params: utils.Params,
key: jnp.ndarray,
num_steps_forward: int,
num_steps_backward: int,
include_z0: bool,
is_training: bool,
**kwargs: Any
) -> Tuple[T, Mapping[str, jnp.ndarray]]:
"""Unrolls the latent dynamics starting from z and pre-processing for the decoder."""
pass
@abc.abstractmethod
def reconstruct(
self,
params: utils.Params,
inputs: jnp.ndarray,
rng_key: Optional[jnp.ndarray],
forward: bool,
) -> distrax.Distribution:
"""Using the first `num_inference_steps` parts of inputs reconstructs the rest."""
pass
@abc.abstractmethod
def training_objectives(
self,
params: utils.Params,
state: hk.State,
rng: jnp.ndarray,
inputs: Union[Dict[str, jnp.ndarray], jnp.ndarray],
step: jnp.ndarray,
is_training: bool = True,
use_mean_for_eval_stats: bool = True
) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]:
"""Returns all training objectives statistics and update states."""
pass
@property
@abc.abstractmethod
def inferred_index(self):
"""Returns the time index in the input sequence, for which the encoder infers.
If the encoder takes as input the sequence x[0:n-1], where
`n = self.num_inference_steps`, then this outputs the index `k` relative to
the begging of the input sequence `x_0`, which the encoder infers.
"""
pass
@property
def inferred_right_offset(self):
return self.num_inference_steps - 1 - self.inferred_index
@abc.abstractmethod
def gt_state_and_latents(
self,
params: hk.Params,
rng: jnp.ndarray,
inputs: Dict[str, jnp.ndarray],
seq_len: int,
is_training: bool = False,
unroll_direction: str = "forward",
**kwargs: Dict[str, Any]
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Computes the ground state and matching latents."""
pass
@abc.abstractmethod
def _init_non_model_params_and_state(
self,
rng: jnp.ndarray
) -> Tuple[utils.Params, utils.Params]:
"""Initializes any non-model parameters and state."""
pass
@abc.abstractmethod
def _init_latent_system(
self,
rng: jnp.ndarray,
z: jnp.ndarray,
**kwargs: Any
) -> hk.Params:
"""Initializes the parameters of the latent system."""
pass
def _init(
self,
rng: jnp.ndarray,
images: jnp.ndarray
) -> Tuple[hk.Params, hk.State]:
"""Initializes the whole model parameters and state."""
inference_data, _, _ = self.train_data_split(images)
# Initialize parameters and state for the vae training
rng, key = jnr.split(rng)
params, state = self._init_non_model_params_and_state(key)
# Initialize and run encoder
inference_data = self.process_inputs_for_encoder(inference_data)
rng, key = jnr.split(rng)
encoder_params = self.encoder.init(key, inference_data, is_training=True)
rng, key = jnr.split(rng)
z_in = self.encoder.apply(encoder_params, key, inference_data,
is_training=True)
# For probabilistic models this will be a distribution
if isinstance(z_in, distrax.Distribution):
z_in = z_in.mean()
# Initialize and run the optional latent transform
if self.latent_transform is not None:
rng, key = jnr.split(rng)
transform_params = self.latent_transform.init(key, z_in, is_training=True)
rng, key = jnr.split(rng)
z_in = self.latent_transform.apply(transform_params, key, z_in,
is_training=True)
else:
transform_params = dict()
# Initialize and run the latent system
z_in = self.process_latents_for_dynamics(z_in)
rng, key = jnr.split(rng)
latent_params = self._init_latent_system(key, z_in, is_training=True)
rng, key = jnr.split(rng)
z_out, _ = self.unroll_latent_dynamics(
z=z_in,
params=latent_params,
key=key,
num_steps_forward=1,
num_steps_backward=0,
include_z0=False,
is_training=True
)
z_out = self.process_latents_for_decoder(z_out)
# Initialize and run the decoder
rng, key = jnr.split(rng)
decoder_params = self.decoder.init(key, z_out[:, 0], is_training=True)
_ = self.decoder.apply(decoder_params, rng, z_out[:, 0], is_training=True)
# Combine all and make immutable
params = hk.data_structures.merge(params, encoder_params, transform_params,
latent_params, decoder_params)
params = hk.data_structures.to_immutable_dict(params)
state = hk.data_structures.to_immutable_dict(state)
return params, state
def init(
self,
rng: jnp.ndarray,
inputs_or_shape: Union[jnp.ndarray, Mapping[str, jnp.ndarray],
Sequence[int]],
) -> Tuple[utils.Params, hk.State]:
"""Initializes the whole model parameters and state."""
if (isinstance(inputs_or_shape, (tuple, list))
and isinstance(inputs_or_shape[0], int)):
images = jnp.zeros(inputs_or_shape)
else:
images = utils.extract_image(inputs_or_shape)
if self._jit_init is None:
self._jit_init = jax.jit(self._init)
return self._jit_init(rng, images)
+117
View File
@@ -0,0 +1,117 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for all models."""
from typing import Any, Dict, Optional
import physics_inspired_models.models.autoregressive as autoregressive
import physics_inspired_models.models.deterministic_vae as deterministic_vae
_physics_arguments = (
"input_space", "simulation_space", "potential_func_form",
"kinetic_func_form", "hgn_kinetic_func_form", "lgn_kinetic_func_form",
"parametrize_mass_matrix", "hgn_parametrize_mass_matrix",
"lgn_parametrize_mass_matrix", "mass_eps"
)
def construct_model(
name: str,
*args,
**kwargs: Dict[str, Any]
):
"""Constructs the correct instance of a model given the short name."""
latent_dynamics_type: Optional[str] = kwargs.pop("latent_dynamics_type", None) # pytype: disable=annotation-type-mismatch
latent_system_kwargs = dict(**kwargs.pop("latent_system_kwargs", dict()))
if name == "AR":
assert latent_dynamics_type in ("vanilla", "lstm", "gru")
# This arguments are not part of the AR models
for k in _physics_arguments + ("integrator_method", "residual"):
latent_system_kwargs.pop(k, None)
return autoregressive.TeacherForcingAutoregressiveModel(
*args,
latent_dynamics_type=latent_dynamics_type,
latent_system_kwargs=latent_system_kwargs,
**kwargs
)
elif name == "RGN":
assert latent_dynamics_type in ("Discrete", None)
latent_dynamics_type = "Discrete"
# This arguments are not part of the RGN models
for k in _physics_arguments + ("integrator_method",):
latent_system_kwargs.pop(k, None)
elif name == "ODE":
assert latent_dynamics_type in ("ODE", None)
latent_dynamics_type = "ODE"
# This arguments are not part of the ODE models
for k in _physics_arguments + ("residual",):
latent_system_kwargs.pop(k, None)
elif name == "HGN":
assert latent_dynamics_type in ("Physics", None)
latent_dynamics_type = "Physics"
assert latent_system_kwargs.get("input_space", None) in ("momentum", None)
latent_system_kwargs["input_space"] = "momentum"
assert (latent_system_kwargs.get("simulation_space", None)
in ("momentum", None))
latent_system_kwargs["simulation_space"] = "momentum"
# Kinetic func form
hgn_specific = latent_system_kwargs.pop("hgn_kinetic_func_form", None)
if hgn_specific is not None:
latent_system_kwargs["kinetic_func_form"] = hgn_specific
# Mass matrix
hgn_specific = latent_system_kwargs.pop("hgn_parametrize_mass_matrix",
None)
if hgn_specific is not None:
latent_system_kwargs["parametrize_mass_matrix"] = hgn_specific
# This arguments are not part of the HGN models
latent_system_kwargs.pop("residual", None)
latent_system_kwargs.pop("lgn_kinetic_func_form", None)
latent_system_kwargs.pop("lgn_parametrize_mass_matrix", None)
elif name == "LGN":
assert latent_dynamics_type in ("Physics", None)
latent_dynamics_type = "Physics"
assert latent_system_kwargs.get("input_space", None) in ("velocity", None)
latent_system_kwargs["input_space"] = "velocity"
assert (latent_system_kwargs.get("simulation_space", None) in
("velocity", None))
latent_system_kwargs["simulation_space"] = "velocity"
# Kinetic func form
lgn_specific = latent_system_kwargs.pop("lgn_kinetic_func_form", None)
if lgn_specific is not None:
latent_system_kwargs["kinetic_func_form"] = lgn_specific
# Mass matrix
lgn_specific = latent_system_kwargs.pop("lgn_parametrize_mass_matrix",
None)
if lgn_specific is not None:
latent_system_kwargs["parametrize_mass_matrix"] = lgn_specific
# This arguments are not part of the HGN models
latent_system_kwargs.pop("residual", None)
latent_system_kwargs.pop("hgn_kinetic_func_form", None)
latent_system_kwargs.pop("hgn_parametrize_mass_matrix", None)
elif name == "PGN":
assert latent_dynamics_type in ("Physics", None)
latent_dynamics_type = "Physics"
# This arguments are not part of the PGN models
latent_system_kwargs.pop("residual")
latent_system_kwargs.pop("hgn_kinetic_func_form", None)
latent_system_kwargs.pop("hgn_parametrize_mass_matrix", None)
latent_system_kwargs.pop("lgn_kinetic_func_form", None)
latent_system_kwargs.pop("lgn_parametrize_mass_matrix", None)
else:
raise NotImplementedError()
return deterministic_vae.DeterministicLatentsGenerativeModel(
*args,
latent_dynamics_type=latent_dynamics_type,
latent_system_kwargs=latent_system_kwargs,
**kwargs)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+494
View File
@@ -0,0 +1,494 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing all of the networks as Haiku modules."""
from typing import Any, Callable, Mapping, Optional, Sequence, Union
from absl import logging
import distrax
import haiku as hk
import jax.numpy as jnp
from physics_inspired_models import utils
Activation = Union[str, Callable[[jnp.ndarray], jnp.ndarray]]
class DenseNet(hk.Module):
"""A feed forward network (MLP)."""
def __init__(
self,
num_units: Sequence[int],
activate_final: bool = False,
activation: Activation = "leaky_relu",
name: Optional[str] = None):
super().__init__(name=name)
self.num_units = num_units
self.num_layers = len(self.num_units)
self.activate_final = activate_final
self.activation = utils.get_activation(activation)
self.linear_modules = []
for i in range(self.num_layers):
self.linear_modules.append(
hk.Linear(
output_size=self.num_units[i],
name=f"ff_{i}"
)
)
def __call__(self, inputs: jnp.ndarray, is_training: bool):
net = inputs
for i, linear in enumerate(self.linear_modules):
net = linear(net)
if i < self.num_layers - 1 or self.activate_final:
net = self.activation(net)
return net
class Conv2DNet(hk.Module):
"""Convolutional Network."""
def __init__(
self,
output_channels: Sequence[int],
kernel_shapes: Union[int, Sequence[int]] = 3,
strides: Union[int, Sequence[int]] = 1,
padding: Union[str, Sequence[str]] = "SAME",
data_format: str = "NHWC",
with_batch_norm: bool = False,
activate_final: bool = False,
activation: Activation = "leaky_relu",
name: Optional[str] = None):
super().__init__(name=name)
self.output_channels = tuple(output_channels)
self.num_layers = len(self.output_channels)
self.kernel_shapes = utils.bcast_if(kernel_shapes, int, self.num_layers)
self.strides = utils.bcast_if(strides, int, self.num_layers)
self.padding = utils.bcast_if(padding, str, self.num_layers)
self.data_format = data_format
self.with_batch_norm = with_batch_norm
self.activate_final = activate_final
self.activation = utils.get_activation(activation)
if len(self.kernel_shapes) != self.num_layers:
raise ValueError(f"Kernel shapes is of size {len(self.kernel_shapes)}, "
f"while output_channels is of size{self.num_layers}.")
if len(self.strides) != self.num_layers:
raise ValueError(f"Strides is of size {len(self.kernel_shapes)}, while "
f"output_channels is of size{self.num_layers}.")
if len(self.padding) != self.num_layers:
raise ValueError(f"Padding is of size {len(self.padding)}, while "
f"output_channels is of size{self.num_layers}.")
self.conv_modules = []
self.bn_modules = []
for i in range(self.num_layers):
self.conv_modules.append(
hk.Conv2D(
output_channels=self.output_channels[i],
kernel_shape=self.kernel_shapes[i],
stride=self.strides[i],
padding=self.padding[i],
data_format=data_format,
name=f"conv_2d_{i}")
)
if with_batch_norm:
self.bn_modules.append(
hk.BatchNorm(
create_offset=True,
create_scale=False,
decay_rate=0.999,
name=f"batch_norm_{i}")
)
else:
self.bn_modules.append(None)
def __call__(self, inputs: jnp.ndarray, is_training: bool):
assert inputs.ndim == 4
net = inputs
for i, (conv, bn) in enumerate(zip(self.conv_modules, self.bn_modules)):
net = conv(net)
# Batch norm
if bn is not None:
net = bn(net, is_training=is_training)
if i < self.num_layers - 1 or self.activate_final:
net = self.activation(net)
return net
class SpatialConvEncoder(hk.Module):
"""Spatial Convolutional Encoder for learning the Hamiltonian."""
def __init__(
self,
latent_dim: int,
conv_channels: Union[Sequence[int], int],
num_blocks: int,
blocks_depth: int = 2,
distribution_name: str = "diagonal_normal",
aggregation_type: Optional[str] = None,
data_format: str = "NHWC",
activation: Activation = "leaky_relu",
scale_factor: int = 2,
kernel_shapes: Union[Sequence[int], int] = 3,
padding: Union[Sequence[str], str] = "SAME",
name: Optional[str] = None):
super().__init__(name=name)
if aggregation_type not in (None, "max", "mean", "linear_projection"):
raise ValueError(f"Unrecognized aggregation_type={aggregation_type}.")
self.latent_dim = latent_dim
self.conv_channels = conv_channels
self.num_blocks = num_blocks
self.scale_factor = scale_factor
self.data_format = data_format
self.distribution_name = distribution_name
self.aggregation_type = aggregation_type
# Compute the required size of the output
if distribution_name is None:
self.output_dim = latent_dim
elif distribution_name == "diagonal_normal":
self.output_dim = 2 * latent_dim
else:
raise ValueError(f"Unrecognized distribution_name={distribution_name}.")
if isinstance(conv_channels, int):
conv_channels = [[conv_channels] * blocks_depth
for _ in range(num_blocks)]
conv_channels[-1] += [self.output_dim]
else:
assert isinstance(conv_channels, (list, tuple))
assert len(conv_channels) == num_blocks
conv_channels = list(list(c) for c in conv_channels)
conv_channels[-1].append(self.output_dim)
if isinstance(kernel_shapes, tuple):
kernel_shapes = list(kernel_shapes)
# Convolutional blocks
self.blocks = []
for i, channels in enumerate(conv_channels):
if isinstance(kernel_shapes, int):
extra_kernel_shapes = 0
else:
extra_kernel_shapes = [3] * (len(channels) - len(kernel_shapes))
self.blocks.append(Conv2DNet(
output_channels=channels,
kernel_shapes=kernel_shapes + extra_kernel_shapes,
strides=[self.scale_factor] + [1] * (len(channels) - 1),
padding=padding,
data_format=data_format,
with_batch_norm=False,
activate_final=i < num_blocks - 1,
activation=activation,
name=f"block_{i}"
))
def spatial_aggregation(self, x: jnp.ndarray) -> jnp.ndarray:
if self.aggregation_type is None:
return x
axis = (1, 2) if self.data_format == "NHWC" else (2, 3)
if self.aggregation_type == "max":
return jnp.max(x, axis=axis)
if self.aggregation_type == "mean":
return jnp.mean(x, axis=axis)
if self.aggregation_type == "linear_projection":
x = x.reshape(x.shape[:-3] + (-1,))
return hk.Linear(self.output_dim, name="LinearProjection")(x)
raise NotImplementedError()
def make_distribution(self, net_output: jnp.ndarray) -> distrax.Distribution:
if self.distribution_name is None:
return net_output
elif self.distribution_name == "diagonal_normal":
if self.aggregation_type is None:
split_axis, num_axes = self.data_format.index("C"), 3
else:
split_axis, num_axes = 1, 1
# Add an extra axis if the input has more than 1 batch dimension
split_axis += net_output.ndim - num_axes - 1
loc, log_scale = jnp.split(net_output, 2, axis=split_axis)
return distrax.Normal(loc, jnp.exp(log_scale))
else:
raise NotImplementedError()
def __call__(
self,
inputs: jnp.ndarray,
is_training: bool
) -> Union[jnp.ndarray, distrax.Distribution]:
# Treat any extra dimensions (like time) as the batch
batched_shape = inputs.shape[:-3]
net = jnp.reshape(inputs, (-1,) + inputs.shape[-3:])
# Apply all blocks in sequence
for block in self.blocks:
net = block(net, is_training=is_training)
# Final projection
net = self.spatial_aggregation(net)
# Reshape back to correct dimensions (like batch + time)
net = jnp.reshape(net, batched_shape + net.shape[1:])
# Return a distribution over the observations
return self.make_distribution(net)
class SpatialConvDecoder(hk.Module):
"""Spatial Convolutional Decoder for learning the Hamiltonian."""
def __init__(
self,
initial_spatial_shape: Sequence[int],
conv_channels: Union[Sequence[int], int],
num_blocks: int,
max_de_aggregation_dims: int,
blocks_depth: int = 2,
scale_factor: int = 2,
output_channels: int = 3,
h_const_channels: int = 2,
data_format: str = "NHWC",
activation: Activation = "leaky_relu",
learned_sigma: bool = False,
de_aggregation_type: Optional[str] = None,
final_activation: Activation = "sigmoid",
discard_half_de_aggregated: bool = False,
kernel_shapes: Union[Sequence[int], int] = 3,
padding: Union[Sequence[str], str] = "SAME",
name: Optional[str] = None):
super().__init__(name=name)
if de_aggregation_type not in (None, "tile", "linear_projection"):
raise ValueError(f"Unrecognized de_aggregation_type="
f"{de_aggregation_type}.")
self.num_blocks = num_blocks
self.scale_factor = scale_factor
self.h_const_channels = h_const_channels
self.data_format = data_format
self.learned_sigma = learned_sigma
self.initial_spatial_shape = tuple(initial_spatial_shape)
self.final_activation = utils.get_activation(final_activation)
self.de_aggregation_type = de_aggregation_type
self.max_de_aggregation_dims = max_de_aggregation_dims
self.discard_half_de_aggregated = discard_half_de_aggregated
if isinstance(conv_channels, int):
conv_channels = [[conv_channels] * blocks_depth
for _ in range(num_blocks)]
conv_channels[-1] += [output_channels]
else:
assert isinstance(conv_channels, (list, tuple))
assert len(conv_channels) == num_blocks
conv_channels = list(list(c) for c in conv_channels)
conv_channels[-1].append(output_channels)
# Convolutional blocks
self.blocks = []
for i, channels in enumerate(conv_channels):
is_final_block = i == num_blocks - 1
self.blocks.append(
Conv2DNet( # pylint: disable=g-complex-comprehension
output_channels=channels,
kernel_shapes=kernel_shapes,
strides=1,
padding=padding,
data_format=data_format,
with_batch_norm=False,
activate_final=not is_final_block,
activation=activation,
name=f"block_{i}"
))
def spatial_de_aggregation(self, x: jnp.ndarray) -> jnp.ndarray:
if self.de_aggregation_type is None:
assert x.ndim >= 4
if self.data_format == "NHWC":
assert x.shape[1:3] == self.initial_spatial_shape
elif self.data_format == "NCHW":
assert x.shape[2:4] == self.initial_spatial_shape
return x
elif self.de_aggregation_type == "linear_projection":
assert x.ndim == 2
n, d = x.shape
d = min(d, self.max_de_aggregation_dims or d)
out_d = d * self.initial_spatial_shape[0] * self.initial_spatial_shape[1]
x = hk.Linear(out_d, name="LinearProjection")(x)
if self.data_format == "NHWC":
shape = (n,) + self.initial_spatial_shape + (d,)
else:
shape = (n, d) + self.initial_spatial_shape
return x.reshape(shape)
elif self.de_aggregation_type == "tile":
assert x.ndim == 2
if self.data_format == "NHWC":
repeats = (1,) + self.initial_spatial_shape + (1,)
x = x[:, None, None, :]
else:
repeats = (1, 1) + self.initial_spatial_shape
x = x[:, :, None, None]
return jnp.tile(x, repeats)
else:
raise NotImplementedError()
def add_constant_channels(self, inputs: jnp.ndarray) -> jnp.ndarray:
# --------------------------------------------
# This is purely for TF compatibility purposes
if self.discard_half_de_aggregated:
axis = self.data_format.index("C")
inputs, _ = jnp.split(inputs, 2, axis=axis)
# --------------------------------------------
# An extra constant channels
if self.data_format == "NHWC":
h_shape = self.initial_spatial_shape + (self.h_const_channels,)
else:
h_shape = (self.h_const_channels,) + self.initial_spatial_shape
h_const = hk.get_parameter("h", h_shape, dtype=inputs.dtype,
init=hk.initializers.Constant(1))
h_const = jnp.tile(h_const, reps=[inputs.shape[0], 1, 1, 1])
return jnp.concatenate([h_const, inputs], axis=self.data_format.index("C"))
def make_distribution(self, net_output: jnp.ndarray) -> distrax.Distribution:
if self.learned_sigma:
init = hk.initializers.Constant(- jnp.log(2.0) / 2.0)
log_scale = hk.get_parameter("log_scale", shape=(),
dtype=net_output.dtype, init=init)
scale = jnp.full_like(net_output, jnp.exp(log_scale))
else:
scale = jnp.full_like(net_output, 1 / jnp.sqrt(2.0))
return distrax.Normal(net_output, scale)
def __call__(
self,
inputs: jnp.ndarray,
is_training: bool
) -> distrax.Distribution:
# Apply the spatial de-aggregation
inputs = self.spatial_de_aggregation(inputs)
# Add the parameterized constant channels
net = self.add_constant_channels(inputs)
# Apply all the blocks
for block in self.blocks:
# Up-sample the image
net = utils.nearest_neighbour_upsampling(net, self.scale_factor)
# Apply the convolutional block
net = block(net, is_training=is_training)
# Apply any specific output nonlinearity
net = self.final_activation(net)
# Construct the distribution over the observations
return self.make_distribution(net)
def make_flexible_net(
net_type: str,
output_dims: int,
conv_channels: Union[Sequence[int], int],
num_units: Union[Sequence[int], int],
num_layers: Optional[int],
activation: Activation,
activate_final: bool = False,
kernel_shapes: Union[Sequence[int], int] = 3,
strides: Union[Sequence[int], int] = 1,
padding: Union[Sequence[str], str] = "SAME",
name: Optional[str] = None,
**unused_kwargs: Mapping[str, Any]
):
"""Commonly used for creating a flexible network."""
if unused_kwargs:
logging.warning("Unused kwargs of `make_flexible_net`: %s",
str(unused_kwargs))
if net_type == "mlp":
if isinstance(num_units, int):
assert num_layers is not None
num_units = [num_units] * (num_layers - 1) + [output_dims]
else:
num_units = list(num_units) + [output_dims]
return DenseNet(
num_units=num_units,
activation=activation,
activate_final=activate_final,
name=name
)
elif net_type == "conv":
if isinstance(conv_channels, int):
assert num_layers is not None
conv_channels = [conv_channels] * (num_layers - 1) + [output_dims]
else:
conv_channels = list(conv_channels) + [output_dims]
return Conv2DNet(
output_channels=conv_channels,
kernel_shapes=kernel_shapes,
strides=strides,
padding=padding,
activation=activation,
activate_final=activate_final,
name=name
)
elif net_type == "transformer":
raise NotImplementedError()
else:
raise ValueError(f"Unrecognized net_type={net_type}.")
def make_flexible_recurrent_net(
core_type: str,
net_type: str,
output_dims: int,
num_units: Union[Sequence[int], int],
num_layers: Optional[int],
activation: Activation,
activate_final: bool = False,
name: Optional[str] = None,
**unused_kwargs
):
"""Commonly used for creating a flexible recurrences."""
if net_type != "mlp":
raise ValueError("We do not support convolutional recurrent nets atm.")
if unused_kwargs:
logging.warning("Unused kwargs of `make_flexible_recurrent_net`: %s",
str(unused_kwargs))
if isinstance(num_units, (list, tuple)):
num_units = list(num_units) + [output_dims]
num_layers = len(num_units)
else:
assert num_layers is not None
num_units = [num_units] * (num_layers - 1) + [output_dims]
name = name or f"{core_type.upper()}"
activation = utils.get_activation(activation)
core_list = []
for i, n in enumerate(num_units):
if core_type.lower() == "vanilla":
core_list.append(hk.VanillaRNN(hidden_size=n, name=f"{name}_{i}"))
elif core_type.lower() == "lstm":
core_list.append(hk.LSTM(hidden_size=n, name=f"{name}_{i}"))
elif core_type.lower() == "gru":
core_list.append(hk.GRU(hidden_size=n, name=f"{name}_{i}"))
else:
raise ValueError(f"Unrecognized core_type={core_type}.")
if i != num_layers - 1:
core_list.append(activation)
if activate_final:
core_list.append(activation)
return hk.DeepRNN(core_list, name="RNN")
+10
View File
@@ -0,0 +1,10 @@
git+https://github.com/deepmind/dm_hamiltonian_dynamics_suite@main#egg=dm_hamiltonian_dynamics_suite
absl-py==0.12.0
numpy>=1.16.4
scikit-learn>=1.0
typing>=3.7.4.3
jax==0.2.20
jaxline==0.0.3
distrax==0.0.2
optax==0.0.6
dm-haiku==0.0.3
+54
View File
@@ -0,0 +1,54 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup for pip package."""
from setuptools import setup
REQUIRED_PACKAGES = (
"dm_hamiltonian_dynamics_suite@git+https://github.com/deepmind/dm_hamiltonian_dynamics_suite", # pylint: disable=line-too-long.
"absl-py>=0.12.0",
"numpy>=1.16.4",
"scikit-learn>=1.0",
"typing>=3.7.4.3",
"jax==0.2.20",
"jaxline==0.0.3",
"distrax==0.0.2",
"optax==0.0.6",
"dm-haiku==0.0.3",
)
LONG_DESCRIPTION = "\n".join([
"A codebase containing the implementation of the following models:",
"Hamiltonian Generative Network (HGN)",
"Lagrangian Generative Network (LGN)",
"Neural ODE",
"Recurrent Generative Network (RGN)",
"and RNN, LSTM and GRU.",
"This is code accompanying the publication of:"
])
setup(
name="physics_inspired_models",
version="0.0.1",
description="Implementation of multiple physically inspired models.",
long_description=LONG_DESCRIPTION,
url="https://github.com/deepmind/deepmind-research/physics_inspired_models",
author="DeepMind",
package_dir={"physics_inspired_models": "."},
packages=["physics_inspired_models", "physics_inspired_models.models"],
install_requires=REQUIRED_PACKAGES,
platforms=["any"],
license="Apache License, Version 2.0",
)
+274
View File
@@ -0,0 +1,274 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities functions for Jax."""
import collections
import functools
from typing import Any, Callable, Dict, Mapping, Union
import distrax
import jax
from jax import core
from jax import lax
from jax import nn
import jax.numpy as jnp
from jax.tree_util import register_pytree_node
from jaxline import utils
import numpy as np
HaikuParams = Mapping[str, Mapping[str, jnp.ndarray]]
Params = Union[Mapping[str, jnp.ndarray], HaikuParams, jnp.ndarray]
_Activation = Callable[[jnp.ndarray], jnp.ndarray]
tf_leaky_relu = functools.partial(nn.leaky_relu, negative_slope=0.2)
def filter_only_scalar_stats(stats):
return {k: v for k, v in stats.items() if v.size == 1}
def to_numpy(obj):
return jax.tree_map(np.array, obj)
@jax.custom_gradient
def geco_lagrange_product(lagrange_multiplier, constraint_ema, constraint_t):
"""Modifies the gradients so that they work as described in GECO.
The evaluation gives:
lagrange * C_ema
The gradient w.r.t lagrange:
- g * C_t
The gradient w.r.t constraint_ema:
0.0
The gradient w.r.t constraint_t:
g * lagrange
Note that if you pass the same value for `constraint_ema` and `constraint_t`
this would only flip the gradient for the lagrange multiplier.
Args:
lagrange_multiplier: The lagrange multiplier
constraint_ema: The moving average of the constraint
constraint_t: The current constraint
Returns:
"""
def grad(gradient):
return (- gradient * constraint_t,
jnp.zeros_like(constraint_ema),
gradient * lagrange_multiplier)
return lagrange_multiplier * constraint_ema, grad
def bcast_if(x, t, n):
return [x] * n if isinstance(x, t) else x
def stack_time_into_channels(
images: jnp.ndarray,
data_format: str
) -> jnp.ndarray:
axis = data_format.index("C")
list_of_time = [jnp.squeeze(v, axis=1) for v in
jnp.split(images, images.shape[1], axis=1)]
return jnp.concatenate(list_of_time, axis)
def stack_device_dim_into_batch(obj):
return jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:]), obj)
def nearest_neighbour_upsampling(x, scale, data_format="NHWC"):
"""Performs nearest-neighbour upsampling."""
if data_format == "NCHW":
b, c, h, w = x.shape
x = jnp.reshape(x, [b, c, h, 1, w, 1])
ones = jnp.ones([1, 1, 1, scale, 1, scale], dtype=x.dtype)
return jnp.reshape(x * ones, [b, c, scale * h, scale * w])
elif data_format == "NHWC":
b, h, w, c = x.shape
x = jnp.reshape(x, [b, h, 1, w, 1, c])
ones = jnp.ones([1, 1, scale, 1, scale, 1], dtype=x.dtype)
return jnp.reshape(x * ones, [b, scale * h, scale * w, c])
else:
raise ValueError(f"Unrecognized data_format={data_format}.")
def get_activation(arg: Union[_Activation, str]) -> _Activation:
"""Returns an activation from provided string."""
if isinstance(arg, str):
# Try fetch in order - [this module, jax.nn, jax.numpy]
if arg in globals():
return globals()[arg]
if hasattr(nn, arg):
return getattr(nn, arg)
elif hasattr(jnp, arg):
return getattr(jnp, arg)
else:
raise ValueError(f"Unrecognized activation with name {arg}.")
if not callable(arg):
raise ValueError(f"Expected a callable, but got {type(arg)}")
return arg
def merge_first_dims(x: jnp.ndarray, num_dims_to_merge: int = 2) -> jnp.ndarray:
return x.reshape((-1,) + x.shape[num_dims_to_merge:])
def extract_image(
inputs: Union[jnp.ndarray, Mapping[str, jnp.ndarray]]
) -> jnp.ndarray:
"""Extracts a tensor with key `image` or `x_image` if it is a dict, otherwise returns the inputs."""
if isinstance(inputs, dict):
if "image" in inputs:
return inputs["image"]
else:
return inputs["x_image"]
elif isinstance(inputs, jnp.ndarray):
return inputs
raise NotImplementedError(f"Not implemented of inputs of type"
f" {type(inputs)}.")
def extract_gt_state(inputs: Any) -> jnp.ndarray:
if isinstance(inputs, dict):
return inputs["x"]
elif not isinstance(inputs, jnp.ndarray):
raise NotImplementedError(f"Not implemented of inputs of type"
f" {type(inputs)}.")
return inputs
def reshape_latents_conv_to_flat(conv_latents, axis_n_to_keep=1):
q, p = jnp.split(conv_latents, 2, axis=-1)
q = jax.tree_map(lambda x: x.reshape(x.shape[:axis_n_to_keep] + (-1,)), q)
p = jax.tree_map(lambda x: x.reshape(x.shape[:axis_n_to_keep] + (-1,)), p)
flat_latents = jnp.concatenate([q, p], axis=-1)
return flat_latents
def triu_matrix_from_v(x, ndim):
assert x.shape[-1] == (ndim * (ndim + 1)) // 2
matrix = jnp.zeros(x.shape[:-1] + (ndim, ndim))
idx = jnp.triu_indices(ndim)
index_update = lambda x, idx, y: x.at[idx].set(y)
for _ in range(x.ndim - 1):
index_update = jax.vmap(index_update, in_axes=(0, None, 0))
return index_update(matrix, idx, x)
def flatten_dict(d, parent_key: str = "", sep: str = "_") -> Dict[str, Any]:
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def convert_to_pytype(target, reference):
"""Makes target the same pytype as reference, by jax.tree_flatten."""
_, pytree = jax.tree_flatten(reference)
leaves, _ = jax.tree_flatten(target)
return jax.tree_unflatten(pytree, leaves)
def func_if_not_scalar(func):
"""Makes a function that uses func only on non-scalar values."""
@functools.wraps(func)
def wrapped(array, axis=0):
if array.ndim == 0:
return array
return func(array, axis=axis)
return wrapped
mean_if_not_scalar = func_if_not_scalar(jnp.mean)
class MultiBatchAccumulator(object):
"""Class for abstracting statistics accumulation over multiple batches."""
def __init__(self):
self._obj = None
self._obj_max = None
self._obj_min = None
self._num_samples = None
def add(self, averaged_values, num_samples):
"""Adds an element to the moving average and the max."""
if self._obj is None:
self._obj_max = jax.tree_map(lambda y: y * 1.0, averaged_values)
self._obj_min = jax.tree_map(lambda y: y * 1.0, averaged_values)
self._obj = jax.tree_map(lambda y: y * num_samples, averaged_values)
self._num_samples = num_samples
else:
self._obj_max = jax.tree_multimap(jnp.maximum, self._obj_max,
averaged_values)
self._obj_min = jax.tree_multimap(jnp.minimum, self._obj_min,
averaged_values)
self._obj = jax.tree_multimap(lambda x, y: x + y * num_samples, self._obj,
averaged_values)
self._num_samples += num_samples
def value(self):
return jax.tree_map(lambda x: x / self._num_samples, self._obj)
def max(self):
return jax.tree_map(float, self._obj_max)
def min(self):
return jax.tree_map(float, self._obj_min)
def sum(self):
return self._obj
register_pytree_node(
distrax.Normal,
lambda instance: ([instance.loc, instance.scale], None),
lambda _, args: distrax.Normal(*args)
)
def inner_product(x: Any, y: Any) -> jnp.ndarray:
products = jax.tree_multimap(lambda x_, y_: jnp.sum(x_ * y_), x, y)
return sum(jax.tree_leaves(products))
get_first = utils.get_first
bcast_local_devices = utils.bcast_local_devices
py_prefetch = utils.py_prefetch
p_split = jax.pmap(lambda x, num: list(jax.random.split(x, num)),
static_broadcasted_argnums=1)
def wrap_if_pmap(p_func):
def p_func_if_pmap(obj, axis_name):
try:
core.axis_frame(axis_name)
return p_func(obj, axis_name)
except NameError:
return obj
return p_func_if_pmap
pmean_if_pmap = wrap_if_pmap(lax.pmean)
psum_if_pmap = wrap_if_pmap(lax.psum)