mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-01 21:56:38 +08:00
Open sourcing the physics inspired models code.
PiperOrigin-RevId: 408640606
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
9b751b7d20
commit
2c7c401024
@@ -85,6 +85,7 @@ https://deepmind.com/research/publications/
|
|||||||
* [REGAL: Transfer Learning for Fast Optimization of Computation Graphs](regal)
|
* [REGAL: Transfer Learning for Fast Optimization of Computation Graphs](regal)
|
||||||
* [Deep Ensembles: A Loss Landscape Perspective](ensemble_loss_landscape)
|
* [Deep Ensembles: A Loss Landscape Perspective](ensemble_loss_landscape)
|
||||||
* [Powerpropagation](powerpropagation)
|
* [Powerpropagation](powerpropagation)
|
||||||
|
* [Physics Inspired Models](physics_inspired_models)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,59 @@
|
|||||||
|
# Implementation of multiple physics inspired models for modelling dynamics
|
||||||
|
|
||||||
|
This repository contains an implementation of different physics inspired models
|
||||||
|
used in the papers: **SyMetric: Measuring the Quality of Learnt Hamiltonian
|
||||||
|
Dynamics Inferred from Vision** and **Which priors matter? Benchmarking models
|
||||||
|
for learning latent dynamics**.
|
||||||
|
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
This is purely research code, provided with no further intentions of support or
|
||||||
|
any guarantees of backward compatibility.
|
||||||
|
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
All package requirements are listed in `requirements.txt`.
|
||||||
|
You will still need to download and setup the datasets from the
|
||||||
|
[DeepMind Hamiltonian Dynamics Suite] manually.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone git@github.com:deepmind/deepmind-research.git
|
||||||
|
pip install -r ./deepmind_research/physics_inspired_models/requirements.txt
|
||||||
|
pip install ./deepmind_research/physics_inspired_models
|
||||||
|
pip install --upgrade "jax[XXX]"
|
||||||
|
```
|
||||||
|
|
||||||
|
where `XXX` is the correct type of accelerator that you have on your machine.
|
||||||
|
Note that if you are using a GPU you might need `XXX` to also include the
|
||||||
|
correct version of CUDA and cuDNN installed on your machine.
|
||||||
|
For more details please read [here](https://github.com/google/jax#installation).
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
The file `jaxline_configs.py` contains all the configurations specifications for
|
||||||
|
the experiments in the two papers. To run an experiment, in addition to passing
|
||||||
|
the location of the configs file, you must provide extra arguments in the
|
||||||
|
following manner:
|
||||||
|
|
||||||
|
`${name_of_configuration},${index_in_sweep},${dataset_name}`
|
||||||
|
|
||||||
|
For example to run the second hyper-parameter configuration of the improved
|
||||||
|
Hamiltonian Generative Network (HGN++) on the mass-spring dataset you should
|
||||||
|
run in the command line (assuming that you are in the folder of the project):
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python3 jaxline_train.py \
|
||||||
|
--config="jaxline_configs.py:sym_metric_hgn_plus_plus_sweep,1,toy_physics/mass_spring" \
|
||||||
|
--jaxline_mode="train" \
|
||||||
|
--logtostderr
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
**SyMetric: Measuring the Quality of Learnt Hamiltonian Dynamics Inferred from Vision**
|
||||||
|
|
||||||
|
**Which priors matter? Benchmarking models for learning latent dynamics**
|
||||||
|
|
||||||
|
[DeepMind Hamiltonian Dynamics Suite]: https://github.com/deepmind/dm_hamiltonian_dynamics_suite
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,353 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Module containing model evaluation metric."""
|
||||||
|
import _thread as thread
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from absl import logging
|
||||||
|
import distrax
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sklearn import linear_model
|
||||||
|
from sklearn import model_selection
|
||||||
|
from sklearn import preprocessing
|
||||||
|
|
||||||
|
|
||||||
|
def quit_function(fn_name):
|
||||||
|
logging.error('%s took too long', fn_name)
|
||||||
|
sys.stderr.flush()
|
||||||
|
thread.interrupt_main()
|
||||||
|
|
||||||
|
|
||||||
|
def exit_after(s):
|
||||||
|
"""Use as decorator to exit function after s seconds."""
|
||||||
|
def outer(fn):
|
||||||
|
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
timer = threading.Timer(s, quit_function, args=[fn.__name__])
|
||||||
|
timer.start()
|
||||||
|
try:
|
||||||
|
result = fn(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
timer.cancel()
|
||||||
|
return result
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
return outer
|
||||||
|
|
||||||
|
|
||||||
|
@exit_after(400)
|
||||||
|
def do_grid_search(data_x_exp, data_y, clf, parameters, cv):
|
||||||
|
scoring_choice = 'explained_variance'
|
||||||
|
regressor = model_selection.GridSearchCV(
|
||||||
|
clf, parameters, cv=cv, refit=True, scoring=scoring_choice)
|
||||||
|
regressor.fit(data_x_exp, data_y)
|
||||||
|
return regressor
|
||||||
|
|
||||||
|
|
||||||
|
def symplectic_matrix(dim):
|
||||||
|
"""Return anti-symmetric identity matrix of given dimensionality."""
|
||||||
|
half_dims = int(dim/2)
|
||||||
|
eye = np.eye(half_dims)
|
||||||
|
zeros = np.zeros([half_dims, half_dims])
|
||||||
|
top_rows = np.concatenate([zeros, - eye], axis=1)
|
||||||
|
bottom_rows = np.concatenate([eye, zeros], axis=1)
|
||||||
|
return np.concatenate([top_rows, bottom_rows], axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def create_latent_mask(z0, dist_std_threshold=0.5):
|
||||||
|
"""Create mask based on informativeness of each latent dimension.
|
||||||
|
|
||||||
|
For stochastic models those latent dimensions that are too close to the prior
|
||||||
|
are likely to be uninformative and can be ignored.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z0: distribution or array of phase space
|
||||||
|
dist_std_threshold: informative latents have average inferred stds <
|
||||||
|
dist_std_threshold
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
latent_mask_final: boolean mask of the same dimensionality as z0
|
||||||
|
"""
|
||||||
|
if isinstance(z0, distrax.Normal):
|
||||||
|
std_vals = np.mean(z0.variance(), axis=0)
|
||||||
|
elif isinstance(z0, distrax.Distribution):
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
# If the latent is deterministic, pass through all dimensions
|
||||||
|
return np.array([True]*z0.shape[-1])
|
||||||
|
|
||||||
|
tensor_shape = std_vals.shape
|
||||||
|
half_dims = int(tensor_shape[-1] / 2)
|
||||||
|
|
||||||
|
std_vals_q = std_vals[:half_dims]
|
||||||
|
std_vals_p = std_vals[half_dims:]
|
||||||
|
|
||||||
|
# Keep both q and corresponding p as either one is informative
|
||||||
|
informative_latents_inds = np.array([
|
||||||
|
x for x in range(len(std_vals_q)) if
|
||||||
|
std_vals_q[x] < dist_std_threshold or std_vals_p[x] < dist_std_threshold
|
||||||
|
])
|
||||||
|
|
||||||
|
if informative_latents_inds.shape[0] > 0:
|
||||||
|
latent_mask_final = np.zeros_like(std_vals_q)
|
||||||
|
latent_mask_final[informative_latents_inds] = 1
|
||||||
|
latent_mask_final = np.concatenate([latent_mask_final, latent_mask_final])
|
||||||
|
latent_mask_final = latent_mask_final == 1
|
||||||
|
|
||||||
|
return latent_mask_final
|
||||||
|
else:
|
||||||
|
return np.array([True]*tensor_shape[-1])
|
||||||
|
|
||||||
|
|
||||||
|
def standardize_data(data):
|
||||||
|
"""Applies the sklearn standardization to the data."""
|
||||||
|
scaler = preprocessing.StandardScaler()
|
||||||
|
scaler.fit(data)
|
||||||
|
return scaler.transform(data)
|
||||||
|
|
||||||
|
|
||||||
|
def find_best_polynomial(data_x, data_y, max_poly_order, rsq_threshold,
|
||||||
|
max_dim_n=32,
|
||||||
|
alpha_sweep=None,
|
||||||
|
max_iter=1000, cv=2):
|
||||||
|
"""Find minimal polynomial expansion that is sufficient to explain data using Lasso regression."""
|
||||||
|
rsq = 0
|
||||||
|
poly_order = 1
|
||||||
|
|
||||||
|
if not np.any(alpha_sweep):
|
||||||
|
alpha_sweep = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2]
|
||||||
|
|
||||||
|
# Avoid a large polynomial expansion for large latent sizes
|
||||||
|
if data_x.shape[-1] > max_dim_n:
|
||||||
|
print(f'>WARNING! Data is too high dimensional at {data_x.shape[-1]}')
|
||||||
|
print('>WARNING! Setting max_poly_order = 1')
|
||||||
|
max_poly_order = 1
|
||||||
|
|
||||||
|
while rsq < rsq_threshold and poly_order <= max_poly_order:
|
||||||
|
time_start = time.perf_counter()
|
||||||
|
poly = preprocessing.PolynomialFeatures(poly_order, include_bias=False)
|
||||||
|
data_x_exp = poly.fit_transform(data_x)
|
||||||
|
time_end = time.perf_counter()
|
||||||
|
print(
|
||||||
|
f'Took {time_end-time_start}s to create polynomial features of order '
|
||||||
|
f'{poly_order} and size {data_x_exp.shape[1]}.')
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter('ignore')
|
||||||
|
time_start = time.perf_counter()
|
||||||
|
clf = linear_model.Lasso(
|
||||||
|
random_state=0, max_iter=max_iter, normalize=False, warm_start=False)
|
||||||
|
parameters = {'alpha': alpha_sweep}
|
||||||
|
try:
|
||||||
|
regressor = do_grid_search(data_x_exp, data_y, clf, parameters, cv)
|
||||||
|
time_end = time.perf_counter()
|
||||||
|
print(f'Took {time_end-time_start}s to do regression grid search.')
|
||||||
|
|
||||||
|
# Get rsq results
|
||||||
|
time_start = time.perf_counter()
|
||||||
|
clf = linear_model.Lasso(
|
||||||
|
random_state=0,
|
||||||
|
alpha=regressor.best_params_['alpha'],
|
||||||
|
max_iter=max_iter,
|
||||||
|
normalize=False,
|
||||||
|
warm_start=False)
|
||||||
|
clf.fit(data_x_exp, data_y)
|
||||||
|
rsq = clf.score(data_x_exp, data_y)
|
||||||
|
time_end = time.perf_counter()
|
||||||
|
print(f'Took {time_end-time_start}s to get rsq results.')
|
||||||
|
|
||||||
|
old_regressor = regressor
|
||||||
|
old_poly_order = poly_order
|
||||||
|
old_poly = poly
|
||||||
|
old_data_x_exp = data_x_exp
|
||||||
|
old_rsq = rsq
|
||||||
|
old_clf = clf
|
||||||
|
print(f'Polynomial of order {poly_order} with '
|
||||||
|
f' alpha={regressor.best_params_} RSQ: {rsq}')
|
||||||
|
poly_order += 1
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
time_end = time.perf_counter()
|
||||||
|
print(f'Timed out after {time_end-time_start}s of doing grid search.')
|
||||||
|
print(f'Continuing with previous poly_order={old_poly_order}...')
|
||||||
|
regressor = old_regressor
|
||||||
|
poly_order = old_poly_order
|
||||||
|
poly = old_poly
|
||||||
|
data_x_exp = old_data_x_exp
|
||||||
|
rsq = old_rsq
|
||||||
|
clf = old_clf
|
||||||
|
print(f'Polynomial of order {poly_order} with '
|
||||||
|
f' alpha={regressor.best_params_} RSQ: {rsq}')
|
||||||
|
break
|
||||||
|
|
||||||
|
return clf, poly, data_x_exp, rsq
|
||||||
|
|
||||||
|
|
||||||
|
def eval_monomial_grad(feature, x, w, grad_acc):
|
||||||
|
"""Accumulates gradient from polynomial features and their weights."""
|
||||||
|
features = feature.split(' ')
|
||||||
|
variable_indices = []
|
||||||
|
grads = np.ones(len(features)) * w
|
||||||
|
for i, feature in enumerate(features):
|
||||||
|
name_and_power = feature.split('^')
|
||||||
|
if len(name_and_power) == 1:
|
||||||
|
name, power = name_and_power[0], 1
|
||||||
|
else:
|
||||||
|
name, power = name_and_power
|
||||||
|
power = int(power)
|
||||||
|
var_index = int(name[1:])
|
||||||
|
variable_indices.append(var_index)
|
||||||
|
new_prod = np.ones_like(grads) * (x[var_index] ** power)
|
||||||
|
# This needs a special case, for situation where x[index] = 0.0
|
||||||
|
if power == 1:
|
||||||
|
new_prod[i] = 1.0
|
||||||
|
else:
|
||||||
|
new_prod[i] = power * (x[var_index] ** (power - 1))
|
||||||
|
grads = grads * new_prod
|
||||||
|
grad_acc[variable_indices] += grads
|
||||||
|
return grad_acc
|
||||||
|
|
||||||
|
|
||||||
|
def compute_jacobian_manual(x, polynomial_features, weight_matrix, tolerance):
|
||||||
|
"""Computes the jacobian manually."""
|
||||||
|
# Put together the equation for each output var
|
||||||
|
# polynomial_features = np.array(polynomial_obj.get_feature_names())
|
||||||
|
weight_mask = np.abs(weight_matrix) > tolerance
|
||||||
|
weight_matrix = weight_mask * weight_matrix
|
||||||
|
jacobians = list()
|
||||||
|
for i in range(weight_matrix.shape[0]):
|
||||||
|
grad_accumulator = np.zeros_like(x)
|
||||||
|
for j, feature in enumerate(polynomial_features):
|
||||||
|
eval_monomial_grad(feature, x, weight_matrix[i, j], grad_accumulator)
|
||||||
|
jacobians.append(grad_accumulator)
|
||||||
|
return np.stack(jacobians)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_jacobian_prod(jacobian, noise_eps=1e-6):
|
||||||
|
"""Calculates AA*, where A=JEJ^T and A*=JE^TJ^T, which should be I."""
|
||||||
|
# Add noise as 0 in jacobian creates issues in calculations later
|
||||||
|
jacobian = jacobian + noise_eps
|
||||||
|
sym_matrix = symplectic_matrix(jacobian.shape[1])
|
||||||
|
pred = np.matmul(jacobian, sym_matrix)
|
||||||
|
pred = np.matmul(pred, np.transpose(jacobian))
|
||||||
|
|
||||||
|
pred_t = np.matmul(jacobian, np.transpose(sym_matrix))
|
||||||
|
pred_t = np.matmul(pred_t, np.transpose(jacobian))
|
||||||
|
|
||||||
|
pred_id = np.matmul(pred, pred_t)
|
||||||
|
|
||||||
|
return pred_id
|
||||||
|
|
||||||
|
|
||||||
|
def normalise_jacobian_prods(jacobian_preds):
|
||||||
|
"""Normalises Jacobians evaluated at various points by a constant."""
|
||||||
|
stacked_preds = np.stack(jacobian_preds)
|
||||||
|
# For each attempt at estimating E, get the max term, and take their average
|
||||||
|
normalisation_factor = np.mean(np.max(np.abs(stacked_preds), axis=(1, 2)))
|
||||||
|
|
||||||
|
if normalisation_factor != 0:
|
||||||
|
stacked_preds = stacked_preds/normalisation_factor
|
||||||
|
|
||||||
|
return stacked_preds
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_symetric_score(
|
||||||
|
gt_data,
|
||||||
|
model_data,
|
||||||
|
max_poly_order,
|
||||||
|
max_sym_score,
|
||||||
|
rsq_threshold,
|
||||||
|
sym_threshold,
|
||||||
|
evaluation_point_n,
|
||||||
|
trajectory_n=1,
|
||||||
|
weight_tolerance=1e-5,
|
||||||
|
alpha_sweep=None,
|
||||||
|
max_iter=1000,
|
||||||
|
cv=2):
|
||||||
|
"""Finds minimal polynomial expansion to explain data using Lasso regression, gets the Jacobian of the mapping and calculates how symplectic the map is."""
|
||||||
|
model_data = model_data[..., :gt_data.shape[0], :]
|
||||||
|
|
||||||
|
# Fing polynomial expansion that explains enough variance in the gt data
|
||||||
|
print('Finding best polynomial expansion...')
|
||||||
|
time_start = time.perf_counter()
|
||||||
|
# Clean up model data to ensure it doesn't contain NaN, infinity
|
||||||
|
# or values too large for dtype('float32')
|
||||||
|
model_data = np.nan_to_num(model_data)
|
||||||
|
model_data = np.clip(model_data, -999999, 999999)
|
||||||
|
|
||||||
|
clf, poly, model_data_exp, best_rsq = find_best_polynomial(
|
||||||
|
model_data, gt_data, max_poly_order, rsq_threshold,
|
||||||
|
32, alpha_sweep, max_iter, cv)
|
||||||
|
time_end = time.perf_counter()
|
||||||
|
print(f'Took {time_end - time_start}s to find best polynomial.')
|
||||||
|
|
||||||
|
# Calculate Symplecticity score
|
||||||
|
all_raw_scores = []
|
||||||
|
features = np.array(poly.get_feature_names())
|
||||||
|
|
||||||
|
points_per_trajectory = int(len(gt_data) / trajectory_n)
|
||||||
|
for trajectory in range(trajectory_n):
|
||||||
|
random_data_inds = np.random.permutation(
|
||||||
|
range(points_per_trajectory))[:evaluation_point_n]
|
||||||
|
|
||||||
|
jacobian_preds = []
|
||||||
|
for point_ind in random_data_inds:
|
||||||
|
input_data_point = model_data[points_per_trajectory * trajectory +
|
||||||
|
point_ind]
|
||||||
|
time_start = time.perf_counter()
|
||||||
|
jacobian = compute_jacobian_manual(input_data_point, features,
|
||||||
|
clf.coef_, weight_tolerance)
|
||||||
|
pred = calculate_jacobian_prod(jacobian)
|
||||||
|
jacobian_preds.append(pred)
|
||||||
|
time_end = time.perf_counter()
|
||||||
|
print(f'Took {time_end - time_start}s to evaluate jacobian '
|
||||||
|
f'around point {point_ind}.')
|
||||||
|
|
||||||
|
# Normalise
|
||||||
|
normalised_jacobian_preds = normalise_jacobian_prods(jacobian_preds)
|
||||||
|
# The score is measured as the deviation from I
|
||||||
|
identity = np.eye(normalised_jacobian_preds.shape[-1])
|
||||||
|
scores = np.mean(np.power(normalised_jacobian_preds - identity, 2),
|
||||||
|
axis=(1, 2))
|
||||||
|
all_raw_scores.append(scores)
|
||||||
|
|
||||||
|
sym_score = np.min([np.mean(all_raw_scores), max_sym_score])
|
||||||
|
# Calculate final SyMetric score
|
||||||
|
if best_rsq > rsq_threshold and sym_score < sym_threshold:
|
||||||
|
sy_metric = 1.0
|
||||||
|
else:
|
||||||
|
sy_metric = 0.0
|
||||||
|
|
||||||
|
results = {
|
||||||
|
'poly_exp_order': poly.get_params()['degree'],
|
||||||
|
'rsq': best_rsq,
|
||||||
|
'sym': sym_score,
|
||||||
|
'SyMetric': sy_metric,
|
||||||
|
}
|
||||||
|
with np.printoptions(precision=4, suppress=True):
|
||||||
|
print(f'----------------FINAL RESULTS FOR {trajectory_n} '
|
||||||
|
'TRAJECTORIES------------------')
|
||||||
|
print(f'BEST POLYNOMIAL EXPANSION ORDER: {results["poly_exp_order"]}')
|
||||||
|
print(f'BEST RSQ (1-best): {results["rsq"]}')
|
||||||
|
print(f'SYMPLECTICITY SCORE AROUND ALL POINTS AND ALL '
|
||||||
|
f'TRAJECTORIES (0-best): {sym_score}')
|
||||||
|
print(f'SyMETRIC SCORE: {sy_metric}')
|
||||||
|
print(f'----------------FINAL RESULTS FOR {trajectory_n} '
|
||||||
|
f'TRAJECTORIES------------------')
|
||||||
|
return results, clf, poly, model_data_exp
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,397 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Module containing all of the configurations for various models."""
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
from jaxline import base_config
|
||||||
|
import ml_collections as collections
|
||||||
|
|
||||||
|
_DATASETS_PATH_VAR_NAME = "DM_HAMILTONIAN_DYNAMICS_SUITE_DATASETS"
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(arg_string):
|
||||||
|
"""Return config object for training."""
|
||||||
|
args = arg_string.split(",")
|
||||||
|
if len(args) != 3:
|
||||||
|
raise ValueError("You must provide exactly three arguments separated by a "
|
||||||
|
"comma - model_config_name,sweep_index,dataset_name.")
|
||||||
|
model_config_name, sweep_index, dataset_name = args
|
||||||
|
sweep_index = int(sweep_index)
|
||||||
|
|
||||||
|
config = base_config.get_base_config()
|
||||||
|
config.random_seed = 123109801
|
||||||
|
config.eval_modes = ("eval", "eval_metric")
|
||||||
|
|
||||||
|
# Get the model config and the sweeps
|
||||||
|
if model_config_name not in globals():
|
||||||
|
raise ValueError(f"The config name {model_config_name} does not exist in "
|
||||||
|
f"jaxline_configs.py")
|
||||||
|
config_and_sweep_fn = globals()[model_config_name]
|
||||||
|
model_config, sweeps = config_and_sweep_fn()
|
||||||
|
|
||||||
|
if not os.environ.get(_DATASETS_PATH_VAR_NAME, None):
|
||||||
|
raise ValueError(f"You need to set the {_DATASETS_PATH_VAR_NAME}")
|
||||||
|
dm_hamiltonian_suite_path = os.environ[_DATASETS_PATH_VAR_NAME]
|
||||||
|
dataset_folder = os.path.join(dm_hamiltonian_suite_path, dataset_name)
|
||||||
|
|
||||||
|
# Experiment config. Note that batch_size is per device.
|
||||||
|
# In the experiments we run on 4 GPUs, so the effective batch size was 128.
|
||||||
|
config.experiment_kwargs = collections.ConfigDict(
|
||||||
|
dict(
|
||||||
|
config=dict(
|
||||||
|
dataset_folder=dataset_folder,
|
||||||
|
model_kwargs=model_config,
|
||||||
|
num_extrapolation_steps=60,
|
||||||
|
drop_stats_containing=("neg_log_p_x", "l2_over_time", "neg_elbo"),
|
||||||
|
optimizer=dict(
|
||||||
|
name="adam",
|
||||||
|
kwargs=dict(
|
||||||
|
learning_rate=1.5e-4,
|
||||||
|
b1=0.9,
|
||||||
|
b2=0.999,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
training=dict(
|
||||||
|
batch_size=32,
|
||||||
|
burnin_steps=5,
|
||||||
|
num_epochs=None,
|
||||||
|
lagging_vae=False
|
||||||
|
),
|
||||||
|
evaluation=dict(
|
||||||
|
batch_size=64,
|
||||||
|
),
|
||||||
|
evaluation_metric=dict(
|
||||||
|
batch_size=5,
|
||||||
|
batch_n=20,
|
||||||
|
num_eval_metric_steps=60,
|
||||||
|
max_poly_order=5,
|
||||||
|
max_jacobian_score=1000,
|
||||||
|
rsq_threshold=0.9,
|
||||||
|
sym_threshold=0.05,
|
||||||
|
evaluation_point_n=10,
|
||||||
|
weight_tolerance=1e-03,
|
||||||
|
max_iter=1000,
|
||||||
|
cv=2,
|
||||||
|
alpha_min_logspace=-4,
|
||||||
|
alpha_max_logspace=-0.5,
|
||||||
|
alpha_step_n=10,
|
||||||
|
calculate_fully_after_steps=40000,
|
||||||
|
),
|
||||||
|
evaluation_metric_mlp=dict(
|
||||||
|
batch_size=64,
|
||||||
|
batch_n=10000,
|
||||||
|
datapoint_param_multiplier=1000,
|
||||||
|
num_eval_metric_steps=60,
|
||||||
|
evaluation_point_n=10,
|
||||||
|
evaluation_trajectory_n=50,
|
||||||
|
rsq_threshold=0.9,
|
||||||
|
sym_threshold=0.05,
|
||||||
|
ridge_lambda=0.01,
|
||||||
|
model=dict(
|
||||||
|
num_units=4,
|
||||||
|
num_layers=4,
|
||||||
|
activation="tanh",
|
||||||
|
),
|
||||||
|
optimizer=dict(
|
||||||
|
name="adam",
|
||||||
|
kwargs=dict(
|
||||||
|
learning_rate=1.5e-3,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
evaluation_vpt=dict(
|
||||||
|
batch_size=5,
|
||||||
|
batch_n=2,
|
||||||
|
vpt_threshold=0.025,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop config.
|
||||||
|
config.training_steps = int(500000)
|
||||||
|
config.interval_type = "steps"
|
||||||
|
config.log_tensors_interval = 50
|
||||||
|
config.log_train_data_interval = 50
|
||||||
|
config.log_all_train_data = False
|
||||||
|
|
||||||
|
config.save_checkpoint_interval = 100
|
||||||
|
config.checkpoint_dir = "/tmp/physics_inspired_models/"
|
||||||
|
config.train_checkpoint_all_hosts = False
|
||||||
|
config.eval_specific_checkpoint_dir = ""
|
||||||
|
|
||||||
|
config.update_from_flattened_dict(sweeps[sweep_index])
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
config_prefix = "experiment_kwargs.config."
|
||||||
|
model_prefix = config_prefix + "model_kwargs."
|
||||||
|
|
||||||
|
default_encoder_kwargs = collections.ConfigDict(dict(
|
||||||
|
conv_channels=64,
|
||||||
|
num_blocks=3,
|
||||||
|
blocks_depth=2,
|
||||||
|
activation="leaky_relu",
|
||||||
|
))
|
||||||
|
|
||||||
|
default_decoder_kwargs = collections.ConfigDict(dict(
|
||||||
|
conv_channels=64,
|
||||||
|
num_blocks=3,
|
||||||
|
blocks_depth=2,
|
||||||
|
activation="leaky_relu",
|
||||||
|
))
|
||||||
|
|
||||||
|
default_latent_system_net_kwargs = collections.ConfigDict(dict(
|
||||||
|
conv_channels=64,
|
||||||
|
num_units=250,
|
||||||
|
num_layers=5,
|
||||||
|
activation="swish",
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
default_latent_system_kwargs = collections.ConfigDict(dict(
|
||||||
|
# Physics model arguments
|
||||||
|
input_space=collections.config_dict.placeholder(str),
|
||||||
|
simulation_space=collections.config_dict.placeholder(str),
|
||||||
|
potential_func_form="separable_net",
|
||||||
|
kinetic_func_form=collections.config_dict.placeholder(str),
|
||||||
|
hgn_kinetic_func_form="separable_net",
|
||||||
|
lgn_kinetic_func_form="matrix_dep_quad",
|
||||||
|
parametrize_mass_matrix=collections.config_dict.placeholder(bool),
|
||||||
|
hgn_parametrize_mass_matrix=False,
|
||||||
|
lgn_parametrize_mass_matrix=True,
|
||||||
|
mass_eps=1.0,
|
||||||
|
# ODE model arguments
|
||||||
|
integrator_method=collections.config_dict.placeholder(str),
|
||||||
|
# RGN model arguments
|
||||||
|
residual=collections.config_dict.placeholder(bool),
|
||||||
|
# General arguments
|
||||||
|
net_kwargs=default_latent_system_net_kwargs
|
||||||
|
))
|
||||||
|
|
||||||
|
default_config_dict = collections.ConfigDict(dict(
|
||||||
|
name=collections.config_dict.placeholder(str),
|
||||||
|
latent_system_dim=32,
|
||||||
|
latent_system_net_type="mlp",
|
||||||
|
latent_system_kwargs=default_latent_system_kwargs,
|
||||||
|
encoder_aggregation_type="linear_projection",
|
||||||
|
decoder_de_aggregation_type=collections.config_dict.placeholder(str),
|
||||||
|
encoder_kwargs=default_encoder_kwargs,
|
||||||
|
decoder_kwargs=default_decoder_kwargs,
|
||||||
|
has_latent_transform=False,
|
||||||
|
num_inference_steps=5,
|
||||||
|
num_target_steps=60,
|
||||||
|
latent_training_type="forward",
|
||||||
|
# Choices: overlap_by_one, no_overlap, include_inference
|
||||||
|
training_data_split="overlap_by_one",
|
||||||
|
objective_type="ELBO",
|
||||||
|
elbo_beta_delay=0,
|
||||||
|
elbo_beta_final=1.0,
|
||||||
|
geco_kappa=0.001,
|
||||||
|
geco_alpha=0.0,
|
||||||
|
dt=0.125,
|
||||||
|
))
|
||||||
|
|
||||||
|
hgn_paper_encoder_kwargs = collections.ConfigDict(dict(
|
||||||
|
conv_channels=[[32, 64], [64, 64], [64]],
|
||||||
|
num_blocks=3,
|
||||||
|
blocks_depth=2,
|
||||||
|
activation="relu",
|
||||||
|
kernel_shapes=[2, 4],
|
||||||
|
padding=["VALID", "SAME"],
|
||||||
|
))
|
||||||
|
|
||||||
|
hgn_paper_decoder_kwargs = collections.ConfigDict(dict(
|
||||||
|
conv_channels=64,
|
||||||
|
num_blocks=3,
|
||||||
|
blocks_depth=2,
|
||||||
|
activation="tf_leaky_relu",
|
||||||
|
))
|
||||||
|
|
||||||
|
hgn_paper_latent_net_kwargs = collections.ConfigDict(dict(
|
||||||
|
conv_channels=[32, 64, 64, 64],
|
||||||
|
num_units=250,
|
||||||
|
num_layers=5,
|
||||||
|
activation="softplus",
|
||||||
|
kernel_shapes=[3, 2, 2, 2, 2],
|
||||||
|
strides=[1, 2, 1, 2, 1],
|
||||||
|
padding=["SAME", "VALID", "SAME", "VALID", "SAME"]
|
||||||
|
))
|
||||||
|
|
||||||
|
hgn_paper_latent_system_kwargs = collections.ConfigDict(dict(
|
||||||
|
potential_func_form="separable_net",
|
||||||
|
kinetic_func_form="separable_net",
|
||||||
|
parametrize_mass_matrix=False,
|
||||||
|
net_kwargs=hgn_paper_latent_net_kwargs
|
||||||
|
))
|
||||||
|
|
||||||
|
hgn_paper_latent_transform_kwargs = collections.ConfigDict(dict(
|
||||||
|
num_layers=5,
|
||||||
|
conv_channels=64,
|
||||||
|
num_units=64,
|
||||||
|
activation="relu",
|
||||||
|
))
|
||||||
|
|
||||||
|
hgn_paper_config = copy.deepcopy(default_config_dict)
|
||||||
|
hgn_paper_config.training_data_split = "include_inference"
|
||||||
|
hgn_paper_config.latent_system_net_type = "conv"
|
||||||
|
hgn_paper_config.encoder_aggregation_type = (collections.config_dict.
|
||||||
|
placeholder(str))
|
||||||
|
hgn_paper_config.decoder_de_aggregation_type = (collections.config_dict.
|
||||||
|
placeholder(str))
|
||||||
|
hgn_paper_config.latent_system_kwargs = hgn_paper_latent_system_kwargs
|
||||||
|
hgn_paper_config.encoder_kwargs = hgn_paper_encoder_kwargs
|
||||||
|
hgn_paper_config.decoder_kwargs = hgn_paper_decoder_kwargs
|
||||||
|
hgn_paper_config.has_latent_transform = True
|
||||||
|
hgn_paper_config.latent_transform_kwargs = hgn_paper_latent_transform_kwargs
|
||||||
|
hgn_paper_config.num_inference_steps = 31
|
||||||
|
hgn_paper_config.num_target_steps = 0
|
||||||
|
hgn_paper_config.objective_type = "GECO"
|
||||||
|
|
||||||
|
|
||||||
|
forward_overlap_by_one = {
|
||||||
|
model_prefix + "latent_training_type": "forward",
|
||||||
|
model_prefix + "training_data_split": "overlap_by_one",
|
||||||
|
}
|
||||||
|
|
||||||
|
forward_backward_include_inference = {
|
||||||
|
model_prefix + "latent_training_type": "forward_backward",
|
||||||
|
model_prefix + "training_data_split": "include_inference",
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_training_sweep = [
|
||||||
|
forward_overlap_by_one,
|
||||||
|
forward_backward_include_inference,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def sym_metric_hgn_plus_plus_sweep():
|
||||||
|
"""HGN++ experimental sweep for the SyMetric paper."""
|
||||||
|
model_config = copy.deepcopy(default_config_dict)
|
||||||
|
model_config.name = "HGN"
|
||||||
|
sweeps = list()
|
||||||
|
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||||
|
sweeps.append({
|
||||||
|
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||||
|
model_prefix + "latent_training_type": "forward",
|
||||||
|
model_prefix + "training_data_split": "overlap_by_one",
|
||||||
|
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||||
|
})
|
||||||
|
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||||
|
sweeps.append({
|
||||||
|
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||||
|
model_prefix + "latent_training_type": "forward_backward",
|
||||||
|
model_prefix + "training_data_split": "include_inference",
|
||||||
|
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||||
|
})
|
||||||
|
|
||||||
|
return model_config, sweeps
|
||||||
|
|
||||||
|
|
||||||
|
def sym_metric_hgn_sweep():
|
||||||
|
"""HGN experimental sweep for the SyMetric paper."""
|
||||||
|
model_config = copy.deepcopy(hgn_paper_config)
|
||||||
|
model_config.name = "HGN"
|
||||||
|
return model_config, list(dict())
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_hgn_overlap_sweep():
|
||||||
|
"""HGN++ sweep for the benchmark paper."""
|
||||||
|
model_config = copy.deepcopy(default_config_dict)
|
||||||
|
model_config.name = "HGN"
|
||||||
|
|
||||||
|
sweeps = list()
|
||||||
|
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||||
|
for train_dict in latent_training_sweep:
|
||||||
|
sweeps.append({
|
||||||
|
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||||
|
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||||
|
})
|
||||||
|
sweeps[-1].update(train_dict)
|
||||||
|
|
||||||
|
return model_config, sweeps
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_lgn_sweep():
|
||||||
|
"""LGN sweep for the benchmark paper."""
|
||||||
|
model_config = copy.deepcopy(default_config_dict)
|
||||||
|
model_config.name = "LGN"
|
||||||
|
|
||||||
|
sweeps = list()
|
||||||
|
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||||
|
for train_dict in latent_training_sweep:
|
||||||
|
sweeps.append({
|
||||||
|
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||||
|
model_prefix + "latent_system_kwargs.kinetic_func_form":
|
||||||
|
"matrix_dep_pure_quad",
|
||||||
|
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||||
|
})
|
||||||
|
sweeps[-1].update(train_dict)
|
||||||
|
|
||||||
|
return model_config, sweeps
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_ode_sweep():
|
||||||
|
"""Neural ODE sweep for the benchmark paper."""
|
||||||
|
model_config = copy.deepcopy(default_config_dict)
|
||||||
|
model_config.name = "ODE"
|
||||||
|
|
||||||
|
sweeps = list()
|
||||||
|
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||||
|
for integrator in ("adaptive", "rk2"):
|
||||||
|
for train_dict in latent_training_sweep:
|
||||||
|
sweeps.append({
|
||||||
|
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||||
|
model_prefix + "integrator_method": integrator,
|
||||||
|
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||||
|
})
|
||||||
|
sweeps[-1].update(train_dict)
|
||||||
|
|
||||||
|
return model_config, sweeps
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_rgn_sweep():
|
||||||
|
"""RGN sweep for the benchmark paper."""
|
||||||
|
model_config = copy.deepcopy(default_config_dict)
|
||||||
|
model_config.name = "RGN"
|
||||||
|
|
||||||
|
sweeps = list()
|
||||||
|
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||||
|
for residual in (True, False):
|
||||||
|
sweeps.append({
|
||||||
|
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||||
|
model_prefix + "latent_system_kwargs.residual": residual,
|
||||||
|
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||||
|
})
|
||||||
|
|
||||||
|
return model_config, sweeps
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_ar_sweep():
|
||||||
|
"""AR sweep for the benchmark paper."""
|
||||||
|
model_config = copy.deepcopy(default_config_dict)
|
||||||
|
model_config.name = "AR"
|
||||||
|
model_config.latent_dynamics_type = "vanilla"
|
||||||
|
|
||||||
|
sweeps = list()
|
||||||
|
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||||
|
for ar_type in ("vanilla", "lstm", "gru"):
|
||||||
|
sweeps.append({
|
||||||
|
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||||
|
model_prefix + "latent_dynamics_type": ar_type,
|
||||||
|
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||||
|
})
|
||||||
|
|
||||||
|
return model_config, sweeps
|
||||||
File diff suppressed because it is too large
Load Diff
Executable
+52
@@ -0,0 +1,52 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Script to execute a single configuration on all datasets.
|
||||||
|
if [[ "$#" -eq 2 ]]; then
|
||||||
|
readonly CONFIG_NAME="$1"
|
||||||
|
readonly NUM_SWEEPS="$2"
|
||||||
|
else
|
||||||
|
echo "You must provide exactly two arguments - the configuration name and " \
|
||||||
|
"how many sweeps it contains. For example:"
|
||||||
|
echo "./launch_all.sh sym_metric_hgn_plus_plus_sweep 1"
|
||||||
|
exit 2
|
||||||
|
fi
|
||||||
|
|
||||||
|
DATASETS=(
|
||||||
|
"toy_physics/mass_spring"
|
||||||
|
"toy_physics/mass_spring_colors"
|
||||||
|
"toy_physics/mass_spring_colors_friction"
|
||||||
|
"toy_physics/pendulum"
|
||||||
|
"toy_physics/pendulum_colors"
|
||||||
|
"toy_physics/pendulum_colors_friction"
|
||||||
|
"toy_physics/two_body"
|
||||||
|
"toy_physics/two_body_colors"
|
||||||
|
"toy_physics/double_pendulum"
|
||||||
|
"toy_physics/double_pendulum_colors"
|
||||||
|
"toy_physics/double_pendulum_colors_friction"
|
||||||
|
"molecular_dynamics/lj_4"
|
||||||
|
"molecular_dynamics/lj_16"
|
||||||
|
"multi_agent/rock_paper_scissors"
|
||||||
|
"multi_agent/matching_pennies"
|
||||||
|
"mujoco_room/circle"
|
||||||
|
"mujoco_room/spiral"
|
||||||
|
)
|
||||||
|
|
||||||
|
readonly DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|
||||||
|
for dataset in "${DATASETS[@]}"; do
|
||||||
|
"${DIR}/launch_local.sh" "${CONFIG_NAME}" "${NUM_SWEEPS}" "${dataset}"
|
||||||
|
done
|
||||||
Executable
+40
@@ -0,0 +1,40 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# A script to execute a single configuration name on a given dataset.
|
||||||
|
if [[ "$#" -eq 3 ]]; then
|
||||||
|
readonly CONFIG_NAME="$1"
|
||||||
|
readonly NUM_SWEEPS="$2"
|
||||||
|
readonly DATASET="$3"
|
||||||
|
else
|
||||||
|
echo "You must provide exactly three arguments - the configuration name, " \
|
||||||
|
"the number of sweeps it contains and the dataset name. For example:"
|
||||||
|
echo "./launch_local.sh sym_metric_hgn_plus_plus_sweep 1 " \
|
||||||
|
"toy_physics/mass_spring"
|
||||||
|
exit 2
|
||||||
|
fi
|
||||||
|
echo "Running with config ${CONFIG_NAME} on ${DATASET}."
|
||||||
|
|
||||||
|
readonly EXPERIMENT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||||
|
readonly TRAIN_FILE="${EXPERIMENT_DIR}/jaxline_train.py"
|
||||||
|
readonly CONFIG_FILE="${EXPERIMENT_DIR}/jaxline_configs.py"
|
||||||
|
|
||||||
|
for sweep_id in $(seq 0 $((NUM_SWEEPS - 1))); do
|
||||||
|
python3 "${TRAIN_FILE}" \
|
||||||
|
--config="${CONFIG_FILE}:${CONFIG_NAME},${sweep_id},${DATASET}" \
|
||||||
|
--jaxline_mode="train" \
|
||||||
|
--logtostderr
|
||||||
|
done
|
||||||
@@ -0,0 +1,247 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Module containing code for computing various metrics for training and evaluation."""
|
||||||
|
from typing import Callable, Dict, Optional
|
||||||
|
|
||||||
|
import distrax
|
||||||
|
import haiku as hk
|
||||||
|
import jax
|
||||||
|
import jax.nn as nn
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import physics_inspired_models.utils as utils
|
||||||
|
|
||||||
|
|
||||||
|
_ReconstructFunc = Callable[[utils.Params, jnp.ndarray, jnp.ndarray, bool],
|
||||||
|
distrax.Distribution]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_small_latents(dist, threshold=0.5):
|
||||||
|
"""Calculates the number of active latents by thresholding the variance of their distribution."""
|
||||||
|
if not isinstance(dist, distrax.Normal):
|
||||||
|
raise NotImplementedError()
|
||||||
|
latent_means = dist.mean()
|
||||||
|
latent_stddevs = dist.variance()
|
||||||
|
small_latents = jnp.sum(
|
||||||
|
(latent_stddevs < threshold) & (jnp.abs(latent_means) > 0.1), axis=1)
|
||||||
|
return jnp.mean(small_latents)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_scale(
|
||||||
|
targets: jnp.ndarray,
|
||||||
|
rescale_by: str
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Compute a scaling factor based on targets shape and the rescale_by argument."""
|
||||||
|
if rescale_by == "pixels_and_time":
|
||||||
|
return jnp.asarray(np.prod(targets.shape[-4:]))
|
||||||
|
elif rescale_by is not None:
|
||||||
|
raise ValueError(f"Unrecognized rescale_by={rescale_by}.")
|
||||||
|
else:
|
||||||
|
return jnp.ones([])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_data_domain_stats(
|
||||||
|
p_x: distrax.Distribution,
|
||||||
|
targets: jnp.ndarray
|
||||||
|
) -> Dict[str, jnp.ndarray]:
|
||||||
|
"""Compute several statistics in the data domain, such as L2 and negative log likelihood."""
|
||||||
|
axis = tuple(range(2, targets.ndim))
|
||||||
|
l2_over_time = jnp.sum((p_x.mean() - targets) ** 2, axis=axis)
|
||||||
|
l2 = jnp.sum(l2_over_time, axis=1)
|
||||||
|
|
||||||
|
# Calculate relative L2 normalised by image "length"
|
||||||
|
norm_factor = jnp.sum(targets**2, axis=(2, 3, 4))
|
||||||
|
l2_over_time_norm = l2_over_time / norm_factor
|
||||||
|
l2_norm = jnp.sum(l2_over_time_norm, axis=1)
|
||||||
|
|
||||||
|
# Compute negative log-likelihood under p(x)
|
||||||
|
neg_log_p_x_over_time = - np.sum(p_x.log_prob(targets), axis=axis)
|
||||||
|
neg_log_p_x = jnp.sum(neg_log_p_x_over_time, axis=1)
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
neg_log_p_x_over_time=neg_log_p_x_over_time,
|
||||||
|
neg_log_p_x=neg_log_p_x,
|
||||||
|
l2_over_time=l2_over_time,
|
||||||
|
l2=l2,
|
||||||
|
l2_over_time_norm=l2_over_time_norm,
|
||||||
|
l2_norm=l2_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_vae_stats(
|
||||||
|
neg_log_p_x: jnp.ndarray,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
q_z: distrax.Distribution,
|
||||||
|
prior: distrax.Distribution
|
||||||
|
) -> Dict[str, jnp.ndarray]:
|
||||||
|
"""Compute the KL(q(z|x)||p(z)) and the negative ELBO, which are used for VAE models."""
|
||||||
|
# Compute the KL
|
||||||
|
kl = distrax.estimate_kl_best_effort(q_z, prior, rng_key=rng, num_samples=1)
|
||||||
|
kl = np.sum(kl, axis=list(range(1, kl.ndim)))
|
||||||
|
# Sanity check
|
||||||
|
assert kl.shape == neg_log_p_x.shape
|
||||||
|
return dict(
|
||||||
|
kl=kl,
|
||||||
|
neg_elbo=neg_log_p_x + kl,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def training_statistics(
|
||||||
|
p_x: distrax.Distribution,
|
||||||
|
targets: jnp.ndarray,
|
||||||
|
rescale_by: Optional[str],
|
||||||
|
rng: Optional[jnp.ndarray] = None,
|
||||||
|
q_z: Optional[distrax.Distribution] = None,
|
||||||
|
prior: Optional[distrax.Distribution] = None,
|
||||||
|
p_x_learned_sigma: bool = False
|
||||||
|
) -> Dict[str, jnp.ndarray]:
|
||||||
|
"""Computes various statistics we track during training."""
|
||||||
|
stats = compute_data_domain_stats(p_x, targets)
|
||||||
|
|
||||||
|
if rng is not None and q_z is not None and prior is not None:
|
||||||
|
stats.update(compute_vae_stats(stats["neg_log_p_x"], rng, q_z, prior))
|
||||||
|
else:
|
||||||
|
assert rng is None and q_z is None and prior is None
|
||||||
|
|
||||||
|
# Rescale these stats accordingly
|
||||||
|
scale = compute_scale(targets, rescale_by)
|
||||||
|
# Note that "_over_time" stats are getting normalised by time here
|
||||||
|
stats = jax.tree_map(lambda x: x / scale, stats)
|
||||||
|
if p_x_learned_sigma:
|
||||||
|
stats["p_x_sigma"] = p_x.variance().reshape([-1])[0]
|
||||||
|
if q_z is not None:
|
||||||
|
stats["small_latents"] = calculate_small_latents(q_z)
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def evaluation_only_statistics(
|
||||||
|
reconstruct_func: _ReconstructFunc,
|
||||||
|
params: hk.Params,
|
||||||
|
inputs: jnp.ndarray,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
rescale_by: str,
|
||||||
|
can_run_backwards: bool,
|
||||||
|
train_sequence_length: int,
|
||||||
|
reconstruction_skip: int,
|
||||||
|
p_x_learned_sigma: bool = False,
|
||||||
|
) -> Dict[str, jnp.ndarray]:
|
||||||
|
"""Computes various statistics we track only during evaluation."""
|
||||||
|
full_trajectory = utils.extract_image(inputs)
|
||||||
|
prefixes = ("forward", "backward") if can_run_backwards else ("forward",)
|
||||||
|
|
||||||
|
full_forward_targets = jax.tree_map(
|
||||||
|
lambda x: x[:, reconstruction_skip:], full_trajectory)
|
||||||
|
full_backward_targets = jax.tree_map(
|
||||||
|
lambda x: x[:, :x.shape[1]-reconstruction_skip], full_trajectory)
|
||||||
|
train_targets_length = train_sequence_length - reconstruction_skip
|
||||||
|
full_targets_length = full_forward_targets.shape[1]
|
||||||
|
|
||||||
|
stats = dict()
|
||||||
|
keys = ()
|
||||||
|
|
||||||
|
for prefix in prefixes:
|
||||||
|
# Fully unroll the model and reconstruct the whole sequence
|
||||||
|
full_prediction = reconstruct_func(params, full_trajectory, rng,
|
||||||
|
prefix == "forward")
|
||||||
|
assert isinstance(full_prediction, distrax.Normal)
|
||||||
|
full_targets = (full_forward_targets if prefix == "forward" else
|
||||||
|
full_backward_targets)
|
||||||
|
# In cases where the model can run backwards it is possible to reconstruct
|
||||||
|
# parts which were indented to be skipped, so here we take care of that.
|
||||||
|
if full_prediction.mean().shape[1] > full_targets_length:
|
||||||
|
if prefix == "forward":
|
||||||
|
full_prediction = jax.tree_map(lambda x: x[:, -full_targets_length:],
|
||||||
|
full_prediction)
|
||||||
|
else:
|
||||||
|
full_prediction = jax.tree_map(lambda x: x[:, :full_targets_length],
|
||||||
|
full_prediction)
|
||||||
|
|
||||||
|
# Based on the prefix and suffix fetch correct predictions and targets
|
||||||
|
for suffix in ("train", "extrapolation", "full"):
|
||||||
|
if prefix == "forward" and suffix == "train":
|
||||||
|
predict, targets = jax.tree_map(lambda x: x[:, :train_targets_length],
|
||||||
|
(full_prediction, full_targets))
|
||||||
|
elif prefix == "forward" and suffix == "extrapolation":
|
||||||
|
predict, targets = jax.tree_map(lambda x: x[:, train_targets_length:],
|
||||||
|
(full_prediction, full_targets))
|
||||||
|
elif prefix == "backward" and suffix == "train":
|
||||||
|
predict, targets = jax.tree_map(lambda x: x[:, -train_targets_length:],
|
||||||
|
(full_prediction, full_targets))
|
||||||
|
elif prefix == "backward" and suffix == "extrapolation":
|
||||||
|
predict, targets = jax.tree_map(lambda x: x[:, :-train_targets_length],
|
||||||
|
(full_prediction, full_targets))
|
||||||
|
else:
|
||||||
|
predict, targets = full_prediction, full_targets
|
||||||
|
|
||||||
|
# Compute train statistics
|
||||||
|
train_stats = training_statistics(predict, targets, rescale_by,
|
||||||
|
p_x_learned_sigma=p_x_learned_sigma)
|
||||||
|
for key, value in train_stats.items():
|
||||||
|
stats[prefix + "_" + suffix + "_" + key] = value
|
||||||
|
# Copy all stats keys
|
||||||
|
keys = tuple(train_stats.keys())
|
||||||
|
|
||||||
|
# Make a combined metric summing forward and backward
|
||||||
|
if can_run_backwards:
|
||||||
|
# Also compute
|
||||||
|
for suffix in ("train", "extrapolation", "full"):
|
||||||
|
for key in keys:
|
||||||
|
forward = stats["forward_" + suffix + "_" + key]
|
||||||
|
backward = stats["backward_" + suffix + "_" + key]
|
||||||
|
combined = (forward + backward) / 2
|
||||||
|
stats["combined_" + suffix + "_" + key] = combined
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def geco_objective(
|
||||||
|
l2_loss,
|
||||||
|
kl,
|
||||||
|
alpha,
|
||||||
|
kappa,
|
||||||
|
constraint_ema,
|
||||||
|
lambda_var,
|
||||||
|
is_training
|
||||||
|
) -> Dict[str, jnp.ndarray]:
|
||||||
|
"""Computes the objective for GECO and some of it statistics used ofr updates."""
|
||||||
|
# C_t
|
||||||
|
constraint_t = l2_loss - kappa
|
||||||
|
if is_training:
|
||||||
|
# We update C_ma only during training
|
||||||
|
constraint_ema = alpha * constraint_ema + (1 - alpha) * constraint_t
|
||||||
|
lagrange = nn.softplus(lambda_var)
|
||||||
|
lagrange = jnp.broadcast_to(lagrange, constraint_ema.shape)
|
||||||
|
# Add this special op for getting all gradients correct
|
||||||
|
loss = utils.geco_lagrange_product(lagrange, constraint_ema, constraint_t)
|
||||||
|
return dict(
|
||||||
|
loss=loss + kl,
|
||||||
|
geco_multiplier=lagrange,
|
||||||
|
geco_constraint=constraint_t,
|
||||||
|
geco_constraint_ema=constraint_ema
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def elbo_objective(neg_log_p_x, kl, final_beta, beta_delay, step):
|
||||||
|
"""Computes objective for optimizing the Evidence Lower Bound (ELBO)."""
|
||||||
|
if beta_delay == 0:
|
||||||
|
beta = final_beta
|
||||||
|
else:
|
||||||
|
delayed_beta = jnp.minimum(float(step) / float(beta_delay), 1.0)
|
||||||
|
beta = delayed_beta * final_beta
|
||||||
|
return dict(
|
||||||
|
loss=neg_log_p_x + beta * kl,
|
||||||
|
elbo_beta=beta
|
||||||
|
)
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,345 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Module for all autoregressive models."""
|
||||||
|
import functools
|
||||||
|
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import distrax
|
||||||
|
import haiku as hk
|
||||||
|
from jax import lax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax.random as jnr
|
||||||
|
|
||||||
|
import physics_inspired_models.metrics as metrics
|
||||||
|
import physics_inspired_models.models.base as base
|
||||||
|
import physics_inspired_models.models.networks as nets
|
||||||
|
import physics_inspired_models.utils as utils
|
||||||
|
|
||||||
|
|
||||||
|
class TeacherForcingAutoregressiveModel(base.SequenceModel):
|
||||||
|
"""A standard autoregressive model trained via teacher forcing."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
latent_system_dim: int,
|
||||||
|
latent_system_net_type: str,
|
||||||
|
latent_system_kwargs: Dict[str, Any],
|
||||||
|
latent_dynamics_type: str,
|
||||||
|
encoder_aggregation_type: Optional[str],
|
||||||
|
decoder_de_aggregation_type: Optional[str],
|
||||||
|
encoder_kwargs: Dict[str, Any],
|
||||||
|
decoder_kwargs: Dict[str, Any],
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_target_steps: int,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
# Remove any parameters from vae models
|
||||||
|
encoder_kwargs = dict(**encoder_kwargs)
|
||||||
|
encoder_kwargs["distribution_name"] = None
|
||||||
|
|
||||||
|
if kwargs.get("has_latent_transform", False):
|
||||||
|
raise ValueError("We do not support AR models with latent transform.")
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
can_run_backwards=False,
|
||||||
|
latent_system_dim=latent_system_dim,
|
||||||
|
latent_system_net_type=latent_system_net_type,
|
||||||
|
latent_system_kwargs=latent_system_kwargs,
|
||||||
|
encoder_aggregation_type=encoder_aggregation_type,
|
||||||
|
decoder_de_aggregation_type=decoder_de_aggregation_type,
|
||||||
|
encoder_kwargs=encoder_kwargs,
|
||||||
|
decoder_kwargs=decoder_kwargs,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_target_steps=num_target_steps,
|
||||||
|
name=name,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
self.latent_dynamics_type = latent_dynamics_type
|
||||||
|
|
||||||
|
# Arguments checks
|
||||||
|
if self.latent_system_net_type != "mlp":
|
||||||
|
raise ValueError("Currently we do not support non-mlp AR models.")
|
||||||
|
|
||||||
|
def recurrence_function(sequence, initial_state=None):
|
||||||
|
core = nets.make_flexible_recurrent_net(
|
||||||
|
core_type=latent_dynamics_type,
|
||||||
|
net_type=latent_system_net_type,
|
||||||
|
output_dims=self.latent_system_dim,
|
||||||
|
**self.latent_system_kwargs["net_kwargs"])
|
||||||
|
initial_state = initial_state or core.initial_state(sequence.shape[1])
|
||||||
|
core(sequence[0], initial_state)
|
||||||
|
return hk.dynamic_unroll(core, sequence, initial_state)
|
||||||
|
|
||||||
|
self.recurrence = hk.transform(recurrence_function)
|
||||||
|
|
||||||
|
def process_inputs_for_encoder(self, x: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def process_latents_for_dynamics(self, z: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
return z
|
||||||
|
|
||||||
|
def process_latents_for_decoder(self, z: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
return z
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inferred_index(self) -> int:
|
||||||
|
return self.num_inference_steps - 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def train_sequence_length(self) -> int:
|
||||||
|
return self.num_target_steps
|
||||||
|
|
||||||
|
def train_data_split(
|
||||||
|
self,
|
||||||
|
images: jnp.ndarray
|
||||||
|
) -> Tuple[jnp.ndarray, jnp.ndarray, Mapping[str, Any]]:
|
||||||
|
images = images[:, :self.train_sequence_length]
|
||||||
|
inference_data = images[:, :-1]
|
||||||
|
target_data = images[:, 1:]
|
||||||
|
return inference_data, target_data, dict(
|
||||||
|
num_steps_forward=1,
|
||||||
|
num_steps_backward=0,
|
||||||
|
include_z0=False)
|
||||||
|
|
||||||
|
def unroll_without_inputs(
|
||||||
|
self,
|
||||||
|
params: utils.Params,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
x_init: jnp.ndarray,
|
||||||
|
h_init: jnp.ndarray,
|
||||||
|
num_steps: int,
|
||||||
|
is_training: bool
|
||||||
|
) -> Tuple[Tuple[distrax.Distribution, jnp.ndarray], Any]:
|
||||||
|
if num_steps < 1:
|
||||||
|
raise ValueError("`num_steps` must be at least 1.")
|
||||||
|
|
||||||
|
def step_fn(carry, key):
|
||||||
|
x_last, h_last = carry
|
||||||
|
enc_key, dec_key = jnr.split(key)
|
||||||
|
z_in_next = self.encoder.apply(params, enc_key, x_last,
|
||||||
|
is_training=is_training)
|
||||||
|
z_next, h_next = self.recurrence.apply(params, None, z_in_next[None],
|
||||||
|
h_last)
|
||||||
|
p_x_next = self.decode_latents(params, dec_key, z_next[0],
|
||||||
|
is_training=is_training)
|
||||||
|
return (p_x_next.mean(), h_next), (p_x_next, z_next[0])
|
||||||
|
|
||||||
|
return lax.scan(
|
||||||
|
step_fn,
|
||||||
|
init=(x_init, h_init),
|
||||||
|
xs=jnr.split(rng, num_steps)
|
||||||
|
)
|
||||||
|
|
||||||
|
def unroll_latent_dynamics(
|
||||||
|
self,
|
||||||
|
z: jnp.ndarray,
|
||||||
|
params: utils.Params,
|
||||||
|
key: jnp.ndarray,
|
||||||
|
num_steps_forward: int,
|
||||||
|
num_steps_backward: int,
|
||||||
|
include_z0: bool,
|
||||||
|
is_training: bool,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
|
||||||
|
init_key, unroll_key, dec_key = jnr.split(key, 3)
|
||||||
|
|
||||||
|
if num_steps_backward != 0:
|
||||||
|
raise ValueError("This model can not run backwards.")
|
||||||
|
|
||||||
|
# Change 'z' time dimension to be first
|
||||||
|
z = jnp.swapaxes(z, 0, 1)
|
||||||
|
|
||||||
|
# Run recurrent model on inputs
|
||||||
|
z_0, h_0 = self.recurrence.apply(params, init_key, z)
|
||||||
|
|
||||||
|
if num_steps_forward == 1:
|
||||||
|
z_t = z_0
|
||||||
|
elif num_steps_forward > 1:
|
||||||
|
p_x_0 = self.decode_latents(params, dec_key, z_0[-1], is_training=False)
|
||||||
|
_, (_, z_t) = self.unroll_without_inputs(
|
||||||
|
params=params,
|
||||||
|
rng=unroll_key,
|
||||||
|
x_init=p_x_0.mean(),
|
||||||
|
h_init=h_0,
|
||||||
|
num_steps=num_steps_forward-1,
|
||||||
|
is_training=is_training
|
||||||
|
)
|
||||||
|
z_t = jnp.concatenate([z_0, z_t], axis=0)
|
||||||
|
else:
|
||||||
|
raise ValueError("num_steps_forward should be at least 1.")
|
||||||
|
|
||||||
|
# Make time dimension second
|
||||||
|
return jnp.swapaxes(z_t, 0, 1), dict()
|
||||||
|
|
||||||
|
def _models_core(
|
||||||
|
self,
|
||||||
|
params: utils.Params,
|
||||||
|
keys: jnp.ndarray,
|
||||||
|
image_data: jnp.ndarray,
|
||||||
|
is_training: bool,
|
||||||
|
**unroll_kwargs: Any
|
||||||
|
) -> Tuple[distrax.Distribution, jnp.ndarray, jnp.ndarray]:
|
||||||
|
enc_key, _, transform_key, unroll_key, dec_key, _ = keys
|
||||||
|
|
||||||
|
# Calculate latent input representation
|
||||||
|
inference_data = self.process_inputs_for_encoder(image_data)
|
||||||
|
z_raw = self.encoder.apply(params, enc_key, inference_data,
|
||||||
|
is_training=is_training)
|
||||||
|
|
||||||
|
# Apply latent transformation (should be identity)
|
||||||
|
z0 = self.apply_latent_transform(params, transform_key, z_raw,
|
||||||
|
is_training=is_training)
|
||||||
|
z0 = self.process_latents_for_dynamics(z0)
|
||||||
|
|
||||||
|
# Calculate latent output representation
|
||||||
|
decoder_z, _ = self.unroll_latent_dynamics(
|
||||||
|
z=z0,
|
||||||
|
params=params,
|
||||||
|
key=unroll_key,
|
||||||
|
is_training=is_training,
|
||||||
|
**unroll_kwargs
|
||||||
|
)
|
||||||
|
decoder_z = self.process_latents_for_decoder(decoder_z)
|
||||||
|
|
||||||
|
# Compute p(x|z)
|
||||||
|
p_x = self.decode_latents(params, dec_key, decoder_z,
|
||||||
|
is_training=is_training)
|
||||||
|
return p_x, z0, decoder_z
|
||||||
|
|
||||||
|
def training_objectives(
|
||||||
|
self,
|
||||||
|
params: hk.Params,
|
||||||
|
state: hk.State,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
inputs: jnp.ndarray,
|
||||||
|
step: jnp.ndarray,
|
||||||
|
is_training: bool = True,
|
||||||
|
use_mean_for_eval_stats: bool = True
|
||||||
|
) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]:
|
||||||
|
"""Computes the training objective and any supporting stats."""
|
||||||
|
# Split all rng keys
|
||||||
|
keys = jnr.split(rng, 6)
|
||||||
|
|
||||||
|
# Process training data
|
||||||
|
images = utils.extract_image(inputs)
|
||||||
|
image_data, target_data, unroll_kwargs = self.train_data_split(images)
|
||||||
|
|
||||||
|
p_x, _, _ = self._models_core(
|
||||||
|
params=params,
|
||||||
|
keys=keys,
|
||||||
|
image_data=image_data,
|
||||||
|
is_training=is_training,
|
||||||
|
**unroll_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute training statistics
|
||||||
|
stats = metrics.training_statistics(
|
||||||
|
p_x=p_x,
|
||||||
|
targets=target_data,
|
||||||
|
rescale_by=self.rescale_by,
|
||||||
|
p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# The loss is just the negative log-likelihood (e.g. the L2 loss)
|
||||||
|
stats["loss"] = stats["neg_log_p_x"]
|
||||||
|
|
||||||
|
if not is_training:
|
||||||
|
# Optionally add the evaluation stats when not training
|
||||||
|
# Add also the evaluation statistics
|
||||||
|
# We need to be able to set `use_mean = False` for some of the tests
|
||||||
|
stats.update(metrics.evaluation_only_statistics(
|
||||||
|
reconstruct_func=functools.partial(
|
||||||
|
self.reconstruct, use_mean=use_mean_for_eval_stats),
|
||||||
|
params=params,
|
||||||
|
inputs=inputs,
|
||||||
|
rng=rng,
|
||||||
|
rescale_by=self.rescale_by,
|
||||||
|
can_run_backwards=self.can_run_backwards,
|
||||||
|
train_sequence_length=self.train_sequence_length,
|
||||||
|
reconstruction_skip=1,
|
||||||
|
p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False)
|
||||||
|
))
|
||||||
|
|
||||||
|
return stats["loss"], (dict(), stats, dict())
|
||||||
|
|
||||||
|
def reconstruct(
|
||||||
|
self,
|
||||||
|
params: utils.Params,
|
||||||
|
inputs: jnp.ndarray,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
forward: bool,
|
||||||
|
use_mean: bool = True,
|
||||||
|
) -> distrax.Distribution:
|
||||||
|
"""Reconstructs the input sequence."""
|
||||||
|
if not forward:
|
||||||
|
raise ValueError("This model can not run backwards.")
|
||||||
|
images = utils.extract_image(inputs)
|
||||||
|
image_data = images[:, :self.num_inference_steps]
|
||||||
|
|
||||||
|
return self._models_core(
|
||||||
|
params=params,
|
||||||
|
keys=jnr.split(rng, 6),
|
||||||
|
image_data=image_data,
|
||||||
|
is_training=False,
|
||||||
|
num_steps_forward=images.shape[1] - self.num_inference_steps,
|
||||||
|
num_steps_backward=0,
|
||||||
|
include_z0=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
def gt_state_and_latents(
|
||||||
|
self,
|
||||||
|
params: hk.Params,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
inputs: Dict[str, jnp.ndarray],
|
||||||
|
seq_length: int,
|
||||||
|
is_training: bool = False,
|
||||||
|
unroll_direction: str = "forward",
|
||||||
|
**kwargs: Dict[str, Any]
|
||||||
|
) -> Tuple[jnp.ndarray, jnp.ndarray,
|
||||||
|
Union[distrax.Distribution, jnp.ndarray]]:
|
||||||
|
"""Computes the ground state and matching latents."""
|
||||||
|
assert unroll_direction == "forward"
|
||||||
|
images = utils.extract_image(inputs)
|
||||||
|
gt_state = utils.extract_gt_state(inputs)
|
||||||
|
image_data = images[:, :self.num_inference_steps]
|
||||||
|
gt_state = gt_state[:, 1:seq_length + 1]
|
||||||
|
|
||||||
|
_, z_in, z_out = self._models_core(
|
||||||
|
params=params,
|
||||||
|
keys=jnr.split(rng, 6),
|
||||||
|
image_data=image_data,
|
||||||
|
is_training=False,
|
||||||
|
num_steps_forward=images.shape[1] - self.num_inference_steps,
|
||||||
|
num_steps_backward=0,
|
||||||
|
include_z0=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return gt_state, z_out, z_in
|
||||||
|
|
||||||
|
def _init_non_model_params_and_state(
|
||||||
|
self,
|
||||||
|
rng: jnp.ndarray
|
||||||
|
) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]:
|
||||||
|
return dict(), dict()
|
||||||
|
|
||||||
|
def _init_latent_system(
|
||||||
|
self,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
z: jnp.ndarray,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> utils.Params:
|
||||||
|
return self.recurrence.init(rng, z)
|
||||||
@@ -0,0 +1,360 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Module containing the base abstract classes for sequence models."""
|
||||||
|
import abc
|
||||||
|
from typing import Any, Dict, Generic, Mapping, Optional, Sequence, Tuple, TypeVar, Union
|
||||||
|
|
||||||
|
from absl import logging
|
||||||
|
import distrax
|
||||||
|
import haiku as hk
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax.random as jnr
|
||||||
|
|
||||||
|
|
||||||
|
from physics_inspired_models import utils
|
||||||
|
from physics_inspired_models.models import networks
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceModel(abc.ABC, Generic[T]):
|
||||||
|
"""An abstract class for sequence models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
can_run_backwards: bool,
|
||||||
|
latent_system_dim: int,
|
||||||
|
latent_system_net_type: str,
|
||||||
|
latent_system_kwargs: Dict[str, Any],
|
||||||
|
encoder_aggregation_type: Optional[str],
|
||||||
|
decoder_de_aggregation_type: Optional[str],
|
||||||
|
encoder_kwargs: Dict[str, Any],
|
||||||
|
decoder_kwargs: Dict[str, Any],
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_target_steps: int,
|
||||||
|
name: str,
|
||||||
|
latent_spatial_shape: Optional[Tuple[int, int]] = (4, 4),
|
||||||
|
has_latent_transform: bool = False,
|
||||||
|
latent_transform_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
rescale_by: Optional[str] = "pixels_and_time",
|
||||||
|
data_format: str = "NHWC",
|
||||||
|
**unused_kwargs
|
||||||
|
):
|
||||||
|
# Arguments checks
|
||||||
|
encoder_kwargs = encoder_kwargs or dict()
|
||||||
|
decoder_kwargs = decoder_kwargs or dict()
|
||||||
|
|
||||||
|
# Set the decoder de-aggregation type the "same" type as the encoder if not
|
||||||
|
# provided
|
||||||
|
if (decoder_de_aggregation_type is None and
|
||||||
|
encoder_aggregation_type is not None):
|
||||||
|
if encoder_aggregation_type == "linear_projection":
|
||||||
|
decoder_de_aggregation_type = "linear_projection"
|
||||||
|
elif encoder_aggregation_type in ("mean", "max"):
|
||||||
|
decoder_de_aggregation_type = "tile"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized encoder_aggregation_type="
|
||||||
|
f"{encoder_aggregation_type}")
|
||||||
|
if latent_system_net_type == "conv":
|
||||||
|
if encoder_aggregation_type is not None:
|
||||||
|
raise ValueError("When the latent system is convolutional, the encoder "
|
||||||
|
"aggregation type should be None.")
|
||||||
|
if decoder_de_aggregation_type is not None:
|
||||||
|
raise ValueError("When the latent system is convolutional, the decoder "
|
||||||
|
"aggregation type should be None.")
|
||||||
|
else:
|
||||||
|
if encoder_aggregation_type is None:
|
||||||
|
raise ValueError("When the latent system is not convolutional, the "
|
||||||
|
"you must provide an encoder aggregation type.")
|
||||||
|
if decoder_de_aggregation_type is None:
|
||||||
|
raise ValueError("When the latent system is not convolutional, the "
|
||||||
|
"you must provide an decoder aggregation type.")
|
||||||
|
if has_latent_transform and latent_transform_kwargs is None:
|
||||||
|
raise ValueError("When using latent transformation you have to provide "
|
||||||
|
"the latent_transform_kwargs argument.")
|
||||||
|
if unused_kwargs:
|
||||||
|
logging.warning("Unused kwargs: %s", str(unused_kwargs))
|
||||||
|
super().__init__(**unused_kwargs)
|
||||||
|
self.can_run_backwards = can_run_backwards
|
||||||
|
self.latent_system_dim = latent_system_dim
|
||||||
|
self.latent_system_kwargs = latent_system_kwargs
|
||||||
|
self.latent_system_net_type = latent_system_net_type
|
||||||
|
self.latent_spatial_shape = latent_spatial_shape
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
self.num_target_steps = num_target_steps
|
||||||
|
self.rescale_by = rescale_by
|
||||||
|
self.data_format = data_format
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
# Encoder
|
||||||
|
self.encoder_kwargs = encoder_kwargs
|
||||||
|
self.encoder = hk.transform(
|
||||||
|
lambda *args, **kwargs: networks.SpatialConvEncoder( # pylint: disable=unnecessary-lambda,g-long-lambda
|
||||||
|
latent_dim=latent_system_dim,
|
||||||
|
aggregation_type=encoder_aggregation_type,
|
||||||
|
data_format=data_format,
|
||||||
|
name="Encoder",
|
||||||
|
**encoder_kwargs
|
||||||
|
)(*args, **kwargs))
|
||||||
|
|
||||||
|
# Decoder
|
||||||
|
self.decoder_kwargs = decoder_kwargs
|
||||||
|
self.decoder = hk.transform(
|
||||||
|
lambda *args, **kwargs: networks.SpatialConvDecoder( # pylint: disable=unnecessary-lambda,g-long-lambda
|
||||||
|
initial_spatial_shape=self.latent_spatial_shape,
|
||||||
|
de_aggregation_type=decoder_de_aggregation_type,
|
||||||
|
data_format=data_format,
|
||||||
|
max_de_aggregation_dims=self.latent_system_dim // 2,
|
||||||
|
name="Decoder",
|
||||||
|
**decoder_kwargs,
|
||||||
|
)(*args, **kwargs))
|
||||||
|
|
||||||
|
self.has_latent_transform = has_latent_transform
|
||||||
|
if has_latent_transform:
|
||||||
|
self.latent_transform = hk.transform(
|
||||||
|
lambda *args, **kwargs: networks.make_flexible_net( # pylint: disable=unnecessary-lambda,g-long-lambda
|
||||||
|
net_type=latent_system_net_type,
|
||||||
|
output_dims=latent_system_dim,
|
||||||
|
name="LatentTransform",
|
||||||
|
**latent_transform_kwargs
|
||||||
|
)(*args, **kwargs))
|
||||||
|
else:
|
||||||
|
self.latent_transform = None
|
||||||
|
|
||||||
|
self._jit_init = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def train_sequence_length(self) -> int:
|
||||||
|
"""Computes the total length of a sequence needed for training or evaluation."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def train_data_split(
|
||||||
|
self,
|
||||||
|
images: jnp.ndarray,
|
||||||
|
) -> Tuple[jnp.ndarray, jnp.ndarray, Mapping[str, Any]]:
|
||||||
|
"""Extracts from the inputs the data splits for training."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def decode_latents(
|
||||||
|
self,
|
||||||
|
params: hk.Params,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
z: jnp.ndarray,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> distrax.Distribution:
|
||||||
|
"""Decodes the latent variable given the parameters of the model."""
|
||||||
|
# Allow to run with both the full parameters and only the decoders
|
||||||
|
if self.latent_system_net_type == "mlp":
|
||||||
|
fixed_dims = 1
|
||||||
|
elif self.latent_system_net_type == "conv":
|
||||||
|
fixed_dims = 1 + len(self.latent_spatial_shape)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
n_shape = z.shape[:-fixed_dims]
|
||||||
|
z = z.reshape((-1,) + z.shape[-fixed_dims:])
|
||||||
|
x = self.decoder.apply(params, rng, z, **kwargs)
|
||||||
|
return jax.tree_map(lambda a: a.reshape(n_shape + a.shape[1:]), x)
|
||||||
|
|
||||||
|
def apply_latent_transform(
|
||||||
|
self,
|
||||||
|
params: hk.Params,
|
||||||
|
key: jnp.ndarray,
|
||||||
|
z: jnp.ndarray,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
if self.latent_transform is not None:
|
||||||
|
return self.latent_transform.apply(params, key, z, **kwargs)
|
||||||
|
else:
|
||||||
|
return z
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def process_inputs_for_encoder(self, x: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def process_latents_for_dynamics(self, z: jnp.ndarray) -> T:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def process_latents_for_decoder(self, z: T) -> jnp.ndarray:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def unroll_latent_dynamics(
|
||||||
|
self,
|
||||||
|
z: T,
|
||||||
|
params: utils.Params,
|
||||||
|
key: jnp.ndarray,
|
||||||
|
num_steps_forward: int,
|
||||||
|
num_steps_backward: int,
|
||||||
|
include_z0: bool,
|
||||||
|
is_training: bool,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Tuple[T, Mapping[str, jnp.ndarray]]:
|
||||||
|
"""Unrolls the latent dynamics starting from z and pre-processing for the decoder."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def reconstruct(
|
||||||
|
self,
|
||||||
|
params: utils.Params,
|
||||||
|
inputs: jnp.ndarray,
|
||||||
|
rng_key: Optional[jnp.ndarray],
|
||||||
|
forward: bool,
|
||||||
|
) -> distrax.Distribution:
|
||||||
|
"""Using the first `num_inference_steps` parts of inputs reconstructs the rest."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def training_objectives(
|
||||||
|
self,
|
||||||
|
params: utils.Params,
|
||||||
|
state: hk.State,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
inputs: Union[Dict[str, jnp.ndarray], jnp.ndarray],
|
||||||
|
step: jnp.ndarray,
|
||||||
|
is_training: bool = True,
|
||||||
|
use_mean_for_eval_stats: bool = True
|
||||||
|
) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]:
|
||||||
|
"""Returns all training objectives statistics and update states."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def inferred_index(self):
|
||||||
|
"""Returns the time index in the input sequence, for which the encoder infers.
|
||||||
|
|
||||||
|
If the encoder takes as input the sequence x[0:n-1], where
|
||||||
|
`n = self.num_inference_steps`, then this outputs the index `k` relative to
|
||||||
|
the begging of the input sequence `x_0`, which the encoder infers.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inferred_right_offset(self):
|
||||||
|
return self.num_inference_steps - 1 - self.inferred_index
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def gt_state_and_latents(
|
||||||
|
self,
|
||||||
|
params: hk.Params,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
inputs: Dict[str, jnp.ndarray],
|
||||||
|
seq_len: int,
|
||||||
|
is_training: bool = False,
|
||||||
|
unroll_direction: str = "forward",
|
||||||
|
**kwargs: Dict[str, Any]
|
||||||
|
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
||||||
|
"""Computes the ground state and matching latents."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _init_non_model_params_and_state(
|
||||||
|
self,
|
||||||
|
rng: jnp.ndarray
|
||||||
|
) -> Tuple[utils.Params, utils.Params]:
|
||||||
|
"""Initializes any non-model parameters and state."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _init_latent_system(
|
||||||
|
self,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
z: jnp.ndarray,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> hk.Params:
|
||||||
|
"""Initializes the parameters of the latent system."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _init(
|
||||||
|
self,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
images: jnp.ndarray
|
||||||
|
) -> Tuple[hk.Params, hk.State]:
|
||||||
|
"""Initializes the whole model parameters and state."""
|
||||||
|
inference_data, _, _ = self.train_data_split(images)
|
||||||
|
# Initialize parameters and state for the vae training
|
||||||
|
rng, key = jnr.split(rng)
|
||||||
|
params, state = self._init_non_model_params_and_state(key)
|
||||||
|
|
||||||
|
# Initialize and run encoder
|
||||||
|
inference_data = self.process_inputs_for_encoder(inference_data)
|
||||||
|
rng, key = jnr.split(rng)
|
||||||
|
encoder_params = self.encoder.init(key, inference_data, is_training=True)
|
||||||
|
rng, key = jnr.split(rng)
|
||||||
|
z_in = self.encoder.apply(encoder_params, key, inference_data,
|
||||||
|
is_training=True)
|
||||||
|
|
||||||
|
# For probabilistic models this will be a distribution
|
||||||
|
if isinstance(z_in, distrax.Distribution):
|
||||||
|
z_in = z_in.mean()
|
||||||
|
|
||||||
|
# Initialize and run the optional latent transform
|
||||||
|
if self.latent_transform is not None:
|
||||||
|
rng, key = jnr.split(rng)
|
||||||
|
transform_params = self.latent_transform.init(key, z_in, is_training=True)
|
||||||
|
rng, key = jnr.split(rng)
|
||||||
|
z_in = self.latent_transform.apply(transform_params, key, z_in,
|
||||||
|
is_training=True)
|
||||||
|
else:
|
||||||
|
transform_params = dict()
|
||||||
|
|
||||||
|
# Initialize and run the latent system
|
||||||
|
z_in = self.process_latents_for_dynamics(z_in)
|
||||||
|
rng, key = jnr.split(rng)
|
||||||
|
latent_params = self._init_latent_system(key, z_in, is_training=True)
|
||||||
|
rng, key = jnr.split(rng)
|
||||||
|
z_out, _ = self.unroll_latent_dynamics(
|
||||||
|
z=z_in,
|
||||||
|
params=latent_params,
|
||||||
|
key=key,
|
||||||
|
num_steps_forward=1,
|
||||||
|
num_steps_backward=0,
|
||||||
|
include_z0=False,
|
||||||
|
is_training=True
|
||||||
|
)
|
||||||
|
z_out = self.process_latents_for_decoder(z_out)
|
||||||
|
|
||||||
|
# Initialize and run the decoder
|
||||||
|
rng, key = jnr.split(rng)
|
||||||
|
decoder_params = self.decoder.init(key, z_out[:, 0], is_training=True)
|
||||||
|
_ = self.decoder.apply(decoder_params, rng, z_out[:, 0], is_training=True)
|
||||||
|
|
||||||
|
# Combine all and make immutable
|
||||||
|
params = hk.data_structures.merge(params, encoder_params, transform_params,
|
||||||
|
latent_params, decoder_params)
|
||||||
|
params = hk.data_structures.to_immutable_dict(params)
|
||||||
|
state = hk.data_structures.to_immutable_dict(state)
|
||||||
|
|
||||||
|
return params, state
|
||||||
|
|
||||||
|
def init(
|
||||||
|
self,
|
||||||
|
rng: jnp.ndarray,
|
||||||
|
inputs_or_shape: Union[jnp.ndarray, Mapping[str, jnp.ndarray],
|
||||||
|
Sequence[int]],
|
||||||
|
) -> Tuple[utils.Params, hk.State]:
|
||||||
|
"""Initializes the whole model parameters and state."""
|
||||||
|
if (isinstance(inputs_or_shape, (tuple, list))
|
||||||
|
and isinstance(inputs_or_shape[0], int)):
|
||||||
|
images = jnp.zeros(inputs_or_shape)
|
||||||
|
else:
|
||||||
|
images = utils.extract_image(inputs_or_shape)
|
||||||
|
if self._jit_init is None:
|
||||||
|
self._jit_init = jax.jit(self._init)
|
||||||
|
return self._jit_init(rng, images)
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Module for all models."""
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import physics_inspired_models.models.autoregressive as autoregressive
|
||||||
|
import physics_inspired_models.models.deterministic_vae as deterministic_vae
|
||||||
|
|
||||||
|
_physics_arguments = (
|
||||||
|
"input_space", "simulation_space", "potential_func_form",
|
||||||
|
"kinetic_func_form", "hgn_kinetic_func_form", "lgn_kinetic_func_form",
|
||||||
|
"parametrize_mass_matrix", "hgn_parametrize_mass_matrix",
|
||||||
|
"lgn_parametrize_mass_matrix", "mass_eps"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_model(
|
||||||
|
name: str,
|
||||||
|
*args,
|
||||||
|
**kwargs: Dict[str, Any]
|
||||||
|
):
|
||||||
|
"""Constructs the correct instance of a model given the short name."""
|
||||||
|
latent_dynamics_type: Optional[str] = kwargs.pop("latent_dynamics_type", None) # pytype: disable=annotation-type-mismatch
|
||||||
|
latent_system_kwargs = dict(**kwargs.pop("latent_system_kwargs", dict()))
|
||||||
|
if name == "AR":
|
||||||
|
assert latent_dynamics_type in ("vanilla", "lstm", "gru")
|
||||||
|
# This arguments are not part of the AR models
|
||||||
|
for k in _physics_arguments + ("integrator_method", "residual"):
|
||||||
|
latent_system_kwargs.pop(k, None)
|
||||||
|
return autoregressive.TeacherForcingAutoregressiveModel(
|
||||||
|
*args,
|
||||||
|
latent_dynamics_type=latent_dynamics_type,
|
||||||
|
latent_system_kwargs=latent_system_kwargs,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
elif name == "RGN":
|
||||||
|
assert latent_dynamics_type in ("Discrete", None)
|
||||||
|
latent_dynamics_type = "Discrete"
|
||||||
|
# This arguments are not part of the RGN models
|
||||||
|
for k in _physics_arguments + ("integrator_method",):
|
||||||
|
latent_system_kwargs.pop(k, None)
|
||||||
|
elif name == "ODE":
|
||||||
|
assert latent_dynamics_type in ("ODE", None)
|
||||||
|
latent_dynamics_type = "ODE"
|
||||||
|
# This arguments are not part of the ODE models
|
||||||
|
for k in _physics_arguments + ("residual",):
|
||||||
|
latent_system_kwargs.pop(k, None)
|
||||||
|
elif name == "HGN":
|
||||||
|
assert latent_dynamics_type in ("Physics", None)
|
||||||
|
latent_dynamics_type = "Physics"
|
||||||
|
assert latent_system_kwargs.get("input_space", None) in ("momentum", None)
|
||||||
|
latent_system_kwargs["input_space"] = "momentum"
|
||||||
|
assert (latent_system_kwargs.get("simulation_space", None)
|
||||||
|
in ("momentum", None))
|
||||||
|
latent_system_kwargs["simulation_space"] = "momentum"
|
||||||
|
# Kinetic func form
|
||||||
|
hgn_specific = latent_system_kwargs.pop("hgn_kinetic_func_form", None)
|
||||||
|
if hgn_specific is not None:
|
||||||
|
latent_system_kwargs["kinetic_func_form"] = hgn_specific
|
||||||
|
# Mass matrix
|
||||||
|
hgn_specific = latent_system_kwargs.pop("hgn_parametrize_mass_matrix",
|
||||||
|
None)
|
||||||
|
if hgn_specific is not None:
|
||||||
|
latent_system_kwargs["parametrize_mass_matrix"] = hgn_specific
|
||||||
|
# This arguments are not part of the HGN models
|
||||||
|
latent_system_kwargs.pop("residual", None)
|
||||||
|
latent_system_kwargs.pop("lgn_kinetic_func_form", None)
|
||||||
|
latent_system_kwargs.pop("lgn_parametrize_mass_matrix", None)
|
||||||
|
elif name == "LGN":
|
||||||
|
assert latent_dynamics_type in ("Physics", None)
|
||||||
|
latent_dynamics_type = "Physics"
|
||||||
|
assert latent_system_kwargs.get("input_space", None) in ("velocity", None)
|
||||||
|
latent_system_kwargs["input_space"] = "velocity"
|
||||||
|
assert (latent_system_kwargs.get("simulation_space", None) in
|
||||||
|
("velocity", None))
|
||||||
|
latent_system_kwargs["simulation_space"] = "velocity"
|
||||||
|
# Kinetic func form
|
||||||
|
lgn_specific = latent_system_kwargs.pop("lgn_kinetic_func_form", None)
|
||||||
|
if lgn_specific is not None:
|
||||||
|
latent_system_kwargs["kinetic_func_form"] = lgn_specific
|
||||||
|
# Mass matrix
|
||||||
|
lgn_specific = latent_system_kwargs.pop("lgn_parametrize_mass_matrix",
|
||||||
|
None)
|
||||||
|
if lgn_specific is not None:
|
||||||
|
latent_system_kwargs["parametrize_mass_matrix"] = lgn_specific
|
||||||
|
# This arguments are not part of the HGN models
|
||||||
|
latent_system_kwargs.pop("residual", None)
|
||||||
|
latent_system_kwargs.pop("hgn_kinetic_func_form", None)
|
||||||
|
latent_system_kwargs.pop("hgn_parametrize_mass_matrix", None)
|
||||||
|
elif name == "PGN":
|
||||||
|
assert latent_dynamics_type in ("Physics", None)
|
||||||
|
latent_dynamics_type = "Physics"
|
||||||
|
# This arguments are not part of the PGN models
|
||||||
|
latent_system_kwargs.pop("residual")
|
||||||
|
latent_system_kwargs.pop("hgn_kinetic_func_form", None)
|
||||||
|
latent_system_kwargs.pop("hgn_parametrize_mass_matrix", None)
|
||||||
|
latent_system_kwargs.pop("lgn_kinetic_func_form", None)
|
||||||
|
latent_system_kwargs.pop("lgn_parametrize_mass_matrix", None)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
return deterministic_vae.DeterministicLatentsGenerativeModel(
|
||||||
|
*args,
|
||||||
|
latent_dynamics_type=latent_dynamics_type,
|
||||||
|
latent_system_kwargs=latent_system_kwargs,
|
||||||
|
**kwargs)
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,494 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Module containing all of the networks as Haiku modules."""
|
||||||
|
from typing import Any, Callable, Mapping, Optional, Sequence, Union
|
||||||
|
|
||||||
|
from absl import logging
|
||||||
|
import distrax
|
||||||
|
import haiku as hk
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from physics_inspired_models import utils
|
||||||
|
|
||||||
|
Activation = Union[str, Callable[[jnp.ndarray], jnp.ndarray]]
|
||||||
|
|
||||||
|
|
||||||
|
class DenseNet(hk.Module):
|
||||||
|
"""A feed forward network (MLP)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_units: Sequence[int],
|
||||||
|
activate_final: bool = False,
|
||||||
|
activation: Activation = "leaky_relu",
|
||||||
|
name: Optional[str] = None):
|
||||||
|
super().__init__(name=name)
|
||||||
|
self.num_units = num_units
|
||||||
|
self.num_layers = len(self.num_units)
|
||||||
|
self.activate_final = activate_final
|
||||||
|
self.activation = utils.get_activation(activation)
|
||||||
|
|
||||||
|
self.linear_modules = []
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
self.linear_modules.append(
|
||||||
|
hk.Linear(
|
||||||
|
output_size=self.num_units[i],
|
||||||
|
name=f"ff_{i}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, inputs: jnp.ndarray, is_training: bool):
|
||||||
|
net = inputs
|
||||||
|
for i, linear in enumerate(self.linear_modules):
|
||||||
|
net = linear(net)
|
||||||
|
if i < self.num_layers - 1 or self.activate_final:
|
||||||
|
net = self.activation(net)
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2DNet(hk.Module):
|
||||||
|
"""Convolutional Network."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_channels: Sequence[int],
|
||||||
|
kernel_shapes: Union[int, Sequence[int]] = 3,
|
||||||
|
strides: Union[int, Sequence[int]] = 1,
|
||||||
|
padding: Union[str, Sequence[str]] = "SAME",
|
||||||
|
data_format: str = "NHWC",
|
||||||
|
with_batch_norm: bool = False,
|
||||||
|
activate_final: bool = False,
|
||||||
|
activation: Activation = "leaky_relu",
|
||||||
|
name: Optional[str] = None):
|
||||||
|
super().__init__(name=name)
|
||||||
|
self.output_channels = tuple(output_channels)
|
||||||
|
self.num_layers = len(self.output_channels)
|
||||||
|
self.kernel_shapes = utils.bcast_if(kernel_shapes, int, self.num_layers)
|
||||||
|
self.strides = utils.bcast_if(strides, int, self.num_layers)
|
||||||
|
self.padding = utils.bcast_if(padding, str, self.num_layers)
|
||||||
|
self.data_format = data_format
|
||||||
|
self.with_batch_norm = with_batch_norm
|
||||||
|
self.activate_final = activate_final
|
||||||
|
self.activation = utils.get_activation(activation)
|
||||||
|
|
||||||
|
if len(self.kernel_shapes) != self.num_layers:
|
||||||
|
raise ValueError(f"Kernel shapes is of size {len(self.kernel_shapes)}, "
|
||||||
|
f"while output_channels is of size{self.num_layers}.")
|
||||||
|
if len(self.strides) != self.num_layers:
|
||||||
|
raise ValueError(f"Strides is of size {len(self.kernel_shapes)}, while "
|
||||||
|
f"output_channels is of size{self.num_layers}.")
|
||||||
|
if len(self.padding) != self.num_layers:
|
||||||
|
raise ValueError(f"Padding is of size {len(self.padding)}, while "
|
||||||
|
f"output_channels is of size{self.num_layers}.")
|
||||||
|
|
||||||
|
self.conv_modules = []
|
||||||
|
self.bn_modules = []
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
self.conv_modules.append(
|
||||||
|
hk.Conv2D(
|
||||||
|
output_channels=self.output_channels[i],
|
||||||
|
kernel_shape=self.kernel_shapes[i],
|
||||||
|
stride=self.strides[i],
|
||||||
|
padding=self.padding[i],
|
||||||
|
data_format=data_format,
|
||||||
|
name=f"conv_2d_{i}")
|
||||||
|
)
|
||||||
|
if with_batch_norm:
|
||||||
|
self.bn_modules.append(
|
||||||
|
hk.BatchNorm(
|
||||||
|
create_offset=True,
|
||||||
|
create_scale=False,
|
||||||
|
decay_rate=0.999,
|
||||||
|
name=f"batch_norm_{i}")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.bn_modules.append(None)
|
||||||
|
|
||||||
|
def __call__(self, inputs: jnp.ndarray, is_training: bool):
|
||||||
|
assert inputs.ndim == 4
|
||||||
|
net = inputs
|
||||||
|
for i, (conv, bn) in enumerate(zip(self.conv_modules, self.bn_modules)):
|
||||||
|
net = conv(net)
|
||||||
|
# Batch norm
|
||||||
|
if bn is not None:
|
||||||
|
net = bn(net, is_training=is_training)
|
||||||
|
if i < self.num_layers - 1 or self.activate_final:
|
||||||
|
net = self.activation(net)
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialConvEncoder(hk.Module):
|
||||||
|
"""Spatial Convolutional Encoder for learning the Hamiltonian."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
latent_dim: int,
|
||||||
|
conv_channels: Union[Sequence[int], int],
|
||||||
|
num_blocks: int,
|
||||||
|
blocks_depth: int = 2,
|
||||||
|
distribution_name: str = "diagonal_normal",
|
||||||
|
aggregation_type: Optional[str] = None,
|
||||||
|
data_format: str = "NHWC",
|
||||||
|
activation: Activation = "leaky_relu",
|
||||||
|
scale_factor: int = 2,
|
||||||
|
kernel_shapes: Union[Sequence[int], int] = 3,
|
||||||
|
padding: Union[Sequence[str], str] = "SAME",
|
||||||
|
name: Optional[str] = None):
|
||||||
|
super().__init__(name=name)
|
||||||
|
if aggregation_type not in (None, "max", "mean", "linear_projection"):
|
||||||
|
raise ValueError(f"Unrecognized aggregation_type={aggregation_type}.")
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
self.conv_channels = conv_channels
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.data_format = data_format
|
||||||
|
self.distribution_name = distribution_name
|
||||||
|
self.aggregation_type = aggregation_type
|
||||||
|
|
||||||
|
# Compute the required size of the output
|
||||||
|
if distribution_name is None:
|
||||||
|
self.output_dim = latent_dim
|
||||||
|
elif distribution_name == "diagonal_normal":
|
||||||
|
self.output_dim = 2 * latent_dim
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized distribution_name={distribution_name}.")
|
||||||
|
|
||||||
|
if isinstance(conv_channels, int):
|
||||||
|
conv_channels = [[conv_channels] * blocks_depth
|
||||||
|
for _ in range(num_blocks)]
|
||||||
|
conv_channels[-1] += [self.output_dim]
|
||||||
|
else:
|
||||||
|
assert isinstance(conv_channels, (list, tuple))
|
||||||
|
assert len(conv_channels) == num_blocks
|
||||||
|
conv_channels = list(list(c) for c in conv_channels)
|
||||||
|
conv_channels[-1].append(self.output_dim)
|
||||||
|
|
||||||
|
if isinstance(kernel_shapes, tuple):
|
||||||
|
kernel_shapes = list(kernel_shapes)
|
||||||
|
|
||||||
|
# Convolutional blocks
|
||||||
|
self.blocks = []
|
||||||
|
for i, channels in enumerate(conv_channels):
|
||||||
|
if isinstance(kernel_shapes, int):
|
||||||
|
extra_kernel_shapes = 0
|
||||||
|
else:
|
||||||
|
extra_kernel_shapes = [3] * (len(channels) - len(kernel_shapes))
|
||||||
|
|
||||||
|
self.blocks.append(Conv2DNet(
|
||||||
|
output_channels=channels,
|
||||||
|
kernel_shapes=kernel_shapes + extra_kernel_shapes,
|
||||||
|
strides=[self.scale_factor] + [1] * (len(channels) - 1),
|
||||||
|
padding=padding,
|
||||||
|
data_format=data_format,
|
||||||
|
with_batch_norm=False,
|
||||||
|
activate_final=i < num_blocks - 1,
|
||||||
|
activation=activation,
|
||||||
|
name=f"block_{i}"
|
||||||
|
))
|
||||||
|
|
||||||
|
def spatial_aggregation(self, x: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
if self.aggregation_type is None:
|
||||||
|
return x
|
||||||
|
axis = (1, 2) if self.data_format == "NHWC" else (2, 3)
|
||||||
|
if self.aggregation_type == "max":
|
||||||
|
return jnp.max(x, axis=axis)
|
||||||
|
if self.aggregation_type == "mean":
|
||||||
|
return jnp.mean(x, axis=axis)
|
||||||
|
if self.aggregation_type == "linear_projection":
|
||||||
|
x = x.reshape(x.shape[:-3] + (-1,))
|
||||||
|
return hk.Linear(self.output_dim, name="LinearProjection")(x)
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def make_distribution(self, net_output: jnp.ndarray) -> distrax.Distribution:
|
||||||
|
if self.distribution_name is None:
|
||||||
|
return net_output
|
||||||
|
elif self.distribution_name == "diagonal_normal":
|
||||||
|
if self.aggregation_type is None:
|
||||||
|
split_axis, num_axes = self.data_format.index("C"), 3
|
||||||
|
else:
|
||||||
|
split_axis, num_axes = 1, 1
|
||||||
|
# Add an extra axis if the input has more than 1 batch dimension
|
||||||
|
split_axis += net_output.ndim - num_axes - 1
|
||||||
|
loc, log_scale = jnp.split(net_output, 2, axis=split_axis)
|
||||||
|
return distrax.Normal(loc, jnp.exp(log_scale))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: jnp.ndarray,
|
||||||
|
is_training: bool
|
||||||
|
) -> Union[jnp.ndarray, distrax.Distribution]:
|
||||||
|
# Treat any extra dimensions (like time) as the batch
|
||||||
|
batched_shape = inputs.shape[:-3]
|
||||||
|
net = jnp.reshape(inputs, (-1,) + inputs.shape[-3:])
|
||||||
|
|
||||||
|
# Apply all blocks in sequence
|
||||||
|
for block in self.blocks:
|
||||||
|
net = block(net, is_training=is_training)
|
||||||
|
|
||||||
|
# Final projection
|
||||||
|
net = self.spatial_aggregation(net)
|
||||||
|
|
||||||
|
# Reshape back to correct dimensions (like batch + time)
|
||||||
|
net = jnp.reshape(net, batched_shape + net.shape[1:])
|
||||||
|
|
||||||
|
# Return a distribution over the observations
|
||||||
|
return self.make_distribution(net)
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialConvDecoder(hk.Module):
|
||||||
|
"""Spatial Convolutional Decoder for learning the Hamiltonian."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
initial_spatial_shape: Sequence[int],
|
||||||
|
conv_channels: Union[Sequence[int], int],
|
||||||
|
num_blocks: int,
|
||||||
|
max_de_aggregation_dims: int,
|
||||||
|
blocks_depth: int = 2,
|
||||||
|
scale_factor: int = 2,
|
||||||
|
output_channels: int = 3,
|
||||||
|
h_const_channels: int = 2,
|
||||||
|
data_format: str = "NHWC",
|
||||||
|
activation: Activation = "leaky_relu",
|
||||||
|
learned_sigma: bool = False,
|
||||||
|
de_aggregation_type: Optional[str] = None,
|
||||||
|
final_activation: Activation = "sigmoid",
|
||||||
|
discard_half_de_aggregated: bool = False,
|
||||||
|
kernel_shapes: Union[Sequence[int], int] = 3,
|
||||||
|
padding: Union[Sequence[str], str] = "SAME",
|
||||||
|
name: Optional[str] = None):
|
||||||
|
super().__init__(name=name)
|
||||||
|
if de_aggregation_type not in (None, "tile", "linear_projection"):
|
||||||
|
raise ValueError(f"Unrecognized de_aggregation_type="
|
||||||
|
f"{de_aggregation_type}.")
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.h_const_channels = h_const_channels
|
||||||
|
self.data_format = data_format
|
||||||
|
self.learned_sigma = learned_sigma
|
||||||
|
self.initial_spatial_shape = tuple(initial_spatial_shape)
|
||||||
|
self.final_activation = utils.get_activation(final_activation)
|
||||||
|
self.de_aggregation_type = de_aggregation_type
|
||||||
|
self.max_de_aggregation_dims = max_de_aggregation_dims
|
||||||
|
self.discard_half_de_aggregated = discard_half_de_aggregated
|
||||||
|
|
||||||
|
if isinstance(conv_channels, int):
|
||||||
|
conv_channels = [[conv_channels] * blocks_depth
|
||||||
|
for _ in range(num_blocks)]
|
||||||
|
conv_channels[-1] += [output_channels]
|
||||||
|
else:
|
||||||
|
assert isinstance(conv_channels, (list, tuple))
|
||||||
|
assert len(conv_channels) == num_blocks
|
||||||
|
conv_channels = list(list(c) for c in conv_channels)
|
||||||
|
conv_channels[-1].append(output_channels)
|
||||||
|
|
||||||
|
# Convolutional blocks
|
||||||
|
self.blocks = []
|
||||||
|
for i, channels in enumerate(conv_channels):
|
||||||
|
is_final_block = i == num_blocks - 1
|
||||||
|
self.blocks.append(
|
||||||
|
Conv2DNet( # pylint: disable=g-complex-comprehension
|
||||||
|
output_channels=channels,
|
||||||
|
kernel_shapes=kernel_shapes,
|
||||||
|
strides=1,
|
||||||
|
padding=padding,
|
||||||
|
data_format=data_format,
|
||||||
|
with_batch_norm=False,
|
||||||
|
activate_final=not is_final_block,
|
||||||
|
activation=activation,
|
||||||
|
name=f"block_{i}"
|
||||||
|
))
|
||||||
|
|
||||||
|
def spatial_de_aggregation(self, x: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
if self.de_aggregation_type is None:
|
||||||
|
assert x.ndim >= 4
|
||||||
|
if self.data_format == "NHWC":
|
||||||
|
assert x.shape[1:3] == self.initial_spatial_shape
|
||||||
|
elif self.data_format == "NCHW":
|
||||||
|
assert x.shape[2:4] == self.initial_spatial_shape
|
||||||
|
return x
|
||||||
|
elif self.de_aggregation_type == "linear_projection":
|
||||||
|
assert x.ndim == 2
|
||||||
|
n, d = x.shape
|
||||||
|
d = min(d, self.max_de_aggregation_dims or d)
|
||||||
|
out_d = d * self.initial_spatial_shape[0] * self.initial_spatial_shape[1]
|
||||||
|
x = hk.Linear(out_d, name="LinearProjection")(x)
|
||||||
|
if self.data_format == "NHWC":
|
||||||
|
shape = (n,) + self.initial_spatial_shape + (d,)
|
||||||
|
else:
|
||||||
|
shape = (n, d) + self.initial_spatial_shape
|
||||||
|
return x.reshape(shape)
|
||||||
|
elif self.de_aggregation_type == "tile":
|
||||||
|
assert x.ndim == 2
|
||||||
|
if self.data_format == "NHWC":
|
||||||
|
repeats = (1,) + self.initial_spatial_shape + (1,)
|
||||||
|
x = x[:, None, None, :]
|
||||||
|
else:
|
||||||
|
repeats = (1, 1) + self.initial_spatial_shape
|
||||||
|
x = x[:, :, None, None]
|
||||||
|
return jnp.tile(x, repeats)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def add_constant_channels(self, inputs: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
# --------------------------------------------
|
||||||
|
# This is purely for TF compatibility purposes
|
||||||
|
if self.discard_half_de_aggregated:
|
||||||
|
axis = self.data_format.index("C")
|
||||||
|
inputs, _ = jnp.split(inputs, 2, axis=axis)
|
||||||
|
# --------------------------------------------
|
||||||
|
|
||||||
|
# An extra constant channels
|
||||||
|
if self.data_format == "NHWC":
|
||||||
|
h_shape = self.initial_spatial_shape + (self.h_const_channels,)
|
||||||
|
else:
|
||||||
|
h_shape = (self.h_const_channels,) + self.initial_spatial_shape
|
||||||
|
h_const = hk.get_parameter("h", h_shape, dtype=inputs.dtype,
|
||||||
|
init=hk.initializers.Constant(1))
|
||||||
|
h_const = jnp.tile(h_const, reps=[inputs.shape[0], 1, 1, 1])
|
||||||
|
return jnp.concatenate([h_const, inputs], axis=self.data_format.index("C"))
|
||||||
|
|
||||||
|
def make_distribution(self, net_output: jnp.ndarray) -> distrax.Distribution:
|
||||||
|
if self.learned_sigma:
|
||||||
|
init = hk.initializers.Constant(- jnp.log(2.0) / 2.0)
|
||||||
|
log_scale = hk.get_parameter("log_scale", shape=(),
|
||||||
|
dtype=net_output.dtype, init=init)
|
||||||
|
scale = jnp.full_like(net_output, jnp.exp(log_scale))
|
||||||
|
else:
|
||||||
|
scale = jnp.full_like(net_output, 1 / jnp.sqrt(2.0))
|
||||||
|
|
||||||
|
return distrax.Normal(net_output, scale)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: jnp.ndarray,
|
||||||
|
is_training: bool
|
||||||
|
) -> distrax.Distribution:
|
||||||
|
# Apply the spatial de-aggregation
|
||||||
|
inputs = self.spatial_de_aggregation(inputs)
|
||||||
|
|
||||||
|
# Add the parameterized constant channels
|
||||||
|
net = self.add_constant_channels(inputs)
|
||||||
|
|
||||||
|
# Apply all the blocks
|
||||||
|
for block in self.blocks:
|
||||||
|
# Up-sample the image
|
||||||
|
net = utils.nearest_neighbour_upsampling(net, self.scale_factor)
|
||||||
|
# Apply the convolutional block
|
||||||
|
net = block(net, is_training=is_training)
|
||||||
|
|
||||||
|
# Apply any specific output nonlinearity
|
||||||
|
net = self.final_activation(net)
|
||||||
|
|
||||||
|
# Construct the distribution over the observations
|
||||||
|
return self.make_distribution(net)
|
||||||
|
|
||||||
|
|
||||||
|
def make_flexible_net(
|
||||||
|
net_type: str,
|
||||||
|
output_dims: int,
|
||||||
|
conv_channels: Union[Sequence[int], int],
|
||||||
|
num_units: Union[Sequence[int], int],
|
||||||
|
num_layers: Optional[int],
|
||||||
|
activation: Activation,
|
||||||
|
activate_final: bool = False,
|
||||||
|
kernel_shapes: Union[Sequence[int], int] = 3,
|
||||||
|
strides: Union[Sequence[int], int] = 1,
|
||||||
|
padding: Union[Sequence[str], str] = "SAME",
|
||||||
|
name: Optional[str] = None,
|
||||||
|
**unused_kwargs: Mapping[str, Any]
|
||||||
|
):
|
||||||
|
"""Commonly used for creating a flexible network."""
|
||||||
|
if unused_kwargs:
|
||||||
|
logging.warning("Unused kwargs of `make_flexible_net`: %s",
|
||||||
|
str(unused_kwargs))
|
||||||
|
if net_type == "mlp":
|
||||||
|
if isinstance(num_units, int):
|
||||||
|
assert num_layers is not None
|
||||||
|
num_units = [num_units] * (num_layers - 1) + [output_dims]
|
||||||
|
else:
|
||||||
|
num_units = list(num_units) + [output_dims]
|
||||||
|
return DenseNet(
|
||||||
|
num_units=num_units,
|
||||||
|
activation=activation,
|
||||||
|
activate_final=activate_final,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
elif net_type == "conv":
|
||||||
|
if isinstance(conv_channels, int):
|
||||||
|
assert num_layers is not None
|
||||||
|
conv_channels = [conv_channels] * (num_layers - 1) + [output_dims]
|
||||||
|
else:
|
||||||
|
conv_channels = list(conv_channels) + [output_dims]
|
||||||
|
return Conv2DNet(
|
||||||
|
output_channels=conv_channels,
|
||||||
|
kernel_shapes=kernel_shapes,
|
||||||
|
strides=strides,
|
||||||
|
padding=padding,
|
||||||
|
activation=activation,
|
||||||
|
activate_final=activate_final,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
elif net_type == "transformer":
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized net_type={net_type}.")
|
||||||
|
|
||||||
|
|
||||||
|
def make_flexible_recurrent_net(
|
||||||
|
core_type: str,
|
||||||
|
net_type: str,
|
||||||
|
output_dims: int,
|
||||||
|
num_units: Union[Sequence[int], int],
|
||||||
|
num_layers: Optional[int],
|
||||||
|
activation: Activation,
|
||||||
|
activate_final: bool = False,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
**unused_kwargs
|
||||||
|
):
|
||||||
|
"""Commonly used for creating a flexible recurrences."""
|
||||||
|
if net_type != "mlp":
|
||||||
|
raise ValueError("We do not support convolutional recurrent nets atm.")
|
||||||
|
if unused_kwargs:
|
||||||
|
logging.warning("Unused kwargs of `make_flexible_recurrent_net`: %s",
|
||||||
|
str(unused_kwargs))
|
||||||
|
|
||||||
|
if isinstance(num_units, (list, tuple)):
|
||||||
|
num_units = list(num_units) + [output_dims]
|
||||||
|
num_layers = len(num_units)
|
||||||
|
else:
|
||||||
|
assert num_layers is not None
|
||||||
|
num_units = [num_units] * (num_layers - 1) + [output_dims]
|
||||||
|
name = name or f"{core_type.upper()}"
|
||||||
|
|
||||||
|
activation = utils.get_activation(activation)
|
||||||
|
core_list = []
|
||||||
|
for i, n in enumerate(num_units):
|
||||||
|
if core_type.lower() == "vanilla":
|
||||||
|
core_list.append(hk.VanillaRNN(hidden_size=n, name=f"{name}_{i}"))
|
||||||
|
elif core_type.lower() == "lstm":
|
||||||
|
core_list.append(hk.LSTM(hidden_size=n, name=f"{name}_{i}"))
|
||||||
|
elif core_type.lower() == "gru":
|
||||||
|
core_list.append(hk.GRU(hidden_size=n, name=f"{name}_{i}"))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized core_type={core_type}.")
|
||||||
|
if i != num_layers - 1:
|
||||||
|
core_list.append(activation)
|
||||||
|
if activate_final:
|
||||||
|
core_list.append(activation)
|
||||||
|
|
||||||
|
return hk.DeepRNN(core_list, name="RNN")
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
git+https://github.com/deepmind/dm_hamiltonian_dynamics_suite@main#egg=dm_hamiltonian_dynamics_suite
|
||||||
|
absl-py==0.12.0
|
||||||
|
numpy>=1.16.4
|
||||||
|
scikit-learn>=1.0
|
||||||
|
typing>=3.7.4.3
|
||||||
|
jax==0.2.20
|
||||||
|
jaxline==0.0.3
|
||||||
|
distrax==0.0.2
|
||||||
|
optax==0.0.6
|
||||||
|
dm-haiku==0.0.3
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Setup for pip package."""
|
||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
REQUIRED_PACKAGES = (
|
||||||
|
"dm_hamiltonian_dynamics_suite@git+https://github.com/deepmind/dm_hamiltonian_dynamics_suite", # pylint: disable=line-too-long.
|
||||||
|
"absl-py>=0.12.0",
|
||||||
|
"numpy>=1.16.4",
|
||||||
|
"scikit-learn>=1.0",
|
||||||
|
"typing>=3.7.4.3",
|
||||||
|
"jax==0.2.20",
|
||||||
|
"jaxline==0.0.3",
|
||||||
|
"distrax==0.0.2",
|
||||||
|
"optax==0.0.6",
|
||||||
|
"dm-haiku==0.0.3",
|
||||||
|
)
|
||||||
|
|
||||||
|
LONG_DESCRIPTION = "\n".join([
|
||||||
|
"A codebase containing the implementation of the following models:",
|
||||||
|
"Hamiltonian Generative Network (HGN)",
|
||||||
|
"Lagrangian Generative Network (LGN)",
|
||||||
|
"Neural ODE",
|
||||||
|
"Recurrent Generative Network (RGN)",
|
||||||
|
"and RNN, LSTM and GRU.",
|
||||||
|
"This is code accompanying the publication of:"
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="physics_inspired_models",
|
||||||
|
version="0.0.1",
|
||||||
|
description="Implementation of multiple physically inspired models.",
|
||||||
|
long_description=LONG_DESCRIPTION,
|
||||||
|
url="https://github.com/deepmind/deepmind-research/physics_inspired_models",
|
||||||
|
author="DeepMind",
|
||||||
|
package_dir={"physics_inspired_models": "."},
|
||||||
|
packages=["physics_inspired_models", "physics_inspired_models.models"],
|
||||||
|
install_requires=REQUIRED_PACKAGES,
|
||||||
|
platforms=["any"],
|
||||||
|
license="Apache License, Version 2.0",
|
||||||
|
)
|
||||||
@@ -0,0 +1,274 @@
|
|||||||
|
# Copyright 2020 DeepMind Technologies Limited.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Utilities functions for Jax."""
|
||||||
|
import collections
|
||||||
|
import functools
|
||||||
|
from typing import Any, Callable, Dict, Mapping, Union
|
||||||
|
|
||||||
|
import distrax
|
||||||
|
import jax
|
||||||
|
from jax import core
|
||||||
|
from jax import lax
|
||||||
|
from jax import nn
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax.tree_util import register_pytree_node
|
||||||
|
from jaxline import utils
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
HaikuParams = Mapping[str, Mapping[str, jnp.ndarray]]
|
||||||
|
Params = Union[Mapping[str, jnp.ndarray], HaikuParams, jnp.ndarray]
|
||||||
|
_Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
||||||
|
|
||||||
|
tf_leaky_relu = functools.partial(nn.leaky_relu, negative_slope=0.2)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_only_scalar_stats(stats):
|
||||||
|
return {k: v for k, v in stats.items() if v.size == 1}
|
||||||
|
|
||||||
|
|
||||||
|
def to_numpy(obj):
|
||||||
|
return jax.tree_map(np.array, obj)
|
||||||
|
|
||||||
|
|
||||||
|
@jax.custom_gradient
|
||||||
|
def geco_lagrange_product(lagrange_multiplier, constraint_ema, constraint_t):
|
||||||
|
"""Modifies the gradients so that they work as described in GECO.
|
||||||
|
|
||||||
|
The evaluation gives:
|
||||||
|
lagrange * C_ema
|
||||||
|
The gradient w.r.t lagrange:
|
||||||
|
- g * C_t
|
||||||
|
The gradient w.r.t constraint_ema:
|
||||||
|
0.0
|
||||||
|
The gradient w.r.t constraint_t:
|
||||||
|
g * lagrange
|
||||||
|
|
||||||
|
Note that if you pass the same value for `constraint_ema` and `constraint_t`
|
||||||
|
this would only flip the gradient for the lagrange multiplier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lagrange_multiplier: The lagrange multiplier
|
||||||
|
constraint_ema: The moving average of the constraint
|
||||||
|
constraint_t: The current constraint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
def grad(gradient):
|
||||||
|
return (- gradient * constraint_t,
|
||||||
|
jnp.zeros_like(constraint_ema),
|
||||||
|
gradient * lagrange_multiplier)
|
||||||
|
return lagrange_multiplier * constraint_ema, grad
|
||||||
|
|
||||||
|
|
||||||
|
def bcast_if(x, t, n):
|
||||||
|
return [x] * n if isinstance(x, t) else x
|
||||||
|
|
||||||
|
|
||||||
|
def stack_time_into_channels(
|
||||||
|
images: jnp.ndarray,
|
||||||
|
data_format: str
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
axis = data_format.index("C")
|
||||||
|
list_of_time = [jnp.squeeze(v, axis=1) for v in
|
||||||
|
jnp.split(images, images.shape[1], axis=1)]
|
||||||
|
return jnp.concatenate(list_of_time, axis)
|
||||||
|
|
||||||
|
|
||||||
|
def stack_device_dim_into_batch(obj):
|
||||||
|
return jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:]), obj)
|
||||||
|
|
||||||
|
|
||||||
|
def nearest_neighbour_upsampling(x, scale, data_format="NHWC"):
|
||||||
|
"""Performs nearest-neighbour upsampling."""
|
||||||
|
|
||||||
|
if data_format == "NCHW":
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
x = jnp.reshape(x, [b, c, h, 1, w, 1])
|
||||||
|
ones = jnp.ones([1, 1, 1, scale, 1, scale], dtype=x.dtype)
|
||||||
|
return jnp.reshape(x * ones, [b, c, scale * h, scale * w])
|
||||||
|
elif data_format == "NHWC":
|
||||||
|
b, h, w, c = x.shape
|
||||||
|
x = jnp.reshape(x, [b, h, 1, w, 1, c])
|
||||||
|
ones = jnp.ones([1, 1, scale, 1, scale, 1], dtype=x.dtype)
|
||||||
|
return jnp.reshape(x * ones, [b, scale * h, scale * w, c])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized data_format={data_format}.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation(arg: Union[_Activation, str]) -> _Activation:
|
||||||
|
"""Returns an activation from provided string."""
|
||||||
|
if isinstance(arg, str):
|
||||||
|
# Try fetch in order - [this module, jax.nn, jax.numpy]
|
||||||
|
if arg in globals():
|
||||||
|
return globals()[arg]
|
||||||
|
if hasattr(nn, arg):
|
||||||
|
return getattr(nn, arg)
|
||||||
|
elif hasattr(jnp, arg):
|
||||||
|
return getattr(jnp, arg)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized activation with name {arg}.")
|
||||||
|
if not callable(arg):
|
||||||
|
raise ValueError(f"Expected a callable, but got {type(arg)}")
|
||||||
|
return arg
|
||||||
|
|
||||||
|
|
||||||
|
def merge_first_dims(x: jnp.ndarray, num_dims_to_merge: int = 2) -> jnp.ndarray:
|
||||||
|
return x.reshape((-1,) + x.shape[num_dims_to_merge:])
|
||||||
|
|
||||||
|
|
||||||
|
def extract_image(
|
||||||
|
inputs: Union[jnp.ndarray, Mapping[str, jnp.ndarray]]
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Extracts a tensor with key `image` or `x_image` if it is a dict, otherwise returns the inputs."""
|
||||||
|
if isinstance(inputs, dict):
|
||||||
|
if "image" in inputs:
|
||||||
|
return inputs["image"]
|
||||||
|
else:
|
||||||
|
return inputs["x_image"]
|
||||||
|
elif isinstance(inputs, jnp.ndarray):
|
||||||
|
return inputs
|
||||||
|
raise NotImplementedError(f"Not implemented of inputs of type"
|
||||||
|
f" {type(inputs)}.")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_gt_state(inputs: Any) -> jnp.ndarray:
|
||||||
|
if isinstance(inputs, dict):
|
||||||
|
return inputs["x"]
|
||||||
|
elif not isinstance(inputs, jnp.ndarray):
|
||||||
|
raise NotImplementedError(f"Not implemented of inputs of type"
|
||||||
|
f" {type(inputs)}.")
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_latents_conv_to_flat(conv_latents, axis_n_to_keep=1):
|
||||||
|
q, p = jnp.split(conv_latents, 2, axis=-1)
|
||||||
|
q = jax.tree_map(lambda x: x.reshape(x.shape[:axis_n_to_keep] + (-1,)), q)
|
||||||
|
p = jax.tree_map(lambda x: x.reshape(x.shape[:axis_n_to_keep] + (-1,)), p)
|
||||||
|
flat_latents = jnp.concatenate([q, p], axis=-1)
|
||||||
|
|
||||||
|
return flat_latents
|
||||||
|
|
||||||
|
|
||||||
|
def triu_matrix_from_v(x, ndim):
|
||||||
|
assert x.shape[-1] == (ndim * (ndim + 1)) // 2
|
||||||
|
matrix = jnp.zeros(x.shape[:-1] + (ndim, ndim))
|
||||||
|
idx = jnp.triu_indices(ndim)
|
||||||
|
index_update = lambda x, idx, y: x.at[idx].set(y)
|
||||||
|
for _ in range(x.ndim - 1):
|
||||||
|
index_update = jax.vmap(index_update, in_axes=(0, None, 0))
|
||||||
|
return index_update(matrix, idx, x)
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_dict(d, parent_key: str = "", sep: str = "_") -> Dict[str, Any]:
|
||||||
|
items = []
|
||||||
|
for k, v in d.items():
|
||||||
|
new_key = parent_key + sep + k if parent_key else k
|
||||||
|
if isinstance(v, collections.MutableMapping):
|
||||||
|
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||||
|
else:
|
||||||
|
items.append((new_key, v))
|
||||||
|
return dict(items)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_pytype(target, reference):
|
||||||
|
"""Makes target the same pytype as reference, by jax.tree_flatten."""
|
||||||
|
_, pytree = jax.tree_flatten(reference)
|
||||||
|
leaves, _ = jax.tree_flatten(target)
|
||||||
|
return jax.tree_unflatten(pytree, leaves)
|
||||||
|
|
||||||
|
|
||||||
|
def func_if_not_scalar(func):
|
||||||
|
"""Makes a function that uses func only on non-scalar values."""
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapped(array, axis=0):
|
||||||
|
if array.ndim == 0:
|
||||||
|
return array
|
||||||
|
return func(array, axis=axis)
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
mean_if_not_scalar = func_if_not_scalar(jnp.mean)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiBatchAccumulator(object):
|
||||||
|
"""Class for abstracting statistics accumulation over multiple batches."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._obj = None
|
||||||
|
self._obj_max = None
|
||||||
|
self._obj_min = None
|
||||||
|
self._num_samples = None
|
||||||
|
|
||||||
|
def add(self, averaged_values, num_samples):
|
||||||
|
"""Adds an element to the moving average and the max."""
|
||||||
|
if self._obj is None:
|
||||||
|
self._obj_max = jax.tree_map(lambda y: y * 1.0, averaged_values)
|
||||||
|
self._obj_min = jax.tree_map(lambda y: y * 1.0, averaged_values)
|
||||||
|
self._obj = jax.tree_map(lambda y: y * num_samples, averaged_values)
|
||||||
|
self._num_samples = num_samples
|
||||||
|
else:
|
||||||
|
self._obj_max = jax.tree_multimap(jnp.maximum, self._obj_max,
|
||||||
|
averaged_values)
|
||||||
|
self._obj_min = jax.tree_multimap(jnp.minimum, self._obj_min,
|
||||||
|
averaged_values)
|
||||||
|
self._obj = jax.tree_multimap(lambda x, y: x + y * num_samples, self._obj,
|
||||||
|
averaged_values)
|
||||||
|
self._num_samples += num_samples
|
||||||
|
|
||||||
|
def value(self):
|
||||||
|
return jax.tree_map(lambda x: x / self._num_samples, self._obj)
|
||||||
|
|
||||||
|
def max(self):
|
||||||
|
return jax.tree_map(float, self._obj_max)
|
||||||
|
|
||||||
|
def min(self):
|
||||||
|
return jax.tree_map(float, self._obj_min)
|
||||||
|
|
||||||
|
def sum(self):
|
||||||
|
return self._obj
|
||||||
|
|
||||||
|
|
||||||
|
register_pytree_node(
|
||||||
|
distrax.Normal,
|
||||||
|
lambda instance: ([instance.loc, instance.scale], None),
|
||||||
|
lambda _, args: distrax.Normal(*args)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def inner_product(x: Any, y: Any) -> jnp.ndarray:
|
||||||
|
products = jax.tree_multimap(lambda x_, y_: jnp.sum(x_ * y_), x, y)
|
||||||
|
return sum(jax.tree_leaves(products))
|
||||||
|
|
||||||
|
|
||||||
|
get_first = utils.get_first
|
||||||
|
bcast_local_devices = utils.bcast_local_devices
|
||||||
|
py_prefetch = utils.py_prefetch
|
||||||
|
p_split = jax.pmap(lambda x, num: list(jax.random.split(x, num)),
|
||||||
|
static_broadcasted_argnums=1)
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_if_pmap(p_func):
|
||||||
|
def p_func_if_pmap(obj, axis_name):
|
||||||
|
try:
|
||||||
|
core.axis_frame(axis_name)
|
||||||
|
return p_func(obj, axis_name)
|
||||||
|
except NameError:
|
||||||
|
return obj
|
||||||
|
return p_func_if_pmap
|
||||||
|
|
||||||
|
|
||||||
|
pmean_if_pmap = wrap_if_pmap(lax.pmean)
|
||||||
|
psum_if_pmap = wrap_if_pmap(lax.psum)
|
||||||
Reference in New Issue
Block a user