mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-01 21:56:38 +08:00
Open sourcing the physics inspired models code.
PiperOrigin-RevId: 408640606
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
9b751b7d20
commit
2c7c401024
@@ -0,0 +1,59 @@
|
||||
# Implementation of multiple physics inspired models for modelling dynamics
|
||||
|
||||
This repository contains an implementation of different physics inspired models
|
||||
used in the papers: **SyMetric: Measuring the Quality of Learnt Hamiltonian
|
||||
Dynamics Inferred from Vision** and **Which priors matter? Benchmarking models
|
||||
for learning latent dynamics**.
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
This is purely research code, provided with no further intentions of support or
|
||||
any guarantees of backward compatibility.
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
All package requirements are listed in `requirements.txt`.
|
||||
You will still need to download and setup the datasets from the
|
||||
[DeepMind Hamiltonian Dynamics Suite] manually.
|
||||
|
||||
```shell
|
||||
git clone git@github.com:deepmind/deepmind-research.git
|
||||
pip install -r ./deepmind_research/physics_inspired_models/requirements.txt
|
||||
pip install ./deepmind_research/physics_inspired_models
|
||||
pip install --upgrade "jax[XXX]"
|
||||
```
|
||||
|
||||
where `XXX` is the correct type of accelerator that you have on your machine.
|
||||
Note that if you are using a GPU you might need `XXX` to also include the
|
||||
correct version of CUDA and cuDNN installed on your machine.
|
||||
For more details please read [here](https://github.com/google/jax#installation).
|
||||
|
||||
## Usage
|
||||
|
||||
The file `jaxline_configs.py` contains all the configurations specifications for
|
||||
the experiments in the two papers. To run an experiment, in addition to passing
|
||||
the location of the configs file, you must provide extra arguments in the
|
||||
following manner:
|
||||
|
||||
`${name_of_configuration},${index_in_sweep},${dataset_name}`
|
||||
|
||||
For example to run the second hyper-parameter configuration of the improved
|
||||
Hamiltonian Generative Network (HGN++) on the mass-spring dataset you should
|
||||
run in the command line (assuming that you are in the folder of the project):
|
||||
|
||||
```shell
|
||||
python3 jaxline_train.py \
|
||||
--config="jaxline_configs.py:sym_metric_hgn_plus_plus_sweep,1,toy_physics/mass_spring" \
|
||||
--jaxline_mode="train" \
|
||||
--logtostderr
|
||||
```
|
||||
|
||||
|
||||
## Reference
|
||||
**SyMetric: Measuring the Quality of Learnt Hamiltonian Dynamics Inferred from Vision**
|
||||
|
||||
**Which priors matter? Benchmarking models for learning latent dynamics**
|
||||
|
||||
[DeepMind Hamiltonian Dynamics Suite]: https://github.com/deepmind/dm_hamiltonian_dynamics_suite
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,353 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module containing model evaluation metric."""
|
||||
import _thread as thread
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from absl import logging
|
||||
import distrax
|
||||
|
||||
import numpy as np
|
||||
from sklearn import linear_model
|
||||
from sklearn import model_selection
|
||||
from sklearn import preprocessing
|
||||
|
||||
|
||||
def quit_function(fn_name):
|
||||
logging.error('%s took too long', fn_name)
|
||||
sys.stderr.flush()
|
||||
thread.interrupt_main()
|
||||
|
||||
|
||||
def exit_after(s):
|
||||
"""Use as decorator to exit function after s seconds."""
|
||||
def outer(fn):
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
timer = threading.Timer(s, quit_function, args=[fn.__name__])
|
||||
timer.start()
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
finally:
|
||||
timer.cancel()
|
||||
return result
|
||||
|
||||
return inner
|
||||
|
||||
return outer
|
||||
|
||||
|
||||
@exit_after(400)
|
||||
def do_grid_search(data_x_exp, data_y, clf, parameters, cv):
|
||||
scoring_choice = 'explained_variance'
|
||||
regressor = model_selection.GridSearchCV(
|
||||
clf, parameters, cv=cv, refit=True, scoring=scoring_choice)
|
||||
regressor.fit(data_x_exp, data_y)
|
||||
return regressor
|
||||
|
||||
|
||||
def symplectic_matrix(dim):
|
||||
"""Return anti-symmetric identity matrix of given dimensionality."""
|
||||
half_dims = int(dim/2)
|
||||
eye = np.eye(half_dims)
|
||||
zeros = np.zeros([half_dims, half_dims])
|
||||
top_rows = np.concatenate([zeros, - eye], axis=1)
|
||||
bottom_rows = np.concatenate([eye, zeros], axis=1)
|
||||
return np.concatenate([top_rows, bottom_rows], axis=0)
|
||||
|
||||
|
||||
def create_latent_mask(z0, dist_std_threshold=0.5):
|
||||
"""Create mask based on informativeness of each latent dimension.
|
||||
|
||||
For stochastic models those latent dimensions that are too close to the prior
|
||||
are likely to be uninformative and can be ignored.
|
||||
|
||||
Args:
|
||||
z0: distribution or array of phase space
|
||||
dist_std_threshold: informative latents have average inferred stds <
|
||||
dist_std_threshold
|
||||
|
||||
Returns:
|
||||
latent_mask_final: boolean mask of the same dimensionality as z0
|
||||
"""
|
||||
if isinstance(z0, distrax.Normal):
|
||||
std_vals = np.mean(z0.variance(), axis=0)
|
||||
elif isinstance(z0, distrax.Distribution):
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
# If the latent is deterministic, pass through all dimensions
|
||||
return np.array([True]*z0.shape[-1])
|
||||
|
||||
tensor_shape = std_vals.shape
|
||||
half_dims = int(tensor_shape[-1] / 2)
|
||||
|
||||
std_vals_q = std_vals[:half_dims]
|
||||
std_vals_p = std_vals[half_dims:]
|
||||
|
||||
# Keep both q and corresponding p as either one is informative
|
||||
informative_latents_inds = np.array([
|
||||
x for x in range(len(std_vals_q)) if
|
||||
std_vals_q[x] < dist_std_threshold or std_vals_p[x] < dist_std_threshold
|
||||
])
|
||||
|
||||
if informative_latents_inds.shape[0] > 0:
|
||||
latent_mask_final = np.zeros_like(std_vals_q)
|
||||
latent_mask_final[informative_latents_inds] = 1
|
||||
latent_mask_final = np.concatenate([latent_mask_final, latent_mask_final])
|
||||
latent_mask_final = latent_mask_final == 1
|
||||
|
||||
return latent_mask_final
|
||||
else:
|
||||
return np.array([True]*tensor_shape[-1])
|
||||
|
||||
|
||||
def standardize_data(data):
|
||||
"""Applies the sklearn standardization to the data."""
|
||||
scaler = preprocessing.StandardScaler()
|
||||
scaler.fit(data)
|
||||
return scaler.transform(data)
|
||||
|
||||
|
||||
def find_best_polynomial(data_x, data_y, max_poly_order, rsq_threshold,
|
||||
max_dim_n=32,
|
||||
alpha_sweep=None,
|
||||
max_iter=1000, cv=2):
|
||||
"""Find minimal polynomial expansion that is sufficient to explain data using Lasso regression."""
|
||||
rsq = 0
|
||||
poly_order = 1
|
||||
|
||||
if not np.any(alpha_sweep):
|
||||
alpha_sweep = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2]
|
||||
|
||||
# Avoid a large polynomial expansion for large latent sizes
|
||||
if data_x.shape[-1] > max_dim_n:
|
||||
print(f'>WARNING! Data is too high dimensional at {data_x.shape[-1]}')
|
||||
print('>WARNING! Setting max_poly_order = 1')
|
||||
max_poly_order = 1
|
||||
|
||||
while rsq < rsq_threshold and poly_order <= max_poly_order:
|
||||
time_start = time.perf_counter()
|
||||
poly = preprocessing.PolynomialFeatures(poly_order, include_bias=False)
|
||||
data_x_exp = poly.fit_transform(data_x)
|
||||
time_end = time.perf_counter()
|
||||
print(
|
||||
f'Took {time_end-time_start}s to create polynomial features of order '
|
||||
f'{poly_order} and size {data_x_exp.shape[1]}.')
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
time_start = time.perf_counter()
|
||||
clf = linear_model.Lasso(
|
||||
random_state=0, max_iter=max_iter, normalize=False, warm_start=False)
|
||||
parameters = {'alpha': alpha_sweep}
|
||||
try:
|
||||
regressor = do_grid_search(data_x_exp, data_y, clf, parameters, cv)
|
||||
time_end = time.perf_counter()
|
||||
print(f'Took {time_end-time_start}s to do regression grid search.')
|
||||
|
||||
# Get rsq results
|
||||
time_start = time.perf_counter()
|
||||
clf = linear_model.Lasso(
|
||||
random_state=0,
|
||||
alpha=regressor.best_params_['alpha'],
|
||||
max_iter=max_iter,
|
||||
normalize=False,
|
||||
warm_start=False)
|
||||
clf.fit(data_x_exp, data_y)
|
||||
rsq = clf.score(data_x_exp, data_y)
|
||||
time_end = time.perf_counter()
|
||||
print(f'Took {time_end-time_start}s to get rsq results.')
|
||||
|
||||
old_regressor = regressor
|
||||
old_poly_order = poly_order
|
||||
old_poly = poly
|
||||
old_data_x_exp = data_x_exp
|
||||
old_rsq = rsq
|
||||
old_clf = clf
|
||||
print(f'Polynomial of order {poly_order} with '
|
||||
f' alpha={regressor.best_params_} RSQ: {rsq}')
|
||||
poly_order += 1
|
||||
|
||||
except KeyboardInterrupt:
|
||||
time_end = time.perf_counter()
|
||||
print(f'Timed out after {time_end-time_start}s of doing grid search.')
|
||||
print(f'Continuing with previous poly_order={old_poly_order}...')
|
||||
regressor = old_regressor
|
||||
poly_order = old_poly_order
|
||||
poly = old_poly
|
||||
data_x_exp = old_data_x_exp
|
||||
rsq = old_rsq
|
||||
clf = old_clf
|
||||
print(f'Polynomial of order {poly_order} with '
|
||||
f' alpha={regressor.best_params_} RSQ: {rsq}')
|
||||
break
|
||||
|
||||
return clf, poly, data_x_exp, rsq
|
||||
|
||||
|
||||
def eval_monomial_grad(feature, x, w, grad_acc):
|
||||
"""Accumulates gradient from polynomial features and their weights."""
|
||||
features = feature.split(' ')
|
||||
variable_indices = []
|
||||
grads = np.ones(len(features)) * w
|
||||
for i, feature in enumerate(features):
|
||||
name_and_power = feature.split('^')
|
||||
if len(name_and_power) == 1:
|
||||
name, power = name_and_power[0], 1
|
||||
else:
|
||||
name, power = name_and_power
|
||||
power = int(power)
|
||||
var_index = int(name[1:])
|
||||
variable_indices.append(var_index)
|
||||
new_prod = np.ones_like(grads) * (x[var_index] ** power)
|
||||
# This needs a special case, for situation where x[index] = 0.0
|
||||
if power == 1:
|
||||
new_prod[i] = 1.0
|
||||
else:
|
||||
new_prod[i] = power * (x[var_index] ** (power - 1))
|
||||
grads = grads * new_prod
|
||||
grad_acc[variable_indices] += grads
|
||||
return grad_acc
|
||||
|
||||
|
||||
def compute_jacobian_manual(x, polynomial_features, weight_matrix, tolerance):
|
||||
"""Computes the jacobian manually."""
|
||||
# Put together the equation for each output var
|
||||
# polynomial_features = np.array(polynomial_obj.get_feature_names())
|
||||
weight_mask = np.abs(weight_matrix) > tolerance
|
||||
weight_matrix = weight_mask * weight_matrix
|
||||
jacobians = list()
|
||||
for i in range(weight_matrix.shape[0]):
|
||||
grad_accumulator = np.zeros_like(x)
|
||||
for j, feature in enumerate(polynomial_features):
|
||||
eval_monomial_grad(feature, x, weight_matrix[i, j], grad_accumulator)
|
||||
jacobians.append(grad_accumulator)
|
||||
return np.stack(jacobians)
|
||||
|
||||
|
||||
def calculate_jacobian_prod(jacobian, noise_eps=1e-6):
|
||||
"""Calculates AA*, where A=JEJ^T and A*=JE^TJ^T, which should be I."""
|
||||
# Add noise as 0 in jacobian creates issues in calculations later
|
||||
jacobian = jacobian + noise_eps
|
||||
sym_matrix = symplectic_matrix(jacobian.shape[1])
|
||||
pred = np.matmul(jacobian, sym_matrix)
|
||||
pred = np.matmul(pred, np.transpose(jacobian))
|
||||
|
||||
pred_t = np.matmul(jacobian, np.transpose(sym_matrix))
|
||||
pred_t = np.matmul(pred_t, np.transpose(jacobian))
|
||||
|
||||
pred_id = np.matmul(pred, pred_t)
|
||||
|
||||
return pred_id
|
||||
|
||||
|
||||
def normalise_jacobian_prods(jacobian_preds):
|
||||
"""Normalises Jacobians evaluated at various points by a constant."""
|
||||
stacked_preds = np.stack(jacobian_preds)
|
||||
# For each attempt at estimating E, get the max term, and take their average
|
||||
normalisation_factor = np.mean(np.max(np.abs(stacked_preds), axis=(1, 2)))
|
||||
|
||||
if normalisation_factor != 0:
|
||||
stacked_preds = stacked_preds/normalisation_factor
|
||||
|
||||
return stacked_preds
|
||||
|
||||
|
||||
def calculate_symetric_score(
|
||||
gt_data,
|
||||
model_data,
|
||||
max_poly_order,
|
||||
max_sym_score,
|
||||
rsq_threshold,
|
||||
sym_threshold,
|
||||
evaluation_point_n,
|
||||
trajectory_n=1,
|
||||
weight_tolerance=1e-5,
|
||||
alpha_sweep=None,
|
||||
max_iter=1000,
|
||||
cv=2):
|
||||
"""Finds minimal polynomial expansion to explain data using Lasso regression, gets the Jacobian of the mapping and calculates how symplectic the map is."""
|
||||
model_data = model_data[..., :gt_data.shape[0], :]
|
||||
|
||||
# Fing polynomial expansion that explains enough variance in the gt data
|
||||
print('Finding best polynomial expansion...')
|
||||
time_start = time.perf_counter()
|
||||
# Clean up model data to ensure it doesn't contain NaN, infinity
|
||||
# or values too large for dtype('float32')
|
||||
model_data = np.nan_to_num(model_data)
|
||||
model_data = np.clip(model_data, -999999, 999999)
|
||||
|
||||
clf, poly, model_data_exp, best_rsq = find_best_polynomial(
|
||||
model_data, gt_data, max_poly_order, rsq_threshold,
|
||||
32, alpha_sweep, max_iter, cv)
|
||||
time_end = time.perf_counter()
|
||||
print(f'Took {time_end - time_start}s to find best polynomial.')
|
||||
|
||||
# Calculate Symplecticity score
|
||||
all_raw_scores = []
|
||||
features = np.array(poly.get_feature_names())
|
||||
|
||||
points_per_trajectory = int(len(gt_data) / trajectory_n)
|
||||
for trajectory in range(trajectory_n):
|
||||
random_data_inds = np.random.permutation(
|
||||
range(points_per_trajectory))[:evaluation_point_n]
|
||||
|
||||
jacobian_preds = []
|
||||
for point_ind in random_data_inds:
|
||||
input_data_point = model_data[points_per_trajectory * trajectory +
|
||||
point_ind]
|
||||
time_start = time.perf_counter()
|
||||
jacobian = compute_jacobian_manual(input_data_point, features,
|
||||
clf.coef_, weight_tolerance)
|
||||
pred = calculate_jacobian_prod(jacobian)
|
||||
jacobian_preds.append(pred)
|
||||
time_end = time.perf_counter()
|
||||
print(f'Took {time_end - time_start}s to evaluate jacobian '
|
||||
f'around point {point_ind}.')
|
||||
|
||||
# Normalise
|
||||
normalised_jacobian_preds = normalise_jacobian_prods(jacobian_preds)
|
||||
# The score is measured as the deviation from I
|
||||
identity = np.eye(normalised_jacobian_preds.shape[-1])
|
||||
scores = np.mean(np.power(normalised_jacobian_preds - identity, 2),
|
||||
axis=(1, 2))
|
||||
all_raw_scores.append(scores)
|
||||
|
||||
sym_score = np.min([np.mean(all_raw_scores), max_sym_score])
|
||||
# Calculate final SyMetric score
|
||||
if best_rsq > rsq_threshold and sym_score < sym_threshold:
|
||||
sy_metric = 1.0
|
||||
else:
|
||||
sy_metric = 0.0
|
||||
|
||||
results = {
|
||||
'poly_exp_order': poly.get_params()['degree'],
|
||||
'rsq': best_rsq,
|
||||
'sym': sym_score,
|
||||
'SyMetric': sy_metric,
|
||||
}
|
||||
with np.printoptions(precision=4, suppress=True):
|
||||
print(f'----------------FINAL RESULTS FOR {trajectory_n} '
|
||||
'TRAJECTORIES------------------')
|
||||
print(f'BEST POLYNOMIAL EXPANSION ORDER: {results["poly_exp_order"]}')
|
||||
print(f'BEST RSQ (1-best): {results["rsq"]}')
|
||||
print(f'SYMPLECTICITY SCORE AROUND ALL POINTS AND ALL '
|
||||
f'TRAJECTORIES (0-best): {sym_score}')
|
||||
print(f'SyMETRIC SCORE: {sy_metric}')
|
||||
print(f'----------------FINAL RESULTS FOR {trajectory_n} '
|
||||
f'TRAJECTORIES------------------')
|
||||
return results, clf, poly, model_data_exp
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,397 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module containing all of the configurations for various models."""
|
||||
import copy
|
||||
import os
|
||||
from jaxline import base_config
|
||||
import ml_collections as collections
|
||||
|
||||
_DATASETS_PATH_VAR_NAME = "DM_HAMILTONIAN_DYNAMICS_SUITE_DATASETS"
|
||||
|
||||
|
||||
def get_config(arg_string):
|
||||
"""Return config object for training."""
|
||||
args = arg_string.split(",")
|
||||
if len(args) != 3:
|
||||
raise ValueError("You must provide exactly three arguments separated by a "
|
||||
"comma - model_config_name,sweep_index,dataset_name.")
|
||||
model_config_name, sweep_index, dataset_name = args
|
||||
sweep_index = int(sweep_index)
|
||||
|
||||
config = base_config.get_base_config()
|
||||
config.random_seed = 123109801
|
||||
config.eval_modes = ("eval", "eval_metric")
|
||||
|
||||
# Get the model config and the sweeps
|
||||
if model_config_name not in globals():
|
||||
raise ValueError(f"The config name {model_config_name} does not exist in "
|
||||
f"jaxline_configs.py")
|
||||
config_and_sweep_fn = globals()[model_config_name]
|
||||
model_config, sweeps = config_and_sweep_fn()
|
||||
|
||||
if not os.environ.get(_DATASETS_PATH_VAR_NAME, None):
|
||||
raise ValueError(f"You need to set the {_DATASETS_PATH_VAR_NAME}")
|
||||
dm_hamiltonian_suite_path = os.environ[_DATASETS_PATH_VAR_NAME]
|
||||
dataset_folder = os.path.join(dm_hamiltonian_suite_path, dataset_name)
|
||||
|
||||
# Experiment config. Note that batch_size is per device.
|
||||
# In the experiments we run on 4 GPUs, so the effective batch size was 128.
|
||||
config.experiment_kwargs = collections.ConfigDict(
|
||||
dict(
|
||||
config=dict(
|
||||
dataset_folder=dataset_folder,
|
||||
model_kwargs=model_config,
|
||||
num_extrapolation_steps=60,
|
||||
drop_stats_containing=("neg_log_p_x", "l2_over_time", "neg_elbo"),
|
||||
optimizer=dict(
|
||||
name="adam",
|
||||
kwargs=dict(
|
||||
learning_rate=1.5e-4,
|
||||
b1=0.9,
|
||||
b2=0.999,
|
||||
)
|
||||
),
|
||||
training=dict(
|
||||
batch_size=32,
|
||||
burnin_steps=5,
|
||||
num_epochs=None,
|
||||
lagging_vae=False
|
||||
),
|
||||
evaluation=dict(
|
||||
batch_size=64,
|
||||
),
|
||||
evaluation_metric=dict(
|
||||
batch_size=5,
|
||||
batch_n=20,
|
||||
num_eval_metric_steps=60,
|
||||
max_poly_order=5,
|
||||
max_jacobian_score=1000,
|
||||
rsq_threshold=0.9,
|
||||
sym_threshold=0.05,
|
||||
evaluation_point_n=10,
|
||||
weight_tolerance=1e-03,
|
||||
max_iter=1000,
|
||||
cv=2,
|
||||
alpha_min_logspace=-4,
|
||||
alpha_max_logspace=-0.5,
|
||||
alpha_step_n=10,
|
||||
calculate_fully_after_steps=40000,
|
||||
),
|
||||
evaluation_metric_mlp=dict(
|
||||
batch_size=64,
|
||||
batch_n=10000,
|
||||
datapoint_param_multiplier=1000,
|
||||
num_eval_metric_steps=60,
|
||||
evaluation_point_n=10,
|
||||
evaluation_trajectory_n=50,
|
||||
rsq_threshold=0.9,
|
||||
sym_threshold=0.05,
|
||||
ridge_lambda=0.01,
|
||||
model=dict(
|
||||
num_units=4,
|
||||
num_layers=4,
|
||||
activation="tanh",
|
||||
),
|
||||
optimizer=dict(
|
||||
name="adam",
|
||||
kwargs=dict(
|
||||
learning_rate=1.5e-3,
|
||||
)
|
||||
),
|
||||
),
|
||||
evaluation_vpt=dict(
|
||||
batch_size=5,
|
||||
batch_n=2,
|
||||
vpt_threshold=0.025,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Training loop config.
|
||||
config.training_steps = int(500000)
|
||||
config.interval_type = "steps"
|
||||
config.log_tensors_interval = 50
|
||||
config.log_train_data_interval = 50
|
||||
config.log_all_train_data = False
|
||||
|
||||
config.save_checkpoint_interval = 100
|
||||
config.checkpoint_dir = "/tmp/physics_inspired_models/"
|
||||
config.train_checkpoint_all_hosts = False
|
||||
config.eval_specific_checkpoint_dir = ""
|
||||
|
||||
config.update_from_flattened_dict(sweeps[sweep_index])
|
||||
return config
|
||||
|
||||
|
||||
config_prefix = "experiment_kwargs.config."
|
||||
model_prefix = config_prefix + "model_kwargs."
|
||||
|
||||
default_encoder_kwargs = collections.ConfigDict(dict(
|
||||
conv_channels=64,
|
||||
num_blocks=3,
|
||||
blocks_depth=2,
|
||||
activation="leaky_relu",
|
||||
))
|
||||
|
||||
default_decoder_kwargs = collections.ConfigDict(dict(
|
||||
conv_channels=64,
|
||||
num_blocks=3,
|
||||
blocks_depth=2,
|
||||
activation="leaky_relu",
|
||||
))
|
||||
|
||||
default_latent_system_net_kwargs = collections.ConfigDict(dict(
|
||||
conv_channels=64,
|
||||
num_units=250,
|
||||
num_layers=5,
|
||||
activation="swish",
|
||||
))
|
||||
|
||||
|
||||
default_latent_system_kwargs = collections.ConfigDict(dict(
|
||||
# Physics model arguments
|
||||
input_space=collections.config_dict.placeholder(str),
|
||||
simulation_space=collections.config_dict.placeholder(str),
|
||||
potential_func_form="separable_net",
|
||||
kinetic_func_form=collections.config_dict.placeholder(str),
|
||||
hgn_kinetic_func_form="separable_net",
|
||||
lgn_kinetic_func_form="matrix_dep_quad",
|
||||
parametrize_mass_matrix=collections.config_dict.placeholder(bool),
|
||||
hgn_parametrize_mass_matrix=False,
|
||||
lgn_parametrize_mass_matrix=True,
|
||||
mass_eps=1.0,
|
||||
# ODE model arguments
|
||||
integrator_method=collections.config_dict.placeholder(str),
|
||||
# RGN model arguments
|
||||
residual=collections.config_dict.placeholder(bool),
|
||||
# General arguments
|
||||
net_kwargs=default_latent_system_net_kwargs
|
||||
))
|
||||
|
||||
default_config_dict = collections.ConfigDict(dict(
|
||||
name=collections.config_dict.placeholder(str),
|
||||
latent_system_dim=32,
|
||||
latent_system_net_type="mlp",
|
||||
latent_system_kwargs=default_latent_system_kwargs,
|
||||
encoder_aggregation_type="linear_projection",
|
||||
decoder_de_aggregation_type=collections.config_dict.placeholder(str),
|
||||
encoder_kwargs=default_encoder_kwargs,
|
||||
decoder_kwargs=default_decoder_kwargs,
|
||||
has_latent_transform=False,
|
||||
num_inference_steps=5,
|
||||
num_target_steps=60,
|
||||
latent_training_type="forward",
|
||||
# Choices: overlap_by_one, no_overlap, include_inference
|
||||
training_data_split="overlap_by_one",
|
||||
objective_type="ELBO",
|
||||
elbo_beta_delay=0,
|
||||
elbo_beta_final=1.0,
|
||||
geco_kappa=0.001,
|
||||
geco_alpha=0.0,
|
||||
dt=0.125,
|
||||
))
|
||||
|
||||
hgn_paper_encoder_kwargs = collections.ConfigDict(dict(
|
||||
conv_channels=[[32, 64], [64, 64], [64]],
|
||||
num_blocks=3,
|
||||
blocks_depth=2,
|
||||
activation="relu",
|
||||
kernel_shapes=[2, 4],
|
||||
padding=["VALID", "SAME"],
|
||||
))
|
||||
|
||||
hgn_paper_decoder_kwargs = collections.ConfigDict(dict(
|
||||
conv_channels=64,
|
||||
num_blocks=3,
|
||||
blocks_depth=2,
|
||||
activation="tf_leaky_relu",
|
||||
))
|
||||
|
||||
hgn_paper_latent_net_kwargs = collections.ConfigDict(dict(
|
||||
conv_channels=[32, 64, 64, 64],
|
||||
num_units=250,
|
||||
num_layers=5,
|
||||
activation="softplus",
|
||||
kernel_shapes=[3, 2, 2, 2, 2],
|
||||
strides=[1, 2, 1, 2, 1],
|
||||
padding=["SAME", "VALID", "SAME", "VALID", "SAME"]
|
||||
))
|
||||
|
||||
hgn_paper_latent_system_kwargs = collections.ConfigDict(dict(
|
||||
potential_func_form="separable_net",
|
||||
kinetic_func_form="separable_net",
|
||||
parametrize_mass_matrix=False,
|
||||
net_kwargs=hgn_paper_latent_net_kwargs
|
||||
))
|
||||
|
||||
hgn_paper_latent_transform_kwargs = collections.ConfigDict(dict(
|
||||
num_layers=5,
|
||||
conv_channels=64,
|
||||
num_units=64,
|
||||
activation="relu",
|
||||
))
|
||||
|
||||
hgn_paper_config = copy.deepcopy(default_config_dict)
|
||||
hgn_paper_config.training_data_split = "include_inference"
|
||||
hgn_paper_config.latent_system_net_type = "conv"
|
||||
hgn_paper_config.encoder_aggregation_type = (collections.config_dict.
|
||||
placeholder(str))
|
||||
hgn_paper_config.decoder_de_aggregation_type = (collections.config_dict.
|
||||
placeholder(str))
|
||||
hgn_paper_config.latent_system_kwargs = hgn_paper_latent_system_kwargs
|
||||
hgn_paper_config.encoder_kwargs = hgn_paper_encoder_kwargs
|
||||
hgn_paper_config.decoder_kwargs = hgn_paper_decoder_kwargs
|
||||
hgn_paper_config.has_latent_transform = True
|
||||
hgn_paper_config.latent_transform_kwargs = hgn_paper_latent_transform_kwargs
|
||||
hgn_paper_config.num_inference_steps = 31
|
||||
hgn_paper_config.num_target_steps = 0
|
||||
hgn_paper_config.objective_type = "GECO"
|
||||
|
||||
|
||||
forward_overlap_by_one = {
|
||||
model_prefix + "latent_training_type": "forward",
|
||||
model_prefix + "training_data_split": "overlap_by_one",
|
||||
}
|
||||
|
||||
forward_backward_include_inference = {
|
||||
model_prefix + "latent_training_type": "forward_backward",
|
||||
model_prefix + "training_data_split": "include_inference",
|
||||
}
|
||||
|
||||
latent_training_sweep = [
|
||||
forward_overlap_by_one,
|
||||
forward_backward_include_inference,
|
||||
]
|
||||
|
||||
|
||||
def sym_metric_hgn_plus_plus_sweep():
|
||||
"""HGN++ experimental sweep for the SyMetric paper."""
|
||||
model_config = copy.deepcopy(default_config_dict)
|
||||
model_config.name = "HGN"
|
||||
sweeps = list()
|
||||
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||
sweeps.append({
|
||||
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||
model_prefix + "latent_training_type": "forward",
|
||||
model_prefix + "training_data_split": "overlap_by_one",
|
||||
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||
})
|
||||
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||
sweeps.append({
|
||||
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||
model_prefix + "latent_training_type": "forward_backward",
|
||||
model_prefix + "training_data_split": "include_inference",
|
||||
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||
})
|
||||
|
||||
return model_config, sweeps
|
||||
|
||||
|
||||
def sym_metric_hgn_sweep():
|
||||
"""HGN experimental sweep for the SyMetric paper."""
|
||||
model_config = copy.deepcopy(hgn_paper_config)
|
||||
model_config.name = "HGN"
|
||||
return model_config, list(dict())
|
||||
|
||||
|
||||
def benchmark_hgn_overlap_sweep():
|
||||
"""HGN++ sweep for the benchmark paper."""
|
||||
model_config = copy.deepcopy(default_config_dict)
|
||||
model_config.name = "HGN"
|
||||
|
||||
sweeps = list()
|
||||
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||
for train_dict in latent_training_sweep:
|
||||
sweeps.append({
|
||||
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||
})
|
||||
sweeps[-1].update(train_dict)
|
||||
|
||||
return model_config, sweeps
|
||||
|
||||
|
||||
def benchmark_lgn_sweep():
|
||||
"""LGN sweep for the benchmark paper."""
|
||||
model_config = copy.deepcopy(default_config_dict)
|
||||
model_config.name = "LGN"
|
||||
|
||||
sweeps = list()
|
||||
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||
for train_dict in latent_training_sweep:
|
||||
sweeps.append({
|
||||
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||
model_prefix + "latent_system_kwargs.kinetic_func_form":
|
||||
"matrix_dep_pure_quad",
|
||||
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||
})
|
||||
sweeps[-1].update(train_dict)
|
||||
|
||||
return model_config, sweeps
|
||||
|
||||
|
||||
def benchmark_ode_sweep():
|
||||
"""Neural ODE sweep for the benchmark paper."""
|
||||
model_config = copy.deepcopy(default_config_dict)
|
||||
model_config.name = "ODE"
|
||||
|
||||
sweeps = list()
|
||||
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||
for integrator in ("adaptive", "rk2"):
|
||||
for train_dict in latent_training_sweep:
|
||||
sweeps.append({
|
||||
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||
model_prefix + "integrator_method": integrator,
|
||||
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||
})
|
||||
sweeps[-1].update(train_dict)
|
||||
|
||||
return model_config, sweeps
|
||||
|
||||
|
||||
def benchmark_rgn_sweep():
|
||||
"""RGN sweep for the benchmark paper."""
|
||||
model_config = copy.deepcopy(default_config_dict)
|
||||
model_config.name = "RGN"
|
||||
|
||||
sweeps = list()
|
||||
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||
for residual in (True, False):
|
||||
sweeps.append({
|
||||
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||
model_prefix + "latent_system_kwargs.residual": residual,
|
||||
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||
})
|
||||
|
||||
return model_config, sweeps
|
||||
|
||||
|
||||
def benchmark_ar_sweep():
|
||||
"""AR sweep for the benchmark paper."""
|
||||
model_config = copy.deepcopy(default_config_dict)
|
||||
model_config.name = "AR"
|
||||
model_config.latent_dynamics_type = "vanilla"
|
||||
|
||||
sweeps = list()
|
||||
for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]:
|
||||
for ar_type in ("vanilla", "lstm", "gru"):
|
||||
sweeps.append({
|
||||
config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4,
|
||||
model_prefix + "latent_dynamics_type": ar_type,
|
||||
model_prefix + "elbo_beta_final": elbo_beta_final,
|
||||
})
|
||||
|
||||
return model_config, sweeps
|
||||
File diff suppressed because it is too large
Load Diff
Executable
+52
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Script to execute a single configuration on all datasets.
|
||||
if [[ "$#" -eq 2 ]]; then
|
||||
readonly CONFIG_NAME="$1"
|
||||
readonly NUM_SWEEPS="$2"
|
||||
else
|
||||
echo "You must provide exactly two arguments - the configuration name and " \
|
||||
"how many sweeps it contains. For example:"
|
||||
echo "./launch_all.sh sym_metric_hgn_plus_plus_sweep 1"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
DATASETS=(
|
||||
"toy_physics/mass_spring"
|
||||
"toy_physics/mass_spring_colors"
|
||||
"toy_physics/mass_spring_colors_friction"
|
||||
"toy_physics/pendulum"
|
||||
"toy_physics/pendulum_colors"
|
||||
"toy_physics/pendulum_colors_friction"
|
||||
"toy_physics/two_body"
|
||||
"toy_physics/two_body_colors"
|
||||
"toy_physics/double_pendulum"
|
||||
"toy_physics/double_pendulum_colors"
|
||||
"toy_physics/double_pendulum_colors_friction"
|
||||
"molecular_dynamics/lj_4"
|
||||
"molecular_dynamics/lj_16"
|
||||
"multi_agent/rock_paper_scissors"
|
||||
"multi_agent/matching_pennies"
|
||||
"mujoco_room/circle"
|
||||
"mujoco_room/spiral"
|
||||
)
|
||||
|
||||
readonly DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||
|
||||
for dataset in "${DATASETS[@]}"; do
|
||||
"${DIR}/launch_local.sh" "${CONFIG_NAME}" "${NUM_SWEEPS}" "${dataset}"
|
||||
done
|
||||
Executable
+40
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# A script to execute a single configuration name on a given dataset.
|
||||
if [[ "$#" -eq 3 ]]; then
|
||||
readonly CONFIG_NAME="$1"
|
||||
readonly NUM_SWEEPS="$2"
|
||||
readonly DATASET="$3"
|
||||
else
|
||||
echo "You must provide exactly three arguments - the configuration name, " \
|
||||
"the number of sweeps it contains and the dataset name. For example:"
|
||||
echo "./launch_local.sh sym_metric_hgn_plus_plus_sweep 1 " \
|
||||
"toy_physics/mass_spring"
|
||||
exit 2
|
||||
fi
|
||||
echo "Running with config ${CONFIG_NAME} on ${DATASET}."
|
||||
|
||||
readonly EXPERIMENT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
readonly TRAIN_FILE="${EXPERIMENT_DIR}/jaxline_train.py"
|
||||
readonly CONFIG_FILE="${EXPERIMENT_DIR}/jaxline_configs.py"
|
||||
|
||||
for sweep_id in $(seq 0 $((NUM_SWEEPS - 1))); do
|
||||
python3 "${TRAIN_FILE}" \
|
||||
--config="${CONFIG_FILE}:${CONFIG_NAME},${sweep_id},${DATASET}" \
|
||||
--jaxline_mode="train" \
|
||||
--logtostderr
|
||||
done
|
||||
@@ -0,0 +1,247 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module containing code for computing various metrics for training and evaluation."""
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import distrax
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.nn as nn
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
import physics_inspired_models.utils as utils
|
||||
|
||||
|
||||
_ReconstructFunc = Callable[[utils.Params, jnp.ndarray, jnp.ndarray, bool],
|
||||
distrax.Distribution]
|
||||
|
||||
|
||||
def calculate_small_latents(dist, threshold=0.5):
|
||||
"""Calculates the number of active latents by thresholding the variance of their distribution."""
|
||||
if not isinstance(dist, distrax.Normal):
|
||||
raise NotImplementedError()
|
||||
latent_means = dist.mean()
|
||||
latent_stddevs = dist.variance()
|
||||
small_latents = jnp.sum(
|
||||
(latent_stddevs < threshold) & (jnp.abs(latent_means) > 0.1), axis=1)
|
||||
return jnp.mean(small_latents)
|
||||
|
||||
|
||||
def compute_scale(
|
||||
targets: jnp.ndarray,
|
||||
rescale_by: str
|
||||
) -> jnp.ndarray:
|
||||
"""Compute a scaling factor based on targets shape and the rescale_by argument."""
|
||||
if rescale_by == "pixels_and_time":
|
||||
return jnp.asarray(np.prod(targets.shape[-4:]))
|
||||
elif rescale_by is not None:
|
||||
raise ValueError(f"Unrecognized rescale_by={rescale_by}.")
|
||||
else:
|
||||
return jnp.ones([])
|
||||
|
||||
|
||||
def compute_data_domain_stats(
|
||||
p_x: distrax.Distribution,
|
||||
targets: jnp.ndarray
|
||||
) -> Dict[str, jnp.ndarray]:
|
||||
"""Compute several statistics in the data domain, such as L2 and negative log likelihood."""
|
||||
axis = tuple(range(2, targets.ndim))
|
||||
l2_over_time = jnp.sum((p_x.mean() - targets) ** 2, axis=axis)
|
||||
l2 = jnp.sum(l2_over_time, axis=1)
|
||||
|
||||
# Calculate relative L2 normalised by image "length"
|
||||
norm_factor = jnp.sum(targets**2, axis=(2, 3, 4))
|
||||
l2_over_time_norm = l2_over_time / norm_factor
|
||||
l2_norm = jnp.sum(l2_over_time_norm, axis=1)
|
||||
|
||||
# Compute negative log-likelihood under p(x)
|
||||
neg_log_p_x_over_time = - np.sum(p_x.log_prob(targets), axis=axis)
|
||||
neg_log_p_x = jnp.sum(neg_log_p_x_over_time, axis=1)
|
||||
|
||||
return dict(
|
||||
neg_log_p_x_over_time=neg_log_p_x_over_time,
|
||||
neg_log_p_x=neg_log_p_x,
|
||||
l2_over_time=l2_over_time,
|
||||
l2=l2,
|
||||
l2_over_time_norm=l2_over_time_norm,
|
||||
l2_norm=l2_norm,
|
||||
)
|
||||
|
||||
|
||||
def compute_vae_stats(
|
||||
neg_log_p_x: jnp.ndarray,
|
||||
rng: jnp.ndarray,
|
||||
q_z: distrax.Distribution,
|
||||
prior: distrax.Distribution
|
||||
) -> Dict[str, jnp.ndarray]:
|
||||
"""Compute the KL(q(z|x)||p(z)) and the negative ELBO, which are used for VAE models."""
|
||||
# Compute the KL
|
||||
kl = distrax.estimate_kl_best_effort(q_z, prior, rng_key=rng, num_samples=1)
|
||||
kl = np.sum(kl, axis=list(range(1, kl.ndim)))
|
||||
# Sanity check
|
||||
assert kl.shape == neg_log_p_x.shape
|
||||
return dict(
|
||||
kl=kl,
|
||||
neg_elbo=neg_log_p_x + kl,
|
||||
)
|
||||
|
||||
|
||||
def training_statistics(
|
||||
p_x: distrax.Distribution,
|
||||
targets: jnp.ndarray,
|
||||
rescale_by: Optional[str],
|
||||
rng: Optional[jnp.ndarray] = None,
|
||||
q_z: Optional[distrax.Distribution] = None,
|
||||
prior: Optional[distrax.Distribution] = None,
|
||||
p_x_learned_sigma: bool = False
|
||||
) -> Dict[str, jnp.ndarray]:
|
||||
"""Computes various statistics we track during training."""
|
||||
stats = compute_data_domain_stats(p_x, targets)
|
||||
|
||||
if rng is not None and q_z is not None and prior is not None:
|
||||
stats.update(compute_vae_stats(stats["neg_log_p_x"], rng, q_z, prior))
|
||||
else:
|
||||
assert rng is None and q_z is None and prior is None
|
||||
|
||||
# Rescale these stats accordingly
|
||||
scale = compute_scale(targets, rescale_by)
|
||||
# Note that "_over_time" stats are getting normalised by time here
|
||||
stats = jax.tree_map(lambda x: x / scale, stats)
|
||||
if p_x_learned_sigma:
|
||||
stats["p_x_sigma"] = p_x.variance().reshape([-1])[0]
|
||||
if q_z is not None:
|
||||
stats["small_latents"] = calculate_small_latents(q_z)
|
||||
return stats
|
||||
|
||||
|
||||
def evaluation_only_statistics(
|
||||
reconstruct_func: _ReconstructFunc,
|
||||
params: hk.Params,
|
||||
inputs: jnp.ndarray,
|
||||
rng: jnp.ndarray,
|
||||
rescale_by: str,
|
||||
can_run_backwards: bool,
|
||||
train_sequence_length: int,
|
||||
reconstruction_skip: int,
|
||||
p_x_learned_sigma: bool = False,
|
||||
) -> Dict[str, jnp.ndarray]:
|
||||
"""Computes various statistics we track only during evaluation."""
|
||||
full_trajectory = utils.extract_image(inputs)
|
||||
prefixes = ("forward", "backward") if can_run_backwards else ("forward",)
|
||||
|
||||
full_forward_targets = jax.tree_map(
|
||||
lambda x: x[:, reconstruction_skip:], full_trajectory)
|
||||
full_backward_targets = jax.tree_map(
|
||||
lambda x: x[:, :x.shape[1]-reconstruction_skip], full_trajectory)
|
||||
train_targets_length = train_sequence_length - reconstruction_skip
|
||||
full_targets_length = full_forward_targets.shape[1]
|
||||
|
||||
stats = dict()
|
||||
keys = ()
|
||||
|
||||
for prefix in prefixes:
|
||||
# Fully unroll the model and reconstruct the whole sequence
|
||||
full_prediction = reconstruct_func(params, full_trajectory, rng,
|
||||
prefix == "forward")
|
||||
assert isinstance(full_prediction, distrax.Normal)
|
||||
full_targets = (full_forward_targets if prefix == "forward" else
|
||||
full_backward_targets)
|
||||
# In cases where the model can run backwards it is possible to reconstruct
|
||||
# parts which were indented to be skipped, so here we take care of that.
|
||||
if full_prediction.mean().shape[1] > full_targets_length:
|
||||
if prefix == "forward":
|
||||
full_prediction = jax.tree_map(lambda x: x[:, -full_targets_length:],
|
||||
full_prediction)
|
||||
else:
|
||||
full_prediction = jax.tree_map(lambda x: x[:, :full_targets_length],
|
||||
full_prediction)
|
||||
|
||||
# Based on the prefix and suffix fetch correct predictions and targets
|
||||
for suffix in ("train", "extrapolation", "full"):
|
||||
if prefix == "forward" and suffix == "train":
|
||||
predict, targets = jax.tree_map(lambda x: x[:, :train_targets_length],
|
||||
(full_prediction, full_targets))
|
||||
elif prefix == "forward" and suffix == "extrapolation":
|
||||
predict, targets = jax.tree_map(lambda x: x[:, train_targets_length:],
|
||||
(full_prediction, full_targets))
|
||||
elif prefix == "backward" and suffix == "train":
|
||||
predict, targets = jax.tree_map(lambda x: x[:, -train_targets_length:],
|
||||
(full_prediction, full_targets))
|
||||
elif prefix == "backward" and suffix == "extrapolation":
|
||||
predict, targets = jax.tree_map(lambda x: x[:, :-train_targets_length],
|
||||
(full_prediction, full_targets))
|
||||
else:
|
||||
predict, targets = full_prediction, full_targets
|
||||
|
||||
# Compute train statistics
|
||||
train_stats = training_statistics(predict, targets, rescale_by,
|
||||
p_x_learned_sigma=p_x_learned_sigma)
|
||||
for key, value in train_stats.items():
|
||||
stats[prefix + "_" + suffix + "_" + key] = value
|
||||
# Copy all stats keys
|
||||
keys = tuple(train_stats.keys())
|
||||
|
||||
# Make a combined metric summing forward and backward
|
||||
if can_run_backwards:
|
||||
# Also compute
|
||||
for suffix in ("train", "extrapolation", "full"):
|
||||
for key in keys:
|
||||
forward = stats["forward_" + suffix + "_" + key]
|
||||
backward = stats["backward_" + suffix + "_" + key]
|
||||
combined = (forward + backward) / 2
|
||||
stats["combined_" + suffix + "_" + key] = combined
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def geco_objective(
|
||||
l2_loss,
|
||||
kl,
|
||||
alpha,
|
||||
kappa,
|
||||
constraint_ema,
|
||||
lambda_var,
|
||||
is_training
|
||||
) -> Dict[str, jnp.ndarray]:
|
||||
"""Computes the objective for GECO and some of it statistics used ofr updates."""
|
||||
# C_t
|
||||
constraint_t = l2_loss - kappa
|
||||
if is_training:
|
||||
# We update C_ma only during training
|
||||
constraint_ema = alpha * constraint_ema + (1 - alpha) * constraint_t
|
||||
lagrange = nn.softplus(lambda_var)
|
||||
lagrange = jnp.broadcast_to(lagrange, constraint_ema.shape)
|
||||
# Add this special op for getting all gradients correct
|
||||
loss = utils.geco_lagrange_product(lagrange, constraint_ema, constraint_t)
|
||||
return dict(
|
||||
loss=loss + kl,
|
||||
geco_multiplier=lagrange,
|
||||
geco_constraint=constraint_t,
|
||||
geco_constraint_ema=constraint_ema
|
||||
)
|
||||
|
||||
|
||||
def elbo_objective(neg_log_p_x, kl, final_beta, beta_delay, step):
|
||||
"""Computes objective for optimizing the Evidence Lower Bound (ELBO)."""
|
||||
if beta_delay == 0:
|
||||
beta = final_beta
|
||||
else:
|
||||
delayed_beta = jnp.minimum(float(step) / float(beta_delay), 1.0)
|
||||
beta = delayed_beta * final_beta
|
||||
return dict(
|
||||
loss=neg_log_p_x + beta * kl,
|
||||
elbo_beta=beta
|
||||
)
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,345 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for all autoregressive models."""
|
||||
import functools
|
||||
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union
|
||||
|
||||
import distrax
|
||||
import haiku as hk
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jnr
|
||||
|
||||
import physics_inspired_models.metrics as metrics
|
||||
import physics_inspired_models.models.base as base
|
||||
import physics_inspired_models.models.networks as nets
|
||||
import physics_inspired_models.utils as utils
|
||||
|
||||
|
||||
class TeacherForcingAutoregressiveModel(base.SequenceModel):
|
||||
"""A standard autoregressive model trained via teacher forcing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
latent_system_dim: int,
|
||||
latent_system_net_type: str,
|
||||
latent_system_kwargs: Dict[str, Any],
|
||||
latent_dynamics_type: str,
|
||||
encoder_aggregation_type: Optional[str],
|
||||
decoder_de_aggregation_type: Optional[str],
|
||||
encoder_kwargs: Dict[str, Any],
|
||||
decoder_kwargs: Dict[str, Any],
|
||||
num_inference_steps: int,
|
||||
num_target_steps: int,
|
||||
name: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# Remove any parameters from vae models
|
||||
encoder_kwargs = dict(**encoder_kwargs)
|
||||
encoder_kwargs["distribution_name"] = None
|
||||
|
||||
if kwargs.get("has_latent_transform", False):
|
||||
raise ValueError("We do not support AR models with latent transform.")
|
||||
|
||||
super().__init__(
|
||||
can_run_backwards=False,
|
||||
latent_system_dim=latent_system_dim,
|
||||
latent_system_net_type=latent_system_net_type,
|
||||
latent_system_kwargs=latent_system_kwargs,
|
||||
encoder_aggregation_type=encoder_aggregation_type,
|
||||
decoder_de_aggregation_type=decoder_de_aggregation_type,
|
||||
encoder_kwargs=encoder_kwargs,
|
||||
decoder_kwargs=decoder_kwargs,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_target_steps=num_target_steps,
|
||||
name=name,
|
||||
**kwargs
|
||||
)
|
||||
self.latent_dynamics_type = latent_dynamics_type
|
||||
|
||||
# Arguments checks
|
||||
if self.latent_system_net_type != "mlp":
|
||||
raise ValueError("Currently we do not support non-mlp AR models.")
|
||||
|
||||
def recurrence_function(sequence, initial_state=None):
|
||||
core = nets.make_flexible_recurrent_net(
|
||||
core_type=latent_dynamics_type,
|
||||
net_type=latent_system_net_type,
|
||||
output_dims=self.latent_system_dim,
|
||||
**self.latent_system_kwargs["net_kwargs"])
|
||||
initial_state = initial_state or core.initial_state(sequence.shape[1])
|
||||
core(sequence[0], initial_state)
|
||||
return hk.dynamic_unroll(core, sequence, initial_state)
|
||||
|
||||
self.recurrence = hk.transform(recurrence_function)
|
||||
|
||||
def process_inputs_for_encoder(self, x: jnp.ndarray) -> jnp.ndarray:
|
||||
return x
|
||||
|
||||
def process_latents_for_dynamics(self, z: jnp.ndarray) -> jnp.ndarray:
|
||||
return z
|
||||
|
||||
def process_latents_for_decoder(self, z: jnp.ndarray) -> jnp.ndarray:
|
||||
return z
|
||||
|
||||
@property
|
||||
def inferred_index(self) -> int:
|
||||
return self.num_inference_steps - 1
|
||||
|
||||
@property
|
||||
def train_sequence_length(self) -> int:
|
||||
return self.num_target_steps
|
||||
|
||||
def train_data_split(
|
||||
self,
|
||||
images: jnp.ndarray
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray, Mapping[str, Any]]:
|
||||
images = images[:, :self.train_sequence_length]
|
||||
inference_data = images[:, :-1]
|
||||
target_data = images[:, 1:]
|
||||
return inference_data, target_data, dict(
|
||||
num_steps_forward=1,
|
||||
num_steps_backward=0,
|
||||
include_z0=False)
|
||||
|
||||
def unroll_without_inputs(
|
||||
self,
|
||||
params: utils.Params,
|
||||
rng: jnp.ndarray,
|
||||
x_init: jnp.ndarray,
|
||||
h_init: jnp.ndarray,
|
||||
num_steps: int,
|
||||
is_training: bool
|
||||
) -> Tuple[Tuple[distrax.Distribution, jnp.ndarray], Any]:
|
||||
if num_steps < 1:
|
||||
raise ValueError("`num_steps` must be at least 1.")
|
||||
|
||||
def step_fn(carry, key):
|
||||
x_last, h_last = carry
|
||||
enc_key, dec_key = jnr.split(key)
|
||||
z_in_next = self.encoder.apply(params, enc_key, x_last,
|
||||
is_training=is_training)
|
||||
z_next, h_next = self.recurrence.apply(params, None, z_in_next[None],
|
||||
h_last)
|
||||
p_x_next = self.decode_latents(params, dec_key, z_next[0],
|
||||
is_training=is_training)
|
||||
return (p_x_next.mean(), h_next), (p_x_next, z_next[0])
|
||||
|
||||
return lax.scan(
|
||||
step_fn,
|
||||
init=(x_init, h_init),
|
||||
xs=jnr.split(rng, num_steps)
|
||||
)
|
||||
|
||||
def unroll_latent_dynamics(
|
||||
self,
|
||||
z: jnp.ndarray,
|
||||
params: utils.Params,
|
||||
key: jnp.ndarray,
|
||||
num_steps_forward: int,
|
||||
num_steps_backward: int,
|
||||
include_z0: bool,
|
||||
is_training: bool,
|
||||
**kwargs: Any
|
||||
) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
|
||||
init_key, unroll_key, dec_key = jnr.split(key, 3)
|
||||
|
||||
if num_steps_backward != 0:
|
||||
raise ValueError("This model can not run backwards.")
|
||||
|
||||
# Change 'z' time dimension to be first
|
||||
z = jnp.swapaxes(z, 0, 1)
|
||||
|
||||
# Run recurrent model on inputs
|
||||
z_0, h_0 = self.recurrence.apply(params, init_key, z)
|
||||
|
||||
if num_steps_forward == 1:
|
||||
z_t = z_0
|
||||
elif num_steps_forward > 1:
|
||||
p_x_0 = self.decode_latents(params, dec_key, z_0[-1], is_training=False)
|
||||
_, (_, z_t) = self.unroll_without_inputs(
|
||||
params=params,
|
||||
rng=unroll_key,
|
||||
x_init=p_x_0.mean(),
|
||||
h_init=h_0,
|
||||
num_steps=num_steps_forward-1,
|
||||
is_training=is_training
|
||||
)
|
||||
z_t = jnp.concatenate([z_0, z_t], axis=0)
|
||||
else:
|
||||
raise ValueError("num_steps_forward should be at least 1.")
|
||||
|
||||
# Make time dimension second
|
||||
return jnp.swapaxes(z_t, 0, 1), dict()
|
||||
|
||||
def _models_core(
|
||||
self,
|
||||
params: utils.Params,
|
||||
keys: jnp.ndarray,
|
||||
image_data: jnp.ndarray,
|
||||
is_training: bool,
|
||||
**unroll_kwargs: Any
|
||||
) -> Tuple[distrax.Distribution, jnp.ndarray, jnp.ndarray]:
|
||||
enc_key, _, transform_key, unroll_key, dec_key, _ = keys
|
||||
|
||||
# Calculate latent input representation
|
||||
inference_data = self.process_inputs_for_encoder(image_data)
|
||||
z_raw = self.encoder.apply(params, enc_key, inference_data,
|
||||
is_training=is_training)
|
||||
|
||||
# Apply latent transformation (should be identity)
|
||||
z0 = self.apply_latent_transform(params, transform_key, z_raw,
|
||||
is_training=is_training)
|
||||
z0 = self.process_latents_for_dynamics(z0)
|
||||
|
||||
# Calculate latent output representation
|
||||
decoder_z, _ = self.unroll_latent_dynamics(
|
||||
z=z0,
|
||||
params=params,
|
||||
key=unroll_key,
|
||||
is_training=is_training,
|
||||
**unroll_kwargs
|
||||
)
|
||||
decoder_z = self.process_latents_for_decoder(decoder_z)
|
||||
|
||||
# Compute p(x|z)
|
||||
p_x = self.decode_latents(params, dec_key, decoder_z,
|
||||
is_training=is_training)
|
||||
return p_x, z0, decoder_z
|
||||
|
||||
def training_objectives(
|
||||
self,
|
||||
params: hk.Params,
|
||||
state: hk.State,
|
||||
rng: jnp.ndarray,
|
||||
inputs: jnp.ndarray,
|
||||
step: jnp.ndarray,
|
||||
is_training: bool = True,
|
||||
use_mean_for_eval_stats: bool = True
|
||||
) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]:
|
||||
"""Computes the training objective and any supporting stats."""
|
||||
# Split all rng keys
|
||||
keys = jnr.split(rng, 6)
|
||||
|
||||
# Process training data
|
||||
images = utils.extract_image(inputs)
|
||||
image_data, target_data, unroll_kwargs = self.train_data_split(images)
|
||||
|
||||
p_x, _, _ = self._models_core(
|
||||
params=params,
|
||||
keys=keys,
|
||||
image_data=image_data,
|
||||
is_training=is_training,
|
||||
**unroll_kwargs
|
||||
)
|
||||
|
||||
# Compute training statistics
|
||||
stats = metrics.training_statistics(
|
||||
p_x=p_x,
|
||||
targets=target_data,
|
||||
rescale_by=self.rescale_by,
|
||||
p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False)
|
||||
)
|
||||
|
||||
# The loss is just the negative log-likelihood (e.g. the L2 loss)
|
||||
stats["loss"] = stats["neg_log_p_x"]
|
||||
|
||||
if not is_training:
|
||||
# Optionally add the evaluation stats when not training
|
||||
# Add also the evaluation statistics
|
||||
# We need to be able to set `use_mean = False` for some of the tests
|
||||
stats.update(metrics.evaluation_only_statistics(
|
||||
reconstruct_func=functools.partial(
|
||||
self.reconstruct, use_mean=use_mean_for_eval_stats),
|
||||
params=params,
|
||||
inputs=inputs,
|
||||
rng=rng,
|
||||
rescale_by=self.rescale_by,
|
||||
can_run_backwards=self.can_run_backwards,
|
||||
train_sequence_length=self.train_sequence_length,
|
||||
reconstruction_skip=1,
|
||||
p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False)
|
||||
))
|
||||
|
||||
return stats["loss"], (dict(), stats, dict())
|
||||
|
||||
def reconstruct(
|
||||
self,
|
||||
params: utils.Params,
|
||||
inputs: jnp.ndarray,
|
||||
rng: jnp.ndarray,
|
||||
forward: bool,
|
||||
use_mean: bool = True,
|
||||
) -> distrax.Distribution:
|
||||
"""Reconstructs the input sequence."""
|
||||
if not forward:
|
||||
raise ValueError("This model can not run backwards.")
|
||||
images = utils.extract_image(inputs)
|
||||
image_data = images[:, :self.num_inference_steps]
|
||||
|
||||
return self._models_core(
|
||||
params=params,
|
||||
keys=jnr.split(rng, 6),
|
||||
image_data=image_data,
|
||||
is_training=False,
|
||||
num_steps_forward=images.shape[1] - self.num_inference_steps,
|
||||
num_steps_backward=0,
|
||||
include_z0=False,
|
||||
)[0]
|
||||
|
||||
def gt_state_and_latents(
|
||||
self,
|
||||
params: hk.Params,
|
||||
rng: jnp.ndarray,
|
||||
inputs: Dict[str, jnp.ndarray],
|
||||
seq_length: int,
|
||||
is_training: bool = False,
|
||||
unroll_direction: str = "forward",
|
||||
**kwargs: Dict[str, Any]
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray,
|
||||
Union[distrax.Distribution, jnp.ndarray]]:
|
||||
"""Computes the ground state and matching latents."""
|
||||
assert unroll_direction == "forward"
|
||||
images = utils.extract_image(inputs)
|
||||
gt_state = utils.extract_gt_state(inputs)
|
||||
image_data = images[:, :self.num_inference_steps]
|
||||
gt_state = gt_state[:, 1:seq_length + 1]
|
||||
|
||||
_, z_in, z_out = self._models_core(
|
||||
params=params,
|
||||
keys=jnr.split(rng, 6),
|
||||
image_data=image_data,
|
||||
is_training=False,
|
||||
num_steps_forward=images.shape[1] - self.num_inference_steps,
|
||||
num_steps_backward=0,
|
||||
include_z0=False,
|
||||
)
|
||||
|
||||
return gt_state, z_out, z_in
|
||||
|
||||
def _init_non_model_params_and_state(
|
||||
self,
|
||||
rng: jnp.ndarray
|
||||
) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]:
|
||||
return dict(), dict()
|
||||
|
||||
def _init_latent_system(
|
||||
self,
|
||||
rng: jnp.ndarray,
|
||||
z: jnp.ndarray,
|
||||
**kwargs: Any
|
||||
) -> utils.Params:
|
||||
return self.recurrence.init(rng, z)
|
||||
@@ -0,0 +1,360 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module containing the base abstract classes for sequence models."""
|
||||
import abc
|
||||
from typing import Any, Dict, Generic, Mapping, Optional, Sequence, Tuple, TypeVar, Union
|
||||
|
||||
from absl import logging
|
||||
import distrax
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jnr
|
||||
|
||||
|
||||
from physics_inspired_models import utils
|
||||
from physics_inspired_models.models import networks
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SequenceModel(abc.ABC, Generic[T]):
|
||||
"""An abstract class for sequence models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
can_run_backwards: bool,
|
||||
latent_system_dim: int,
|
||||
latent_system_net_type: str,
|
||||
latent_system_kwargs: Dict[str, Any],
|
||||
encoder_aggregation_type: Optional[str],
|
||||
decoder_de_aggregation_type: Optional[str],
|
||||
encoder_kwargs: Dict[str, Any],
|
||||
decoder_kwargs: Dict[str, Any],
|
||||
num_inference_steps: int,
|
||||
num_target_steps: int,
|
||||
name: str,
|
||||
latent_spatial_shape: Optional[Tuple[int, int]] = (4, 4),
|
||||
has_latent_transform: bool = False,
|
||||
latent_transform_kwargs: Optional[Dict[str, Any]] = None,
|
||||
rescale_by: Optional[str] = "pixels_and_time",
|
||||
data_format: str = "NHWC",
|
||||
**unused_kwargs
|
||||
):
|
||||
# Arguments checks
|
||||
encoder_kwargs = encoder_kwargs or dict()
|
||||
decoder_kwargs = decoder_kwargs or dict()
|
||||
|
||||
# Set the decoder de-aggregation type the "same" type as the encoder if not
|
||||
# provided
|
||||
if (decoder_de_aggregation_type is None and
|
||||
encoder_aggregation_type is not None):
|
||||
if encoder_aggregation_type == "linear_projection":
|
||||
decoder_de_aggregation_type = "linear_projection"
|
||||
elif encoder_aggregation_type in ("mean", "max"):
|
||||
decoder_de_aggregation_type = "tile"
|
||||
else:
|
||||
raise ValueError(f"Unrecognized encoder_aggregation_type="
|
||||
f"{encoder_aggregation_type}")
|
||||
if latent_system_net_type == "conv":
|
||||
if encoder_aggregation_type is not None:
|
||||
raise ValueError("When the latent system is convolutional, the encoder "
|
||||
"aggregation type should be None.")
|
||||
if decoder_de_aggregation_type is not None:
|
||||
raise ValueError("When the latent system is convolutional, the decoder "
|
||||
"aggregation type should be None.")
|
||||
else:
|
||||
if encoder_aggregation_type is None:
|
||||
raise ValueError("When the latent system is not convolutional, the "
|
||||
"you must provide an encoder aggregation type.")
|
||||
if decoder_de_aggregation_type is None:
|
||||
raise ValueError("When the latent system is not convolutional, the "
|
||||
"you must provide an decoder aggregation type.")
|
||||
if has_latent_transform and latent_transform_kwargs is None:
|
||||
raise ValueError("When using latent transformation you have to provide "
|
||||
"the latent_transform_kwargs argument.")
|
||||
if unused_kwargs:
|
||||
logging.warning("Unused kwargs: %s", str(unused_kwargs))
|
||||
super().__init__(**unused_kwargs)
|
||||
self.can_run_backwards = can_run_backwards
|
||||
self.latent_system_dim = latent_system_dim
|
||||
self.latent_system_kwargs = latent_system_kwargs
|
||||
self.latent_system_net_type = latent_system_net_type
|
||||
self.latent_spatial_shape = latent_spatial_shape
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.num_target_steps = num_target_steps
|
||||
self.rescale_by = rescale_by
|
||||
self.data_format = data_format
|
||||
self.name = name
|
||||
|
||||
# Encoder
|
||||
self.encoder_kwargs = encoder_kwargs
|
||||
self.encoder = hk.transform(
|
||||
lambda *args, **kwargs: networks.SpatialConvEncoder( # pylint: disable=unnecessary-lambda,g-long-lambda
|
||||
latent_dim=latent_system_dim,
|
||||
aggregation_type=encoder_aggregation_type,
|
||||
data_format=data_format,
|
||||
name="Encoder",
|
||||
**encoder_kwargs
|
||||
)(*args, **kwargs))
|
||||
|
||||
# Decoder
|
||||
self.decoder_kwargs = decoder_kwargs
|
||||
self.decoder = hk.transform(
|
||||
lambda *args, **kwargs: networks.SpatialConvDecoder( # pylint: disable=unnecessary-lambda,g-long-lambda
|
||||
initial_spatial_shape=self.latent_spatial_shape,
|
||||
de_aggregation_type=decoder_de_aggregation_type,
|
||||
data_format=data_format,
|
||||
max_de_aggregation_dims=self.latent_system_dim // 2,
|
||||
name="Decoder",
|
||||
**decoder_kwargs,
|
||||
)(*args, **kwargs))
|
||||
|
||||
self.has_latent_transform = has_latent_transform
|
||||
if has_latent_transform:
|
||||
self.latent_transform = hk.transform(
|
||||
lambda *args, **kwargs: networks.make_flexible_net( # pylint: disable=unnecessary-lambda,g-long-lambda
|
||||
net_type=latent_system_net_type,
|
||||
output_dims=latent_system_dim,
|
||||
name="LatentTransform",
|
||||
**latent_transform_kwargs
|
||||
)(*args, **kwargs))
|
||||
else:
|
||||
self.latent_transform = None
|
||||
|
||||
self._jit_init = None
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def train_sequence_length(self) -> int:
|
||||
"""Computes the total length of a sequence needed for training or evaluation."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def train_data_split(
|
||||
self,
|
||||
images: jnp.ndarray,
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray, Mapping[str, Any]]:
|
||||
"""Extracts from the inputs the data splits for training."""
|
||||
pass
|
||||
|
||||
def decode_latents(
|
||||
self,
|
||||
params: hk.Params,
|
||||
rng: jnp.ndarray,
|
||||
z: jnp.ndarray,
|
||||
**kwargs: Any
|
||||
) -> distrax.Distribution:
|
||||
"""Decodes the latent variable given the parameters of the model."""
|
||||
# Allow to run with both the full parameters and only the decoders
|
||||
if self.latent_system_net_type == "mlp":
|
||||
fixed_dims = 1
|
||||
elif self.latent_system_net_type == "conv":
|
||||
fixed_dims = 1 + len(self.latent_spatial_shape)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
n_shape = z.shape[:-fixed_dims]
|
||||
z = z.reshape((-1,) + z.shape[-fixed_dims:])
|
||||
x = self.decoder.apply(params, rng, z, **kwargs)
|
||||
return jax.tree_map(lambda a: a.reshape(n_shape + a.shape[1:]), x)
|
||||
|
||||
def apply_latent_transform(
|
||||
self,
|
||||
params: hk.Params,
|
||||
key: jnp.ndarray,
|
||||
z: jnp.ndarray,
|
||||
**kwargs: Any
|
||||
) -> jnp.ndarray:
|
||||
if self.latent_transform is not None:
|
||||
return self.latent_transform.apply(params, key, z, **kwargs)
|
||||
else:
|
||||
return z
|
||||
|
||||
@abc.abstractmethod
|
||||
def process_inputs_for_encoder(self, x: jnp.ndarray) -> jnp.ndarray:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def process_latents_for_dynamics(self, z: jnp.ndarray) -> T:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def process_latents_for_decoder(self, z: T) -> jnp.ndarray:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def unroll_latent_dynamics(
|
||||
self,
|
||||
z: T,
|
||||
params: utils.Params,
|
||||
key: jnp.ndarray,
|
||||
num_steps_forward: int,
|
||||
num_steps_backward: int,
|
||||
include_z0: bool,
|
||||
is_training: bool,
|
||||
**kwargs: Any
|
||||
) -> Tuple[T, Mapping[str, jnp.ndarray]]:
|
||||
"""Unrolls the latent dynamics starting from z and pre-processing for the decoder."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def reconstruct(
|
||||
self,
|
||||
params: utils.Params,
|
||||
inputs: jnp.ndarray,
|
||||
rng_key: Optional[jnp.ndarray],
|
||||
forward: bool,
|
||||
) -> distrax.Distribution:
|
||||
"""Using the first `num_inference_steps` parts of inputs reconstructs the rest."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def training_objectives(
|
||||
self,
|
||||
params: utils.Params,
|
||||
state: hk.State,
|
||||
rng: jnp.ndarray,
|
||||
inputs: Union[Dict[str, jnp.ndarray], jnp.ndarray],
|
||||
step: jnp.ndarray,
|
||||
is_training: bool = True,
|
||||
use_mean_for_eval_stats: bool = True
|
||||
) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]:
|
||||
"""Returns all training objectives statistics and update states."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def inferred_index(self):
|
||||
"""Returns the time index in the input sequence, for which the encoder infers.
|
||||
|
||||
If the encoder takes as input the sequence x[0:n-1], where
|
||||
`n = self.num_inference_steps`, then this outputs the index `k` relative to
|
||||
the begging of the input sequence `x_0`, which the encoder infers.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def inferred_right_offset(self):
|
||||
return self.num_inference_steps - 1 - self.inferred_index
|
||||
|
||||
@abc.abstractmethod
|
||||
def gt_state_and_latents(
|
||||
self,
|
||||
params: hk.Params,
|
||||
rng: jnp.ndarray,
|
||||
inputs: Dict[str, jnp.ndarray],
|
||||
seq_len: int,
|
||||
is_training: bool = False,
|
||||
unroll_direction: str = "forward",
|
||||
**kwargs: Dict[str, Any]
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
||||
"""Computes the ground state and matching latents."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _init_non_model_params_and_state(
|
||||
self,
|
||||
rng: jnp.ndarray
|
||||
) -> Tuple[utils.Params, utils.Params]:
|
||||
"""Initializes any non-model parameters and state."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _init_latent_system(
|
||||
self,
|
||||
rng: jnp.ndarray,
|
||||
z: jnp.ndarray,
|
||||
**kwargs: Any
|
||||
) -> hk.Params:
|
||||
"""Initializes the parameters of the latent system."""
|
||||
pass
|
||||
|
||||
def _init(
|
||||
self,
|
||||
rng: jnp.ndarray,
|
||||
images: jnp.ndarray
|
||||
) -> Tuple[hk.Params, hk.State]:
|
||||
"""Initializes the whole model parameters and state."""
|
||||
inference_data, _, _ = self.train_data_split(images)
|
||||
# Initialize parameters and state for the vae training
|
||||
rng, key = jnr.split(rng)
|
||||
params, state = self._init_non_model_params_and_state(key)
|
||||
|
||||
# Initialize and run encoder
|
||||
inference_data = self.process_inputs_for_encoder(inference_data)
|
||||
rng, key = jnr.split(rng)
|
||||
encoder_params = self.encoder.init(key, inference_data, is_training=True)
|
||||
rng, key = jnr.split(rng)
|
||||
z_in = self.encoder.apply(encoder_params, key, inference_data,
|
||||
is_training=True)
|
||||
|
||||
# For probabilistic models this will be a distribution
|
||||
if isinstance(z_in, distrax.Distribution):
|
||||
z_in = z_in.mean()
|
||||
|
||||
# Initialize and run the optional latent transform
|
||||
if self.latent_transform is not None:
|
||||
rng, key = jnr.split(rng)
|
||||
transform_params = self.latent_transform.init(key, z_in, is_training=True)
|
||||
rng, key = jnr.split(rng)
|
||||
z_in = self.latent_transform.apply(transform_params, key, z_in,
|
||||
is_training=True)
|
||||
else:
|
||||
transform_params = dict()
|
||||
|
||||
# Initialize and run the latent system
|
||||
z_in = self.process_latents_for_dynamics(z_in)
|
||||
rng, key = jnr.split(rng)
|
||||
latent_params = self._init_latent_system(key, z_in, is_training=True)
|
||||
rng, key = jnr.split(rng)
|
||||
z_out, _ = self.unroll_latent_dynamics(
|
||||
z=z_in,
|
||||
params=latent_params,
|
||||
key=key,
|
||||
num_steps_forward=1,
|
||||
num_steps_backward=0,
|
||||
include_z0=False,
|
||||
is_training=True
|
||||
)
|
||||
z_out = self.process_latents_for_decoder(z_out)
|
||||
|
||||
# Initialize and run the decoder
|
||||
rng, key = jnr.split(rng)
|
||||
decoder_params = self.decoder.init(key, z_out[:, 0], is_training=True)
|
||||
_ = self.decoder.apply(decoder_params, rng, z_out[:, 0], is_training=True)
|
||||
|
||||
# Combine all and make immutable
|
||||
params = hk.data_structures.merge(params, encoder_params, transform_params,
|
||||
latent_params, decoder_params)
|
||||
params = hk.data_structures.to_immutable_dict(params)
|
||||
state = hk.data_structures.to_immutable_dict(state)
|
||||
|
||||
return params, state
|
||||
|
||||
def init(
|
||||
self,
|
||||
rng: jnp.ndarray,
|
||||
inputs_or_shape: Union[jnp.ndarray, Mapping[str, jnp.ndarray],
|
||||
Sequence[int]],
|
||||
) -> Tuple[utils.Params, hk.State]:
|
||||
"""Initializes the whole model parameters and state."""
|
||||
if (isinstance(inputs_or_shape, (tuple, list))
|
||||
and isinstance(inputs_or_shape[0], int)):
|
||||
images = jnp.zeros(inputs_or_shape)
|
||||
else:
|
||||
images = utils.extract_image(inputs_or_shape)
|
||||
if self._jit_init is None:
|
||||
self._jit_init = jax.jit(self._init)
|
||||
return self._jit_init(rng, images)
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for all models."""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import physics_inspired_models.models.autoregressive as autoregressive
|
||||
import physics_inspired_models.models.deterministic_vae as deterministic_vae
|
||||
|
||||
_physics_arguments = (
|
||||
"input_space", "simulation_space", "potential_func_form",
|
||||
"kinetic_func_form", "hgn_kinetic_func_form", "lgn_kinetic_func_form",
|
||||
"parametrize_mass_matrix", "hgn_parametrize_mass_matrix",
|
||||
"lgn_parametrize_mass_matrix", "mass_eps"
|
||||
)
|
||||
|
||||
|
||||
def construct_model(
|
||||
name: str,
|
||||
*args,
|
||||
**kwargs: Dict[str, Any]
|
||||
):
|
||||
"""Constructs the correct instance of a model given the short name."""
|
||||
latent_dynamics_type: Optional[str] = kwargs.pop("latent_dynamics_type", None) # pytype: disable=annotation-type-mismatch
|
||||
latent_system_kwargs = dict(**kwargs.pop("latent_system_kwargs", dict()))
|
||||
if name == "AR":
|
||||
assert latent_dynamics_type in ("vanilla", "lstm", "gru")
|
||||
# This arguments are not part of the AR models
|
||||
for k in _physics_arguments + ("integrator_method", "residual"):
|
||||
latent_system_kwargs.pop(k, None)
|
||||
return autoregressive.TeacherForcingAutoregressiveModel(
|
||||
*args,
|
||||
latent_dynamics_type=latent_dynamics_type,
|
||||
latent_system_kwargs=latent_system_kwargs,
|
||||
**kwargs
|
||||
)
|
||||
elif name == "RGN":
|
||||
assert latent_dynamics_type in ("Discrete", None)
|
||||
latent_dynamics_type = "Discrete"
|
||||
# This arguments are not part of the RGN models
|
||||
for k in _physics_arguments + ("integrator_method",):
|
||||
latent_system_kwargs.pop(k, None)
|
||||
elif name == "ODE":
|
||||
assert latent_dynamics_type in ("ODE", None)
|
||||
latent_dynamics_type = "ODE"
|
||||
# This arguments are not part of the ODE models
|
||||
for k in _physics_arguments + ("residual",):
|
||||
latent_system_kwargs.pop(k, None)
|
||||
elif name == "HGN":
|
||||
assert latent_dynamics_type in ("Physics", None)
|
||||
latent_dynamics_type = "Physics"
|
||||
assert latent_system_kwargs.get("input_space", None) in ("momentum", None)
|
||||
latent_system_kwargs["input_space"] = "momentum"
|
||||
assert (latent_system_kwargs.get("simulation_space", None)
|
||||
in ("momentum", None))
|
||||
latent_system_kwargs["simulation_space"] = "momentum"
|
||||
# Kinetic func form
|
||||
hgn_specific = latent_system_kwargs.pop("hgn_kinetic_func_form", None)
|
||||
if hgn_specific is not None:
|
||||
latent_system_kwargs["kinetic_func_form"] = hgn_specific
|
||||
# Mass matrix
|
||||
hgn_specific = latent_system_kwargs.pop("hgn_parametrize_mass_matrix",
|
||||
None)
|
||||
if hgn_specific is not None:
|
||||
latent_system_kwargs["parametrize_mass_matrix"] = hgn_specific
|
||||
# This arguments are not part of the HGN models
|
||||
latent_system_kwargs.pop("residual", None)
|
||||
latent_system_kwargs.pop("lgn_kinetic_func_form", None)
|
||||
latent_system_kwargs.pop("lgn_parametrize_mass_matrix", None)
|
||||
elif name == "LGN":
|
||||
assert latent_dynamics_type in ("Physics", None)
|
||||
latent_dynamics_type = "Physics"
|
||||
assert latent_system_kwargs.get("input_space", None) in ("velocity", None)
|
||||
latent_system_kwargs["input_space"] = "velocity"
|
||||
assert (latent_system_kwargs.get("simulation_space", None) in
|
||||
("velocity", None))
|
||||
latent_system_kwargs["simulation_space"] = "velocity"
|
||||
# Kinetic func form
|
||||
lgn_specific = latent_system_kwargs.pop("lgn_kinetic_func_form", None)
|
||||
if lgn_specific is not None:
|
||||
latent_system_kwargs["kinetic_func_form"] = lgn_specific
|
||||
# Mass matrix
|
||||
lgn_specific = latent_system_kwargs.pop("lgn_parametrize_mass_matrix",
|
||||
None)
|
||||
if lgn_specific is not None:
|
||||
latent_system_kwargs["parametrize_mass_matrix"] = lgn_specific
|
||||
# This arguments are not part of the HGN models
|
||||
latent_system_kwargs.pop("residual", None)
|
||||
latent_system_kwargs.pop("hgn_kinetic_func_form", None)
|
||||
latent_system_kwargs.pop("hgn_parametrize_mass_matrix", None)
|
||||
elif name == "PGN":
|
||||
assert latent_dynamics_type in ("Physics", None)
|
||||
latent_dynamics_type = "Physics"
|
||||
# This arguments are not part of the PGN models
|
||||
latent_system_kwargs.pop("residual")
|
||||
latent_system_kwargs.pop("hgn_kinetic_func_form", None)
|
||||
latent_system_kwargs.pop("hgn_parametrize_mass_matrix", None)
|
||||
latent_system_kwargs.pop("lgn_kinetic_func_form", None)
|
||||
latent_system_kwargs.pop("lgn_parametrize_mass_matrix", None)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return deterministic_vae.DeterministicLatentsGenerativeModel(
|
||||
*args,
|
||||
latent_dynamics_type=latent_dynamics_type,
|
||||
latent_system_kwargs=latent_system_kwargs,
|
||||
**kwargs)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,494 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module containing all of the networks as Haiku modules."""
|
||||
from typing import Any, Callable, Mapping, Optional, Sequence, Union
|
||||
|
||||
from absl import logging
|
||||
import distrax
|
||||
import haiku as hk
|
||||
import jax.numpy as jnp
|
||||
|
||||
from physics_inspired_models import utils
|
||||
|
||||
Activation = Union[str, Callable[[jnp.ndarray], jnp.ndarray]]
|
||||
|
||||
|
||||
class DenseNet(hk.Module):
|
||||
"""A feed forward network (MLP)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_units: Sequence[int],
|
||||
activate_final: bool = False,
|
||||
activation: Activation = "leaky_relu",
|
||||
name: Optional[str] = None):
|
||||
super().__init__(name=name)
|
||||
self.num_units = num_units
|
||||
self.num_layers = len(self.num_units)
|
||||
self.activate_final = activate_final
|
||||
self.activation = utils.get_activation(activation)
|
||||
|
||||
self.linear_modules = []
|
||||
for i in range(self.num_layers):
|
||||
self.linear_modules.append(
|
||||
hk.Linear(
|
||||
output_size=self.num_units[i],
|
||||
name=f"ff_{i}"
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, inputs: jnp.ndarray, is_training: bool):
|
||||
net = inputs
|
||||
for i, linear in enumerate(self.linear_modules):
|
||||
net = linear(net)
|
||||
if i < self.num_layers - 1 or self.activate_final:
|
||||
net = self.activation(net)
|
||||
return net
|
||||
|
||||
|
||||
class Conv2DNet(hk.Module):
|
||||
"""Convolutional Network."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_channels: Sequence[int],
|
||||
kernel_shapes: Union[int, Sequence[int]] = 3,
|
||||
strides: Union[int, Sequence[int]] = 1,
|
||||
padding: Union[str, Sequence[str]] = "SAME",
|
||||
data_format: str = "NHWC",
|
||||
with_batch_norm: bool = False,
|
||||
activate_final: bool = False,
|
||||
activation: Activation = "leaky_relu",
|
||||
name: Optional[str] = None):
|
||||
super().__init__(name=name)
|
||||
self.output_channels = tuple(output_channels)
|
||||
self.num_layers = len(self.output_channels)
|
||||
self.kernel_shapes = utils.bcast_if(kernel_shapes, int, self.num_layers)
|
||||
self.strides = utils.bcast_if(strides, int, self.num_layers)
|
||||
self.padding = utils.bcast_if(padding, str, self.num_layers)
|
||||
self.data_format = data_format
|
||||
self.with_batch_norm = with_batch_norm
|
||||
self.activate_final = activate_final
|
||||
self.activation = utils.get_activation(activation)
|
||||
|
||||
if len(self.kernel_shapes) != self.num_layers:
|
||||
raise ValueError(f"Kernel shapes is of size {len(self.kernel_shapes)}, "
|
||||
f"while output_channels is of size{self.num_layers}.")
|
||||
if len(self.strides) != self.num_layers:
|
||||
raise ValueError(f"Strides is of size {len(self.kernel_shapes)}, while "
|
||||
f"output_channels is of size{self.num_layers}.")
|
||||
if len(self.padding) != self.num_layers:
|
||||
raise ValueError(f"Padding is of size {len(self.padding)}, while "
|
||||
f"output_channels is of size{self.num_layers}.")
|
||||
|
||||
self.conv_modules = []
|
||||
self.bn_modules = []
|
||||
for i in range(self.num_layers):
|
||||
self.conv_modules.append(
|
||||
hk.Conv2D(
|
||||
output_channels=self.output_channels[i],
|
||||
kernel_shape=self.kernel_shapes[i],
|
||||
stride=self.strides[i],
|
||||
padding=self.padding[i],
|
||||
data_format=data_format,
|
||||
name=f"conv_2d_{i}")
|
||||
)
|
||||
if with_batch_norm:
|
||||
self.bn_modules.append(
|
||||
hk.BatchNorm(
|
||||
create_offset=True,
|
||||
create_scale=False,
|
||||
decay_rate=0.999,
|
||||
name=f"batch_norm_{i}")
|
||||
)
|
||||
else:
|
||||
self.bn_modules.append(None)
|
||||
|
||||
def __call__(self, inputs: jnp.ndarray, is_training: bool):
|
||||
assert inputs.ndim == 4
|
||||
net = inputs
|
||||
for i, (conv, bn) in enumerate(zip(self.conv_modules, self.bn_modules)):
|
||||
net = conv(net)
|
||||
# Batch norm
|
||||
if bn is not None:
|
||||
net = bn(net, is_training=is_training)
|
||||
if i < self.num_layers - 1 or self.activate_final:
|
||||
net = self.activation(net)
|
||||
return net
|
||||
|
||||
|
||||
class SpatialConvEncoder(hk.Module):
|
||||
"""Spatial Convolutional Encoder for learning the Hamiltonian."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
latent_dim: int,
|
||||
conv_channels: Union[Sequence[int], int],
|
||||
num_blocks: int,
|
||||
blocks_depth: int = 2,
|
||||
distribution_name: str = "diagonal_normal",
|
||||
aggregation_type: Optional[str] = None,
|
||||
data_format: str = "NHWC",
|
||||
activation: Activation = "leaky_relu",
|
||||
scale_factor: int = 2,
|
||||
kernel_shapes: Union[Sequence[int], int] = 3,
|
||||
padding: Union[Sequence[str], str] = "SAME",
|
||||
name: Optional[str] = None):
|
||||
super().__init__(name=name)
|
||||
if aggregation_type not in (None, "max", "mean", "linear_projection"):
|
||||
raise ValueError(f"Unrecognized aggregation_type={aggregation_type}.")
|
||||
self.latent_dim = latent_dim
|
||||
self.conv_channels = conv_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.scale_factor = scale_factor
|
||||
self.data_format = data_format
|
||||
self.distribution_name = distribution_name
|
||||
self.aggregation_type = aggregation_type
|
||||
|
||||
# Compute the required size of the output
|
||||
if distribution_name is None:
|
||||
self.output_dim = latent_dim
|
||||
elif distribution_name == "diagonal_normal":
|
||||
self.output_dim = 2 * latent_dim
|
||||
else:
|
||||
raise ValueError(f"Unrecognized distribution_name={distribution_name}.")
|
||||
|
||||
if isinstance(conv_channels, int):
|
||||
conv_channels = [[conv_channels] * blocks_depth
|
||||
for _ in range(num_blocks)]
|
||||
conv_channels[-1] += [self.output_dim]
|
||||
else:
|
||||
assert isinstance(conv_channels, (list, tuple))
|
||||
assert len(conv_channels) == num_blocks
|
||||
conv_channels = list(list(c) for c in conv_channels)
|
||||
conv_channels[-1].append(self.output_dim)
|
||||
|
||||
if isinstance(kernel_shapes, tuple):
|
||||
kernel_shapes = list(kernel_shapes)
|
||||
|
||||
# Convolutional blocks
|
||||
self.blocks = []
|
||||
for i, channels in enumerate(conv_channels):
|
||||
if isinstance(kernel_shapes, int):
|
||||
extra_kernel_shapes = 0
|
||||
else:
|
||||
extra_kernel_shapes = [3] * (len(channels) - len(kernel_shapes))
|
||||
|
||||
self.blocks.append(Conv2DNet(
|
||||
output_channels=channels,
|
||||
kernel_shapes=kernel_shapes + extra_kernel_shapes,
|
||||
strides=[self.scale_factor] + [1] * (len(channels) - 1),
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
with_batch_norm=False,
|
||||
activate_final=i < num_blocks - 1,
|
||||
activation=activation,
|
||||
name=f"block_{i}"
|
||||
))
|
||||
|
||||
def spatial_aggregation(self, x: jnp.ndarray) -> jnp.ndarray:
|
||||
if self.aggregation_type is None:
|
||||
return x
|
||||
axis = (1, 2) if self.data_format == "NHWC" else (2, 3)
|
||||
if self.aggregation_type == "max":
|
||||
return jnp.max(x, axis=axis)
|
||||
if self.aggregation_type == "mean":
|
||||
return jnp.mean(x, axis=axis)
|
||||
if self.aggregation_type == "linear_projection":
|
||||
x = x.reshape(x.shape[:-3] + (-1,))
|
||||
return hk.Linear(self.output_dim, name="LinearProjection")(x)
|
||||
raise NotImplementedError()
|
||||
|
||||
def make_distribution(self, net_output: jnp.ndarray) -> distrax.Distribution:
|
||||
if self.distribution_name is None:
|
||||
return net_output
|
||||
elif self.distribution_name == "diagonal_normal":
|
||||
if self.aggregation_type is None:
|
||||
split_axis, num_axes = self.data_format.index("C"), 3
|
||||
else:
|
||||
split_axis, num_axes = 1, 1
|
||||
# Add an extra axis if the input has more than 1 batch dimension
|
||||
split_axis += net_output.ndim - num_axes - 1
|
||||
loc, log_scale = jnp.split(net_output, 2, axis=split_axis)
|
||||
return distrax.Normal(loc, jnp.exp(log_scale))
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: jnp.ndarray,
|
||||
is_training: bool
|
||||
) -> Union[jnp.ndarray, distrax.Distribution]:
|
||||
# Treat any extra dimensions (like time) as the batch
|
||||
batched_shape = inputs.shape[:-3]
|
||||
net = jnp.reshape(inputs, (-1,) + inputs.shape[-3:])
|
||||
|
||||
# Apply all blocks in sequence
|
||||
for block in self.blocks:
|
||||
net = block(net, is_training=is_training)
|
||||
|
||||
# Final projection
|
||||
net = self.spatial_aggregation(net)
|
||||
|
||||
# Reshape back to correct dimensions (like batch + time)
|
||||
net = jnp.reshape(net, batched_shape + net.shape[1:])
|
||||
|
||||
# Return a distribution over the observations
|
||||
return self.make_distribution(net)
|
||||
|
||||
|
||||
class SpatialConvDecoder(hk.Module):
|
||||
"""Spatial Convolutional Decoder for learning the Hamiltonian."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_spatial_shape: Sequence[int],
|
||||
conv_channels: Union[Sequence[int], int],
|
||||
num_blocks: int,
|
||||
max_de_aggregation_dims: int,
|
||||
blocks_depth: int = 2,
|
||||
scale_factor: int = 2,
|
||||
output_channels: int = 3,
|
||||
h_const_channels: int = 2,
|
||||
data_format: str = "NHWC",
|
||||
activation: Activation = "leaky_relu",
|
||||
learned_sigma: bool = False,
|
||||
de_aggregation_type: Optional[str] = None,
|
||||
final_activation: Activation = "sigmoid",
|
||||
discard_half_de_aggregated: bool = False,
|
||||
kernel_shapes: Union[Sequence[int], int] = 3,
|
||||
padding: Union[Sequence[str], str] = "SAME",
|
||||
name: Optional[str] = None):
|
||||
super().__init__(name=name)
|
||||
if de_aggregation_type not in (None, "tile", "linear_projection"):
|
||||
raise ValueError(f"Unrecognized de_aggregation_type="
|
||||
f"{de_aggregation_type}.")
|
||||
self.num_blocks = num_blocks
|
||||
self.scale_factor = scale_factor
|
||||
self.h_const_channels = h_const_channels
|
||||
self.data_format = data_format
|
||||
self.learned_sigma = learned_sigma
|
||||
self.initial_spatial_shape = tuple(initial_spatial_shape)
|
||||
self.final_activation = utils.get_activation(final_activation)
|
||||
self.de_aggregation_type = de_aggregation_type
|
||||
self.max_de_aggregation_dims = max_de_aggregation_dims
|
||||
self.discard_half_de_aggregated = discard_half_de_aggregated
|
||||
|
||||
if isinstance(conv_channels, int):
|
||||
conv_channels = [[conv_channels] * blocks_depth
|
||||
for _ in range(num_blocks)]
|
||||
conv_channels[-1] += [output_channels]
|
||||
else:
|
||||
assert isinstance(conv_channels, (list, tuple))
|
||||
assert len(conv_channels) == num_blocks
|
||||
conv_channels = list(list(c) for c in conv_channels)
|
||||
conv_channels[-1].append(output_channels)
|
||||
|
||||
# Convolutional blocks
|
||||
self.blocks = []
|
||||
for i, channels in enumerate(conv_channels):
|
||||
is_final_block = i == num_blocks - 1
|
||||
self.blocks.append(
|
||||
Conv2DNet( # pylint: disable=g-complex-comprehension
|
||||
output_channels=channels,
|
||||
kernel_shapes=kernel_shapes,
|
||||
strides=1,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
with_batch_norm=False,
|
||||
activate_final=not is_final_block,
|
||||
activation=activation,
|
||||
name=f"block_{i}"
|
||||
))
|
||||
|
||||
def spatial_de_aggregation(self, x: jnp.ndarray) -> jnp.ndarray:
|
||||
if self.de_aggregation_type is None:
|
||||
assert x.ndim >= 4
|
||||
if self.data_format == "NHWC":
|
||||
assert x.shape[1:3] == self.initial_spatial_shape
|
||||
elif self.data_format == "NCHW":
|
||||
assert x.shape[2:4] == self.initial_spatial_shape
|
||||
return x
|
||||
elif self.de_aggregation_type == "linear_projection":
|
||||
assert x.ndim == 2
|
||||
n, d = x.shape
|
||||
d = min(d, self.max_de_aggregation_dims or d)
|
||||
out_d = d * self.initial_spatial_shape[0] * self.initial_spatial_shape[1]
|
||||
x = hk.Linear(out_d, name="LinearProjection")(x)
|
||||
if self.data_format == "NHWC":
|
||||
shape = (n,) + self.initial_spatial_shape + (d,)
|
||||
else:
|
||||
shape = (n, d) + self.initial_spatial_shape
|
||||
return x.reshape(shape)
|
||||
elif self.de_aggregation_type == "tile":
|
||||
assert x.ndim == 2
|
||||
if self.data_format == "NHWC":
|
||||
repeats = (1,) + self.initial_spatial_shape + (1,)
|
||||
x = x[:, None, None, :]
|
||||
else:
|
||||
repeats = (1, 1) + self.initial_spatial_shape
|
||||
x = x[:, :, None, None]
|
||||
return jnp.tile(x, repeats)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def add_constant_channels(self, inputs: jnp.ndarray) -> jnp.ndarray:
|
||||
# --------------------------------------------
|
||||
# This is purely for TF compatibility purposes
|
||||
if self.discard_half_de_aggregated:
|
||||
axis = self.data_format.index("C")
|
||||
inputs, _ = jnp.split(inputs, 2, axis=axis)
|
||||
# --------------------------------------------
|
||||
|
||||
# An extra constant channels
|
||||
if self.data_format == "NHWC":
|
||||
h_shape = self.initial_spatial_shape + (self.h_const_channels,)
|
||||
else:
|
||||
h_shape = (self.h_const_channels,) + self.initial_spatial_shape
|
||||
h_const = hk.get_parameter("h", h_shape, dtype=inputs.dtype,
|
||||
init=hk.initializers.Constant(1))
|
||||
h_const = jnp.tile(h_const, reps=[inputs.shape[0], 1, 1, 1])
|
||||
return jnp.concatenate([h_const, inputs], axis=self.data_format.index("C"))
|
||||
|
||||
def make_distribution(self, net_output: jnp.ndarray) -> distrax.Distribution:
|
||||
if self.learned_sigma:
|
||||
init = hk.initializers.Constant(- jnp.log(2.0) / 2.0)
|
||||
log_scale = hk.get_parameter("log_scale", shape=(),
|
||||
dtype=net_output.dtype, init=init)
|
||||
scale = jnp.full_like(net_output, jnp.exp(log_scale))
|
||||
else:
|
||||
scale = jnp.full_like(net_output, 1 / jnp.sqrt(2.0))
|
||||
|
||||
return distrax.Normal(net_output, scale)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: jnp.ndarray,
|
||||
is_training: bool
|
||||
) -> distrax.Distribution:
|
||||
# Apply the spatial de-aggregation
|
||||
inputs = self.spatial_de_aggregation(inputs)
|
||||
|
||||
# Add the parameterized constant channels
|
||||
net = self.add_constant_channels(inputs)
|
||||
|
||||
# Apply all the blocks
|
||||
for block in self.blocks:
|
||||
# Up-sample the image
|
||||
net = utils.nearest_neighbour_upsampling(net, self.scale_factor)
|
||||
# Apply the convolutional block
|
||||
net = block(net, is_training=is_training)
|
||||
|
||||
# Apply any specific output nonlinearity
|
||||
net = self.final_activation(net)
|
||||
|
||||
# Construct the distribution over the observations
|
||||
return self.make_distribution(net)
|
||||
|
||||
|
||||
def make_flexible_net(
|
||||
net_type: str,
|
||||
output_dims: int,
|
||||
conv_channels: Union[Sequence[int], int],
|
||||
num_units: Union[Sequence[int], int],
|
||||
num_layers: Optional[int],
|
||||
activation: Activation,
|
||||
activate_final: bool = False,
|
||||
kernel_shapes: Union[Sequence[int], int] = 3,
|
||||
strides: Union[Sequence[int], int] = 1,
|
||||
padding: Union[Sequence[str], str] = "SAME",
|
||||
name: Optional[str] = None,
|
||||
**unused_kwargs: Mapping[str, Any]
|
||||
):
|
||||
"""Commonly used for creating a flexible network."""
|
||||
if unused_kwargs:
|
||||
logging.warning("Unused kwargs of `make_flexible_net`: %s",
|
||||
str(unused_kwargs))
|
||||
if net_type == "mlp":
|
||||
if isinstance(num_units, int):
|
||||
assert num_layers is not None
|
||||
num_units = [num_units] * (num_layers - 1) + [output_dims]
|
||||
else:
|
||||
num_units = list(num_units) + [output_dims]
|
||||
return DenseNet(
|
||||
num_units=num_units,
|
||||
activation=activation,
|
||||
activate_final=activate_final,
|
||||
name=name
|
||||
)
|
||||
elif net_type == "conv":
|
||||
if isinstance(conv_channels, int):
|
||||
assert num_layers is not None
|
||||
conv_channels = [conv_channels] * (num_layers - 1) + [output_dims]
|
||||
else:
|
||||
conv_channels = list(conv_channels) + [output_dims]
|
||||
return Conv2DNet(
|
||||
output_channels=conv_channels,
|
||||
kernel_shapes=kernel_shapes,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
activation=activation,
|
||||
activate_final=activate_final,
|
||||
name=name
|
||||
)
|
||||
elif net_type == "transformer":
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise ValueError(f"Unrecognized net_type={net_type}.")
|
||||
|
||||
|
||||
def make_flexible_recurrent_net(
|
||||
core_type: str,
|
||||
net_type: str,
|
||||
output_dims: int,
|
||||
num_units: Union[Sequence[int], int],
|
||||
num_layers: Optional[int],
|
||||
activation: Activation,
|
||||
activate_final: bool = False,
|
||||
name: Optional[str] = None,
|
||||
**unused_kwargs
|
||||
):
|
||||
"""Commonly used for creating a flexible recurrences."""
|
||||
if net_type != "mlp":
|
||||
raise ValueError("We do not support convolutional recurrent nets atm.")
|
||||
if unused_kwargs:
|
||||
logging.warning("Unused kwargs of `make_flexible_recurrent_net`: %s",
|
||||
str(unused_kwargs))
|
||||
|
||||
if isinstance(num_units, (list, tuple)):
|
||||
num_units = list(num_units) + [output_dims]
|
||||
num_layers = len(num_units)
|
||||
else:
|
||||
assert num_layers is not None
|
||||
num_units = [num_units] * (num_layers - 1) + [output_dims]
|
||||
name = name or f"{core_type.upper()}"
|
||||
|
||||
activation = utils.get_activation(activation)
|
||||
core_list = []
|
||||
for i, n in enumerate(num_units):
|
||||
if core_type.lower() == "vanilla":
|
||||
core_list.append(hk.VanillaRNN(hidden_size=n, name=f"{name}_{i}"))
|
||||
elif core_type.lower() == "lstm":
|
||||
core_list.append(hk.LSTM(hidden_size=n, name=f"{name}_{i}"))
|
||||
elif core_type.lower() == "gru":
|
||||
core_list.append(hk.GRU(hidden_size=n, name=f"{name}_{i}"))
|
||||
else:
|
||||
raise ValueError(f"Unrecognized core_type={core_type}.")
|
||||
if i != num_layers - 1:
|
||||
core_list.append(activation)
|
||||
if activate_final:
|
||||
core_list.append(activation)
|
||||
|
||||
return hk.DeepRNN(core_list, name="RNN")
|
||||
@@ -0,0 +1,10 @@
|
||||
git+https://github.com/deepmind/dm_hamiltonian_dynamics_suite@main#egg=dm_hamiltonian_dynamics_suite
|
||||
absl-py==0.12.0
|
||||
numpy>=1.16.4
|
||||
scikit-learn>=1.0
|
||||
typing>=3.7.4.3
|
||||
jax==0.2.20
|
||||
jaxline==0.0.3
|
||||
distrax==0.0.2
|
||||
optax==0.0.6
|
||||
dm-haiku==0.0.3
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Setup for pip package."""
|
||||
from setuptools import setup
|
||||
|
||||
REQUIRED_PACKAGES = (
|
||||
"dm_hamiltonian_dynamics_suite@git+https://github.com/deepmind/dm_hamiltonian_dynamics_suite", # pylint: disable=line-too-long.
|
||||
"absl-py>=0.12.0",
|
||||
"numpy>=1.16.4",
|
||||
"scikit-learn>=1.0",
|
||||
"typing>=3.7.4.3",
|
||||
"jax==0.2.20",
|
||||
"jaxline==0.0.3",
|
||||
"distrax==0.0.2",
|
||||
"optax==0.0.6",
|
||||
"dm-haiku==0.0.3",
|
||||
)
|
||||
|
||||
LONG_DESCRIPTION = "\n".join([
|
||||
"A codebase containing the implementation of the following models:",
|
||||
"Hamiltonian Generative Network (HGN)",
|
||||
"Lagrangian Generative Network (LGN)",
|
||||
"Neural ODE",
|
||||
"Recurrent Generative Network (RGN)",
|
||||
"and RNN, LSTM and GRU.",
|
||||
"This is code accompanying the publication of:"
|
||||
])
|
||||
|
||||
|
||||
setup(
|
||||
name="physics_inspired_models",
|
||||
version="0.0.1",
|
||||
description="Implementation of multiple physically inspired models.",
|
||||
long_description=LONG_DESCRIPTION,
|
||||
url="https://github.com/deepmind/deepmind-research/physics_inspired_models",
|
||||
author="DeepMind",
|
||||
package_dir={"physics_inspired_models": "."},
|
||||
packages=["physics_inspired_models", "physics_inspired_models.models"],
|
||||
install_requires=REQUIRED_PACKAGES,
|
||||
platforms=["any"],
|
||||
license="Apache License, Version 2.0",
|
||||
)
|
||||
@@ -0,0 +1,274 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities functions for Jax."""
|
||||
import collections
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, Mapping, Union
|
||||
|
||||
import distrax
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import nn
|
||||
import jax.numpy as jnp
|
||||
from jax.tree_util import register_pytree_node
|
||||
from jaxline import utils
|
||||
import numpy as np
|
||||
|
||||
HaikuParams = Mapping[str, Mapping[str, jnp.ndarray]]
|
||||
Params = Union[Mapping[str, jnp.ndarray], HaikuParams, jnp.ndarray]
|
||||
_Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
||||
|
||||
tf_leaky_relu = functools.partial(nn.leaky_relu, negative_slope=0.2)
|
||||
|
||||
|
||||
def filter_only_scalar_stats(stats):
|
||||
return {k: v for k, v in stats.items() if v.size == 1}
|
||||
|
||||
|
||||
def to_numpy(obj):
|
||||
return jax.tree_map(np.array, obj)
|
||||
|
||||
|
||||
@jax.custom_gradient
|
||||
def geco_lagrange_product(lagrange_multiplier, constraint_ema, constraint_t):
|
||||
"""Modifies the gradients so that they work as described in GECO.
|
||||
|
||||
The evaluation gives:
|
||||
lagrange * C_ema
|
||||
The gradient w.r.t lagrange:
|
||||
- g * C_t
|
||||
The gradient w.r.t constraint_ema:
|
||||
0.0
|
||||
The gradient w.r.t constraint_t:
|
||||
g * lagrange
|
||||
|
||||
Note that if you pass the same value for `constraint_ema` and `constraint_t`
|
||||
this would only flip the gradient for the lagrange multiplier.
|
||||
|
||||
Args:
|
||||
lagrange_multiplier: The lagrange multiplier
|
||||
constraint_ema: The moving average of the constraint
|
||||
constraint_t: The current constraint
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
def grad(gradient):
|
||||
return (- gradient * constraint_t,
|
||||
jnp.zeros_like(constraint_ema),
|
||||
gradient * lagrange_multiplier)
|
||||
return lagrange_multiplier * constraint_ema, grad
|
||||
|
||||
|
||||
def bcast_if(x, t, n):
|
||||
return [x] * n if isinstance(x, t) else x
|
||||
|
||||
|
||||
def stack_time_into_channels(
|
||||
images: jnp.ndarray,
|
||||
data_format: str
|
||||
) -> jnp.ndarray:
|
||||
axis = data_format.index("C")
|
||||
list_of_time = [jnp.squeeze(v, axis=1) for v in
|
||||
jnp.split(images, images.shape[1], axis=1)]
|
||||
return jnp.concatenate(list_of_time, axis)
|
||||
|
||||
|
||||
def stack_device_dim_into_batch(obj):
|
||||
return jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:]), obj)
|
||||
|
||||
|
||||
def nearest_neighbour_upsampling(x, scale, data_format="NHWC"):
|
||||
"""Performs nearest-neighbour upsampling."""
|
||||
|
||||
if data_format == "NCHW":
|
||||
b, c, h, w = x.shape
|
||||
x = jnp.reshape(x, [b, c, h, 1, w, 1])
|
||||
ones = jnp.ones([1, 1, 1, scale, 1, scale], dtype=x.dtype)
|
||||
return jnp.reshape(x * ones, [b, c, scale * h, scale * w])
|
||||
elif data_format == "NHWC":
|
||||
b, h, w, c = x.shape
|
||||
x = jnp.reshape(x, [b, h, 1, w, 1, c])
|
||||
ones = jnp.ones([1, 1, scale, 1, scale, 1], dtype=x.dtype)
|
||||
return jnp.reshape(x * ones, [b, scale * h, scale * w, c])
|
||||
else:
|
||||
raise ValueError(f"Unrecognized data_format={data_format}.")
|
||||
|
||||
|
||||
def get_activation(arg: Union[_Activation, str]) -> _Activation:
|
||||
"""Returns an activation from provided string."""
|
||||
if isinstance(arg, str):
|
||||
# Try fetch in order - [this module, jax.nn, jax.numpy]
|
||||
if arg in globals():
|
||||
return globals()[arg]
|
||||
if hasattr(nn, arg):
|
||||
return getattr(nn, arg)
|
||||
elif hasattr(jnp, arg):
|
||||
return getattr(jnp, arg)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized activation with name {arg}.")
|
||||
if not callable(arg):
|
||||
raise ValueError(f"Expected a callable, but got {type(arg)}")
|
||||
return arg
|
||||
|
||||
|
||||
def merge_first_dims(x: jnp.ndarray, num_dims_to_merge: int = 2) -> jnp.ndarray:
|
||||
return x.reshape((-1,) + x.shape[num_dims_to_merge:])
|
||||
|
||||
|
||||
def extract_image(
|
||||
inputs: Union[jnp.ndarray, Mapping[str, jnp.ndarray]]
|
||||
) -> jnp.ndarray:
|
||||
"""Extracts a tensor with key `image` or `x_image` if it is a dict, otherwise returns the inputs."""
|
||||
if isinstance(inputs, dict):
|
||||
if "image" in inputs:
|
||||
return inputs["image"]
|
||||
else:
|
||||
return inputs["x_image"]
|
||||
elif isinstance(inputs, jnp.ndarray):
|
||||
return inputs
|
||||
raise NotImplementedError(f"Not implemented of inputs of type"
|
||||
f" {type(inputs)}.")
|
||||
|
||||
|
||||
def extract_gt_state(inputs: Any) -> jnp.ndarray:
|
||||
if isinstance(inputs, dict):
|
||||
return inputs["x"]
|
||||
elif not isinstance(inputs, jnp.ndarray):
|
||||
raise NotImplementedError(f"Not implemented of inputs of type"
|
||||
f" {type(inputs)}.")
|
||||
return inputs
|
||||
|
||||
|
||||
def reshape_latents_conv_to_flat(conv_latents, axis_n_to_keep=1):
|
||||
q, p = jnp.split(conv_latents, 2, axis=-1)
|
||||
q = jax.tree_map(lambda x: x.reshape(x.shape[:axis_n_to_keep] + (-1,)), q)
|
||||
p = jax.tree_map(lambda x: x.reshape(x.shape[:axis_n_to_keep] + (-1,)), p)
|
||||
flat_latents = jnp.concatenate([q, p], axis=-1)
|
||||
|
||||
return flat_latents
|
||||
|
||||
|
||||
def triu_matrix_from_v(x, ndim):
|
||||
assert x.shape[-1] == (ndim * (ndim + 1)) // 2
|
||||
matrix = jnp.zeros(x.shape[:-1] + (ndim, ndim))
|
||||
idx = jnp.triu_indices(ndim)
|
||||
index_update = lambda x, idx, y: x.at[idx].set(y)
|
||||
for _ in range(x.ndim - 1):
|
||||
index_update = jax.vmap(index_update, in_axes=(0, None, 0))
|
||||
return index_update(matrix, idx, x)
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key: str = "", sep: str = "_") -> Dict[str, Any]:
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = parent_key + sep + k if parent_key else k
|
||||
if isinstance(v, collections.MutableMapping):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
|
||||
def convert_to_pytype(target, reference):
|
||||
"""Makes target the same pytype as reference, by jax.tree_flatten."""
|
||||
_, pytree = jax.tree_flatten(reference)
|
||||
leaves, _ = jax.tree_flatten(target)
|
||||
return jax.tree_unflatten(pytree, leaves)
|
||||
|
||||
|
||||
def func_if_not_scalar(func):
|
||||
"""Makes a function that uses func only on non-scalar values."""
|
||||
@functools.wraps(func)
|
||||
def wrapped(array, axis=0):
|
||||
if array.ndim == 0:
|
||||
return array
|
||||
return func(array, axis=axis)
|
||||
return wrapped
|
||||
|
||||
|
||||
mean_if_not_scalar = func_if_not_scalar(jnp.mean)
|
||||
|
||||
|
||||
class MultiBatchAccumulator(object):
|
||||
"""Class for abstracting statistics accumulation over multiple batches."""
|
||||
|
||||
def __init__(self):
|
||||
self._obj = None
|
||||
self._obj_max = None
|
||||
self._obj_min = None
|
||||
self._num_samples = None
|
||||
|
||||
def add(self, averaged_values, num_samples):
|
||||
"""Adds an element to the moving average and the max."""
|
||||
if self._obj is None:
|
||||
self._obj_max = jax.tree_map(lambda y: y * 1.0, averaged_values)
|
||||
self._obj_min = jax.tree_map(lambda y: y * 1.0, averaged_values)
|
||||
self._obj = jax.tree_map(lambda y: y * num_samples, averaged_values)
|
||||
self._num_samples = num_samples
|
||||
else:
|
||||
self._obj_max = jax.tree_multimap(jnp.maximum, self._obj_max,
|
||||
averaged_values)
|
||||
self._obj_min = jax.tree_multimap(jnp.minimum, self._obj_min,
|
||||
averaged_values)
|
||||
self._obj = jax.tree_multimap(lambda x, y: x + y * num_samples, self._obj,
|
||||
averaged_values)
|
||||
self._num_samples += num_samples
|
||||
|
||||
def value(self):
|
||||
return jax.tree_map(lambda x: x / self._num_samples, self._obj)
|
||||
|
||||
def max(self):
|
||||
return jax.tree_map(float, self._obj_max)
|
||||
|
||||
def min(self):
|
||||
return jax.tree_map(float, self._obj_min)
|
||||
|
||||
def sum(self):
|
||||
return self._obj
|
||||
|
||||
|
||||
register_pytree_node(
|
||||
distrax.Normal,
|
||||
lambda instance: ([instance.loc, instance.scale], None),
|
||||
lambda _, args: distrax.Normal(*args)
|
||||
)
|
||||
|
||||
|
||||
def inner_product(x: Any, y: Any) -> jnp.ndarray:
|
||||
products = jax.tree_multimap(lambda x_, y_: jnp.sum(x_ * y_), x, y)
|
||||
return sum(jax.tree_leaves(products))
|
||||
|
||||
|
||||
get_first = utils.get_first
|
||||
bcast_local_devices = utils.bcast_local_devices
|
||||
py_prefetch = utils.py_prefetch
|
||||
p_split = jax.pmap(lambda x, num: list(jax.random.split(x, num)),
|
||||
static_broadcasted_argnums=1)
|
||||
|
||||
|
||||
def wrap_if_pmap(p_func):
|
||||
def p_func_if_pmap(obj, axis_name):
|
||||
try:
|
||||
core.axis_frame(axis_name)
|
||||
return p_func(obj, axis_name)
|
||||
except NameError:
|
||||
return obj
|
||||
return p_func_if_pmap
|
||||
|
||||
|
||||
pmean_if_pmap = wrap_if_pmap(lax.pmean)
|
||||
psum_if_pmap = wrap_if_pmap(lax.psum)
|
||||
Reference in New Issue
Block a user