diff --git a/README.md b/README.md index ba81b8b..130e476 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,7 @@ https://deepmind.com/research/publications/ * [REGAL: Transfer Learning for Fast Optimization of Computation Graphs](regal) * [Deep Ensembles: A Loss Landscape Perspective](ensemble_loss_landscape) * [Powerpropagation](powerpropagation) +* [Physics Inspired Models](physics_inspired_models) diff --git a/physics_inspired_models/README.md b/physics_inspired_models/README.md new file mode 100644 index 0000000..7b54c94 --- /dev/null +++ b/physics_inspired_models/README.md @@ -0,0 +1,59 @@ +# Implementation of multiple physics inspired models for modelling dynamics + +This repository contains an implementation of different physics inspired models +used in the papers: **SyMetric: Measuring the Quality of Learnt Hamiltonian +Dynamics Inferred from Vision** and **Which priors matter? Benchmarking models +for learning latent dynamics**. + + +## Contributing + +This is purely research code, provided with no further intentions of support or +any guarantees of backward compatibility. + + +## Installation + +All package requirements are listed in `requirements.txt`. +You will still need to download and setup the datasets from the +[DeepMind Hamiltonian Dynamics Suite] manually. + +```shell +git clone git@github.com:deepmind/deepmind-research.git +pip install -r ./deepmind_research/physics_inspired_models/requirements.txt +pip install ./deepmind_research/physics_inspired_models +pip install --upgrade "jax[XXX]" +``` + +where `XXX` is the correct type of accelerator that you have on your machine. +Note that if you are using a GPU you might need `XXX` to also include the +correct version of CUDA and cuDNN installed on your machine. +For more details please read [here](https://github.com/google/jax#installation). + +## Usage + +The file `jaxline_configs.py` contains all the configurations specifications for +the experiments in the two papers. To run an experiment, in addition to passing +the location of the configs file, you must provide extra arguments in the +following manner: + +`${name_of_configuration},${index_in_sweep},${dataset_name}` + +For example to run the second hyper-parameter configuration of the improved +Hamiltonian Generative Network (HGN++) on the mass-spring dataset you should +run in the command line (assuming that you are in the folder of the project): + +```shell +python3 jaxline_train.py \ + --config="jaxline_configs.py:sym_metric_hgn_plus_plus_sweep,1,toy_physics/mass_spring" \ + --jaxline_mode="train" \ + --logtostderr +``` + + +## Reference +**SyMetric: Measuring the Quality of Learnt Hamiltonian Dynamics Inferred from Vision** + +**Which priors matter? Benchmarking models for learning latent dynamics** + +[DeepMind Hamiltonian Dynamics Suite]: https://github.com/deepmind/dm_hamiltonian_dynamics_suite diff --git a/physics_inspired_models/__init__.py b/physics_inspired_models/__init__.py new file mode 100644 index 0000000..204a70c --- /dev/null +++ b/physics_inspired_models/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/physics_inspired_models/eval_metric.py b/physics_inspired_models/eval_metric.py new file mode 100644 index 0000000..db48fbe --- /dev/null +++ b/physics_inspired_models/eval_metric.py @@ -0,0 +1,353 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing model evaluation metric.""" +import _thread as thread +import sys +import threading +import time +import warnings +from absl import logging +import distrax + +import numpy as np +from sklearn import linear_model +from sklearn import model_selection +from sklearn import preprocessing + + +def quit_function(fn_name): + logging.error('%s took too long', fn_name) + sys.stderr.flush() + thread.interrupt_main() + + +def exit_after(s): + """Use as decorator to exit function after s seconds.""" + def outer(fn): + + def inner(*args, **kwargs): + timer = threading.Timer(s, quit_function, args=[fn.__name__]) + timer.start() + try: + result = fn(*args, **kwargs) + finally: + timer.cancel() + return result + + return inner + + return outer + + +@exit_after(400) +def do_grid_search(data_x_exp, data_y, clf, parameters, cv): + scoring_choice = 'explained_variance' + regressor = model_selection.GridSearchCV( + clf, parameters, cv=cv, refit=True, scoring=scoring_choice) + regressor.fit(data_x_exp, data_y) + return regressor + + +def symplectic_matrix(dim): + """Return anti-symmetric identity matrix of given dimensionality.""" + half_dims = int(dim/2) + eye = np.eye(half_dims) + zeros = np.zeros([half_dims, half_dims]) + top_rows = np.concatenate([zeros, - eye], axis=1) + bottom_rows = np.concatenate([eye, zeros], axis=1) + return np.concatenate([top_rows, bottom_rows], axis=0) + + +def create_latent_mask(z0, dist_std_threshold=0.5): + """Create mask based on informativeness of each latent dimension. + + For stochastic models those latent dimensions that are too close to the prior + are likely to be uninformative and can be ignored. + + Args: + z0: distribution or array of phase space + dist_std_threshold: informative latents have average inferred stds < + dist_std_threshold + + Returns: + latent_mask_final: boolean mask of the same dimensionality as z0 + """ + if isinstance(z0, distrax.Normal): + std_vals = np.mean(z0.variance(), axis=0) + elif isinstance(z0, distrax.Distribution): + raise NotImplementedError() + else: + # If the latent is deterministic, pass through all dimensions + return np.array([True]*z0.shape[-1]) + + tensor_shape = std_vals.shape + half_dims = int(tensor_shape[-1] / 2) + + std_vals_q = std_vals[:half_dims] + std_vals_p = std_vals[half_dims:] + + # Keep both q and corresponding p as either one is informative + informative_latents_inds = np.array([ + x for x in range(len(std_vals_q)) if + std_vals_q[x] < dist_std_threshold or std_vals_p[x] < dist_std_threshold + ]) + + if informative_latents_inds.shape[0] > 0: + latent_mask_final = np.zeros_like(std_vals_q) + latent_mask_final[informative_latents_inds] = 1 + latent_mask_final = np.concatenate([latent_mask_final, latent_mask_final]) + latent_mask_final = latent_mask_final == 1 + + return latent_mask_final + else: + return np.array([True]*tensor_shape[-1]) + + +def standardize_data(data): + """Applies the sklearn standardization to the data.""" + scaler = preprocessing.StandardScaler() + scaler.fit(data) + return scaler.transform(data) + + +def find_best_polynomial(data_x, data_y, max_poly_order, rsq_threshold, + max_dim_n=32, + alpha_sweep=None, + max_iter=1000, cv=2): + """Find minimal polynomial expansion that is sufficient to explain data using Lasso regression.""" + rsq = 0 + poly_order = 1 + + if not np.any(alpha_sweep): + alpha_sweep = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2] + + # Avoid a large polynomial expansion for large latent sizes + if data_x.shape[-1] > max_dim_n: + print(f'>WARNING! Data is too high dimensional at {data_x.shape[-1]}') + print('>WARNING! Setting max_poly_order = 1') + max_poly_order = 1 + + while rsq < rsq_threshold and poly_order <= max_poly_order: + time_start = time.perf_counter() + poly = preprocessing.PolynomialFeatures(poly_order, include_bias=False) + data_x_exp = poly.fit_transform(data_x) + time_end = time.perf_counter() + print( + f'Took {time_end-time_start}s to create polynomial features of order ' + f'{poly_order} and size {data_x_exp.shape[1]}.') + + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + time_start = time.perf_counter() + clf = linear_model.Lasso( + random_state=0, max_iter=max_iter, normalize=False, warm_start=False) + parameters = {'alpha': alpha_sweep} + try: + regressor = do_grid_search(data_x_exp, data_y, clf, parameters, cv) + time_end = time.perf_counter() + print(f'Took {time_end-time_start}s to do regression grid search.') + + # Get rsq results + time_start = time.perf_counter() + clf = linear_model.Lasso( + random_state=0, + alpha=regressor.best_params_['alpha'], + max_iter=max_iter, + normalize=False, + warm_start=False) + clf.fit(data_x_exp, data_y) + rsq = clf.score(data_x_exp, data_y) + time_end = time.perf_counter() + print(f'Took {time_end-time_start}s to get rsq results.') + + old_regressor = regressor + old_poly_order = poly_order + old_poly = poly + old_data_x_exp = data_x_exp + old_rsq = rsq + old_clf = clf + print(f'Polynomial of order {poly_order} with ' + f' alpha={regressor.best_params_} RSQ: {rsq}') + poly_order += 1 + + except KeyboardInterrupt: + time_end = time.perf_counter() + print(f'Timed out after {time_end-time_start}s of doing grid search.') + print(f'Continuing with previous poly_order={old_poly_order}...') + regressor = old_regressor + poly_order = old_poly_order + poly = old_poly + data_x_exp = old_data_x_exp + rsq = old_rsq + clf = old_clf + print(f'Polynomial of order {poly_order} with ' + f' alpha={regressor.best_params_} RSQ: {rsq}') + break + + return clf, poly, data_x_exp, rsq + + +def eval_monomial_grad(feature, x, w, grad_acc): + """Accumulates gradient from polynomial features and their weights.""" + features = feature.split(' ') + variable_indices = [] + grads = np.ones(len(features)) * w + for i, feature in enumerate(features): + name_and_power = feature.split('^') + if len(name_and_power) == 1: + name, power = name_and_power[0], 1 + else: + name, power = name_and_power + power = int(power) + var_index = int(name[1:]) + variable_indices.append(var_index) + new_prod = np.ones_like(grads) * (x[var_index] ** power) + # This needs a special case, for situation where x[index] = 0.0 + if power == 1: + new_prod[i] = 1.0 + else: + new_prod[i] = power * (x[var_index] ** (power - 1)) + grads = grads * new_prod + grad_acc[variable_indices] += grads + return grad_acc + + +def compute_jacobian_manual(x, polynomial_features, weight_matrix, tolerance): + """Computes the jacobian manually.""" + # Put together the equation for each output var + # polynomial_features = np.array(polynomial_obj.get_feature_names()) + weight_mask = np.abs(weight_matrix) > tolerance + weight_matrix = weight_mask * weight_matrix + jacobians = list() + for i in range(weight_matrix.shape[0]): + grad_accumulator = np.zeros_like(x) + for j, feature in enumerate(polynomial_features): + eval_monomial_grad(feature, x, weight_matrix[i, j], grad_accumulator) + jacobians.append(grad_accumulator) + return np.stack(jacobians) + + +def calculate_jacobian_prod(jacobian, noise_eps=1e-6): + """Calculates AA*, where A=JEJ^T and A*=JE^TJ^T, which should be I.""" + # Add noise as 0 in jacobian creates issues in calculations later + jacobian = jacobian + noise_eps + sym_matrix = symplectic_matrix(jacobian.shape[1]) + pred = np.matmul(jacobian, sym_matrix) + pred = np.matmul(pred, np.transpose(jacobian)) + + pred_t = np.matmul(jacobian, np.transpose(sym_matrix)) + pred_t = np.matmul(pred_t, np.transpose(jacobian)) + + pred_id = np.matmul(pred, pred_t) + + return pred_id + + +def normalise_jacobian_prods(jacobian_preds): + """Normalises Jacobians evaluated at various points by a constant.""" + stacked_preds = np.stack(jacobian_preds) + # For each attempt at estimating E, get the max term, and take their average + normalisation_factor = np.mean(np.max(np.abs(stacked_preds), axis=(1, 2))) + + if normalisation_factor != 0: + stacked_preds = stacked_preds/normalisation_factor + + return stacked_preds + + +def calculate_symetric_score( + gt_data, + model_data, + max_poly_order, + max_sym_score, + rsq_threshold, + sym_threshold, + evaluation_point_n, + trajectory_n=1, + weight_tolerance=1e-5, + alpha_sweep=None, + max_iter=1000, + cv=2): + """Finds minimal polynomial expansion to explain data using Lasso regression, gets the Jacobian of the mapping and calculates how symplectic the map is.""" + model_data = model_data[..., :gt_data.shape[0], :] + + # Fing polynomial expansion that explains enough variance in the gt data + print('Finding best polynomial expansion...') + time_start = time.perf_counter() + # Clean up model data to ensure it doesn't contain NaN, infinity + # or values too large for dtype('float32') + model_data = np.nan_to_num(model_data) + model_data = np.clip(model_data, -999999, 999999) + + clf, poly, model_data_exp, best_rsq = find_best_polynomial( + model_data, gt_data, max_poly_order, rsq_threshold, + 32, alpha_sweep, max_iter, cv) + time_end = time.perf_counter() + print(f'Took {time_end - time_start}s to find best polynomial.') + + # Calculate Symplecticity score + all_raw_scores = [] + features = np.array(poly.get_feature_names()) + + points_per_trajectory = int(len(gt_data) / trajectory_n) + for trajectory in range(trajectory_n): + random_data_inds = np.random.permutation( + range(points_per_trajectory))[:evaluation_point_n] + + jacobian_preds = [] + for point_ind in random_data_inds: + input_data_point = model_data[points_per_trajectory * trajectory + + point_ind] + time_start = time.perf_counter() + jacobian = compute_jacobian_manual(input_data_point, features, + clf.coef_, weight_tolerance) + pred = calculate_jacobian_prod(jacobian) + jacobian_preds.append(pred) + time_end = time.perf_counter() + print(f'Took {time_end - time_start}s to evaluate jacobian ' + f'around point {point_ind}.') + + # Normalise + normalised_jacobian_preds = normalise_jacobian_prods(jacobian_preds) + # The score is measured as the deviation from I + identity = np.eye(normalised_jacobian_preds.shape[-1]) + scores = np.mean(np.power(normalised_jacobian_preds - identity, 2), + axis=(1, 2)) + all_raw_scores.append(scores) + + sym_score = np.min([np.mean(all_raw_scores), max_sym_score]) + # Calculate final SyMetric score + if best_rsq > rsq_threshold and sym_score < sym_threshold: + sy_metric = 1.0 + else: + sy_metric = 0.0 + + results = { + 'poly_exp_order': poly.get_params()['degree'], + 'rsq': best_rsq, + 'sym': sym_score, + 'SyMetric': sy_metric, + } + with np.printoptions(precision=4, suppress=True): + print(f'----------------FINAL RESULTS FOR {trajectory_n} ' + 'TRAJECTORIES------------------') + print(f'BEST POLYNOMIAL EXPANSION ORDER: {results["poly_exp_order"]}') + print(f'BEST RSQ (1-best): {results["rsq"]}') + print(f'SYMPLECTICITY SCORE AROUND ALL POINTS AND ALL ' + f'TRAJECTORIES (0-best): {sym_score}') + print(f'SyMETRIC SCORE: {sy_metric}') + print(f'----------------FINAL RESULTS FOR {trajectory_n} ' + f'TRAJECTORIES------------------') + return results, clf, poly, model_data_exp diff --git a/physics_inspired_models/integrators.py b/physics_inspired_models/integrators.py new file mode 100644 index 0000000..c41719f --- /dev/null +++ b/physics_inspired_models/integrators.py @@ -0,0 +1,1041 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing the implementations of the various numerical integrators. + +Higher order methods mostly taken from [1]. + +References: + [1] Leimkuhler, Benedict and Sebastian Reich. Simulating hamiltonian dynamics. + Vol. 14. Cambridge university press, 2004. + [2] Forest, Etienne and Ronald D. Ruth. Fourth-order symplectic integration. + Physica D: Nonlinear Phenomena 43.1 (1990): 105-117. + [3] Blanes, Sergio and Per Christian Moan. Practical symplectic partitioned + Runge–Kutta and Runge–Kutta–Nyström methods. Journal of Computational and + Applied Mathematics 142.2 (2002): 313-330. + [4] McLachlan, Robert I. On the numerical integration of ordinary differential + equations by symmetric composition methods. SIAM Journal on Scientific + Computing 16.1 (1995): 151-168. + [5] Yoshida, Haruo. Construction of higher order symplectic integrators. + Physics letters A 150.5-7 (1990): 262-268. + [6] Süli, Endre; Mayers, David (2003), An Introduction to Numerical Analysis, + Cambridge University Press, ISBN 0-521-00794-1. + [7] Hairer, Ernst; Nørsett, Syvert Paul; Wanner, Gerhard (1993), Solving + ordinary differential equations I: Nonstiff problems, Berlin, New York: + Springer-Verlag, ISBN 978-3-540-56670-0. +""" +from typing import Callable, Dict, Optional, Sequence, Tuple, TypeVar, Union + +from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space +import jax +from jax import lax +from jax.experimental import ode +import jax.numpy as jnp +import numpy as np + +M = TypeVar("M") +TM = TypeVar("TM") +TimeInterval = Union[jnp.ndarray, Tuple[float, float]] + +# _____ _ +# / ____| | | +# | | __ ___ _ __ ___ _ __ __ _| | +# | | |_ |/ _ \ '_ \ / _ \ '__/ _` | | +# | |__| | __/ | | | __/ | | (_| | | +# \_____|\___|_| |_|\___|_| \__,_|_| +# _____ _ _ _ +# |_ _| | | | | (_) +# | | _ __ | |_ ___ __ _ _ __ __ _| |_ _ ___ _ __ +# | | | '_ \| __/ _ \/ _` | '__/ _` | __| |/ _ \| '_ \ +# _| |_| | | | || __/ (_| | | | (_| | |_| | (_) | | | | +# |_____|_| |_|\__\___|\__, |_| \__,_|\__|_|\___/|_| |_| +# __/ | +# |___/ + + +GeneralTangentFunction = Callable[ + [ + Optional[Union[float, jnp.ndarray]], # t + M # y + ], + TM # dy_dt +] + +GeneralIntegrator = Callable[ + [ + GeneralTangentFunction, + Optional[Union[float, jnp.ndarray]], # t + M, # y + jnp.ndarray, # dt + ], + M # y_next +] + + +def solve_ivp_dt( + fun: GeneralTangentFunction, + y0: M, + t0: Union[float, jnp.ndarray], + dt: Union[float, jnp.ndarray], + method: Union[str, GeneralIntegrator], + num_steps: Optional[int] = None, + steps_per_dt: int = 1, + use_scan: bool = True, + ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None +) -> Tuple[jnp.ndarray, M]: + """Solve an initial value problem for a system of ODEs using explicit method. + + This function numerically integrates a system of ordinary differential + equations given an initial value:: + dy / dt = f(t, y) + y(t0) = y0 + Here t is a one-dimensional independent variable (time), y(t) is an + n-dimensional vector-valued function (state), and an n-dimensional + vector-valued function f(t, y) determines the differential equations. + The goal is to find y(t) approximately satisfying the differential + equations, given an initial value y(t0)=y0. + + All of the solvers supported here are explicit and non-adaptive. This makes + them easy to run with a fixed amount of computation and ensures solutions are + easily differentiable. + + Args: + fun: callable + Right-hand side of the system. The calling signature is ``fun(t, y)``. + Here `t` is a scalar representing the time instance. `y` can be any + type `M`, including a flat array, that is registered as a + pytree. In addition, there is a type denoted as `TM` that represents + the tangent space to `M`. It is assumed that any element of `TM` can be + multiplied by arrays and scalars, can be added to other `TM` instances + as well as they can be right added to an element of `M`, that is + add(M, TM) exists. The function should return an element of `TM` that + defines the time derivative of `y`. + y0: an instance of `M` + Initial state at `t_span[0]`. + t0: float or array. + The initial time point of integration. + dt: array + Array containing all consecutive increments in time, at which the integral + to be evaluated. The size of this array along axis 0 defines the number of + steps that the integrator would do. + method: string or `GeneralIntegrator` + The integrator method to use. Possible values for string are: + * general_euler - see `GeneralEuler` + * rk2 - see `RungaKutta2` + * rk4 - see `RungaKutta4` + * rk38 - see `RungaKutta38` + num_steps: Optional int. + If provided the `dt` will be treated as the same per step time interval, + applied for this many steps. In other words setting this argument is + equivalent to replicating `dt` num_steps times and stacking over axis=0. + steps_per_dt: int + This determines the overall step size. Between any two values of t_eval + the step size is `dt = (t_eval[i+1] - t_eval[i]) / steps_per_dt. + use_scan: bool + Whether for the loop to use `lax.scan` or a python loop + ode_int_kwargs: dict + Extra arguments to be passed to `ode.odeint` when method="adaptive" + + Returns: + t: array + Time points at which the solution is evaluated. + y : an instance of M + Values of the solution at `t`. + """ + if method == "adaptive": + ndim = y0.q.ndim if isinstance(y0, phase_space.PhaseSpace) else y0.ndim + signs = jnp.asarray(jnp.sign(dt)) + signs = signs.reshape([-1] + [1] * (ndim - 1)) + if isinstance(dt, float) or dt.ndim == 0: + true_t_eval = t0 + dt * np.arange(1, num_steps + 1) + else: + true_t_eval = t0 + dt[None] * np.arange(1, num_steps + 1)[:, None] + if isinstance(dt, float): + dt = np.asarray(dt) + if isinstance(dt, np.ndarray) and dt.ndim > 0: + if np.all(np.abs(dt) != np.abs(dt[0])): + raise ValueError("Not all values of `dt` where the same.") + elif isinstance(dt, jnp.ndarray) and dt.ndim > 0: + raise ValueError("The code here works only when `dy_dt` is time " + "independent and `np.abs(dt)` is the same. For this we " + "allow calling this only with numpy (not jax.numpy) " + "arrays.") + dt: jnp.ndarray = jnp.abs(jnp.asarray(dt)) + dt = dt.reshape([-1])[0] + t_eval = t0 + dt * np.arange(num_steps + 1) + + outputs = ode.odeint( + func=lambda y_, t_: fun(None, y_) * signs, + y0=y0, + t=jnp.abs(t_eval - t0), + **(ode_int_kwargs or dict()) + ) + # Note that we do not return the initial point + return true_t_eval, jax.tree_map(lambda x: x[1:], outputs) + + method = get_integrator(method) + if num_steps is not None: + dt = jnp.repeat(jnp.asarray(dt)[None], repeats=num_steps, axis=0) + t_eval = t0 + jnp.cumsum(dt, axis=0) + t0 = jnp.ones_like(t_eval[..., :1]) * t0 + t = jnp.concatenate([t0, t_eval[..., :-1]], axis=-1) + def loop_body(y_: M, t_dt: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[M, M]: + t_, dt_ = t_dt + dt_: jnp.ndarray = dt_ / steps_per_dt + for _ in range(steps_per_dt): + y_ = method(fun, t_, y_, dt_) + t_ = t_ + dt_ + return y_, y_ + if use_scan: + return t_eval, lax.scan(loop_body, init=y0, xs=(t, dt))[1] + else: + y = [y0] + for t_and_dt_i in zip(t, dt): + y.append(loop_body(y[-1], t_and_dt_i)[0]) + # Note that we do not return the initial point + return t_eval, jax.tree_multimap(lambda *args: jnp.stack(args, axis=0), + *y[1:]) + + +def solve_ivp_dt_two_directions( + fun: GeneralTangentFunction, + y0: M, + t0: Union[float, jnp.ndarray], + dt: Union[float, jnp.ndarray], + method: Union[str, GeneralIntegrator], + num_steps_forward: int, + num_steps_backward: int, + include_y0: bool = True, + steps_per_dt: int = 1, + use_scan: bool = True, + ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None +) -> M: + """Equivalent to `solve_ivp_dt` but you can specify unrolling the problem for a fixed number of steps in both time directions.""" + yt = [] + if num_steps_backward > 0: + yt_bck = solve_ivp_dt( + fun=fun, + y0=y0, + t0=t0, + dt=- dt, + method=method, + num_steps=num_steps_backward, + steps_per_dt=steps_per_dt, + use_scan=use_scan, + ode_int_kwargs=ode_int_kwargs + )[1] + yt.append(jax.tree_map(lambda x: jnp.flip(x, axis=0), yt_bck)) + if include_y0: + yt.append(jax.tree_map(lambda x: x[None], y0)) + if num_steps_forward > 0: + yt_fwd = solve_ivp_dt( + fun=fun, + y0=y0, + t0=t0, + dt=dt, + method=method, + num_steps=num_steps_forward, + steps_per_dt=steps_per_dt, + use_scan=use_scan, + ode_int_kwargs=ode_int_kwargs + )[1] + yt.append(yt_fwd) + if len(yt) > 1: + return jax.tree_multimap(lambda *a: jnp.concatenate(a, axis=0), *yt) + else: + return yt[0] + + +def solve_ivp_t_eval( + fun: GeneralTangentFunction, + t_span: TimeInterval, + y0: M, + method: Union[str, GeneralIntegrator], + t_eval: Optional[jnp.ndarray] = None, + steps_per_dt: int = 1, + use_scan: bool = True, + ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None +) -> Tuple[jnp.ndarray, M]: + """Solve an initial value problem for a system of ODEs using an explicit method. + + This function numerically integrates a system of ordinary differential + equations given an initial value:: + dy / dt = f(t, y) + y(t0) = y0 + Here t is a one-dimensional independent variable (time), y(t) is an + n-dimensional vector-valued function (state), and an n-dimensional + vector-valued function f(t, y) determines the differential equations. + The goal is to find y(t) approximately satisfying the differential + equations, given an initial value y(t0)=y0. + + All of the solvers supported here are explicit and non-adaptive. This in + terms makes them easy to run with fixed amount of computation and + the solutions to be easily differentiable. + + Args: + fun: callable + Right-hand side of the system. The calling signature is ``fun(t, y)``. + Here `t` is a scalar representing the time instance. `y` can be any + type `M`, including a flat array, that is registered as a + pytree. In addition, there is a type denoted as `TM` that represents + the tangent space to `M`. It is assumed that any element of `TM` can be + multiplied by arrays and scalars, can be added to other `TM` instances + as well as they can be right added to an element of `M`, that is + add(M, TM) exists. The function should return an element of `TM` that + defines the time derivative of `y`. + t_span: 2-tuple of floats + Interval of integration (t0, tf). The solver starts with t=t0 and + integrates until it reaches t=tf. + y0: an instance of `M` + Initial state at `t_span[0]`. + method: string or `GeneralIntegrator` + The integrator method to use. Possible values for string are: + * general_euler - see `GeneralEuler` + * rk2 - see `RungaKutta2` + * rk4 - see `RungaKutta4` + * rk38 - see `RungaKutta38` + t_eval: array or None. + Times at which to store the computed solution. Must be sorted and lie + within `t_span`. If None then t_eval = [t_span[-1]] + steps_per_dt: int + This determines the overall step size. Between any two values of t_eval + the step size is `dt = (t_eval[i+1] - t_eval[i]) / steps_per_dt. + use_scan: bool + Whether for the loop to use `lax.scan` or a python loop + ode_int_kwargs: dict + Extra arguments to be passed to `ode.odeint` when method="adaptive" + + Returns: + t: array + Time points at which the solution is evaluated. + y : an instance of M + Values of the solution at `t`. + """ + # Check for t_eval + if t_eval is None: + t_eval = np.asarray([t_span[-1]]) + if isinstance(t_span[0], float) and isinstance(t_span[1], float): + t_span = np.asarray(t_span) + elif isinstance(t_span[0], float) and isinstance(t_span[1], jnp.ndarray): + t_span = (np.full_like(t_span[1], t_span[0]), t_span[1]) + t_span = np.stack(t_span, axis=0) + elif isinstance(t_span[1], float) and isinstance(t_span[0], jnp.ndarray): + t_span = (t_span[0], jnp.full_like(t_span[0], t_span[1])) + t_span = np.stack(t_span, axis=0) + else: + t_span = np.stack(t_span, axis=0) + def check_span(span, ts): + # Verify t_span and t_eval + if span[0] < span[1]: + # Forward in time + if not np.all(np.logical_and(span[0] <= ts, ts <= span[1])): + raise ValueError("Values in `t_eval` are not within `t_span`.") + if not np.all(ts[:-1] < ts[1:]): + raise ValueError("Values in `t_eval` are not properly sorted.") + else: + # Backward in time + if not np.all(np.logical_and(span[0] >= ts, ts >= span[1])): + raise ValueError("Values in `t_eval` are not within `t_span`.") + if not np.all(ts[:-1] > ts[1:]): + raise ValueError("Values in `t_eval` are not properly sorted.") + if t_span.ndim == 1: + check_span(t_span, t_eval) + elif t_span.ndim == 2: + if t_eval.ndim != 2: + raise ValueError("t_eval should have rank 2.") + for i in range(t_span.shape[1]): + check_span(t_span[:, i], t_eval[:, i]) + + t = np.concatenate([t_span[:1], t_eval[:-1]], axis=0) + + return solve_ivp_dt( + fun=fun, + y0=y0, + t0=t_span[0], + dt=t_eval - t, + method=method, + steps_per_dt=steps_per_dt, + use_scan=use_scan, + ode_int_kwargs=ode_int_kwargs + ) + + +class RungaKutta(GeneralIntegrator): + """A general Runga-Kutta integrator defined using a Butcher tableau.""" + + def __init__( + self, + a_tableau: Sequence[Sequence[float]], + b_tableau: Sequence[float], + c_tableau: Sequence[float], + order: int): + if len(b_tableau) != len(c_tableau) + 1: + raise ValueError("The length of b_tableau should be exactly one more than" + " the length of c_tableau.") + if len(b_tableau) != len(a_tableau) + 1: + raise ValueError("The length of b_tableau should be exactly one more than" + " the length of a_tableau.") + self.a_tableau = a_tableau + self.b_tableau = b_tableau + self.c_tableau = c_tableau + self.order = order + + def __call__( + self, + tangent_func: GeneralTangentFunction, + t: jnp.ndarray, + y: M, + dt: jnp.ndarray + ) -> M: # pytype: disable=invalid-annotation + k = [tangent_func(t, y)] + zero = jax.tree_map(jnp.zeros_like, k[0]) + # We always broadcast opposite to numpy (e.g. leading dims (batch) count) + if dt.ndim > 0: + dt = dt.reshape(dt.shape + (1,) * (y.ndim - dt.ndim)) + if t.ndim > 0: + t = t.reshape(t.shape + (1,) * (y.ndim - t.ndim)) + for c_n, a_n_row in zip(self.c_tableau, self.a_tableau): + t_n = t + dt * c_n + products = [a_i * k_i for a_i, k_i in zip(a_n_row, k) if a_i != 0.0] + delta_n = sum(products, zero) + y_n = y + dt * delta_n + k.append(tangent_func(t_n, y_n)) + products = [b_i * k_i for b_i, k_i in zip(self.b_tableau, k) if b_i != 0.0] + delta = sum(products, zero) + return y + dt * delta + + +class GeneralEuler(RungaKutta): + """The standard Euler method (for general ODE problems).""" + + def __init__(self): + super().__init__( + a_tableau=[], + b_tableau=[1.0], + c_tableau=[], + order=1 + ) + + +class RungaKutta2(RungaKutta): + """The second order Runga-Kutta method corresponding to the mid-point rule.""" + + def __init__(self): + super().__init__( + a_tableau=[[1.0 / 2.0]], + b_tableau=[0.0, 1.0], + c_tableau=[1.0 / 2.0], + order=2 + ) + + +class RungaKutta4(RungaKutta): + """The fourth order Runga-Kutta method from [6].""" + + def __init__(self): + super().__init__( + a_tableau=[[1.0 / 2.0], + [0.0, 1.0 / 2.0], + [0.0, 0.0, 1.0]], + b_tableau=[1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0], + c_tableau=[1.0 / 2.0, 1.0 / 2.0, 1.0], + order=4 + ) + + +class RungaKutta38(RungaKutta): + """The fourth order 3/8 rule Runga-Kutta method from [7].""" + + def __init__(self): + super().__init__( + a_tableau=[[1.0 / 3.0], + [-1.0 / 3.0, 1.0], + [1.0, -1.0, 1.0]], + b_tableau=[1.0 / 8.0, 3.0 / 8.0, 3.0 / 8.0, 1.0 / 8.0], + c_tableau=[1.0 / 3.0, 2.0 / 3.0, 1.0], + order=4 + ) + + +# _____ _ _ _ +# / ____| | | | | (_) +# | (___ _ _ _ __ ___ _ __ | | ___ ___| |_ _ ___ +# \___ \| | | | '_ ` _ \| '_ \| |/ _ \/ __| __| |/ __| +# ____) | |_| | | | | | | |_) | | __/ (__| |_| | (__ +# |_____/ \__, |_| |_| |_| .__/|_|\___|\___|\__|_|\___| +# __/ | | | +# |___/ |_| +# _____ _ _ _ +# |_ _| | | | | (_) +# | | _ __ | |_ ___ __ _ _ __ __ _| |_ _ ___ _ __ +# | | | '_ \| __/ _ \/ _` | '__/ _` | __| |/ _ \| '_ \ +# _| |_| | | | || __/ (_| | | | (_| | |_| | (_) | | | | +# |_____|_| |_|\__\___|\__, |_| \__,_|\__|_|\___/|_| |_| +# __/ | +# |___/ + + +SymplecticIntegrator = Callable[ + [ + phase_space.SymplecticTangentFunction, + jnp.ndarray, # t + phase_space.PhaseSpace, # (q, p) + jnp.ndarray, # dt + ], + phase_space.PhaseSpace # (q_next, p_next) +] + + +def solve_hamiltonian_ivp_dt( + hamiltonian: phase_space.HamiltonianFunction, + y0: phase_space.PhaseSpace, + t0: Union[float, jnp.ndarray], + dt: Union[float, jnp.ndarray], + method: Union[str, SymplecticIntegrator], + num_steps: Optional[int] = None, + steps_per_dt: int = 1, + use_scan: bool = True, + ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None +) -> Tuple[jnp.ndarray, phase_space.PhaseSpace]: + """Solve an initial value problem for a Hamiltonian system. + + This function numerically integrates a Hamiltonian system given an + initial value:: + dq / dt = dH / dp + dp / dt = - dH / dq + q(t0), p(t0) = y0.q, y0.p + Here t is a one-dimensional independent variable (time), y(t) is an + n-dimensional vector-valued function (state), and an n-dimensional + vector-valued function H(t, q, p) determines the value of the Hamiltonian. + The goal is to find q(t) and p(t) approximately satisfying the differential + equations, given an initial values q(t0), p(t0) = y0.q, y0.p + + All of the solvers supported here are explicit and non-adaptive. This in + terms makes them easy to run with fixed amount of computation and + the solutions to be easily differentiable. + + Args: + hamiltonian: callable + The Hamiltonian function. The calling signature is ``h(t, s)``, where + `s` is an instance of `PhaseSpace`. + y0: an instance of `M` + Initial state at t=t0. + t0: float or array. + The initial time point of integration. + dt: array + Array containing all consecutive increments in time, at which the integral + to be evaluated. The size of this array along axis 0 defines the number of + steps that the integrator would do. + method: string or `GeneralIntegrator` + The integrator method to use. Possible values for string are: + * symp_euler - see `SymplecticEuler` + * symp_euler_q - a `SymplecticEuler` with position_first=True + * symp_euler_p - a `SymplecticEuler` with position_first=False + * leap_frog - see `LeapFrog` + * leap_frog_q - a `LeapFrog` with position_first=True + * leap_frog_p - a `LeapFrog` with position_first=False + * stormer_verlet - same as leap_frog + * stormer_verlet_q - same as leap_frog_q + * stormer_verlet_p - same as leap_frog_p + * ruth4 - see `Ruth4`, + * sym4 - see `Symmetric4` + * sym6 - see `Symmetric6` + * so4 - see `SymmetricSo4` + * so4_q - a `SymmetricSo4` with position_first=True + * so4_p - a `SymmetricSo4` with position_first=False + * so6 - see `SymmetricSo6` + * so6_q - a `SymmetricSo6` with position_first=True + * so6_p - a `SymmetricSo6` with position_first=False + * so8 - see `SymmetricSo8` + * so8_q - a `SymmetricSo8` with position_first=True + * so8_p - a `SymmetricSo8` with position_first=False + num_steps: Optional int. + If provided the `dt` will be treated as the same per step time interval, + applied for this many steps. In other words setting this argument is + equivalent to replicating `dt` num_steps times and stacking over axis=0. + steps_per_dt: int + This determines the overall step size. Between any two values of t_eval + the step size is `dt = (t_eval[i+1] - t_eval[i]) / steps_per_dt. + use_scan: bool + Whether for the loop to use `lax.scan` or a python loop + ode_int_kwargs: dict + Extra arguments to be passed to `ode.odeint` when method="adaptive" + + Returns: + t: array + Time points at which the solution is evaluated. + y : an instance of M + Values of the solution at `t`. + """ + if not isinstance(y0, phase_space.PhaseSpace): + raise ValueError("The initial state must be an instance of `PhaseSpace`.") + dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian) + + return solve_ivp_dt( + fun=dy_dt, + y0=y0, + t0=t0, + dt=dt, + method=method, + num_steps=num_steps, + steps_per_dt=steps_per_dt, + use_scan=use_scan, + ode_int_kwargs=ode_int_kwargs + ) + + +def solve_hamiltonian_ivp_t_eval( + hamiltonian: phase_space.HamiltonianFunction, + t_span: TimeInterval, + y0: phase_space.PhaseSpace, + method: Union[str, SymplecticIntegrator], + t_eval: Optional[jnp.ndarray] = None, + steps_per_dt: int = 1, + use_scan: bool = True, + ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None +) -> Tuple[jnp.ndarray, phase_space.PhaseSpace]: + """Solve an initial value problem for a Hamiltonian system. + + This function numerically integrates a Hamiltonian system given an + initial value:: + dq / dt = dH / dp + dp / dt = - dH / dq + q(t0), p(t0) = y0.q, y0.p + Here t is a one-dimensional independent variable (time), y(t) is an + n-dimensional vector-valued function (state), and an n-dimensional + vector-valued function H(t, q, p) determines the value of the Hamiltonian. + The goal is to find q(t) and p(t) approximately satisfying the differential + equations, given an initial values q(t0), p(t0) = y0.q, y0.p + + All of the solvers supported here are explicit and non-adaptive. This in + terms makes them easy to run with fixed amount of computation and + the solutions to be easily differentiable. + + Args: + hamiltonian: callable + The Hamiltonian function. The calling signature is ``h(t, s)``, where + `s` is an instance of `PhaseSpace`. + t_span: 2-tuple of floats + Interval of integration (t0, tf). The solver starts with t=t0 and + integrates until it reaches t=tf. + y0: an instance of `M` + Initial state at `t_span[0]`. + method: string or `GeneralIntegrator` + The integrator method to use. Possible values for string are: + * symp_euler - see `SymplecticEuler` + * symp_euler_q - a `SymplecticEuler` with position_first=True + * symp_euler_p - a `SymplecticEuler` with position_first=False + * leap_frog - see `LeapFrog` + * leap_frog_q - a `LeapFrog` with position_first=True + * leap_frog_p - a `LeapFrog` with position_first=False + * stormer_verlet - same as leap_frog + * stormer_verlet_q - same as leap_frog_q + * stormer_verlet_p - same as leap_frog_p + * ruth4 - see `Ruth4`, + * sym4 - see `Symmetric4` + * sym6 - see `Symmetric6` + * so4 - see `SymmetricSo4` + * so4_q - a `SymmetricSo4` with position_first=True + * so4_p - a `SymmetricSo4` with position_first=False + * so6 - see `SymmetricSo6` + * so6_q - a `SymmetricSo6` with position_first=True + * so6_p - a `SymmetricSo6` with position_first=False + * so8 - see `SymmetricSo8` + * so8_q - a `SymmetricSo8` with position_first=True + * so8_p - a `SymmetricSo8` with position_first=False + t_eval: array or None. + Times at which to store the computed solution. Must be sorted and lie + within `t_span`. If None then t_eval = [t_span[-1]] + steps_per_dt: int + This determines the overall step size. Between any two values of t_eval + the step size is `dt = (t_eval[i+1] - t_eval[i]) / steps_per_dt. + use_scan: bool + Whether for the loop to use `lax.scan` or a python loop + ode_int_kwargs: dict + Extra argumrnts to be passed to `ode.odeint` when method="adaptive" + + Returns: + t: array + Time points at which the solution is evaluated. + y : an instance of M + Values of the solution at `t`. + """ + if not isinstance(y0, phase_space.PhaseSpace): + raise ValueError("The initial state must be an instance of `PhaseSpace`.") + dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian) + if method == "adaptive": + dy_dt = phase_space.transform_symplectic_tangent_function_using_array(dy_dt) + + return solve_ivp_t_eval( + fun=dy_dt, + t_span=t_span, + y0=y0, + method=method, + t_eval=t_eval, + steps_per_dt=steps_per_dt, + use_scan=use_scan, + ode_int_kwargs=ode_int_kwargs + ) + + +class CompositionSymplectic(SymplecticIntegrator): + """A generalized symplectic integrator based on compositions. + + Simulates Hamiltonian dynamics using a composition of symplectic steps: + q_{0} = q_init, p_{0} = p_init + for i in [1, n]: + p_{i+1} = p_{i} - c_{i} * dH/dq(q_{i}) * dt + q_{i+1} = q_{i} + d_{i} * dH/dp(p_{i+1}) * dt + q_next = q_{n}, p_next = p_{n} + + This integrator always starts with updating the momentum. + The order argument is used mainly for testing to estimate the error when + integrating various systems. + """ + + def __init__( + self, + momentum_coefficients: Sequence[float], + position_coefficients: Sequence[float], + order: int): + if len(position_coefficients) != len(momentum_coefficients): + raise ValueError("The number of momentum_coefficients and " + "position_coefficients must be the same.") + if not np.allclose(sum(position_coefficients), 1.0): + raise ValueError("The sum of the position_coefficients " + "must be equal to 1.") + if not np.allclose(sum(momentum_coefficients), 1.0): + raise ValueError("The sum of the momentum_coefficients " + "must be equal to 1.") + self.momentum_coefficients = momentum_coefficients + self.position_coefficients = position_coefficients + self.order = order + + def __call__( + self, + tangent_func: phase_space.SymplecticTangentFunction, + t: jnp.ndarray, + y: phase_space.PhaseSpace, + dt: jnp.ndarray + ) -> phase_space.PhaseSpace: + q, p = y.q, y.p + # This is intentional to prevent a bug where one uses y later + del y + # We always broadcast opposite to numpy (e.g. leading dims (batch) count) + if dt.ndim > 0: + dt = dt.reshape(dt.shape + (1,) * (q.ndim - dt.ndim)) + if t.ndim > 0: + t = t.reshape(t.shape + (1,) * (q.ndim - t.ndim)) + t_q = t + t_p = t + for c, d in zip(self.momentum_coefficients, self.position_coefficients): + # Update momentum + if c != 0.0: + dp_dt = tangent_func(t_p, phase_space.PhaseSpace(q, p)).p + p = p + c * dt * dp_dt + t_p = t_p + c * dt + # Update position + if d != 0.0: + dq_dt = tangent_func(t_q, phase_space.PhaseSpace(q, p)).q + q = q + d * dt * dq_dt + t_q = t_q + d * dt + return phase_space.PhaseSpace(position=q, momentum=p) + + +class SymplecticEuler(CompositionSymplectic): + """The symplectic Euler method (for Hamiltonian systems). + + If position_first = True: + q_{t+1} = q_{t} + dH/dp(p_{t}) * dt + p_{t+1} = p_{t} - dH/dq(q_{t+1}) * dt + else: + p_{t+1} = p_{t} - dH/dq(q_{t}) * dt + q_{t+1} = q_{t} + dH/dp(p_{t+1}) * dt + """ + + def __init__(self, position_first=True): + if position_first: + super().__init__( + momentum_coefficients=[0.0, 1.0], + position_coefficients=[1.0, 0.0], + order=1 + ) + else: + super().__init__( + momentum_coefficients=[1.0], + position_coefficients=[1.0], + order=1 + ) + + +class SymmetricCompositionSymplectic(CompositionSymplectic): + """A generalized composition integrator that is symmetric. + + The integrators produced are always of the form: + [update_q, update_p, ..., update_p, update_q] + or + [update_p, update_q, ..., update_q, update_p] + based on the position_first argument. The method will expect which ever is + updated first to have one more coefficient. + """ + + def __init__( + self, + momentum_coefficients: Sequence[float], + position_coefficients: Sequence[float], + position_first: bool, + order: int): + position_coefficients = list(position_coefficients) + momentum_coefficients = list(momentum_coefficients) + if position_first: + if len(position_coefficients) != len(momentum_coefficients) + 1: + raise ValueError("The number of position_coefficients must be one more " + "than momentum_coefficients when position_first=True.") + momentum_coefficients = [0.0] + momentum_coefficients + else: + if len(position_coefficients) + 1 != len(momentum_coefficients): + raise ValueError("The number of momentum_coefficients must be one more " + "than position_coefficients when position_first=True.") + position_coefficients = position_coefficients + [0.0] + super().__init__( + position_coefficients=position_coefficients, + momentum_coefficients=momentum_coefficients, + order=order + ) + + +def symmetrize_coefficients( + coefficients: Sequence[float], + odd_number: bool +) -> Sequence[float]: + """Symmetrizes the coefficients for an integrator.""" + coefficients = list(coefficients) + if odd_number: + final = 1.0 - 2.0 * sum(coefficients) + return coefficients + [final] + coefficients[::-1] + else: + final = 0.5 - sum(coefficients) + return coefficients + [final, final] + coefficients[::-1] + + +class LeapFrog(SymmetricCompositionSymplectic): + """The standard Leap-Frog method (also known as Stormer-Verlet). + + If position_first = True: + q_half = q_{t} + dH/dp(p_{t}) * dt / 2 + p_{t+1} = p_{t} - dH/dq(q_half) * dt + q_{t+1} = q_half + dH/dp(p_{t+1}) * dt / 2 + else: + p_half = p_{t} - dH/dq(q_{t}) * dt / 2 + q_{t+1} = q_{t} + dH/dp(p_half) * dt + p_{t+1} = p_half - dH/dq(q_{t+1}) * dt / 2 + """ + + def __init__(self, position_first=False): + if position_first: + super().__init__( + position_coefficients=[0.5, 0.5], + momentum_coefficients=[1.0], + position_first=True, + order=2 + ) + else: + super().__init__( + position_coefficients=[1.0], + momentum_coefficients=[0.5, 0.5], + position_first=False, + order=2 + ) + + +class Ruth4(SymmetricCompositionSymplectic): + """The Fourth order method from [2].""" + + def __init__(self): + cbrt_2 = float(np.cbrt(2.0)) + + c = [1.0 / (2.0 - cbrt_2)] + # 3: [c1, 1.0 - 2*c1, c1] + c = symmetrize_coefficients(c, odd_number=True) + + d = [1.0 / (4.0 - 2.0 * cbrt_2)] + # 4: [d1, 0.5 - d1, 0.5 - d1, d1] + d = symmetrize_coefficients(d, odd_number=False) + + super().__init__( + position_coefficients=d, + momentum_coefficients=c, + position_first=True, + order=4 + ) + + +class Symmetric4(SymmetricCompositionSymplectic): + """The fourth order method from Table 6.1 in [1] (originally from [3]).""" + + def __init__(self): + c = [0.0792036964311957, 0.353172906049774, -0.0420650803577195] + # 7 : [c1, c2, c3, 1.0 - c1 - c2 - c3, c3, c2, c1] + c = symmetrize_coefficients(c, odd_number=True) + + d = [0.209515106613362, -0.143851773179818] + # 6: [d1, d2, 0.5 - d1, 0.5 - d1, d2, d1] + d = symmetrize_coefficients(d, odd_number=False) + + super().__init__( + position_coefficients=d, + momentum_coefficients=c, + position_first=False, + order=4 + ) + + +class Symmetric6(SymmetricCompositionSymplectic): + """The sixth order method from Table 6.1 in [1] (originally from [3]).""" + + def __init__(self): + c = [0.0502627644003922, 0.413514300428344, 0.0450798897943977, + -0.188054853819569, 0.541960678450780] + # 11 : [c1, c2, c3, c4, c5, 1.0 - sum(ci), c5, c4, c3, c2, c1] + c = symmetrize_coefficients(c, odd_number=True) + + d = [0.148816447901042, -0.132385865767784, 0.067307604692185, + 0.432666402578175] + # 10: [d1, d2, d3, d4, 0.5 - sum(di), 0.5 - sum(di), d4, d3, d2, d1] + d = symmetrize_coefficients(d, odd_number=False) + + super().__init__( + position_coefficients=d, + momentum_coefficients=c, + position_first=False, + order=4 + ) + + +def coefficients_based_on_composing_second_order( + weights: Sequence[float] +) -> Tuple[Sequence[float], Sequence[float]]: + """Constructs the coefficients for methods based on second-order schemes.""" + coefficients_0 = [] + coefficients_1 = [] + coefficients_0.append(weights[0] / 2.0) + for i in range(len(weights) - 1): + coefficients_1.append(weights[i]) + coefficients_0.append((weights[i] + weights[i + 1]) / 2.0) + coefficients_1.append(weights[-1]) + coefficients_0.append(weights[-1] / 2.0) + return coefficients_0, coefficients_1 + + +class SymmetricSo4(SymmetricCompositionSymplectic): + """The fourth order method from Table 6.2 in [1] (originally from [4]).""" + + def __init__(self, position_first: bool = False): + w = [0.28, 0.62546642846767004501] + # 5 + w = symmetrize_coefficients(w, odd_number=True) + c0, c1 = coefficients_based_on_composing_second_order(w) + c_q, c_p = (c0, c1) if position_first else (c1, c0) + super().__init__( + position_coefficients=c_q, + momentum_coefficients=c_p, + position_first=position_first, + order=4 + ) + + +class SymmetricSo6(SymmetricCompositionSymplectic): + """The sixth order method from Table 6.2 in [1] (originally from [5]).""" + + def __init__(self, position_first: bool = False): + w = [0.78451361047755726382, 0.23557321335935813368, + -1.17767998417887100695] + # 7 + w = symmetrize_coefficients(w, odd_number=True) + c0, c1 = coefficients_based_on_composing_second_order(w) + c_q, c_p = (c0, c1) if position_first else (c1, c0) + super().__init__( + position_coefficients=c_q, + momentum_coefficients=c_p, + position_first=position_first, + order=6 + ) + + +class SymmetricSo8(SymmetricCompositionSymplectic): + """The eighth order method from Table 6.2 in [1] (originally from [4]).""" + + def __init__(self, position_first: bool = False): + w = [0.74167036435061295345, -0.40910082580003159400, + 0.19075471029623837995, -0.57386247111608226666, + 0.29906418130365592384, 0.33462491824529818378, + 0.31529309239676659663] + # 15 + w = symmetrize_coefficients(w, odd_number=True) + c0, c1 = coefficients_based_on_composing_second_order(w) + c_q, c_p = (c0, c1) if position_first else (c1, c0) + super().__init__( + position_coefficients=c_q, + momentum_coefficients=c_p, + position_first=position_first, + order=8 + ) + + +general_integrators = dict( + general_euler=GeneralEuler(), + rk2=RungaKutta2(), + rk4=RungaKutta4(), + rk38=RungaKutta38() +) + +symplectic_integrators = dict( + symp_euler=SymplecticEuler(position_first=True), + symp_euler_q=SymplecticEuler(position_first=True), + symp_euler_p=SymplecticEuler(position_first=False), + leap_frog=LeapFrog(position_first=False), + leap_frog_q=LeapFrog(position_first=True), + leap_frog_p=LeapFrog(position_first=False), + stormer_verlet=LeapFrog(position_first=False), + stormer_verlet_q=LeapFrog(position_first=True), + stormer_verlet_p=LeapFrog(position_first=False), + ruth4=Ruth4(), + sym4=Symmetric4(), + sym6=Symmetric6(), + so4=SymmetricSo4(position_first=False), + so4_q=SymmetricSo4(position_first=True), + so4_p=SymmetricSo4(position_first=False), + so6=SymmetricSo6(position_first=False), + so6_q=SymmetricSo6(position_first=True), + so6_p=SymmetricSo6(position_first=False), + so8=SymmetricSo8(position_first=False), + so8_q=SymmetricSo8(position_first=True), + so8_p=SymmetricSo8(position_first=False), +) + + +def get_integrator( + name_or_callable: Union[str, GeneralIntegrator] +) -> GeneralIntegrator: + """Returns any integrator with the provided name or the argument.""" + if isinstance(name_or_callable, str): + if name_or_callable in general_integrators: + return general_integrators[name_or_callable] + elif name_or_callable in symplectic_integrators: + return symplectic_integrators[name_or_callable] + else: + raise ValueError(f"Unrecognized integrator with name {name_or_callable}.") + if not callable(name_or_callable): + raise ValueError(f"Expected a callable, but got {type(name_or_callable)}.") + return name_or_callable diff --git a/physics_inspired_models/jaxline_configs.py b/physics_inspired_models/jaxline_configs.py new file mode 100644 index 0000000..555b889 --- /dev/null +++ b/physics_inspired_models/jaxline_configs.py @@ -0,0 +1,397 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing all of the configurations for various models.""" +import copy +import os +from jaxline import base_config +import ml_collections as collections + +_DATASETS_PATH_VAR_NAME = "DM_HAMILTONIAN_DYNAMICS_SUITE_DATASETS" + + +def get_config(arg_string): + """Return config object for training.""" + args = arg_string.split(",") + if len(args) != 3: + raise ValueError("You must provide exactly three arguments separated by a " + "comma - model_config_name,sweep_index,dataset_name.") + model_config_name, sweep_index, dataset_name = args + sweep_index = int(sweep_index) + + config = base_config.get_base_config() + config.random_seed = 123109801 + config.eval_modes = ("eval", "eval_metric") + + # Get the model config and the sweeps + if model_config_name not in globals(): + raise ValueError(f"The config name {model_config_name} does not exist in " + f"jaxline_configs.py") + config_and_sweep_fn = globals()[model_config_name] + model_config, sweeps = config_and_sweep_fn() + + if not os.environ.get(_DATASETS_PATH_VAR_NAME, None): + raise ValueError(f"You need to set the {_DATASETS_PATH_VAR_NAME}") + dm_hamiltonian_suite_path = os.environ[_DATASETS_PATH_VAR_NAME] + dataset_folder = os.path.join(dm_hamiltonian_suite_path, dataset_name) + + # Experiment config. Note that batch_size is per device. + # In the experiments we run on 4 GPUs, so the effective batch size was 128. + config.experiment_kwargs = collections.ConfigDict( + dict( + config=dict( + dataset_folder=dataset_folder, + model_kwargs=model_config, + num_extrapolation_steps=60, + drop_stats_containing=("neg_log_p_x", "l2_over_time", "neg_elbo"), + optimizer=dict( + name="adam", + kwargs=dict( + learning_rate=1.5e-4, + b1=0.9, + b2=0.999, + ) + ), + training=dict( + batch_size=32, + burnin_steps=5, + num_epochs=None, + lagging_vae=False + ), + evaluation=dict( + batch_size=64, + ), + evaluation_metric=dict( + batch_size=5, + batch_n=20, + num_eval_metric_steps=60, + max_poly_order=5, + max_jacobian_score=1000, + rsq_threshold=0.9, + sym_threshold=0.05, + evaluation_point_n=10, + weight_tolerance=1e-03, + max_iter=1000, + cv=2, + alpha_min_logspace=-4, + alpha_max_logspace=-0.5, + alpha_step_n=10, + calculate_fully_after_steps=40000, + ), + evaluation_metric_mlp=dict( + batch_size=64, + batch_n=10000, + datapoint_param_multiplier=1000, + num_eval_metric_steps=60, + evaluation_point_n=10, + evaluation_trajectory_n=50, + rsq_threshold=0.9, + sym_threshold=0.05, + ridge_lambda=0.01, + model=dict( + num_units=4, + num_layers=4, + activation="tanh", + ), + optimizer=dict( + name="adam", + kwargs=dict( + learning_rate=1.5e-3, + ) + ), + ), + evaluation_vpt=dict( + batch_size=5, + batch_n=2, + vpt_threshold=0.025, + ) + ) + ) + ) + + # Training loop config. + config.training_steps = int(500000) + config.interval_type = "steps" + config.log_tensors_interval = 50 + config.log_train_data_interval = 50 + config.log_all_train_data = False + + config.save_checkpoint_interval = 100 + config.checkpoint_dir = "/tmp/physics_inspired_models/" + config.train_checkpoint_all_hosts = False + config.eval_specific_checkpoint_dir = "" + + config.update_from_flattened_dict(sweeps[sweep_index]) + return config + + +config_prefix = "experiment_kwargs.config." +model_prefix = config_prefix + "model_kwargs." + +default_encoder_kwargs = collections.ConfigDict(dict( + conv_channels=64, + num_blocks=3, + blocks_depth=2, + activation="leaky_relu", +)) + +default_decoder_kwargs = collections.ConfigDict(dict( + conv_channels=64, + num_blocks=3, + blocks_depth=2, + activation="leaky_relu", +)) + +default_latent_system_net_kwargs = collections.ConfigDict(dict( + conv_channels=64, + num_units=250, + num_layers=5, + activation="swish", +)) + + +default_latent_system_kwargs = collections.ConfigDict(dict( + # Physics model arguments + input_space=collections.config_dict.placeholder(str), + simulation_space=collections.config_dict.placeholder(str), + potential_func_form="separable_net", + kinetic_func_form=collections.config_dict.placeholder(str), + hgn_kinetic_func_form="separable_net", + lgn_kinetic_func_form="matrix_dep_quad", + parametrize_mass_matrix=collections.config_dict.placeholder(bool), + hgn_parametrize_mass_matrix=False, + lgn_parametrize_mass_matrix=True, + mass_eps=1.0, + # ODE model arguments + integrator_method=collections.config_dict.placeholder(str), + # RGN model arguments + residual=collections.config_dict.placeholder(bool), + # General arguments + net_kwargs=default_latent_system_net_kwargs +)) + +default_config_dict = collections.ConfigDict(dict( + name=collections.config_dict.placeholder(str), + latent_system_dim=32, + latent_system_net_type="mlp", + latent_system_kwargs=default_latent_system_kwargs, + encoder_aggregation_type="linear_projection", + decoder_de_aggregation_type=collections.config_dict.placeholder(str), + encoder_kwargs=default_encoder_kwargs, + decoder_kwargs=default_decoder_kwargs, + has_latent_transform=False, + num_inference_steps=5, + num_target_steps=60, + latent_training_type="forward", + # Choices: overlap_by_one, no_overlap, include_inference + training_data_split="overlap_by_one", + objective_type="ELBO", + elbo_beta_delay=0, + elbo_beta_final=1.0, + geco_kappa=0.001, + geco_alpha=0.0, + dt=0.125, +)) + +hgn_paper_encoder_kwargs = collections.ConfigDict(dict( + conv_channels=[[32, 64], [64, 64], [64]], + num_blocks=3, + blocks_depth=2, + activation="relu", + kernel_shapes=[2, 4], + padding=["VALID", "SAME"], +)) + +hgn_paper_decoder_kwargs = collections.ConfigDict(dict( + conv_channels=64, + num_blocks=3, + blocks_depth=2, + activation="tf_leaky_relu", +)) + +hgn_paper_latent_net_kwargs = collections.ConfigDict(dict( + conv_channels=[32, 64, 64, 64], + num_units=250, + num_layers=5, + activation="softplus", + kernel_shapes=[3, 2, 2, 2, 2], + strides=[1, 2, 1, 2, 1], + padding=["SAME", "VALID", "SAME", "VALID", "SAME"] +)) + +hgn_paper_latent_system_kwargs = collections.ConfigDict(dict( + potential_func_form="separable_net", + kinetic_func_form="separable_net", + parametrize_mass_matrix=False, + net_kwargs=hgn_paper_latent_net_kwargs +)) + +hgn_paper_latent_transform_kwargs = collections.ConfigDict(dict( + num_layers=5, + conv_channels=64, + num_units=64, + activation="relu", +)) + +hgn_paper_config = copy.deepcopy(default_config_dict) +hgn_paper_config.training_data_split = "include_inference" +hgn_paper_config.latent_system_net_type = "conv" +hgn_paper_config.encoder_aggregation_type = (collections.config_dict. + placeholder(str)) +hgn_paper_config.decoder_de_aggregation_type = (collections.config_dict. + placeholder(str)) +hgn_paper_config.latent_system_kwargs = hgn_paper_latent_system_kwargs +hgn_paper_config.encoder_kwargs = hgn_paper_encoder_kwargs +hgn_paper_config.decoder_kwargs = hgn_paper_decoder_kwargs +hgn_paper_config.has_latent_transform = True +hgn_paper_config.latent_transform_kwargs = hgn_paper_latent_transform_kwargs +hgn_paper_config.num_inference_steps = 31 +hgn_paper_config.num_target_steps = 0 +hgn_paper_config.objective_type = "GECO" + + +forward_overlap_by_one = { + model_prefix + "latent_training_type": "forward", + model_prefix + "training_data_split": "overlap_by_one", +} + +forward_backward_include_inference = { + model_prefix + "latent_training_type": "forward_backward", + model_prefix + "training_data_split": "include_inference", +} + +latent_training_sweep = [ + forward_overlap_by_one, + forward_backward_include_inference, +] + + +def sym_metric_hgn_plus_plus_sweep(): + """HGN++ experimental sweep for the SyMetric paper.""" + model_config = copy.deepcopy(default_config_dict) + model_config.name = "HGN" + sweeps = list() + for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]: + sweeps.append({ + config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4, + model_prefix + "latent_training_type": "forward", + model_prefix + "training_data_split": "overlap_by_one", + model_prefix + "elbo_beta_final": elbo_beta_final, + }) + for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]: + sweeps.append({ + config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4, + model_prefix + "latent_training_type": "forward_backward", + model_prefix + "training_data_split": "include_inference", + model_prefix + "elbo_beta_final": elbo_beta_final, + }) + + return model_config, sweeps + + +def sym_metric_hgn_sweep(): + """HGN experimental sweep for the SyMetric paper.""" + model_config = copy.deepcopy(hgn_paper_config) + model_config.name = "HGN" + return model_config, list(dict()) + + +def benchmark_hgn_overlap_sweep(): + """HGN++ sweep for the benchmark paper.""" + model_config = copy.deepcopy(default_config_dict) + model_config.name = "HGN" + + sweeps = list() + for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]: + for train_dict in latent_training_sweep: + sweeps.append({ + config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4, + model_prefix + "elbo_beta_final": elbo_beta_final, + }) + sweeps[-1].update(train_dict) + + return model_config, sweeps + + +def benchmark_lgn_sweep(): + """LGN sweep for the benchmark paper.""" + model_config = copy.deepcopy(default_config_dict) + model_config.name = "LGN" + + sweeps = list() + for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]: + for train_dict in latent_training_sweep: + sweeps.append({ + config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4, + model_prefix + "latent_system_kwargs.kinetic_func_form": + "matrix_dep_pure_quad", + model_prefix + "elbo_beta_final": elbo_beta_final, + }) + sweeps[-1].update(train_dict) + + return model_config, sweeps + + +def benchmark_ode_sweep(): + """Neural ODE sweep for the benchmark paper.""" + model_config = copy.deepcopy(default_config_dict) + model_config.name = "ODE" + + sweeps = list() + for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]: + for integrator in ("adaptive", "rk2"): + for train_dict in latent_training_sweep: + sweeps.append({ + config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4, + model_prefix + "integrator_method": integrator, + model_prefix + "elbo_beta_final": elbo_beta_final, + }) + sweeps[-1].update(train_dict) + + return model_config, sweeps + + +def benchmark_rgn_sweep(): + """RGN sweep for the benchmark paper.""" + model_config = copy.deepcopy(default_config_dict) + model_config.name = "RGN" + + sweeps = list() + for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]: + for residual in (True, False): + sweeps.append({ + config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4, + model_prefix + "latent_system_kwargs.residual": residual, + model_prefix + "elbo_beta_final": elbo_beta_final, + }) + + return model_config, sweeps + + +def benchmark_ar_sweep(): + """AR sweep for the benchmark paper.""" + model_config = copy.deepcopy(default_config_dict) + model_config.name = "AR" + model_config.latent_dynamics_type = "vanilla" + + sweeps = list() + for elbo_beta_final in [0.001, 0.1, 1.0, 2.0]: + for ar_type in ("vanilla", "lstm", "gru"): + sweeps.append({ + config_prefix + "optimizer.kwargs.learning_rate": 1.5e-4, + model_prefix + "latent_dynamics_type": ar_type, + model_prefix + "elbo_beta_final": elbo_beta_final, + }) + + return model_config, sweeps diff --git a/physics_inspired_models/jaxline_train.py b/physics_inspired_models/jaxline_train.py new file mode 100644 index 0000000..3b736ee --- /dev/null +++ b/physics_inspired_models/jaxline_train.py @@ -0,0 +1,574 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The training script for the HGN models.""" +import functools + +from absl import app +from absl import flags +from absl import logging +from dm_hamiltonian_dynamics_suite import load_datasets +import haiku as hk +import jax +import jax.numpy as jnp +from jaxline import experiment +from jaxline import platform +import numpy as np +import optax + +from physics_inspired_models import eval_metric +from physics_inspired_models import utils +from physics_inspired_models.models import common + +AutoregressiveModel = common.autoregressive.TeacherForcingAutoregressiveModel + + +class HGNExperiment(experiment.AbstractExperiment): + """HGN experiment.""" + CHECKPOINT_ATTRS = { + "_params": "params", + "_state": "state", + "_opt_state": "opt_state", + } + NON_BROADCAST_CHECKPOINT_ATTRS = { + "_python_step": "python_step" + } + + def __init__(self, mode, init_rng, config): + super().__init__(mode=mode) + self.mode = mode + self.init_rng = init_rng + self.config = config + + # Checkpointed experiment state. + self._python_step = None + self._params = None + self._state = None + self._opt_state = None + + # Input pipelines. + self._train_input = None + self._step_fn = None + self._burnin_fn = None + self._eval_input = None + self._eval_batch = None + self._eval_input_metric = None + self._eval_input_vpt = None + self._compute_gt_state_and_latents = None + self._get_reconstructions = None + self._get_samples = None + + # Construct the model + model_kwargs = dict(**self.config.model_kwargs) + self.model = common.construct_model(**model_kwargs) + # Construct the optimizer + optimizer_ctor = getattr(optax, self.config.optimizer.name) + self.optimizer = optimizer_ctor(**self.config.optimizer.kwargs) + self.model_init = jax.pmap(self.model.init) + self.opt_init = jax.pmap(self.optimizer.init) + logging.info("Number of hosts: %d/%d", + jax.process_index(), jax.process_count()) + logging.info("Number of local devices: %d/%d", jax.local_device_count(), + jax.device_count()) + + def _process_stats(self, stats, axis_name=None): + keys_to_remove = list() + for key in stats.keys(): + for dropped_keys in self.config.drop_stats_containing: + if dropped_keys in key: + keys_to_remove.append(key) + break + for key in keys_to_remove: + stats.pop(key) + # Take average statistics + stats = jax.tree_map(utils.mean_if_not_scalar, stats) + stats = utils.filter_only_scalar_stats(stats) + if axis_name is not None: + stats = utils.pmean_if_pmap(stats, axis_name="i") + return stats + + # _ _ + # | |_ _ __ __ _(_)_ __ + # | __| '__/ _` | | '_ \ + # | |_| | | (_| | | | | | + # \__|_| \__,_|_|_| |_| + # + def step(self, global_step, rng, **unused_args): + """See base class.""" + if self._train_input is None: + self._initialize_train() + + # Do a small burnin to accumulate any persistent network state + if self._python_step == 0 and self._state: + for _ in range(self.config.training.burnin_steps): + rng, key = utils.p_split(rng, 2) + batch = next(self._train_input) + self._state = self._burnin_fn(self._params, self._state, key, batch) + self._state = jax.tree_map( + lambda x: x / self.config.training.burnin_steps, self._state) + + batch = next(self._train_input) + self._params, self._state, self._opt_state, stats = self._step_fn( + self._params, self._state, self._opt_state, rng, batch, global_step) + self._python_step += 1 + + stats = utils.get_first(stats) + logging.info("global_step: %d, %s", self._python_step, + jax.tree_map(float, stats)) + return stats + + def _initialize_train(self): + self._train_input = utils.py_prefetch( + load_datasets.dataset_as_iter(self._build_train_input)) + self._burnin_fn = jax.pmap( + self._jax_burnin_fn, axis_name="i", donate_argnums=list(range(1, 4))) + self._step_fn = jax.pmap( + self._jax_train_step_fn, axis_name="i", donate_argnums=list(range(5))) + + if self._params is not None: + logging.info("Not running initialization - loaded from checkpoint.") + assert self._opt_state is not None + return + + logging.info("Initializing parameters - NOT loading from checkpoint.") + + # Use the same rng on all devices, so that the initialization is identical + init_rng = utils.bcast_local_devices(self.init_rng) + + # Initialize the parameters and the optimizer + batch = next(self._train_input) + self._params, self._state = self.model_init(init_rng, batch) + self._python_step = 0 + self._opt_state = self.opt_init(self._params) + + def _build_train_input(self): + batch_size = self.config.training.batch_size + return load_datasets.load_dataset( + path=self.config.dataset_folder, + tfrecord_prefix="train", + sub_sample_length=self.model.train_sequence_length, + per_device_batch_size=batch_size, + num_epochs=self.config.training.num_epochs, + drop_remainder=True, + multi_device=True, + shuffle=True, + shuffle_buffer=100 * batch_size, + cache=False, + keys_to_preserve=["image"], + ) + + def _jax_train_step_fn(self, params, state, opt_state, rng_key, batch, step): + # The loss and the stats are averaged over the batch + def loss_func(*args): + outs = self.model.training_objectives(*args, is_training=True) + # Average everything over the batch + return jax.tree_map(utils.mean_if_not_scalar, outs) + + # Compute gradients + grad_fn = jax.grad(loss_func, has_aux=True) + grads, (state, stats, _) = grad_fn(params, state, rng_key, batch, step) + # Average everything over the devices (e.g. average and sync) + grads, state = utils.pmean_if_pmap((grads, state), axis_name="i") + # Apply updates + updates, opt_state = self.optimizer.update(grads, opt_state) + params = optax.apply_updates(params, updates) + return params, state, opt_state, self._process_stats(stats, axis_name="i") + + def _jax_burnin_fn(self, params, state, rng_key, batch): + _, (new_state, _, _) = self.model.training_objectives( + params, state, rng_key, batch, jnp.zeros([]), is_training=True) + new_state = jax.tree_map(utils.mean_if_not_scalar, new_state) + new_state = utils.pmean_if_pmap(new_state, axis_name="i") + new_state = hk.data_structures.to_mutable_dict(new_state) + new_state = hk.data_structures.to_immutable_dict(new_state) + return jax.tree_multimap(jnp.add, new_state, state) + + # _ + # _____ ____ _| | + # / _ \ \ / / _` | | + # | __/\ V / (_| | | + # \___| \_/ \__,_|_| + # + def evaluate(self, global_step, rng, writer): + """See base class.""" + logging.info("Starting evaluation.") + if self.mode == "eval": + if self._eval_input is None: + self._initialize_eval() + self._initialize_eval_vpt() + key1, _ = utils.p_split(rng, 2) + stats = utils.to_numpy(self._eval_epoch(global_step, key1)) + stats.update(utils.to_numpy(self._eval_epoch_vpt(global_step, rng))) + elif self.mode == "eval_metric": + if self._eval_input_metric is None: + self._initialize_eval_metric() + stats = utils.to_numpy(self._eval_epoch_metric(global_step, rng)) + else: + raise NotImplementedError() + logging.info("Finished evaluation.") + return stats + + def _eval_epoch(self, step, rng): + """Evaluates an epoch.""" + accumulator = utils.MultiBatchAccumulator() + for batch in self._eval_input(): + rng, key = utils.p_split(rng, 2) + stats, num_samples = utils.get_first( + self._eval_batch(self._params, self._state, key, batch, step) + ) + accumulator.add(stats, num_samples) + return accumulator.value() + + def _eval_epoch_metric(self, step, rng): + """Evaluates an epoch.""" + # To prevent from calculating SyMetric early on in training where a large + # polynomial expansion is likely to be required and the score is likely + # to be bad anyway, we only compute using a single batch to save compute + if step[0] > self.config.evaluation_metric.calculate_fully_after_steps: + batch_n = self.config.evaluation_metric.batch_n + else: + batch_n = 1 + logging.info("Step: %d, batch_n: %d", step[0], batch_n) + + accumulator = utils.MultiBatchAccumulator() + for _ in range(self.config.evaluation_metric.batch_n): + batch = next(self._eval_input_metric) + rng, key = utils.p_split(rng, 2) + stats = self._eval_batch_metric( + self._params, key, batch, + eval_seq_len=self.config.evaluation_metric.num_eval_metric_steps, + ) + accumulator.add(stats, 1) + stats = utils.flatten_dict(accumulator.value()) + max_keys = ("sym", "SyMetric") + for k, v in utils.flatten_dict(accumulator.max()).items(): + if any(m in k for m in max_keys): + stats[k + "_max"] = v + + min_keys = ("sym", "SyMetric") + for k, v in utils.flatten_dict(accumulator.min()).items(): + if any(m in k for m in min_keys): + stats[k + "_min"] = v + + sum_keys = ("sym", "SyMetric") + for k, v in utils.flatten_dict(accumulator.sum()).items(): + if any(m in k for m in sum_keys): + stats[k + "_sum"] = v + return stats + + def _eval_epoch_vpt(self, step, rng): + """Evaluates an epoch.""" + accumulator = utils.MultiBatchAccumulator() + for _ in range(self.config.evaluation_vpt.batch_n): + batch = next(self._eval_input_vpt) + rng, key = utils.p_split(rng, 2) + stats = self._eval_batch_vpt(self._params, self._state, key, batch) + accumulator.add(stats, 1) + stats = utils.flatten_dict(accumulator.value()) + return stats + + def _reconstruct_and_align(self, rng_key, full_trajectory, prefix, suffix): + if hasattr(self.model, "training_data_split"): + if self.model.training_data_split == "overlap_by_one": + reconstruction_skip = self.model.num_inference_steps - 1 + elif self.model.training_data_split == "no_overlap": + reconstruction_skip = self.model.num_inference_steps + elif self.model.training_data_split == "include_inference": + reconstruction_skip = 0 + else: + raise NotImplementedError() + else: + reconstruction_skip = 1 + + full_forward_targets = jax.tree_map( + lambda x: x[:, :, reconstruction_skip:], full_trajectory) + full_backward_targets = jax.tree_map( + lambda x: x[:, :, :x.shape[2] - reconstruction_skip], full_trajectory) + train_targets_length = (self.model.train_sequence_length - + reconstruction_skip) + full_targets_length = full_forward_targets.shape[2] + + # Fully unroll the model and reconstruct the whole sequence, take the mean + full_prediction = self._get_reconstructions(self._params, full_trajectory, + rng_key, prefix == "forward", + True).mean() + full_targets = (full_forward_targets if prefix == "forward" else + full_backward_targets) + + # In cases where the model can run backwards it is possible to reconstruct + # parts which were indented to be skipped, so here we take care of that. + if full_prediction.mean().shape[2] > full_targets_length: + if prefix == "forward": + full_prediction = jax.tree_map( + lambda x: x[:, :, -full_targets_length:], full_prediction) + else: + full_prediction = jax.tree_map( + lambda x: x[:, :, :full_targets_length], full_prediction) + + # Based on the prefix and suffix fetch correct predictions and targets + if prefix == "forward" and suffix == "train": + predict, targets = jax.tree_map( + lambda x: x[:, :, :train_targets_length], + (full_prediction, full_targets)) + elif prefix == "forward" and suffix == "extrapolation": + predict, targets = jax.tree_map( + lambda x: x[:, :, train_targets_length:], + (full_prediction, full_targets)) + elif prefix == "backward" and suffix == "train": + predict, targets = jax.tree_map( + lambda x: x[:, :, -train_targets_length:], + (full_prediction, full_targets)) + elif prefix == "backward" and suffix == "extrapolation": + predict, targets = jax.tree_map( + lambda x: x[:, :, :-train_targets_length], + (full_prediction, full_targets)) + else: + predict, targets = full_prediction, full_targets + + return predict, targets + + def _initialize_eval(self): + length = (self.model.train_sequence_length + + self.config.num_extrapolation_steps) + batch_size = self.config.evaluation.batch_size + self._eval_input = load_datasets.dataset_as_iter( + load_datasets.load_dataset, + path=self.config.dataset_folder, + tfrecord_prefix="test", + sub_sample_length=length, + per_device_batch_size=batch_size, + num_epochs=1, + drop_remainder=False, + shuffle=False, + cache=False, + keys_to_preserve=["image"] + ) + self._eval_batch = jax.pmap( + self._jax_eval_step_fn, axis_name="i") + self._get_reconstructions = jax.pmap( + self.model.reconstruct, axis_name="i", + static_broadcasted_argnums=(3, 4)) + if isinstance(self.model, + common.deterministic_vae.DeterministicLatentsGenerativeModel): + self._get_samples = jax.pmap( + self.model.sample_trajectories_from_prior, + static_broadcasted_argnums=(1, 3, 4)) + + def _initialize_eval_metric(self): + self._eval_input_metric = utils.py_prefetch( + load_datasets.dataset_as_iter( + load_datasets.load_dataset, + path=self.config.dataset_folder, + tfrecord_prefix="test", + sub_sample_length=None, + per_device_batch_size=self.config.evaluation_metric.batch_size, + num_epochs=None, + drop_remainder=False, + cache=False, + shuffle=False, + keys_to_preserve=["image", "x"] + ) + ) + def compute_gt_state_and_latents(*args): + # Note that the `dt` has to be passed as a kwargs argument + if len(args) == 4: + return self.model.gt_state_and_latents(*args[:4]) + elif len(args) == 5: + return self.model.gt_state_and_latents(*args[:4], dt=args[4]) + else: + raise NotImplementedError() + self._compute_gt_state_and_latents = jax.pmap( + compute_gt_state_and_latents, static_broadcasted_argnums=3) + + def _initialize_eval_vpt(self): + dataset_name = self.config.dataset_folder.split("/")[-1] + dataset_folder = self.config.dataset_folder + if dataset_name in ("hnn_mass_spring_dt_0_05", + "mass_spring_colors_v1_dt_0_05", + "hnn_pendulum_dt_0_05", + "pendulum_colors_v1_dt_0_05", + "matrix_rps_dt_0_1", + "matrix_mp_dt_0_1"): + dataset_folder += "_long_trajectory" + + self._eval_input_vpt = utils.py_prefetch( + load_datasets.dataset_as_iter( + load_datasets.load_dataset, + path=dataset_folder, + tfrecord_prefix="test", + sub_sample_length=None, + per_device_batch_size=self.config.evaluation_vpt.batch_size, + num_epochs=None, + drop_remainder=False, + cache=False, + shuffle=False, + keys_to_preserve=["image", "x"] + ) + ) + + self._get_reconstructions = jax.pmap( + self.model.reconstruct, axis_name="i", + static_broadcasted_argnums=(3, 4)) + + def _jax_eval_step_fn(self, params, state, rng_key, batch, step): + # We care only about the statistics + _, (_, stats, _) = self.model.training_objectives(params, state, rng_key, + batch, step, + is_training=False) + # Compute the full batch size + batch_size = jax.tree_flatten(batch)[0][0].shape[0] + batch_size = utils.psum_if_pmap(batch_size, axis_name="i") + + return self._process_stats(stats, axis_name="i"), batch_size + + def _eval_batch_vpt(self, params, state, rng_key, batch): + full_trajectory = utils.extract_image(batch) + prefixes = ("forward", + "backward") if self.model.can_run_backwards else ("forward",) + stats = dict() + vpt_abs_scores = [] + vpt_rel_scores = [] + seq_length = None + for prefix in prefixes: + reconstruction, gt_images = self._reconstruct_and_align( + rng_key, full_trajectory, prefix, "extrapolation") + seq_length = gt_images.shape[2] + + mse_norm = np.mean( + (gt_images - reconstruction)**2, axis=(3, 4, 5)) / np.mean( + gt_images**2, axis=(3, 4, 5)) + + vpt_scores = [] + for i in range(mse_norm.shape[1]): + vpt_ind = np.argwhere( + mse_norm[:, i:i + 1, :] > self.config.evaluation_vpt.vpt_threshold) + + if vpt_ind.shape[0] > 0: + vpt_ind = vpt_ind[0][2] + else: + vpt_ind = mse_norm.shape[-1] + + vpt_scores.append(vpt_ind) + + vpt_abs_scores.append(np.median(vpt_scores)) + vpt_rel_scores.append(np.median(vpt_scores) / seq_length) + scores = {"vpt_abs": vpt_abs_scores[-1], "vpt_rel": vpt_rel_scores[-1]} + scores = utils.to_numpy(scores) + scores = utils.filter_only_scalar_stats(scores) + stats[prefix] = scores + + stats["vpt_abs"] = utils.to_numpy(np.mean(vpt_abs_scores)) + stats["vpt_rel"] = utils.to_numpy(np.mean(vpt_rel_scores)) + logging.info("vpt_abs: %s, seq_length: %d}", + str(vpt_abs_scores), seq_length) + return stats + + def _eval_batch_metric(self, params, rng, batch, eval_seq_len=200): + # Initialise alpha values for Lasso regression + alpha_sweep = np.logspace(self.config.evaluation_metric.alpha_min_logspace, + self.config.evaluation_metric.alpha_max_logspace, + self.config.evaluation_metric.alpha_step_n) + trajectory_n = self.config.evaluation_metric.batch_size + subsection = f"{trajectory_n}tr" + stats = dict() + + # Get data + (gt_trajectory, + model_trajectory, + informative_dim_n) = self._get_gt_and_model_phase_space_for_eval( + params, rng, batch, eval_seq_len) + + # Calculate SyMetric scores + if informative_dim_n > 1: + scores, *_ = eval_metric.calculate_symetric_score( + gt_trajectory, + model_trajectory, + self.config.evaluation_metric.max_poly_order, + self.config.evaluation_metric.max_jacobian_score, + self.config.evaluation_metric.rsq_threshold, + self.config.evaluation_metric.sym_threshold, + self.config.evaluation_metric.evaluation_point_n, + trajectory_n=trajectory_n, + weight_tolerance=self.config.evaluation_metric.weight_tolerance, + alpha_sweep=alpha_sweep, + max_iter=self.config.evaluation_metric.max_iter, + cv=self.config.evaluation_metric.cv) + + scores["unmasked_latents"] = informative_dim_n + scores = utils.to_numpy(scores) + scores = utils.filter_only_scalar_stats(scores) + stats[subsection] = scores + else: + scores = { + "poly_exp_order": + self.config.evaluation_metric.max_poly_order, + "rsq": + 0, + "sym": + self.config.evaluation_metric.max_jacobian_score, + "SyMetric": 0.0, + "unmasked_latents": + informative_dim_n + } + scores = utils.to_numpy(scores) + scores = utils.filter_only_scalar_stats(scores) + stats[subsection] = scores + + return stats + + def _get_gt_and_model_phase_space_for_eval(self, params, rng, batch, + eval_seq_len): + # Get data + gt_data, model_data, z0 = utils.stack_device_dim_into_batch( + self._compute_gt_state_and_latents(params, rng, batch, eval_seq_len) + ) + + if isinstance(self.model, AutoregressiveModel): + # These models return the `z` for the whole sequence + z0 = z0[:, 0] + + # If latent space is image like, reshape it down to vector + if self.model.latent_system_net_type == "conv": + z0 = jax.tree_map(utils.reshape_latents_conv_to_flat, z0) + model_data = jax.tree_map( + lambda x: utils.reshape_latents_conv_to_flat(x, axis_n_to_keep=2), + model_data) + + # Create mask to get rid of uninformative latents + latent_mask = eval_metric.create_latent_mask(z0) + informative_dim_n = np.sum(latent_mask) + + model_data = model_data[:, :, latent_mask] + logging.info("Masking out model data, leaving dim_n=%d dimensions.", + model_data.shape[-1]) + + gt_trajectory = np.reshape( + gt_data, + [np.product(gt_data.shape[:-1]), gt_data.shape[-1]] + ) + + model_trajectory = np.reshape(model_data, [ + np.product(model_data.shape[:-1]), model_data.shape[-1] + ]) + + # Standardize data + gt_trajectory = eval_metric.standardize_data(gt_trajectory) + model_trajectory = eval_metric.standardize_data(model_trajectory) + + return gt_trajectory, model_trajectory, informative_dim_n + +if __name__ == "__main__": + flags.mark_flag_as_required("config") + logging.set_stderrthreshold(logging.INFO) + app.run(functools.partial(platform.main, HGNExperiment)) diff --git a/physics_inspired_models/launch_all.sh b/physics_inspired_models/launch_all.sh new file mode 100755 index 0000000..d96d0c4 --- /dev/null +++ b/physics_inspired_models/launch_all.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Script to execute a single configuration on all datasets. +if [[ "$#" -eq 2 ]]; then + readonly CONFIG_NAME="$1" + readonly NUM_SWEEPS="$2" +else + echo "You must provide exactly two arguments - the configuration name and " \ + "how many sweeps it contains. For example:" + echo "./launch_all.sh sym_metric_hgn_plus_plus_sweep 1" + exit 2 +fi + +DATASETS=( + "toy_physics/mass_spring" + "toy_physics/mass_spring_colors" + "toy_physics/mass_spring_colors_friction" + "toy_physics/pendulum" + "toy_physics/pendulum_colors" + "toy_physics/pendulum_colors_friction" + "toy_physics/two_body" + "toy_physics/two_body_colors" + "toy_physics/double_pendulum" + "toy_physics/double_pendulum_colors" + "toy_physics/double_pendulum_colors_friction" + "molecular_dynamics/lj_4" + "molecular_dynamics/lj_16" + "multi_agent/rock_paper_scissors" + "multi_agent/matching_pennies" + "mujoco_room/circle" + "mujoco_room/spiral" +) + +readonly DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +for dataset in "${DATASETS[@]}"; do + "${DIR}/launch_local.sh" "${CONFIG_NAME}" "${NUM_SWEEPS}" "${dataset}" +done diff --git a/physics_inspired_models/launch_local.sh b/physics_inspired_models/launch_local.sh new file mode 100755 index 0000000..65c97b1 --- /dev/null +++ b/physics_inspired_models/launch_local.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# A script to execute a single configuration name on a given dataset. +if [[ "$#" -eq 3 ]]; then + readonly CONFIG_NAME="$1" + readonly NUM_SWEEPS="$2" + readonly DATASET="$3" +else + echo "You must provide exactly three arguments - the configuration name, " \ + "the number of sweeps it contains and the dataset name. For example:" + echo "./launch_local.sh sym_metric_hgn_plus_plus_sweep 1 " \ + "toy_physics/mass_spring" + exit 2 +fi +echo "Running with config ${CONFIG_NAME} on ${DATASET}." + +readonly EXPERIMENT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +readonly TRAIN_FILE="${EXPERIMENT_DIR}/jaxline_train.py" +readonly CONFIG_FILE="${EXPERIMENT_DIR}/jaxline_configs.py" + +for sweep_id in $(seq 0 $((NUM_SWEEPS - 1))); do + python3 "${TRAIN_FILE}" \ + --config="${CONFIG_FILE}:${CONFIG_NAME},${sweep_id},${DATASET}" \ + --jaxline_mode="train" \ + --logtostderr +done diff --git a/physics_inspired_models/metrics.py b/physics_inspired_models/metrics.py new file mode 100644 index 0000000..d4c8a7d --- /dev/null +++ b/physics_inspired_models/metrics.py @@ -0,0 +1,247 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing code for computing various metrics for training and evaluation.""" +from typing import Callable, Dict, Optional + +import distrax +import haiku as hk +import jax +import jax.nn as nn +import jax.numpy as jnp +import numpy as np + +import physics_inspired_models.utils as utils + + +_ReconstructFunc = Callable[[utils.Params, jnp.ndarray, jnp.ndarray, bool], + distrax.Distribution] + + +def calculate_small_latents(dist, threshold=0.5): + """Calculates the number of active latents by thresholding the variance of their distribution.""" + if not isinstance(dist, distrax.Normal): + raise NotImplementedError() + latent_means = dist.mean() + latent_stddevs = dist.variance() + small_latents = jnp.sum( + (latent_stddevs < threshold) & (jnp.abs(latent_means) > 0.1), axis=1) + return jnp.mean(small_latents) + + +def compute_scale( + targets: jnp.ndarray, + rescale_by: str +) -> jnp.ndarray: + """Compute a scaling factor based on targets shape and the rescale_by argument.""" + if rescale_by == "pixels_and_time": + return jnp.asarray(np.prod(targets.shape[-4:])) + elif rescale_by is not None: + raise ValueError(f"Unrecognized rescale_by={rescale_by}.") + else: + return jnp.ones([]) + + +def compute_data_domain_stats( + p_x: distrax.Distribution, + targets: jnp.ndarray +) -> Dict[str, jnp.ndarray]: + """Compute several statistics in the data domain, such as L2 and negative log likelihood.""" + axis = tuple(range(2, targets.ndim)) + l2_over_time = jnp.sum((p_x.mean() - targets) ** 2, axis=axis) + l2 = jnp.sum(l2_over_time, axis=1) + + # Calculate relative L2 normalised by image "length" + norm_factor = jnp.sum(targets**2, axis=(2, 3, 4)) + l2_over_time_norm = l2_over_time / norm_factor + l2_norm = jnp.sum(l2_over_time_norm, axis=1) + + # Compute negative log-likelihood under p(x) + neg_log_p_x_over_time = - np.sum(p_x.log_prob(targets), axis=axis) + neg_log_p_x = jnp.sum(neg_log_p_x_over_time, axis=1) + + return dict( + neg_log_p_x_over_time=neg_log_p_x_over_time, + neg_log_p_x=neg_log_p_x, + l2_over_time=l2_over_time, + l2=l2, + l2_over_time_norm=l2_over_time_norm, + l2_norm=l2_norm, + ) + + +def compute_vae_stats( + neg_log_p_x: jnp.ndarray, + rng: jnp.ndarray, + q_z: distrax.Distribution, + prior: distrax.Distribution +) -> Dict[str, jnp.ndarray]: + """Compute the KL(q(z|x)||p(z)) and the negative ELBO, which are used for VAE models.""" + # Compute the KL + kl = distrax.estimate_kl_best_effort(q_z, prior, rng_key=rng, num_samples=1) + kl = np.sum(kl, axis=list(range(1, kl.ndim))) + # Sanity check + assert kl.shape == neg_log_p_x.shape + return dict( + kl=kl, + neg_elbo=neg_log_p_x + kl, + ) + + +def training_statistics( + p_x: distrax.Distribution, + targets: jnp.ndarray, + rescale_by: Optional[str], + rng: Optional[jnp.ndarray] = None, + q_z: Optional[distrax.Distribution] = None, + prior: Optional[distrax.Distribution] = None, + p_x_learned_sigma: bool = False +) -> Dict[str, jnp.ndarray]: + """Computes various statistics we track during training.""" + stats = compute_data_domain_stats(p_x, targets) + + if rng is not None and q_z is not None and prior is not None: + stats.update(compute_vae_stats(stats["neg_log_p_x"], rng, q_z, prior)) + else: + assert rng is None and q_z is None and prior is None + + # Rescale these stats accordingly + scale = compute_scale(targets, rescale_by) + # Note that "_over_time" stats are getting normalised by time here + stats = jax.tree_map(lambda x: x / scale, stats) + if p_x_learned_sigma: + stats["p_x_sigma"] = p_x.variance().reshape([-1])[0] + if q_z is not None: + stats["small_latents"] = calculate_small_latents(q_z) + return stats + + +def evaluation_only_statistics( + reconstruct_func: _ReconstructFunc, + params: hk.Params, + inputs: jnp.ndarray, + rng: jnp.ndarray, + rescale_by: str, + can_run_backwards: bool, + train_sequence_length: int, + reconstruction_skip: int, + p_x_learned_sigma: bool = False, +) -> Dict[str, jnp.ndarray]: + """Computes various statistics we track only during evaluation.""" + full_trajectory = utils.extract_image(inputs) + prefixes = ("forward", "backward") if can_run_backwards else ("forward",) + + full_forward_targets = jax.tree_map( + lambda x: x[:, reconstruction_skip:], full_trajectory) + full_backward_targets = jax.tree_map( + lambda x: x[:, :x.shape[1]-reconstruction_skip], full_trajectory) + train_targets_length = train_sequence_length - reconstruction_skip + full_targets_length = full_forward_targets.shape[1] + + stats = dict() + keys = () + + for prefix in prefixes: + # Fully unroll the model and reconstruct the whole sequence + full_prediction = reconstruct_func(params, full_trajectory, rng, + prefix == "forward") + assert isinstance(full_prediction, distrax.Normal) + full_targets = (full_forward_targets if prefix == "forward" else + full_backward_targets) + # In cases where the model can run backwards it is possible to reconstruct + # parts which were indented to be skipped, so here we take care of that. + if full_prediction.mean().shape[1] > full_targets_length: + if prefix == "forward": + full_prediction = jax.tree_map(lambda x: x[:, -full_targets_length:], + full_prediction) + else: + full_prediction = jax.tree_map(lambda x: x[:, :full_targets_length], + full_prediction) + + # Based on the prefix and suffix fetch correct predictions and targets + for suffix in ("train", "extrapolation", "full"): + if prefix == "forward" and suffix == "train": + predict, targets = jax.tree_map(lambda x: x[:, :train_targets_length], + (full_prediction, full_targets)) + elif prefix == "forward" and suffix == "extrapolation": + predict, targets = jax.tree_map(lambda x: x[:, train_targets_length:], + (full_prediction, full_targets)) + elif prefix == "backward" and suffix == "train": + predict, targets = jax.tree_map(lambda x: x[:, -train_targets_length:], + (full_prediction, full_targets)) + elif prefix == "backward" and suffix == "extrapolation": + predict, targets = jax.tree_map(lambda x: x[:, :-train_targets_length], + (full_prediction, full_targets)) + else: + predict, targets = full_prediction, full_targets + + # Compute train statistics + train_stats = training_statistics(predict, targets, rescale_by, + p_x_learned_sigma=p_x_learned_sigma) + for key, value in train_stats.items(): + stats[prefix + "_" + suffix + "_" + key] = value + # Copy all stats keys + keys = tuple(train_stats.keys()) + + # Make a combined metric summing forward and backward + if can_run_backwards: + # Also compute + for suffix in ("train", "extrapolation", "full"): + for key in keys: + forward = stats["forward_" + suffix + "_" + key] + backward = stats["backward_" + suffix + "_" + key] + combined = (forward + backward) / 2 + stats["combined_" + suffix + "_" + key] = combined + + return stats + + +def geco_objective( + l2_loss, + kl, + alpha, + kappa, + constraint_ema, + lambda_var, + is_training +) -> Dict[str, jnp.ndarray]: + """Computes the objective for GECO and some of it statistics used ofr updates.""" + # C_t + constraint_t = l2_loss - kappa + if is_training: + # We update C_ma only during training + constraint_ema = alpha * constraint_ema + (1 - alpha) * constraint_t + lagrange = nn.softplus(lambda_var) + lagrange = jnp.broadcast_to(lagrange, constraint_ema.shape) + # Add this special op for getting all gradients correct + loss = utils.geco_lagrange_product(lagrange, constraint_ema, constraint_t) + return dict( + loss=loss + kl, + geco_multiplier=lagrange, + geco_constraint=constraint_t, + geco_constraint_ema=constraint_ema + ) + + +def elbo_objective(neg_log_p_x, kl, final_beta, beta_delay, step): + """Computes objective for optimizing the Evidence Lower Bound (ELBO).""" + if beta_delay == 0: + beta = final_beta + else: + delayed_beta = jnp.minimum(float(step) / float(beta_delay), 1.0) + beta = delayed_beta * final_beta + return dict( + loss=neg_log_p_x + beta * kl, + elbo_beta=beta + ) diff --git a/physics_inspired_models/models/__init__.py b/physics_inspired_models/models/__init__.py new file mode 100644 index 0000000..204a70c --- /dev/null +++ b/physics_inspired_models/models/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/physics_inspired_models/models/autoregressive.py b/physics_inspired_models/models/autoregressive.py new file mode 100644 index 0000000..0336b86 --- /dev/null +++ b/physics_inspired_models/models/autoregressive.py @@ -0,0 +1,345 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for all autoregressive models.""" +import functools +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union + +import distrax +import haiku as hk +from jax import lax +import jax.numpy as jnp +import jax.random as jnr + +import physics_inspired_models.metrics as metrics +import physics_inspired_models.models.base as base +import physics_inspired_models.models.networks as nets +import physics_inspired_models.utils as utils + + +class TeacherForcingAutoregressiveModel(base.SequenceModel): + """A standard autoregressive model trained via teacher forcing.""" + + def __init__( + self, + latent_system_dim: int, + latent_system_net_type: str, + latent_system_kwargs: Dict[str, Any], + latent_dynamics_type: str, + encoder_aggregation_type: Optional[str], + decoder_de_aggregation_type: Optional[str], + encoder_kwargs: Dict[str, Any], + decoder_kwargs: Dict[str, Any], + num_inference_steps: int, + num_target_steps: int, + name: Optional[str] = None, + **kwargs + ): + + # Remove any parameters from vae models + encoder_kwargs = dict(**encoder_kwargs) + encoder_kwargs["distribution_name"] = None + + if kwargs.get("has_latent_transform", False): + raise ValueError("We do not support AR models with latent transform.") + + super().__init__( + can_run_backwards=False, + latent_system_dim=latent_system_dim, + latent_system_net_type=latent_system_net_type, + latent_system_kwargs=latent_system_kwargs, + encoder_aggregation_type=encoder_aggregation_type, + decoder_de_aggregation_type=decoder_de_aggregation_type, + encoder_kwargs=encoder_kwargs, + decoder_kwargs=decoder_kwargs, + num_inference_steps=num_inference_steps, + num_target_steps=num_target_steps, + name=name, + **kwargs + ) + self.latent_dynamics_type = latent_dynamics_type + + # Arguments checks + if self.latent_system_net_type != "mlp": + raise ValueError("Currently we do not support non-mlp AR models.") + + def recurrence_function(sequence, initial_state=None): + core = nets.make_flexible_recurrent_net( + core_type=latent_dynamics_type, + net_type=latent_system_net_type, + output_dims=self.latent_system_dim, + **self.latent_system_kwargs["net_kwargs"]) + initial_state = initial_state or core.initial_state(sequence.shape[1]) + core(sequence[0], initial_state) + return hk.dynamic_unroll(core, sequence, initial_state) + + self.recurrence = hk.transform(recurrence_function) + + def process_inputs_for_encoder(self, x: jnp.ndarray) -> jnp.ndarray: + return x + + def process_latents_for_dynamics(self, z: jnp.ndarray) -> jnp.ndarray: + return z + + def process_latents_for_decoder(self, z: jnp.ndarray) -> jnp.ndarray: + return z + + @property + def inferred_index(self) -> int: + return self.num_inference_steps - 1 + + @property + def train_sequence_length(self) -> int: + return self.num_target_steps + + def train_data_split( + self, + images: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray, Mapping[str, Any]]: + images = images[:, :self.train_sequence_length] + inference_data = images[:, :-1] + target_data = images[:, 1:] + return inference_data, target_data, dict( + num_steps_forward=1, + num_steps_backward=0, + include_z0=False) + + def unroll_without_inputs( + self, + params: utils.Params, + rng: jnp.ndarray, + x_init: jnp.ndarray, + h_init: jnp.ndarray, + num_steps: int, + is_training: bool + ) -> Tuple[Tuple[distrax.Distribution, jnp.ndarray], Any]: + if num_steps < 1: + raise ValueError("`num_steps` must be at least 1.") + + def step_fn(carry, key): + x_last, h_last = carry + enc_key, dec_key = jnr.split(key) + z_in_next = self.encoder.apply(params, enc_key, x_last, + is_training=is_training) + z_next, h_next = self.recurrence.apply(params, None, z_in_next[None], + h_last) + p_x_next = self.decode_latents(params, dec_key, z_next[0], + is_training=is_training) + return (p_x_next.mean(), h_next), (p_x_next, z_next[0]) + + return lax.scan( + step_fn, + init=(x_init, h_init), + xs=jnr.split(rng, num_steps) + ) + + def unroll_latent_dynamics( + self, + z: jnp.ndarray, + params: utils.Params, + key: jnp.ndarray, + num_steps_forward: int, + num_steps_backward: int, + include_z0: bool, + is_training: bool, + **kwargs: Any + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + init_key, unroll_key, dec_key = jnr.split(key, 3) + + if num_steps_backward != 0: + raise ValueError("This model can not run backwards.") + + # Change 'z' time dimension to be first + z = jnp.swapaxes(z, 0, 1) + + # Run recurrent model on inputs + z_0, h_0 = self.recurrence.apply(params, init_key, z) + + if num_steps_forward == 1: + z_t = z_0 + elif num_steps_forward > 1: + p_x_0 = self.decode_latents(params, dec_key, z_0[-1], is_training=False) + _, (_, z_t) = self.unroll_without_inputs( + params=params, + rng=unroll_key, + x_init=p_x_0.mean(), + h_init=h_0, + num_steps=num_steps_forward-1, + is_training=is_training + ) + z_t = jnp.concatenate([z_0, z_t], axis=0) + else: + raise ValueError("num_steps_forward should be at least 1.") + + # Make time dimension second + return jnp.swapaxes(z_t, 0, 1), dict() + + def _models_core( + self, + params: utils.Params, + keys: jnp.ndarray, + image_data: jnp.ndarray, + is_training: bool, + **unroll_kwargs: Any + ) -> Tuple[distrax.Distribution, jnp.ndarray, jnp.ndarray]: + enc_key, _, transform_key, unroll_key, dec_key, _ = keys + + # Calculate latent input representation + inference_data = self.process_inputs_for_encoder(image_data) + z_raw = self.encoder.apply(params, enc_key, inference_data, + is_training=is_training) + + # Apply latent transformation (should be identity) + z0 = self.apply_latent_transform(params, transform_key, z_raw, + is_training=is_training) + z0 = self.process_latents_for_dynamics(z0) + + # Calculate latent output representation + decoder_z, _ = self.unroll_latent_dynamics( + z=z0, + params=params, + key=unroll_key, + is_training=is_training, + **unroll_kwargs + ) + decoder_z = self.process_latents_for_decoder(decoder_z) + + # Compute p(x|z) + p_x = self.decode_latents(params, dec_key, decoder_z, + is_training=is_training) + return p_x, z0, decoder_z + + def training_objectives( + self, + params: hk.Params, + state: hk.State, + rng: jnp.ndarray, + inputs: jnp.ndarray, + step: jnp.ndarray, + is_training: bool = True, + use_mean_for_eval_stats: bool = True + ) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]: + """Computes the training objective and any supporting stats.""" + # Split all rng keys + keys = jnr.split(rng, 6) + + # Process training data + images = utils.extract_image(inputs) + image_data, target_data, unroll_kwargs = self.train_data_split(images) + + p_x, _, _ = self._models_core( + params=params, + keys=keys, + image_data=image_data, + is_training=is_training, + **unroll_kwargs + ) + + # Compute training statistics + stats = metrics.training_statistics( + p_x=p_x, + targets=target_data, + rescale_by=self.rescale_by, + p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False) + ) + + # The loss is just the negative log-likelihood (e.g. the L2 loss) + stats["loss"] = stats["neg_log_p_x"] + + if not is_training: + # Optionally add the evaluation stats when not training + # Add also the evaluation statistics + # We need to be able to set `use_mean = False` for some of the tests + stats.update(metrics.evaluation_only_statistics( + reconstruct_func=functools.partial( + self.reconstruct, use_mean=use_mean_for_eval_stats), + params=params, + inputs=inputs, + rng=rng, + rescale_by=self.rescale_by, + can_run_backwards=self.can_run_backwards, + train_sequence_length=self.train_sequence_length, + reconstruction_skip=1, + p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False) + )) + + return stats["loss"], (dict(), stats, dict()) + + def reconstruct( + self, + params: utils.Params, + inputs: jnp.ndarray, + rng: jnp.ndarray, + forward: bool, + use_mean: bool = True, + ) -> distrax.Distribution: + """Reconstructs the input sequence.""" + if not forward: + raise ValueError("This model can not run backwards.") + images = utils.extract_image(inputs) + image_data = images[:, :self.num_inference_steps] + + return self._models_core( + params=params, + keys=jnr.split(rng, 6), + image_data=image_data, + is_training=False, + num_steps_forward=images.shape[1] - self.num_inference_steps, + num_steps_backward=0, + include_z0=False, + )[0] + + def gt_state_and_latents( + self, + params: hk.Params, + rng: jnp.ndarray, + inputs: Dict[str, jnp.ndarray], + seq_length: int, + is_training: bool = False, + unroll_direction: str = "forward", + **kwargs: Dict[str, Any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, + Union[distrax.Distribution, jnp.ndarray]]: + """Computes the ground state and matching latents.""" + assert unroll_direction == "forward" + images = utils.extract_image(inputs) + gt_state = utils.extract_gt_state(inputs) + image_data = images[:, :self.num_inference_steps] + gt_state = gt_state[:, 1:seq_length + 1] + + _, z_in, z_out = self._models_core( + params=params, + keys=jnr.split(rng, 6), + image_data=image_data, + is_training=False, + num_steps_forward=images.shape[1] - self.num_inference_steps, + num_steps_backward=0, + include_z0=False, + ) + + return gt_state, z_out, z_in + + def _init_non_model_params_and_state( + self, + rng: jnp.ndarray + ) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]: + return dict(), dict() + + def _init_latent_system( + self, + rng: jnp.ndarray, + z: jnp.ndarray, + **kwargs: Any + ) -> utils.Params: + return self.recurrence.init(rng, z) diff --git a/physics_inspired_models/models/base.py b/physics_inspired_models/models/base.py new file mode 100644 index 0000000..f987052 --- /dev/null +++ b/physics_inspired_models/models/base.py @@ -0,0 +1,360 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing the base abstract classes for sequence models.""" +import abc +from typing import Any, Dict, Generic, Mapping, Optional, Sequence, Tuple, TypeVar, Union + +from absl import logging +import distrax +import haiku as hk +import jax +import jax.numpy as jnp +import jax.random as jnr + + +from physics_inspired_models import utils +from physics_inspired_models.models import networks + +T = TypeVar("T") + + +class SequenceModel(abc.ABC, Generic[T]): + """An abstract class for sequence models.""" + + def __init__( + self, + can_run_backwards: bool, + latent_system_dim: int, + latent_system_net_type: str, + latent_system_kwargs: Dict[str, Any], + encoder_aggregation_type: Optional[str], + decoder_de_aggregation_type: Optional[str], + encoder_kwargs: Dict[str, Any], + decoder_kwargs: Dict[str, Any], + num_inference_steps: int, + num_target_steps: int, + name: str, + latent_spatial_shape: Optional[Tuple[int, int]] = (4, 4), + has_latent_transform: bool = False, + latent_transform_kwargs: Optional[Dict[str, Any]] = None, + rescale_by: Optional[str] = "pixels_and_time", + data_format: str = "NHWC", + **unused_kwargs + ): + # Arguments checks + encoder_kwargs = encoder_kwargs or dict() + decoder_kwargs = decoder_kwargs or dict() + + # Set the decoder de-aggregation type the "same" type as the encoder if not + # provided + if (decoder_de_aggregation_type is None and + encoder_aggregation_type is not None): + if encoder_aggregation_type == "linear_projection": + decoder_de_aggregation_type = "linear_projection" + elif encoder_aggregation_type in ("mean", "max"): + decoder_de_aggregation_type = "tile" + else: + raise ValueError(f"Unrecognized encoder_aggregation_type=" + f"{encoder_aggregation_type}") + if latent_system_net_type == "conv": + if encoder_aggregation_type is not None: + raise ValueError("When the latent system is convolutional, the encoder " + "aggregation type should be None.") + if decoder_de_aggregation_type is not None: + raise ValueError("When the latent system is convolutional, the decoder " + "aggregation type should be None.") + else: + if encoder_aggregation_type is None: + raise ValueError("When the latent system is not convolutional, the " + "you must provide an encoder aggregation type.") + if decoder_de_aggregation_type is None: + raise ValueError("When the latent system is not convolutional, the " + "you must provide an decoder aggregation type.") + if has_latent_transform and latent_transform_kwargs is None: + raise ValueError("When using latent transformation you have to provide " + "the latent_transform_kwargs argument.") + if unused_kwargs: + logging.warning("Unused kwargs: %s", str(unused_kwargs)) + super().__init__(**unused_kwargs) + self.can_run_backwards = can_run_backwards + self.latent_system_dim = latent_system_dim + self.latent_system_kwargs = latent_system_kwargs + self.latent_system_net_type = latent_system_net_type + self.latent_spatial_shape = latent_spatial_shape + self.num_inference_steps = num_inference_steps + self.num_target_steps = num_target_steps + self.rescale_by = rescale_by + self.data_format = data_format + self.name = name + + # Encoder + self.encoder_kwargs = encoder_kwargs + self.encoder = hk.transform( + lambda *args, **kwargs: networks.SpatialConvEncoder( # pylint: disable=unnecessary-lambda,g-long-lambda + latent_dim=latent_system_dim, + aggregation_type=encoder_aggregation_type, + data_format=data_format, + name="Encoder", + **encoder_kwargs + )(*args, **kwargs)) + + # Decoder + self.decoder_kwargs = decoder_kwargs + self.decoder = hk.transform( + lambda *args, **kwargs: networks.SpatialConvDecoder( # pylint: disable=unnecessary-lambda,g-long-lambda + initial_spatial_shape=self.latent_spatial_shape, + de_aggregation_type=decoder_de_aggregation_type, + data_format=data_format, + max_de_aggregation_dims=self.latent_system_dim // 2, + name="Decoder", + **decoder_kwargs, + )(*args, **kwargs)) + + self.has_latent_transform = has_latent_transform + if has_latent_transform: + self.latent_transform = hk.transform( + lambda *args, **kwargs: networks.make_flexible_net( # pylint: disable=unnecessary-lambda,g-long-lambda + net_type=latent_system_net_type, + output_dims=latent_system_dim, + name="LatentTransform", + **latent_transform_kwargs + )(*args, **kwargs)) + else: + self.latent_transform = None + + self._jit_init = None + + @property + @abc.abstractmethod + def train_sequence_length(self) -> int: + """Computes the total length of a sequence needed for training or evaluation.""" + pass + + @abc.abstractmethod + def train_data_split( + self, + images: jnp.ndarray, + ) -> Tuple[jnp.ndarray, jnp.ndarray, Mapping[str, Any]]: + """Extracts from the inputs the data splits for training.""" + pass + + def decode_latents( + self, + params: hk.Params, + rng: jnp.ndarray, + z: jnp.ndarray, + **kwargs: Any + ) -> distrax.Distribution: + """Decodes the latent variable given the parameters of the model.""" + # Allow to run with both the full parameters and only the decoders + if self.latent_system_net_type == "mlp": + fixed_dims = 1 + elif self.latent_system_net_type == "conv": + fixed_dims = 1 + len(self.latent_spatial_shape) + else: + raise NotImplementedError() + n_shape = z.shape[:-fixed_dims] + z = z.reshape((-1,) + z.shape[-fixed_dims:]) + x = self.decoder.apply(params, rng, z, **kwargs) + return jax.tree_map(lambda a: a.reshape(n_shape + a.shape[1:]), x) + + def apply_latent_transform( + self, + params: hk.Params, + key: jnp.ndarray, + z: jnp.ndarray, + **kwargs: Any + ) -> jnp.ndarray: + if self.latent_transform is not None: + return self.latent_transform.apply(params, key, z, **kwargs) + else: + return z + + @abc.abstractmethod + def process_inputs_for_encoder(self, x: jnp.ndarray) -> jnp.ndarray: + pass + + @abc.abstractmethod + def process_latents_for_dynamics(self, z: jnp.ndarray) -> T: + pass + + @abc.abstractmethod + def process_latents_for_decoder(self, z: T) -> jnp.ndarray: + pass + + @abc.abstractmethod + def unroll_latent_dynamics( + self, + z: T, + params: utils.Params, + key: jnp.ndarray, + num_steps_forward: int, + num_steps_backward: int, + include_z0: bool, + is_training: bool, + **kwargs: Any + ) -> Tuple[T, Mapping[str, jnp.ndarray]]: + """Unrolls the latent dynamics starting from z and pre-processing for the decoder.""" + pass + + @abc.abstractmethod + def reconstruct( + self, + params: utils.Params, + inputs: jnp.ndarray, + rng_key: Optional[jnp.ndarray], + forward: bool, + ) -> distrax.Distribution: + """Using the first `num_inference_steps` parts of inputs reconstructs the rest.""" + pass + + @abc.abstractmethod + def training_objectives( + self, + params: utils.Params, + state: hk.State, + rng: jnp.ndarray, + inputs: Union[Dict[str, jnp.ndarray], jnp.ndarray], + step: jnp.ndarray, + is_training: bool = True, + use_mean_for_eval_stats: bool = True + ) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]: + """Returns all training objectives statistics and update states.""" + pass + + @property + @abc.abstractmethod + def inferred_index(self): + """Returns the time index in the input sequence, for which the encoder infers. + + If the encoder takes as input the sequence x[0:n-1], where + `n = self.num_inference_steps`, then this outputs the index `k` relative to + the begging of the input sequence `x_0`, which the encoder infers. + """ + pass + + @property + def inferred_right_offset(self): + return self.num_inference_steps - 1 - self.inferred_index + + @abc.abstractmethod + def gt_state_and_latents( + self, + params: hk.Params, + rng: jnp.ndarray, + inputs: Dict[str, jnp.ndarray], + seq_len: int, + is_training: bool = False, + unroll_direction: str = "forward", + **kwargs: Dict[str, Any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Computes the ground state and matching latents.""" + pass + + @abc.abstractmethod + def _init_non_model_params_and_state( + self, + rng: jnp.ndarray + ) -> Tuple[utils.Params, utils.Params]: + """Initializes any non-model parameters and state.""" + pass + + @abc.abstractmethod + def _init_latent_system( + self, + rng: jnp.ndarray, + z: jnp.ndarray, + **kwargs: Any + ) -> hk.Params: + """Initializes the parameters of the latent system.""" + pass + + def _init( + self, + rng: jnp.ndarray, + images: jnp.ndarray + ) -> Tuple[hk.Params, hk.State]: + """Initializes the whole model parameters and state.""" + inference_data, _, _ = self.train_data_split(images) + # Initialize parameters and state for the vae training + rng, key = jnr.split(rng) + params, state = self._init_non_model_params_and_state(key) + + # Initialize and run encoder + inference_data = self.process_inputs_for_encoder(inference_data) + rng, key = jnr.split(rng) + encoder_params = self.encoder.init(key, inference_data, is_training=True) + rng, key = jnr.split(rng) + z_in = self.encoder.apply(encoder_params, key, inference_data, + is_training=True) + + # For probabilistic models this will be a distribution + if isinstance(z_in, distrax.Distribution): + z_in = z_in.mean() + + # Initialize and run the optional latent transform + if self.latent_transform is not None: + rng, key = jnr.split(rng) + transform_params = self.latent_transform.init(key, z_in, is_training=True) + rng, key = jnr.split(rng) + z_in = self.latent_transform.apply(transform_params, key, z_in, + is_training=True) + else: + transform_params = dict() + + # Initialize and run the latent system + z_in = self.process_latents_for_dynamics(z_in) + rng, key = jnr.split(rng) + latent_params = self._init_latent_system(key, z_in, is_training=True) + rng, key = jnr.split(rng) + z_out, _ = self.unroll_latent_dynamics( + z=z_in, + params=latent_params, + key=key, + num_steps_forward=1, + num_steps_backward=0, + include_z0=False, + is_training=True + ) + z_out = self.process_latents_for_decoder(z_out) + + # Initialize and run the decoder + rng, key = jnr.split(rng) + decoder_params = self.decoder.init(key, z_out[:, 0], is_training=True) + _ = self.decoder.apply(decoder_params, rng, z_out[:, 0], is_training=True) + + # Combine all and make immutable + params = hk.data_structures.merge(params, encoder_params, transform_params, + latent_params, decoder_params) + params = hk.data_structures.to_immutable_dict(params) + state = hk.data_structures.to_immutable_dict(state) + + return params, state + + def init( + self, + rng: jnp.ndarray, + inputs_or_shape: Union[jnp.ndarray, Mapping[str, jnp.ndarray], + Sequence[int]], + ) -> Tuple[utils.Params, hk.State]: + """Initializes the whole model parameters and state.""" + if (isinstance(inputs_or_shape, (tuple, list)) + and isinstance(inputs_or_shape[0], int)): + images = jnp.zeros(inputs_or_shape) + else: + images = utils.extract_image(inputs_or_shape) + if self._jit_init is None: + self._jit_init = jax.jit(self._init) + return self._jit_init(rng, images) diff --git a/physics_inspired_models/models/common.py b/physics_inspired_models/models/common.py new file mode 100644 index 0000000..02010ec --- /dev/null +++ b/physics_inspired_models/models/common.py @@ -0,0 +1,117 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for all models.""" +from typing import Any, Dict, Optional + +import physics_inspired_models.models.autoregressive as autoregressive +import physics_inspired_models.models.deterministic_vae as deterministic_vae + +_physics_arguments = ( + "input_space", "simulation_space", "potential_func_form", + "kinetic_func_form", "hgn_kinetic_func_form", "lgn_kinetic_func_form", + "parametrize_mass_matrix", "hgn_parametrize_mass_matrix", + "lgn_parametrize_mass_matrix", "mass_eps" +) + + +def construct_model( + name: str, + *args, + **kwargs: Dict[str, Any] +): + """Constructs the correct instance of a model given the short name.""" + latent_dynamics_type: Optional[str] = kwargs.pop("latent_dynamics_type", None) # pytype: disable=annotation-type-mismatch + latent_system_kwargs = dict(**kwargs.pop("latent_system_kwargs", dict())) + if name == "AR": + assert latent_dynamics_type in ("vanilla", "lstm", "gru") + # This arguments are not part of the AR models + for k in _physics_arguments + ("integrator_method", "residual"): + latent_system_kwargs.pop(k, None) + return autoregressive.TeacherForcingAutoregressiveModel( + *args, + latent_dynamics_type=latent_dynamics_type, + latent_system_kwargs=latent_system_kwargs, + **kwargs + ) + elif name == "RGN": + assert latent_dynamics_type in ("Discrete", None) + latent_dynamics_type = "Discrete" + # This arguments are not part of the RGN models + for k in _physics_arguments + ("integrator_method",): + latent_system_kwargs.pop(k, None) + elif name == "ODE": + assert latent_dynamics_type in ("ODE", None) + latent_dynamics_type = "ODE" + # This arguments are not part of the ODE models + for k in _physics_arguments + ("residual",): + latent_system_kwargs.pop(k, None) + elif name == "HGN": + assert latent_dynamics_type in ("Physics", None) + latent_dynamics_type = "Physics" + assert latent_system_kwargs.get("input_space", None) in ("momentum", None) + latent_system_kwargs["input_space"] = "momentum" + assert (latent_system_kwargs.get("simulation_space", None) + in ("momentum", None)) + latent_system_kwargs["simulation_space"] = "momentum" + # Kinetic func form + hgn_specific = latent_system_kwargs.pop("hgn_kinetic_func_form", None) + if hgn_specific is not None: + latent_system_kwargs["kinetic_func_form"] = hgn_specific + # Mass matrix + hgn_specific = latent_system_kwargs.pop("hgn_parametrize_mass_matrix", + None) + if hgn_specific is not None: + latent_system_kwargs["parametrize_mass_matrix"] = hgn_specific + # This arguments are not part of the HGN models + latent_system_kwargs.pop("residual", None) + latent_system_kwargs.pop("lgn_kinetic_func_form", None) + latent_system_kwargs.pop("lgn_parametrize_mass_matrix", None) + elif name == "LGN": + assert latent_dynamics_type in ("Physics", None) + latent_dynamics_type = "Physics" + assert latent_system_kwargs.get("input_space", None) in ("velocity", None) + latent_system_kwargs["input_space"] = "velocity" + assert (latent_system_kwargs.get("simulation_space", None) in + ("velocity", None)) + latent_system_kwargs["simulation_space"] = "velocity" + # Kinetic func form + lgn_specific = latent_system_kwargs.pop("lgn_kinetic_func_form", None) + if lgn_specific is not None: + latent_system_kwargs["kinetic_func_form"] = lgn_specific + # Mass matrix + lgn_specific = latent_system_kwargs.pop("lgn_parametrize_mass_matrix", + None) + if lgn_specific is not None: + latent_system_kwargs["parametrize_mass_matrix"] = lgn_specific + # This arguments are not part of the HGN models + latent_system_kwargs.pop("residual", None) + latent_system_kwargs.pop("hgn_kinetic_func_form", None) + latent_system_kwargs.pop("hgn_parametrize_mass_matrix", None) + elif name == "PGN": + assert latent_dynamics_type in ("Physics", None) + latent_dynamics_type = "Physics" + # This arguments are not part of the PGN models + latent_system_kwargs.pop("residual") + latent_system_kwargs.pop("hgn_kinetic_func_form", None) + latent_system_kwargs.pop("hgn_parametrize_mass_matrix", None) + latent_system_kwargs.pop("lgn_kinetic_func_form", None) + latent_system_kwargs.pop("lgn_parametrize_mass_matrix", None) + else: + raise NotImplementedError() + return deterministic_vae.DeterministicLatentsGenerativeModel( + *args, + latent_dynamics_type=latent_dynamics_type, + latent_system_kwargs=latent_system_kwargs, + **kwargs) diff --git a/physics_inspired_models/models/deterministic_vae.py b/physics_inspired_models/models/deterministic_vae.py new file mode 100644 index 0000000..a477b97 --- /dev/null +++ b/physics_inspired_models/models/deterministic_vae.py @@ -0,0 +1,617 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing the main models code.""" +import functools +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union + +import distrax +from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space +import haiku as hk +import jax.numpy as jnp +import jax.random as jnr +import numpy as np + +from physics_inspired_models import metrics +from physics_inspired_models import utils +from physics_inspired_models.models import base +from physics_inspired_models.models import dynamics + +_ArrayOrPhase = Union[jnp.ndarray, phase_space.PhaseSpace] + + +class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPhase]): + """Common class for generative models with deterministic latent dynamics.""" + + def __init__( + self, + latent_system_dim: int, + latent_system_net_type: str, + latent_system_kwargs: Dict[str, Any], + latent_dynamics_type: str, + encoder_aggregation_type: Optional[str], + decoder_de_aggregation_type: Optional[str], + encoder_kwargs: Dict[str, Any], + decoder_kwargs: Dict[str, Any], + num_inference_steps: int, + num_target_steps: int, + latent_training_type: str, + training_data_split: str, + objective_type: str, + dt: float = 0.125, + render_from_q_only: bool = True, + prior_type: str = "standard_normal", + use_analytical_kl: bool = True, + geco_kappa: float = 0.001, + geco_alpha: Optional[float] = 0.0, + elbo_beta_delay: int = 0, + elbo_beta_final: float = 1.0, + name: Optional[str] = None, + **kwargs + ): + can_run_backwards = latent_dynamics_type in ("ODE", "Physics") + + # Verify arguments + if objective_type not in ("GECO", "ELBO", "NON-PROB"): + raise ValueError(f"Unrecognized training type - {objective_type}") + if geco_alpha is None: + geco_alpha = 0 + if geco_alpha < 0 or geco_alpha >= 1: + raise ValueError("GECO alpha parameter must be in [0, 1).") + if prior_type not in ("standard_normal", "made", "made_gated"): + raise ValueError(f"Unrecognized prior_type='{prior_type}.") + if (latent_training_type == "forward_backward" and + training_data_split != "include_inference"): + raise ValueError("Training forward_backward works only when " + "training_data_split=include_inference.") + if (latent_training_type == "forward_backward" and + num_inference_steps % 2 == 0): + raise ValueError("Training forward_backward works only when " + "num_inference_steps are odd.") + if latent_training_type == "forward_backward" and not can_run_backwards: + raise ValueError("Training forward_backward works only when the model can" + " be run backwards.") + if prior_type != "standard_normal": + raise ValueError("For now we support only `standard_normal`.") + + super().__init__( + can_run_backwards=can_run_backwards, + latent_system_dim=latent_system_dim, + latent_system_net_type=latent_system_net_type, + latent_system_kwargs=latent_system_kwargs, + encoder_aggregation_type=encoder_aggregation_type, + decoder_de_aggregation_type=decoder_de_aggregation_type, + encoder_kwargs=encoder_kwargs, + decoder_kwargs=decoder_kwargs, + num_inference_steps=num_inference_steps, + num_target_steps=num_target_steps, + name=name, + **kwargs + ) + # VAE specific arguments + self.prior_type = prior_type + self.objective_type = objective_type + self.use_analytical_kl = use_analytical_kl + self.geco_kappa = geco_kappa + self.geco_alpha = geco_alpha + self.elbo_beta_delay = elbo_beta_delay + self.elbo_beta_final = jnp.asarray(elbo_beta_final) + + # The dynamics module and arguments + self.latent_dynamics_type = latent_dynamics_type + self.latent_training_type = latent_training_type + self.training_data_split = training_data_split + self.dt = dt + self.render_from_q_only = render_from_q_only + latent_system_kwargs["net_kwargs"] = dict( + latent_system_kwargs["net_kwargs"]) + latent_system_kwargs["net_kwargs"]["net_type"] = self.latent_system_net_type + + if self.latent_dynamics_type == "Physics": + # Note that here system_dim means the dimensionality of `q` and `p`. + model_constructor = functools.partial( + dynamics.PhysicsSimulationNetwork, + system_dim=self.latent_system_dim // 2, + name="Physics", + **latent_system_kwargs + ) + elif self.latent_dynamics_type == "ODE": + model_constructor = functools.partial( + dynamics.OdeNetwork, + system_dim=self.latent_system_dim, + name="ODE", + **latent_system_kwargs + ) + elif self.latent_dynamics_type == "Discrete": + model_constructor = functools.partial( + dynamics.DiscreteDynamicsNetwork, + system_dim=self.latent_system_dim, + name="Discrete", + **latent_system_kwargs + ) + else: + raise NotImplementedError() + self.dynamics = hk.transform( + lambda *args, **kwargs_: model_constructor()(*args, **kwargs_)) # pylint: disable=unnecessary-lambda + + def process_inputs_for_encoder(self, x: jnp.ndarray) -> jnp.ndarray: + return utils.stack_time_into_channels(x, self.data_format) + + def process_latents_for_dynamics(self, z: jnp.ndarray) -> _ArrayOrPhase: + if self.latent_dynamics_type == "Physics": + return phase_space.PhaseSpace.from_state(z) + return z + + def process_latents_for_decoder(self, z: _ArrayOrPhase) -> jnp.ndarray: + if self.latent_dynamics_type == "Physics": + return z.q if self.render_from_q_only else z.single_state + return z + + @property + def inferred_index(self) -> int: + if self.latent_training_type == "forward": + return self.num_inference_steps - 1 + elif self.latent_training_type == "forward_backward": + assert self.num_inference_steps % 2 == 1 + return self.num_inference_steps // 2 + else: + raise NotImplementedError() + + @property + def targets_index_offset(self) -> int: + if self.training_data_split == "overlap_by_one": + return -1 + elif self.training_data_split == "no_overlap": + return 0 + elif self.training_data_split == "include_inference": + return - self.num_inference_steps + else: + raise NotImplementedError() + + @property + def targets_length(self) -> int: + if self.training_data_split == "include_inference": + return self.num_inference_steps + self.num_target_steps + return self.num_target_steps + + @property + def train_sequence_length(self) -> int: + """Computes the total length of a sequence needed for training.""" + if self.training_data_split == "overlap_by_one": + # Input - [-------------------------------------------------] + # Inference - [---------------] + # Targets - [---------------------------------] + return self.num_inference_steps + self.num_target_steps - 1 + elif self.training_data_split == "no_overlap": + # Input - [-------------------------------------------------] + # Inference - [---------------] + # Targets - [--------------------------------] + return self.num_inference_steps + self.num_target_steps + elif self.training_data_split == "include_inference": + # Input - [-------------------------------------------------] + # Inference - [---------------] + # Targets - [-------------------------------------------------] + return self.num_inference_steps + self.num_target_steps + else: + raise NotImplementedError() + + def train_data_split( + self, + images: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray, Mapping[str, Any]]: + images = images[:, :self.train_sequence_length] + inf_idx = self.num_inference_steps + t_idx = self.num_inference_steps + self.targets_index_offset + if self.latent_training_type == "forward": + inference_data = images[:, :inf_idx] + target_data = images[:, t_idx:] + if self.training_data_split == "include_inference": + num_steps_backward = self.inferred_index + else: + num_steps_backward = 0 + num_steps_forward = self.num_target_steps + if self.training_data_split == "overlap_by_one": + num_steps_forward -= 1 + unroll_kwargs = dict( + num_steps_backward=num_steps_backward, + include_z0=self.training_data_split != "no_overlap", + num_steps_forward=num_steps_forward, + dt=self.dt + ) + elif self.latent_training_type == "forward_backward": + assert self.training_data_split == "include_inference" + n_fwd = images.shape[0] // 2 + inference_fwd = images[:n_fwd, :inf_idx] + targets_fwd = images[:n_fwd, t_idx:] + inference_bckwd = images[n_fwd:, -inf_idx:] + targets_bckwd = jnp.flip(images[n_fwd:, :images.shape[1] - t_idx], axis=1) + inference_data = jnp.concatenate([inference_fwd, inference_bckwd], axis=0) + target_data = jnp.concatenate([targets_fwd, targets_bckwd], axis=0) + # This needs to by numpy rather than jax.numpy, because we make some + # verification checks in `integrators.py:149-161`. + dt_fwd = np.full([n_fwd], self.dt) + dt_bckwd = np.full([images.shape[0] - n_fwd], self.dt) + dt = np.concatenate([dt_fwd, -dt_bckwd], axis=0) + unroll_kwargs = dict( + num_steps_backward=self.inferred_index, + include_z0=True, + num_steps_forward=self.targets_length - self.inferred_index - 1, + dt=dt + ) + else: + raise NotImplementedError() + return inference_data, target_data, unroll_kwargs + + def prior(self) -> distrax.Distribution: + """Given the parameters returns the prior distribution of the model.""" + # Allow to run with both the full parameters and only the priors + if self.prior_type == "standard_normal": + # assert self.prior_nets is None and self.gated_made is None + if self.latent_system_net_type == "mlp": + event_shape = (self.latent_system_dim,) + elif self.latent_system_net_type == "conv": + if self.data_format == "NHWC": + event_shape = self.latent_spatial_shape + (self.latent_system_dim,) + else: + event_shape = (self.latent_system_dim,) + self.latent_spatial_shape + else: + raise NotImplementedError() + return distrax.Normal(jnp.zeros(event_shape), jnp.ones(event_shape)) + else: + raise ValueError(f"Unrecognized prior_type='{self.prior_type}'.") + + def sample_latent_from_prior( + self, + params: utils.Params, + rng: jnp.ndarray, + num_samples: int = 1, + **kwargs: Any) -> jnp.ndarray: + """Takes sample from the prior (and optionally puts them through the latent transform function.""" + _, sample_key, transf_key = jnr.split(rng, 3) + prior = self.prior() + z_raw = prior.sample(seed=sample_key, sample_shape=[num_samples]) + return self.apply_latent_transform(params, transf_key, z_raw, **kwargs) + + def sample_trajectories_from_prior( + self, + params: utils.Params, + num_steps: int, + rng: jnp.ndarray, + num_samples: int = 1, + is_training: bool = False, + **kwargs + ) -> distrax.Distribution: + """Generates samples from the prior (unconditional generation).""" + sample_key, unroll_key, dec_key = jnr.split(rng, 3) + z0 = self.sample_latent_from_prior(params, sample_key, num_samples, + is_training=is_training) + z, _ = self.unroll_latent_dynamics( + z=self.process_latents_for_dynamics(z0), + params=params, + key=unroll_key, + num_steps_forward=num_steps, + num_steps_backward=0, + include_z0=True, + is_training=is_training, + **kwargs + ) + z = self.process_latents_for_decoder(z) + return self.decode_latents(params, dec_key, z, is_training=is_training) + + def verify_unroll_args( + self, + num_steps_forward: int, + num_steps_backward: int, + include_z0: bool + ) -> None: + if num_steps_forward < 0 or num_steps_backward < 0: + raise ValueError("num_steps_forward and num_steps_backward can not be " + "negative.") + if num_steps_forward == 0 and num_steps_backward == 0: + raise ValueError("You need one of num_steps_forward or " + "num_of_steps_backward to be positive.") + if num_steps_forward > 0 and num_steps_backward > 0 and not include_z0: + raise ValueError("When both num_steps_forward and num_steps_backward are " + "positive include_t0 should be True.") + if num_steps_backward > 0 and not self.can_run_backwards: + raise ValueError("This model can not be unrolled backward in time.") + + def unroll_latent_dynamics( + self, + z: phase_space.PhaseSpace, + params: hk.Params, + key: jnp.ndarray, + num_steps_forward: int, + num_steps_backward: int, + include_z0: bool, + is_training: bool, + **kwargs: Any + ) -> Tuple[_ArrayOrPhase, Mapping[str, jnp.ndarray]]: + self.verify_unroll_args(num_steps_forward, num_steps_backward, include_z0) + return self.dynamics.apply( + params, + key, + y0=z, + dt=kwargs.pop("dt", self.dt), + num_steps_forward=num_steps_forward, + num_steps_backward=num_steps_backward, + include_y0=include_z0, + return_stats=True, + is_training=is_training + ) + + def _models_core( + self, + params: utils.Params, + keys: jnp.ndarray, + image_data: jnp.ndarray, + use_mean: bool, + is_training: bool, + **unroll_kwargs: Any + ) -> Tuple[distrax.Distribution, distrax.Distribution, distrax.Distribution, + jnp.ndarray, jnp.ndarray, Mapping[str, jnp.ndarray]]: + enc_key, sample_key, transform_key, unroll_key, dec_key, _ = keys + + # Calculate the approximate posterior q(z|x) + inference_data = self.process_inputs_for_encoder(image_data) + q_z: distrax.Distribution = self.encoder.apply(params, enc_key, + inference_data, + is_training=is_training) + + # Sample latent variables or take the mean + z_raw = q_z.mean() if use_mean else q_z.sample(seed=sample_key) + + # Apply latent transformation + z0 = self.apply_latent_transform(params, transform_key, z_raw, + is_training=is_training) + + # Unroll the latent variable + z, dyn_stats = self.unroll_latent_dynamics( + z=self.process_latents_for_dynamics(z0), + params=params, + key=unroll_key, + is_training=is_training, + **unroll_kwargs + ) + decoder_z = self.process_latents_for_decoder(z) + + # Compute p(x|z) + p_x = self.decode_latents(params, dec_key, decoder_z, + is_training=is_training) + + z = z.single_state if isinstance(z, phase_space.PhaseSpace) else z + return p_x, q_z, self.prior(), z0, z, dyn_stats + + def training_objectives( + self, + params: utils.Params, + state: hk.State, + rng: jnp.ndarray, + inputs: jnp.ndarray, + step: jnp.ndarray, + is_training: bool = True, + use_mean_for_eval_stats: bool = True + ) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]: + # Split all rng keys + keys = jnr.split(rng, 6) + + # Process training data + images = utils.extract_image(inputs) + image_data, target_data, unroll_kwargs = self.train_data_split(images) + + p_x, q_z, prior, _, _, dyn_stats = self._models_core( + params=params, + keys=keys, + image_data=image_data, + use_mean=False, + is_training=is_training, + **unroll_kwargs + ) + + # Note: we reuse the rng key used to sample the latent variable here + # so that it can be reused to evaluate a (non-analytical) KL at that sample. + stats = metrics.training_statistics( + p_x=p_x, + targets=target_data, + rescale_by=self.rescale_by, + rng=keys[1], + q_z=q_z, + prior=prior, + p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False) + ) + stats.update(dyn_stats) + + # Compute other (non-reported statistics) + z_stats = dict() + other_stats = dict(x_reconstruct=p_x.mean(), z_stats=z_stats) + + # The loss computation and GECO state update + new_state = dict() + if self.objective_type == "GECO": + geco_stats = metrics.geco_objective( + l2_loss=stats["l2"], + kl=stats["kl"], + alpha=self.geco_alpha, + kappa=self.geco_kappa, + constraint_ema=state["GECO"]["geco_constraint_ema"], + lambda_var=params["GECO"]["geco_lambda_var"], + is_training=is_training + ) + new_state["GECO"] = dict( + geco_constraint_ema=geco_stats["geco_constraint_ema"]) + stats.update(geco_stats) + elif self.objective_type == "ELBO": + elbo_stats = metrics.elbo_objective( + neg_log_p_x=stats["neg_log_p_x"], + kl=stats["kl"], + final_beta=self.elbo_beta_final, + beta_delay=self.elbo_beta_delay, + step=step + ) + stats.update(elbo_stats) + elif self.objective_type == "NON-PROB": + stats["loss"] = stats["neg_log_p_x"] + else: + raise ValueError() + + if not is_training: + if self.training_data_split == "overlap_by_one": + reconstruction_skip = self.num_inference_steps - 1 + elif self.training_data_split == "no_overlap": + reconstruction_skip = self.num_inference_steps + elif self.training_data_split == "include_inference": + reconstruction_skip = 0 + else: + raise NotImplementedError() + # We intentionally reuse the same rng as the training, in order to be able + # to run tests and verify that the evaluation and reconstruction work + # correctly. + # We need to be able to set `use_mean = False` for some of the tests + stats.update(metrics.evaluation_only_statistics( + reconstruct_func=functools.partial( + self.reconstruct, use_mean=use_mean_for_eval_stats), + params=params, + inputs=inputs, + rng=rng, + rescale_by=self.rescale_by, + can_run_backwards=self.can_run_backwards, + train_sequence_length=self.train_sequence_length, + reconstruction_skip=reconstruction_skip, + p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False) + )) + + # Make new state the same type as state + new_state = utils.convert_to_pytype(new_state, state) + return stats["loss"], (new_state, stats, other_stats) + + def reconstruct( + self, + params: utils.Params, + inputs: jnp.ndarray, + rng: Optional[jnp.ndarray], + forward: bool, + use_mean: bool = True, + ) -> distrax.Distribution: + if not self.can_run_backwards and not forward: + raise ValueError("This model can not be run backwards.") + images = utils.extract_image(inputs) + # This is intentionally matching the split for the training stats + if forward: + num_steps_backward = self.inferred_index + num_steps_forward = images.shape[1] - num_steps_backward - 1 + else: + num_steps_forward = self.num_inference_steps - self.inferred_index - 1 + num_steps_backward = images.shape[1] - num_steps_forward - 1 + if not self.can_run_backwards: + num_steps_backward = 0 + + if forward: + image_data = images[:, :self.num_inference_steps] + else: + image_data = images[:, -self.num_inference_steps:] + + return self._models_core( + params=params, + keys=jnr.split(rng, 6), + image_data=image_data, + use_mean=use_mean, + is_training=False, + num_steps_forward=num_steps_forward, + num_steps_backward=num_steps_backward, + include_z0=True, + )[0] + + def gt_state_and_latents( + self, + params: hk.Params, + rng: jnp.ndarray, + inputs: Dict[str, jnp.ndarray], + seq_length: int, + is_training: bool = False, + unroll_direction: str = "forward", + **kwargs: Dict[str, Any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, + Union[distrax.Distribution, jnp.ndarray]]: + """Computes the ground state and matching latents.""" + assert unroll_direction in ("forward", "backward") + if unroll_direction == "backward" and not self.can_run_backwards: + raise ValueError("This model can not be unrolled backwards.") + + images = utils.extract_image(inputs) + gt_state = utils.extract_gt_state(inputs) + + if unroll_direction == "forward": + image_data = images[:, :self.num_inference_steps] + if self.can_run_backwards: + num_steps_backward = self.inferred_index + gt_start_idx = 0 + else: + num_steps_backward = 0 + gt_start_idx = self.inferred_index + num_steps_forward = seq_length - num_steps_backward - 1 + gt_state = gt_state[:, gt_start_idx: seq_length + gt_start_idx] + elif unroll_direction == "backward": + inference_start_idx = seq_length - self.num_inference_steps + image_data = images[:, inference_start_idx: seq_length] + num_steps_forward = self.num_inference_steps - self.inferred_index - 1 + num_steps_backward = seq_length - num_steps_forward - 1 + gt_state = gt_state[:, :seq_length] + else: + raise NotImplementedError() + + _, q_z, _, z0, z, _ = self._models_core( + params=params, + keys=jnr.split(rng, 6), + image_data=image_data, + use_mean=True, + is_training=False, + num_steps_forward=num_steps_forward, + num_steps_backward=num_steps_backward, + include_z0=True, + ) + + if self.has_latent_transform: + return gt_state, z, z0 + else: + return gt_state, z, q_z + + def _init_non_model_params_and_state( + self, + rng: jnp.ndarray + ) -> Tuple[utils.Params, utils.Params]: + if self.objective_type == "GECO": + # Initialize such that softplus(lambda_var) = 1 + geco_lambda_var = jnp.asarray(jnp.log(jnp.e - 1.0)) + geco_constraint_ema = jnp.asarray(0.0) + return (dict(GECO=dict(geco_lambda_var=geco_lambda_var)), + dict(GECO=dict(geco_constraint_ema=geco_constraint_ema))) + else: + return dict(), dict() + + def _init_latent_system( + self, + rng: jnp.ndarray, + z: jnp.ndarray, + **kwargs: Mapping[str, Any] + ) -> hk.Params: + """Initializes the parameters of the latent system.""" + return self.dynamics.init( + rng, + y0=z, + dt=self.dt, + num_steps_forward=1, + num_steps_backward=0, + include_y0=True, + **kwargs + ) diff --git a/physics_inspired_models/models/dynamics.py b/physics_inspired_models/models/dynamics.py new file mode 100644 index 0000000..25e22b4 --- /dev/null +++ b/physics_inspired_models/models/dynamics.py @@ -0,0 +1,839 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing all of the networks as Haiku modules.""" +from typing import Any, Mapping, Optional, Tuple, Union + +from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space +import haiku as hk +import jax +import jax.numpy as jnp + +from physics_inspired_models import integrators +from physics_inspired_models import utils +from physics_inspired_models.models import networks + +_PhysicsSimulationOutput = Union[ + phase_space.PhaseSpace, + Tuple[phase_space.PhaseSpace, Mapping[str, jnp.ndarray]] +] + + +class PhysicsSimulationNetwork(hk.Module): + """A model for simulating an abstract physical system, whose energy is defined by a neural network.""" + + def __init__( + self, + system_dim: int, + input_space: str, + simulation_space: str, + potential_func_form: str, + kinetic_func_form: str, + parametrize_mass_matrix: bool, + net_kwargs: Mapping[str, Any], + mass_eps: float = 1.0, + integrator_method: Optional[str] = None, + steps_per_dt: int = 1, + ode_int_kwargs: Optional[Mapping[str, float]] = None, + use_scan: bool = True, + feature_axis: int = -1, + features_extra_dims: Optional[int] = None, + network_creation_func=networks.make_flexible_net, + name: Optional[str] = None + ): + """Initializes the model. + + Args: + system_dim: The number of system dimensions. Note that this specifies the + number of dimensions only of the position vectors, not of position and + momentum. Hence the generalized coordinates would be of dimension + `2 * system_dim`. + input_space: Either `velocity` or `momentum`. Specifies whether the inputs + to the model are to be interpreted as `(position, velocity)` or as + `(position, momentum)`. + simulation_space: Either `velocity` or `momentum`. Specifies whether the + model should simulate the dynamics in `(position, velocity)` space + using the Lagrangian formulation or in `(position, momentum)` space + using the Hamiltonian formulation. If this is different than the value + of `input_space` then `kinetic_func_form` must be one of pure_quad, + matrix_diag_quad, matrix_quad, matrix_dep_diag_quad, matrix_dep_quad. + In all other cases one can not compute analytically the form of the + functional (Lagrangian or Hamiltonian) from the other. + potential_func_form: String specifying the form of the potential energy: + * separable_net - The network uses only the position: + U(q, q_dot/p) = f(q) f: R^d -> R + * dep_net - The network uses both the position and velocity/momentum: + U(q, q_dot/p) = f(q, q_dot/p) f: R^d x R^d -> R + * embed_quad - A quadratic of the embedding of a network embedding of + the velocity/momentum: + U(q, q_dot/p) = f(q)^T f(q) / 2 f: R^d -> R^d + kinetic_func_form: String specifying the form of the potential energy: + * separable_net - The network uses only the velocity/momentum: + K(q, q_dot/p) = f(q_dot/p) f: R^d -> R + * dep_net - The network uses both the position and velocity/momentum: + K(q, q_dot/p) = f(q, q_dot/p) f: R^d x R^d -> R + * pure_quad - A quadratic function of the velocity/momentum: + K(q, q_dot/p) = (q_dot/p)^T (q_dot/p) / 2 + * matrix_diag_quad - A quadratic function of the velocity/momentum, + where there is diagonal mass matrix, whose log `P` is a parameter: + K(q, q_dot) = q_dot^T M q_dot / 2 + K(q, p) = p^T M^-1 p / 2 + [if `parameterize_mass_matrix`] + M = diag(exp(P) + mass_eps) + [else] + M^-1 = diag(exp(P) + mass_eps) + * matrix_quad - A quadratic function of the velocity/momentum, where + there is a full mass matrix, whose Cholesky factor L is a parameter: + K(q, q_dot) = q_dot^T M q_dot / 2 + K(q, p) = p^T M^-1 p / 2 + [if `parameterize_mass_matrix`] + M = LL^T + mass_eps * I + [else] + M^-1 = LL^T + mass_eps * I + * matrix_dep_quad - A quadratic function of the velocity/momentum, where + there is a full mass matrix defined as a function of the position: + K(q, q_dot) = q_dot^T M(q) q_dot / 2 + K(q, p) = p^T M(q)^-1 p / 2 + [if `parameterize_mass_matrix`] + M(q) = g(q) g(q)^T + mass_eps * I g: R^d -> R^(d(d+1)/2) + [else] + M(q)^-1 = g(q) g(q)^T + mass_eps * I g: R^d -> R^(d(d+1)/2) + * embed_quad - A quadratic of the embedding of a network embedding of + the velocity/momentum: + K(q, q_dot/p) = f(q_dot/p)^T f(q_dot/p) / 2 f: R^d -> R^d + * matrix_dep_diag_embed_quad - A quadratic of the embedding of a network + embedding of the velocity/momentum where there is diagonal mass matrix + defined as a function of the position: + K(q, q_dot) = f(q_dot)^T M(q) f(q_dot) / 2 f: R^d -> R^d + K(q, p) = f(p)^T M(q)^-1 f(p) / 2 f: R^d -> R^d + [if `parameterize_mass_matrix`] + M(q) = diag(exp(g(q)) + mass_eps * I g: R^d -> R^d + [else] + M(q)^-1 = diag(exp(g(q)) + mass_eps * I g: R^d -> R^d + * matrix_dep_embed_quad - A quadratic of the embedding of a network + embedding of the velocity/momentum where there is a full mass matrix + defined as a function of the position: + K(q, q_dot) = f(q_dot)^T M(q) f(q_dot) / 2 f: R^d -> R^d + K(q, p) = f(p)^T M(q)^-1 f(p) / 2 f: R^d -> R^d + [if `parameterize_mass_matrix`] + M(q) = g(q) g(q)^T + mass_eps * I g: R^d -> R^(d(d+1)/2) + [else] + M(q)^-1 = g(q) g(q)^T + mass_eps * I g: R^d -> R^(d(d+1)/2) + For any of the function forms with mass matrices, if we have a + convolutional input it is assumed that the matrix is shared across all + spatial locations. + parametrize_mass_matrix: Defines for the kinetic functional form, whether + the network output defines the mass or the inverse of the mass matrix. + net_kwargs: Any keyword arguments to pass down to the networks. + mass_eps: The additional weight of the identity added to the mass matrix, + when relevant. + integrator_method: What method to use for integrating the system. + steps_per_dt: How many internal steps per a single `dt` step to do. + ode_int_kwargs: Extra arguments when using "implicit" integrator method. + use_scan: Whether to use `lax.scan` for explicit integrators. + feature_axis: The number of the features axis in the inputs. + features_extra_dims: If the inputs have extra features (like spatial for + convolutions) this specifies how many of them there are. + network_creation_func: A function that creates the networks. Should have a + signature `network_creation_func(output_dims, name, **net_kwargs)`. + name: The name of this Haiku module. + """ + super().__init__(name=name) + if input_space not in ("velocity", "momentum"): + raise ValueError("input_space must be either velocity or momentum.") + if simulation_space not in ("velocity", "momentum"): + raise ValueError("simulation_space must be either velocity or momentum.") + if potential_func_form not in ("separable_net", "dep_net", "embed_quad"): + raise ValueError("The potential network can be only a network.") + if kinetic_func_form not in ("separable_net", "dep_net", "pure_quad", + "matrix_diag_quad", "matrix_quad", + "matrix_dep_diag_quad", "matrix_dep_quad", + "embed_quad", "matrix_dep_diag_embed_quad", + "matrix_dep_embed_quad"): + raise ValueError(f"Unrecognized kinetic func form {kinetic_func_form}.") + if input_space != simulation_space: + if kinetic_func_form not in ( + "pure_quad", "matrix_diag_quad", "matrix_quad", + "matrix_dep_diag_quad", "matrix_dep_quad"): + raise ValueError( + "When the input and simulation space are not the same, it is " + "possible to simulate the physical system only if kinetic_func_form" + " is one of pure_quad, matrix_diag_quad, matrix_quad, " + "matrix_dep_diag_quad, matrix_dep_quad. In all other cases one can" + "not compute analytically the form of the functional (Lagrangian or" + " Hamiltonian) from the other.") + if feature_axis != -1: + raise ValueError("Currently we only support features_axis=-1.") + if integrator_method is None: + if simulation_space == "velocity": + integrator_method = "rk2" + else: + integrator_method = "leap_frog" + if features_extra_dims is None: + if net_kwargs["net_type"] == "mlp": + features_extra_dims = 0 + elif net_kwargs["net_type"] == "conv": + features_extra_dims = 2 + else: + raise NotImplementedError() + ode_int_kwargs = dict(ode_int_kwargs or {}) + ode_int_kwargs.setdefault("rtol", 1e-6) + ode_int_kwargs.setdefault("atol", 1e-6) + ode_int_kwargs.setdefault("mxstep", 50) + + self.system_dim = system_dim + self.input_space = input_space + self.simulation_space = simulation_space + self.potential_func_form = potential_func_form + self.kinetic_func_form = kinetic_func_form + self.parametrize_mass_matrix = parametrize_mass_matrix + self.features_axis = feature_axis + self.features_extra_dims = features_extra_dims + self.integrator_method = integrator_method + self.steps_per_dt = steps_per_dt + self.ode_int_kwargs = ode_int_kwargs + self.net_kwargs = net_kwargs + self.mass_eps = mass_eps + self.use_scan = use_scan + self.name = name + + self.potential_net = network_creation_func( + output_dims=1, name="PotentialNet", **net_kwargs) + + if kinetic_func_form in ("separable_net", "dep_net"): + self.kinetic_net = network_creation_func( + output_dims=1, name="KineticNet", **net_kwargs) + else: + self.kinetic_net = None + if kinetic_func_form in ("matrix_dep_quad", "matrix_dep_embed_quad"): + output_dims = (system_dim * (system_dim + 1)) // 2 + name = "MatrixNet" if parametrize_mass_matrix else "InvMatrixNet" + self.mass_matrix_net = network_creation_func( + output_dims=output_dims, name=name, **net_kwargs) + elif kinetic_func_form in ("matrix_dep_diag_quad", + "matrix_dep_diag_embed_quad", + "matrix_dep_embed_quad"): + name = "MatrixNet" if parametrize_mass_matrix else "InvMatrixNet" + self.mass_matrix_net = network_creation_func( + output_dims=system_dim, name=name, **net_kwargs) + else: + self.mass_matrix_net = None + if kinetic_func_form in ("embed_quad", "matrix_dep_diag_embed_quad", + "matrix_dep_embed_quad"): + self.kinetic_embed_net = network_creation_func( + output_dims=system_dim, name="KineticEmbed", **net_kwargs) + else: + self.kinetic_embed_net = None + + def sum_per_dim_energy(self, energy: jnp.ndarray) -> jnp.ndarray: + """Sums the per dimension energy.""" + axis = [-i-1 for i in range(self.features_extra_dims + 1)] + return jnp.sum(energy, axis=axis) + + def feature_matrix_vector(self, m, v): + """A utility function to compute the product of a matrix and vector in the features axis.""" + v = jnp.expand_dims(v, axis=self.features_axis-1) + return jnp.sum(m * v, axis=self.features_axis) + + def mass_matrix_mul( + self, + q: jnp.ndarray, + v: jnp.ndarray, + **kwargs + ) -> jnp.ndarray: + """Computes the product of the mass matrix with a vector and throws an error if not applicable.""" + if self.kinetic_func_form in ("separable_net", "dep_net"): + raise ValueError("It is not possible to compute `M q_dot` when using a " + "network for the kinetic energy.") + if self.kinetic_func_form in ("pure_quad", "embed_quad"): + return v + if self.kinetic_func_form == "matrix_diag_quad": + if self.parametrize_mass_matrix: + m_diag_log = hk.get_parameter("MassMatrixDiagLog", + shape=[self.system_dim], + init=hk.initializers.Constant(0.0)) + m_diag = jnp.exp(m_diag_log) + self.mass_eps + else: + m_inv_diag_log = hk.get_parameter("InvMassMatrixDiagLog", + shape=[self.system_dim], + init=hk.initializers.Constant(0.0)) + m_diag = 1.0 / (jnp.exp(m_inv_diag_log) + self.mass_eps) + return m_diag * v + if self.kinetic_func_form == "matrix_quad": + if self.parametrize_mass_matrix: + m_triu = hk.get_parameter("MassMatrixU", + shape=[self.system_dim, self.system_dim], + init=hk.initializers.Identity()) + m_triu = jnp.triu(m_triu) + m = jnp.matmul(m_triu.T, m_triu) + m = m + self.mass_eps * jnp.eye(self.system_dim) + return self.feature_matrix_vector(m, v) + else: + m_inv_triu = hk.get_parameter("InvMassMatrixU", + shape=[self.system_dim, self.system_dim], + init=hk.initializers.Identity()) + m_inv_triu = jnp.triu(m_inv_triu) + m_inv = jnp.matmul(m_inv_triu.T, m_inv_triu) + m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim) + solve = jnp.linalg.solve + for _ in range(v.ndim + 1 - m_inv.ndim): + solve = jax.vmap(solve, in_axes=(None, 0)) + return solve(m_inv, v) + if self.kinetic_func_form in ("matrix_dep_diag_quad", + "matrix_dep_diag_embed_quad"): + if self.parametrize_mass_matrix: + m_diag_log = self.mass_matrix_net(q, **kwargs) + m_diag = jnp.exp(m_diag_log) + self.mass_eps + else: + m_inv_diag_log = self.mass_matrix_net(q, **kwargs) + m_diag = 1.0 / (jnp.exp(m_inv_diag_log) + self.mass_eps) + return m_diag * v + if self.kinetic_func_form in ("matrix_dep_quad", + "matrix_dep_embed_quad"): + if self.parametrize_mass_matrix: + m_triu = self.mass_matrix_net(q, **kwargs) + m_triu = utils.triu_matrix_from_v(m_triu, self.system_dim) + m = jnp.matmul(jnp.swapaxes(m_triu, -1, -2), m_triu) + m = m + self.mass_eps * jnp.eye(self.system_dim) + return self.feature_matrix_vector(m, v) + else: + m_inv_triu = self.mass_matrix_net(q, **kwargs) + m_inv_triu = utils.triu_matrix_from_v(m_inv_triu, self.system_dim) + m_inv = jnp.matmul(jnp.swapaxes(m_inv_triu, -1, -2), m_inv_triu) + m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim) + return jnp.linalg.solve(m_inv, v) + raise NotImplementedError() + + def mass_matrix_inv_mul( + self, + q: jnp.ndarray, + v: jnp.ndarray, + **kwargs + ) -> jnp.ndarray: + """Computes the product of the inverse mass matrix with a vector.""" + if self.kinetic_func_form in ("separable_net", "dep_net"): + raise ValueError("It is not possible to compute `M^-1 p` when using a " + "network for the kinetic energy.") + if self.kinetic_func_form in ("pure_quad", "embed_quad"): + return v + if self.kinetic_func_form == "matrix_diag_quad": + if self.parametrize_mass_matrix: + m_diag_log = hk.get_parameter("MassMatrixDiagLog", + shape=[self.system_dim], + init=hk.initializers.Constant(0.0)) + m_inv_diag = 1.0 / (jnp.exp(m_diag_log) + self.mass_eps) + else: + m_inv_diag_log = hk.get_parameter("InvMassMatrixDiagLog", + shape=[self.system_dim], + init=hk.initializers.Constant(0.0)) + m_inv_diag = jnp.exp(m_inv_diag_log) + self.mass_eps + return m_inv_diag * v + if self.kinetic_func_form == "matrix_quad": + if self.parametrize_mass_matrix: + m_triu = hk.get_parameter("MassMatrixU", + shape=[self.system_dim, self.system_dim], + init=hk.initializers.Identity()) + m_triu = jnp.triu(m_triu) + m = jnp.matmul(m_triu.T, m_triu) + m = m + self.mass_eps * jnp.eye(self.system_dim) + solve = jnp.linalg.solve + for _ in range(v.ndim + 1 - m.ndim): + solve = jax.vmap(solve, in_axes=(None, 0)) + return solve(m, v) + else: + m_inv_triu = hk.get_parameter("InvMassMatrixU", + shape=[self.system_dim, self.system_dim], + init=hk.initializers.Identity()) + m_inv_triu = jnp.triu(m_inv_triu) + m_inv = jnp.matmul(m_inv_triu.T, m_inv_triu) + m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim) + return self.feature_matrix_vector(m_inv, v) + if self.kinetic_func_form in ("matrix_dep_diag_quad", + "matrix_dep_diag_embed_quad"): + if self.parametrize_mass_matrix: + m_diag_log = self.mass_matrix_net(q, **kwargs) + m_inv_diag = 1.0 / (jnp.exp(m_diag_log) + self.mass_eps) + else: + m_inv_diag_log = self.mass_matrix_net(q, **kwargs) + m_inv_diag = jnp.exp(m_inv_diag_log) + self.mass_eps + return m_inv_diag * v + if self.kinetic_func_form in ("matrix_dep_quad", + "matrix_dep_embed_quad"): + if self.parametrize_mass_matrix: + m_triu = self.mass_matrix_net(q, **kwargs) + m_triu = utils.triu_matrix_from_v(m_triu, self.system_dim) + m = jnp.matmul(jnp.swapaxes(m_triu, -2, -1), m_triu) + m = m + self.mass_eps * jnp.eye(self.system_dim) + return jnp.linalg.solve(m, v) + else: + m_inv_triu = self.mass_matrix_net(q, **kwargs) + m_inv_triu = utils.triu_matrix_from_v(m_inv_triu, self.system_dim) + m_inv = jnp.matmul(jnp.swapaxes(m_inv_triu, -2, -1), m_inv_triu) + m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim) + return self.feature_matrix_vector(m_inv, v) + raise NotImplementedError() + + def momentum_from_velocity( + self, + q: jnp.ndarray, + q_dot: jnp.ndarray, + **kwargs + ) -> jnp.ndarray: + """Computes the momentum from position and velocity.""" + def local_lagrangian(q_dot_): + # We take the sum so we can easily take gradients + return jnp.sum(self.lagrangian( + phase_space.PhaseSpace(q, q_dot_), **kwargs)) + return jax.grad(local_lagrangian)(q_dot) + + def velocity_from_momentum( + self, + q: jnp.ndarray, + p: jnp.ndarray, + **kwargs + ) -> jnp.ndarray: + """Computes the velocity from position and momentum.""" + def local_hamiltonian(p_): + # We take the sum so we can easily take gradients + return jnp.sum(self.hamiltonian( + phase_space.PhaseSpace(q, p_), **kwargs)) + return jax.grad(local_hamiltonian)(p) + + def kinetic_energy_velocity( + self, + q: jnp.ndarray, + q_dot: jnp.ndarray, + **kwargs + ) -> jnp.ndarray: + """Computes the kinetic energy in velocity coordinates.""" + if self.kinetic_func_form in ("separable_net", "dep_net"): + if self.input_space != "velocity": + raise ValueError("Can not evaluate the Kinetic energy from velocity, " + "when the input space is momentum and " + "kinetic_func_form is separable_net or dep_net.") + if self.kinetic_func_form == "separable_net": + s = q_dot + else: + s = jnp.concatenate([q, q_dot], axis=-1) + per_dim_energy = self.kinetic_net(s, **kwargs) + else: + if self.kinetic_embed_net is not None: + if self.input_space != "velocity": + raise ValueError("Can not evaluate the Kinetic energy from velocity, " + "when the input space is momentum and " + "kinetic_func_form is embed_quad, " + "matrix_dep_diag_embed_quad or " + "matrix_dep_embed_quad.") + q_dot = self.kinetic_embed_net(q_dot, **kwargs) + m_q_dot = self.mass_matrix_mul(q, q_dot, **kwargs) + per_dim_energy = q_dot * m_q_dot / 2 + + return self.sum_per_dim_energy(per_dim_energy) + + def kinetic_energy_momentum( + self, + q: jnp.ndarray, + p: jnp.ndarray, + **kwargs + ) -> jnp.ndarray: + """Computes the kinetic energy in momentum coordinates.""" + if self.kinetic_func_form in ("separable_net", "dep_net"): + if self.input_space != "momentum": + raise ValueError("Can not evaluate the Kinetic energy from momentum, " + "when the input space is velocity and " + "kinetic_func_form is separable_net or dep_net.") + if self.kinetic_func_form == "separable_net": + s = p + else: + s = jnp.concatenate([q, p], axis=-1) + per_dim_energy = self.kinetic_net(s, **kwargs) + else: + if self.kinetic_embed_net is not None: + if self.input_space != "momentum": + raise ValueError("Can not evaluate the Kinetic energy from momentum, " + "when the input space is velocity and " + "kinetic_func_form is embed_quad, " + "matrix_dep_diag_embed_quad or " + "matrix_dep_embed_quad.") + p = self.kinetic_embed_net(p, **kwargs) + m_inv_p = self.mass_matrix_inv_mul(q, p, **kwargs) + per_dim_energy = p * m_inv_p / 2 + + return self.sum_per_dim_energy(per_dim_energy) + + def potential_energy_velocity( + self, + q: jnp.ndarray, + q_dot: jnp.ndarray, + **kwargs + ) -> jnp.ndarray: + """Computes the potential energy in velocity coordinates.""" + if self.potential_func_form == "separable_net": + per_dim_energy = self.potential_net(q, **kwargs) + elif self.input_space != "momentum": + raise ValueError("Can not evaluate the Potential energy from velocity, " + "when the input space is momentum and " + "potential_func_form is dep_net.") + else: + s = jnp.concatenate([q, q_dot], axis=-1) + per_dim_energy = self.potential_net(s, **kwargs) + return self.sum_per_dim_energy(per_dim_energy) + + def potential_energy_momentum( + self, + q: jnp.ndarray, + p: jnp.ndarray, + **kwargs + ) -> jnp.ndarray: + """Computes the potential energy in momentum coordinates.""" + if self.potential_func_form == "separable_net": + per_dim_energy = self.potential_net(q, **kwargs) + elif self.input_space != "momentum": + raise ValueError("Can not evaluate the Potential energy from momentum, " + "when the input space is velocity and " + "potential_func_form is dep_net.") + else: + s = jnp.concatenate([q, p], axis=-1) + per_dim_energy = self.potential_net(s, **kwargs) + return self.sum_per_dim_energy(per_dim_energy) + + def hamiltonian( + self, + s: phase_space.PhaseSpace, + **kwargs + ) -> jnp.ndarray: + """Computes the Hamiltonian in momentum coordinates.""" + potential = self.potential_energy_momentum(s.q, s.p, **kwargs) + kinetic = self.kinetic_energy_momentum(s.q, s.p, **kwargs) + # Sanity check + assert potential.shape == kinetic.shape + return kinetic + potential + + def lagrangian( + self, + s: phase_space.PhaseSpace, + **kwargs + ) -> jnp.ndarray: + """Computes the Lagrangian in velocity coordinates.""" + potential = self.potential_energy_velocity(s.q, s.p, **kwargs) + kinetic = self.kinetic_energy_velocity(s.q, s.p, **kwargs) + # Sanity check + assert potential.shape == kinetic.shape + return kinetic - potential + + def energy_from_momentum( + self, + s: phase_space.PhaseSpace, + **kwargs + ) -> jnp.ndarray: + """Computes the energy of the system in momentum coordinates.""" + return self.hamiltonian(s, **kwargs) + + def energy_from_velocity( + self, + s: phase_space.PhaseSpace, + **kwargs + ) -> jnp.ndarray: + """Computes the energy of the system in velocity coordinates.""" + q, q_dot = s.q, s.p + p = self.momentum_from_velocity(q, q_dot, **kwargs) + q_dot_p = jnp.sum(q_dot * p, self.features_axis) + return q_dot_p - self.lagrangian(s, **kwargs) + + def velocity_and_acceleration( + self, + q: jnp.ndarray, + q_dot: jnp.ndarray, + **kwargs + ) -> phase_space.TangentPhaseSpace: + """Computes the velocity and acceleration of the system in velocity coordinates.""" + def local_lagrangian(*q_and_q_dot): + # We take the sum so we can easily take gradients + return jnp.sum(self.lagrangian( + phase_space.PhaseSpace(*q_and_q_dot), **kwargs)) + + grad_q = jax.grad(local_lagrangian, 0)(q, q_dot) + grad_q_dot_func = jax.grad(local_lagrangian, 1) + _, grad_q_dot_grad_q_times_q_dot = jax.jvp(grad_q_dot_func, (q, q_dot), + (q_dot, jnp.zeros_like(q_dot))) + pre_acc_vector = grad_q - grad_q_dot_grad_q_times_q_dot + if self.kinetic_func_form in ("pure_quad", "matrix_diag_quad", + "matrix_quad", "matrix_dep_diag_quad", + "matrix_dep_quad"): + q_dot_dot = self.mass_matrix_inv_mul(q, pre_acc_vector, **kwargs) + else: + hess_q_dot = jax.vmap(jax.hessian(local_lagrangian, 1))(q, q_dot) + q_dot_dot = jnp.linalg.solve(hess_q_dot, pre_acc_vector) + return phase_space.TangentPhaseSpace(q_dot, q_dot_dot) + + def simulate( + self, + y0: phase_space.PhaseSpace, + dt: Union[float, jnp.ndarray], + num_steps_forward: int, + num_steps_backward: int, + include_y0: bool, + return_stats: bool = True, + **nets_kwargs + ) -> _PhysicsSimulationOutput: + """Simulates the continuous dynamics of the physical system. + + Args: + y0: Initial state of the system. + dt: The size of the time intervals at which to evolve the system. + num_steps_forward: Number of steps to make into the future. + num_steps_backward: Number of steps to make into the past. + include_y0: Whether to include the initial state in the result. + return_stats: Whether to return additional statistics. + **nets_kwargs: Keyword arguments to pass to the networks. + + Returns: + * The state of the system evolved as many steps as specified by the + arguments into the past and future, all in chronological order. + * Optionally return a dictionary of additional statistics. For the moment + this only returns the energy of the system at each evaluation point. + """ + # Define the dynamics + if self.simulation_space == "velocity": + dy_dt = lambda t_, y: self.velocity_and_acceleration( # pylint: disable=g-long-lambda + y.q, y.p, **nets_kwargs) + # Special Haiku magic to avoid tracer issues + if hk.running_init(): + return self.lagrangian(y0, **nets_kwargs) + else: + hamiltonian = lambda t_, y: self.hamiltonian(y, **nets_kwargs) + dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian) + if hk.running_init(): + return self.hamiltonian(y0, **nets_kwargs) + + # Optionally switch coordinate frame + if self.input_space == "velocity" and self.simulation_space == "momentum": + p = self.momentum_from_velocity(y0.q, y0.p, **nets_kwargs) + y0 = phase_space.PhaseSpace(y0.q, p) + if self.input_space == "momentum" and self.simulation_space == "velocity": + q_dot = self.velocity_from_momentum(y0.q, y0.p, **nets_kwargs) + y0 = phase_space.PhaseSpace(y0.q, q_dot) + + yt = integrators.solve_ivp_dt_two_directions( + fun=dy_dt, + y0=y0, + t0=0.0, + dt=dt, + method=self.integrator_method, + num_steps_forward=num_steps_forward, + num_steps_backward=num_steps_backward, + include_y0=include_y0, + steps_per_dt=self.steps_per_dt, + ode_int_kwargs=self.ode_int_kwargs + ) + # Make time axis second + yt = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), yt) + + # Compute energies for the full trajectory + yt_energy = jax.tree_map(utils.merge_first_dims, yt) + if self.simulation_space == "momentum": + energy = self.energy_from_momentum(yt_energy, **nets_kwargs) + else: + energy = self.energy_from_velocity(yt_energy, **nets_kwargs) + energy = energy.reshape(yt.q.shape[:2]) + + # Optionally switch back to input coordinate frame + if self.input_space == "velocity" and self.simulation_space == "momentum": + q_dot = self.velocity_from_momentum(yt.q, yt.p, **nets_kwargs) + yt = phase_space.PhaseSpace(yt.q, q_dot) + if self.input_space == "momentum" and self.simulation_space == "velocity": + p = self.momentum_from_velocity(yt.q, yt.p, **nets_kwargs) + yt = phase_space.PhaseSpace(yt.q, p) + + # Compute energy deficit + t = energy.shape[-1] + non_zero_diffs = float((t * (t - 1)) // 2) + energy_deficits = jnp.abs(energy[..., None, :] - energy[..., None]) + avg_deficit = jnp.sum(energy_deficits, axis=(-2, -1)) / non_zero_diffs + max_deficit = jnp.max(energy_deficits) + + # Return the states and energies + if return_stats: + return yt, dict(avg_energy_deficit=avg_deficit, + max_energy_deficit=max_deficit) + else: + return yt + + def __call__(self, *args, **kwargs): + return self.simulate(*args, **kwargs) + + +class OdeNetwork(hk.Module): + """A simple haiku module for constructing a NeuralODE.""" + + def __init__( + self, + system_dim: int, + net_kwargs: Mapping[str, Any], + integrator_method: Optional[str] = None, + steps_per_dt: int = 1, + ode_int_kwargs: Optional[Mapping[str, float]] = None, + use_scan: bool = True, + network_creation_func=networks.make_flexible_net, + name: Optional[str] = None, + ): + super().__init__(name=name) + ode_int_kwargs = dict(ode_int_kwargs or {}) + ode_int_kwargs.setdefault("rtol", 1e-6) + ode_int_kwargs.setdefault("atol", 1e-6) + ode_int_kwargs.setdefault("mxstep", 50) + + self.system_dim = system_dim + self.integrator_method = integrator_method or "adaptive" + self.steps_per_dt = steps_per_dt + self.ode_int_kwargs = ode_int_kwargs + self.net_kwargs = net_kwargs + self.use_scan = use_scan + + self.core = network_creation_func( + output_dims=system_dim, name="Net", **net_kwargs) + + def simulate( + self, + y0: jnp.ndarray, + dt: Union[float, jnp.ndarray], + num_steps_forward: int, + num_steps_backward: int, + include_y0: bool, + return_stats: bool = True, + **nets_kwargs + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]]: + """Simulates the continuous dynamics of the ODE specified by the network. + + Args: + y0: Initial state of the system. + dt: The size of the time intervals at which to evolve the system. + num_steps_forward: Number of steps to make into the future. + num_steps_backward: Number of steps to make into the past. + include_y0: Whether to include the initial state in the result. + return_stats: Whether to return additional statistics. + **nets_kwargs: Keyword arguments to pass to the networks. + + Returns: + * The state of the system evolved as many steps as specified by the + arguments into the past and future, all in chronological order. + * Optionally return a dictionary of additional statistics. For the moment + this is just an empty dictionary. + """ + if hk.running_init(): + return self.core(y0, **nets_kwargs) + yt = integrators.solve_ivp_dt_two_directions( + fun=lambda t, y: self.core(y, **nets_kwargs), + y0=y0, + t0=0.0, + dt=dt, + method=self.integrator_method, + num_steps_forward=num_steps_forward, + num_steps_backward=num_steps_backward, + include_y0=include_y0, + steps_per_dt=self.steps_per_dt, + ode_int_kwargs=self.ode_int_kwargs + ) + # Make time axis second + yt = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), yt) + if return_stats: + return yt, dict() + else: + return yt + + def __call__(self, *args, **kwargs): + return self.simulate(*args, **kwargs) + + +class DiscreteDynamicsNetwork(hk.Module): + """A simple haiku module for constructing a discrete dynamics network.""" + + def __init__( + self, + system_dim: int, + residual: bool, + net_kwargs: Mapping[str, Any], + use_scan: bool = True, + network_creation_func=networks.make_flexible_net, + name: Optional[str] = None, + ): + super().__init__(name=name) + self.system_dim = system_dim + self.residual = residual + self.net_kwargs = net_kwargs + self.use_scan = use_scan + self.core = network_creation_func( + output_dims=system_dim, name="Net", **net_kwargs) + + def simulate( + self, + y0: jnp.ndarray, + num_steps_forward: int, + include_y0: bool, + return_stats: bool = True, + **nets_kwargs + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]]: + """Simulates the dynamics of the discrete system. + + Args: + y0: Initial state of the system. + num_steps_forward: Number of steps to make into the future. + include_y0: Whether to include the initial state in the result. + return_stats: Whether to return additional statistics. + **nets_kwargs: Keyword arguments to pass to the networks. + + Returns: + * The state of the system evolved as many steps as specified by the + arguments into the past and future, all in chronological order. + * Optionally return a dictionary of additional statistics. For the moment + this is just an empty dictionary. + """ + if num_steps_forward < 0: + raise ValueError("It is required to unroll at least one step.") + nets_kwargs.pop("dt", None) + nets_kwargs.pop("num_steps_backward", None) + if hk.running_init(): + return self.core(y0, **nets_kwargs) + + def step(*args): + y, _ = args + if self.residual: + y_next = y + self.core(y, **nets_kwargs) + else: + y_next = self.core(y, **nets_kwargs) + return y_next, y_next + + if self.use_scan: + _, yt = jax.lax.scan(step, init=y0, xs=None, length=num_steps_forward) + if include_y0: + yt = jnp.concatenate([y0[None], yt], axis=0) + # Make time axis second + yt = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), yt) + else: + yt = [y0] + for _ in range(num_steps_forward): + yt.append(step(yt[-1], None)[0]) + if not include_y0: + yt = yt[1:] + if len(yt) == 1: + yt = yt[0][:, None] + else: + yt = jax.tree_multimap(lambda args: jnp.stack(args, 1), yt) + if return_stats: + return yt, dict() + else: + return yt + + def __call__(self, *args, **kwargs): + return self.simulate(*args, **kwargs) diff --git a/physics_inspired_models/models/networks.py b/physics_inspired_models/models/networks.py new file mode 100644 index 0000000..e7f601b --- /dev/null +++ b/physics_inspired_models/models/networks.py @@ -0,0 +1,494 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing all of the networks as Haiku modules.""" +from typing import Any, Callable, Mapping, Optional, Sequence, Union + +from absl import logging +import distrax +import haiku as hk +import jax.numpy as jnp + +from physics_inspired_models import utils + +Activation = Union[str, Callable[[jnp.ndarray], jnp.ndarray]] + + +class DenseNet(hk.Module): + """A feed forward network (MLP).""" + + def __init__( + self, + num_units: Sequence[int], + activate_final: bool = False, + activation: Activation = "leaky_relu", + name: Optional[str] = None): + super().__init__(name=name) + self.num_units = num_units + self.num_layers = len(self.num_units) + self.activate_final = activate_final + self.activation = utils.get_activation(activation) + + self.linear_modules = [] + for i in range(self.num_layers): + self.linear_modules.append( + hk.Linear( + output_size=self.num_units[i], + name=f"ff_{i}" + ) + ) + + def __call__(self, inputs: jnp.ndarray, is_training: bool): + net = inputs + for i, linear in enumerate(self.linear_modules): + net = linear(net) + if i < self.num_layers - 1 or self.activate_final: + net = self.activation(net) + return net + + +class Conv2DNet(hk.Module): + """Convolutional Network.""" + + def __init__( + self, + output_channels: Sequence[int], + kernel_shapes: Union[int, Sequence[int]] = 3, + strides: Union[int, Sequence[int]] = 1, + padding: Union[str, Sequence[str]] = "SAME", + data_format: str = "NHWC", + with_batch_norm: bool = False, + activate_final: bool = False, + activation: Activation = "leaky_relu", + name: Optional[str] = None): + super().__init__(name=name) + self.output_channels = tuple(output_channels) + self.num_layers = len(self.output_channels) + self.kernel_shapes = utils.bcast_if(kernel_shapes, int, self.num_layers) + self.strides = utils.bcast_if(strides, int, self.num_layers) + self.padding = utils.bcast_if(padding, str, self.num_layers) + self.data_format = data_format + self.with_batch_norm = with_batch_norm + self.activate_final = activate_final + self.activation = utils.get_activation(activation) + + if len(self.kernel_shapes) != self.num_layers: + raise ValueError(f"Kernel shapes is of size {len(self.kernel_shapes)}, " + f"while output_channels is of size{self.num_layers}.") + if len(self.strides) != self.num_layers: + raise ValueError(f"Strides is of size {len(self.kernel_shapes)}, while " + f"output_channels is of size{self.num_layers}.") + if len(self.padding) != self.num_layers: + raise ValueError(f"Padding is of size {len(self.padding)}, while " + f"output_channels is of size{self.num_layers}.") + + self.conv_modules = [] + self.bn_modules = [] + for i in range(self.num_layers): + self.conv_modules.append( + hk.Conv2D( + output_channels=self.output_channels[i], + kernel_shape=self.kernel_shapes[i], + stride=self.strides[i], + padding=self.padding[i], + data_format=data_format, + name=f"conv_2d_{i}") + ) + if with_batch_norm: + self.bn_modules.append( + hk.BatchNorm( + create_offset=True, + create_scale=False, + decay_rate=0.999, + name=f"batch_norm_{i}") + ) + else: + self.bn_modules.append(None) + + def __call__(self, inputs: jnp.ndarray, is_training: bool): + assert inputs.ndim == 4 + net = inputs + for i, (conv, bn) in enumerate(zip(self.conv_modules, self.bn_modules)): + net = conv(net) + # Batch norm + if bn is not None: + net = bn(net, is_training=is_training) + if i < self.num_layers - 1 or self.activate_final: + net = self.activation(net) + return net + + +class SpatialConvEncoder(hk.Module): + """Spatial Convolutional Encoder for learning the Hamiltonian.""" + + def __init__( + self, + latent_dim: int, + conv_channels: Union[Sequence[int], int], + num_blocks: int, + blocks_depth: int = 2, + distribution_name: str = "diagonal_normal", + aggregation_type: Optional[str] = None, + data_format: str = "NHWC", + activation: Activation = "leaky_relu", + scale_factor: int = 2, + kernel_shapes: Union[Sequence[int], int] = 3, + padding: Union[Sequence[str], str] = "SAME", + name: Optional[str] = None): + super().__init__(name=name) + if aggregation_type not in (None, "max", "mean", "linear_projection"): + raise ValueError(f"Unrecognized aggregation_type={aggregation_type}.") + self.latent_dim = latent_dim + self.conv_channels = conv_channels + self.num_blocks = num_blocks + self.scale_factor = scale_factor + self.data_format = data_format + self.distribution_name = distribution_name + self.aggregation_type = aggregation_type + + # Compute the required size of the output + if distribution_name is None: + self.output_dim = latent_dim + elif distribution_name == "diagonal_normal": + self.output_dim = 2 * latent_dim + else: + raise ValueError(f"Unrecognized distribution_name={distribution_name}.") + + if isinstance(conv_channels, int): + conv_channels = [[conv_channels] * blocks_depth + for _ in range(num_blocks)] + conv_channels[-1] += [self.output_dim] + else: + assert isinstance(conv_channels, (list, tuple)) + assert len(conv_channels) == num_blocks + conv_channels = list(list(c) for c in conv_channels) + conv_channels[-1].append(self.output_dim) + + if isinstance(kernel_shapes, tuple): + kernel_shapes = list(kernel_shapes) + + # Convolutional blocks + self.blocks = [] + for i, channels in enumerate(conv_channels): + if isinstance(kernel_shapes, int): + extra_kernel_shapes = 0 + else: + extra_kernel_shapes = [3] * (len(channels) - len(kernel_shapes)) + + self.blocks.append(Conv2DNet( + output_channels=channels, + kernel_shapes=kernel_shapes + extra_kernel_shapes, + strides=[self.scale_factor] + [1] * (len(channels) - 1), + padding=padding, + data_format=data_format, + with_batch_norm=False, + activate_final=i < num_blocks - 1, + activation=activation, + name=f"block_{i}" + )) + + def spatial_aggregation(self, x: jnp.ndarray) -> jnp.ndarray: + if self.aggregation_type is None: + return x + axis = (1, 2) if self.data_format == "NHWC" else (2, 3) + if self.aggregation_type == "max": + return jnp.max(x, axis=axis) + if self.aggregation_type == "mean": + return jnp.mean(x, axis=axis) + if self.aggregation_type == "linear_projection": + x = x.reshape(x.shape[:-3] + (-1,)) + return hk.Linear(self.output_dim, name="LinearProjection")(x) + raise NotImplementedError() + + def make_distribution(self, net_output: jnp.ndarray) -> distrax.Distribution: + if self.distribution_name is None: + return net_output + elif self.distribution_name == "diagonal_normal": + if self.aggregation_type is None: + split_axis, num_axes = self.data_format.index("C"), 3 + else: + split_axis, num_axes = 1, 1 + # Add an extra axis if the input has more than 1 batch dimension + split_axis += net_output.ndim - num_axes - 1 + loc, log_scale = jnp.split(net_output, 2, axis=split_axis) + return distrax.Normal(loc, jnp.exp(log_scale)) + else: + raise NotImplementedError() + + def __call__( + self, + inputs: jnp.ndarray, + is_training: bool + ) -> Union[jnp.ndarray, distrax.Distribution]: + # Treat any extra dimensions (like time) as the batch + batched_shape = inputs.shape[:-3] + net = jnp.reshape(inputs, (-1,) + inputs.shape[-3:]) + + # Apply all blocks in sequence + for block in self.blocks: + net = block(net, is_training=is_training) + + # Final projection + net = self.spatial_aggregation(net) + + # Reshape back to correct dimensions (like batch + time) + net = jnp.reshape(net, batched_shape + net.shape[1:]) + + # Return a distribution over the observations + return self.make_distribution(net) + + +class SpatialConvDecoder(hk.Module): + """Spatial Convolutional Decoder for learning the Hamiltonian.""" + + def __init__( + self, + initial_spatial_shape: Sequence[int], + conv_channels: Union[Sequence[int], int], + num_blocks: int, + max_de_aggregation_dims: int, + blocks_depth: int = 2, + scale_factor: int = 2, + output_channels: int = 3, + h_const_channels: int = 2, + data_format: str = "NHWC", + activation: Activation = "leaky_relu", + learned_sigma: bool = False, + de_aggregation_type: Optional[str] = None, + final_activation: Activation = "sigmoid", + discard_half_de_aggregated: bool = False, + kernel_shapes: Union[Sequence[int], int] = 3, + padding: Union[Sequence[str], str] = "SAME", + name: Optional[str] = None): + super().__init__(name=name) + if de_aggregation_type not in (None, "tile", "linear_projection"): + raise ValueError(f"Unrecognized de_aggregation_type=" + f"{de_aggregation_type}.") + self.num_blocks = num_blocks + self.scale_factor = scale_factor + self.h_const_channels = h_const_channels + self.data_format = data_format + self.learned_sigma = learned_sigma + self.initial_spatial_shape = tuple(initial_spatial_shape) + self.final_activation = utils.get_activation(final_activation) + self.de_aggregation_type = de_aggregation_type + self.max_de_aggregation_dims = max_de_aggregation_dims + self.discard_half_de_aggregated = discard_half_de_aggregated + + if isinstance(conv_channels, int): + conv_channels = [[conv_channels] * blocks_depth + for _ in range(num_blocks)] + conv_channels[-1] += [output_channels] + else: + assert isinstance(conv_channels, (list, tuple)) + assert len(conv_channels) == num_blocks + conv_channels = list(list(c) for c in conv_channels) + conv_channels[-1].append(output_channels) + + # Convolutional blocks + self.blocks = [] + for i, channels in enumerate(conv_channels): + is_final_block = i == num_blocks - 1 + self.blocks.append( + Conv2DNet( # pylint: disable=g-complex-comprehension + output_channels=channels, + kernel_shapes=kernel_shapes, + strides=1, + padding=padding, + data_format=data_format, + with_batch_norm=False, + activate_final=not is_final_block, + activation=activation, + name=f"block_{i}" + )) + + def spatial_de_aggregation(self, x: jnp.ndarray) -> jnp.ndarray: + if self.de_aggregation_type is None: + assert x.ndim >= 4 + if self.data_format == "NHWC": + assert x.shape[1:3] == self.initial_spatial_shape + elif self.data_format == "NCHW": + assert x.shape[2:4] == self.initial_spatial_shape + return x + elif self.de_aggregation_type == "linear_projection": + assert x.ndim == 2 + n, d = x.shape + d = min(d, self.max_de_aggregation_dims or d) + out_d = d * self.initial_spatial_shape[0] * self.initial_spatial_shape[1] + x = hk.Linear(out_d, name="LinearProjection")(x) + if self.data_format == "NHWC": + shape = (n,) + self.initial_spatial_shape + (d,) + else: + shape = (n, d) + self.initial_spatial_shape + return x.reshape(shape) + elif self.de_aggregation_type == "tile": + assert x.ndim == 2 + if self.data_format == "NHWC": + repeats = (1,) + self.initial_spatial_shape + (1,) + x = x[:, None, None, :] + else: + repeats = (1, 1) + self.initial_spatial_shape + x = x[:, :, None, None] + return jnp.tile(x, repeats) + else: + raise NotImplementedError() + + def add_constant_channels(self, inputs: jnp.ndarray) -> jnp.ndarray: + # -------------------------------------------- + # This is purely for TF compatibility purposes + if self.discard_half_de_aggregated: + axis = self.data_format.index("C") + inputs, _ = jnp.split(inputs, 2, axis=axis) + # -------------------------------------------- + + # An extra constant channels + if self.data_format == "NHWC": + h_shape = self.initial_spatial_shape + (self.h_const_channels,) + else: + h_shape = (self.h_const_channels,) + self.initial_spatial_shape + h_const = hk.get_parameter("h", h_shape, dtype=inputs.dtype, + init=hk.initializers.Constant(1)) + h_const = jnp.tile(h_const, reps=[inputs.shape[0], 1, 1, 1]) + return jnp.concatenate([h_const, inputs], axis=self.data_format.index("C")) + + def make_distribution(self, net_output: jnp.ndarray) -> distrax.Distribution: + if self.learned_sigma: + init = hk.initializers.Constant(- jnp.log(2.0) / 2.0) + log_scale = hk.get_parameter("log_scale", shape=(), + dtype=net_output.dtype, init=init) + scale = jnp.full_like(net_output, jnp.exp(log_scale)) + else: + scale = jnp.full_like(net_output, 1 / jnp.sqrt(2.0)) + + return distrax.Normal(net_output, scale) + + def __call__( + self, + inputs: jnp.ndarray, + is_training: bool + ) -> distrax.Distribution: + # Apply the spatial de-aggregation + inputs = self.spatial_de_aggregation(inputs) + + # Add the parameterized constant channels + net = self.add_constant_channels(inputs) + + # Apply all the blocks + for block in self.blocks: + # Up-sample the image + net = utils.nearest_neighbour_upsampling(net, self.scale_factor) + # Apply the convolutional block + net = block(net, is_training=is_training) + + # Apply any specific output nonlinearity + net = self.final_activation(net) + + # Construct the distribution over the observations + return self.make_distribution(net) + + +def make_flexible_net( + net_type: str, + output_dims: int, + conv_channels: Union[Sequence[int], int], + num_units: Union[Sequence[int], int], + num_layers: Optional[int], + activation: Activation, + activate_final: bool = False, + kernel_shapes: Union[Sequence[int], int] = 3, + strides: Union[Sequence[int], int] = 1, + padding: Union[Sequence[str], str] = "SAME", + name: Optional[str] = None, + **unused_kwargs: Mapping[str, Any] +): + """Commonly used for creating a flexible network.""" + if unused_kwargs: + logging.warning("Unused kwargs of `make_flexible_net`: %s", + str(unused_kwargs)) + if net_type == "mlp": + if isinstance(num_units, int): + assert num_layers is not None + num_units = [num_units] * (num_layers - 1) + [output_dims] + else: + num_units = list(num_units) + [output_dims] + return DenseNet( + num_units=num_units, + activation=activation, + activate_final=activate_final, + name=name + ) + elif net_type == "conv": + if isinstance(conv_channels, int): + assert num_layers is not None + conv_channels = [conv_channels] * (num_layers - 1) + [output_dims] + else: + conv_channels = list(conv_channels) + [output_dims] + return Conv2DNet( + output_channels=conv_channels, + kernel_shapes=kernel_shapes, + strides=strides, + padding=padding, + activation=activation, + activate_final=activate_final, + name=name + ) + elif net_type == "transformer": + raise NotImplementedError() + else: + raise ValueError(f"Unrecognized net_type={net_type}.") + + +def make_flexible_recurrent_net( + core_type: str, + net_type: str, + output_dims: int, + num_units: Union[Sequence[int], int], + num_layers: Optional[int], + activation: Activation, + activate_final: bool = False, + name: Optional[str] = None, + **unused_kwargs +): + """Commonly used for creating a flexible recurrences.""" + if net_type != "mlp": + raise ValueError("We do not support convolutional recurrent nets atm.") + if unused_kwargs: + logging.warning("Unused kwargs of `make_flexible_recurrent_net`: %s", + str(unused_kwargs)) + + if isinstance(num_units, (list, tuple)): + num_units = list(num_units) + [output_dims] + num_layers = len(num_units) + else: + assert num_layers is not None + num_units = [num_units] * (num_layers - 1) + [output_dims] + name = name or f"{core_type.upper()}" + + activation = utils.get_activation(activation) + core_list = [] + for i, n in enumerate(num_units): + if core_type.lower() == "vanilla": + core_list.append(hk.VanillaRNN(hidden_size=n, name=f"{name}_{i}")) + elif core_type.lower() == "lstm": + core_list.append(hk.LSTM(hidden_size=n, name=f"{name}_{i}")) + elif core_type.lower() == "gru": + core_list.append(hk.GRU(hidden_size=n, name=f"{name}_{i}")) + else: + raise ValueError(f"Unrecognized core_type={core_type}.") + if i != num_layers - 1: + core_list.append(activation) + if activate_final: + core_list.append(activation) + + return hk.DeepRNN(core_list, name="RNN") diff --git a/physics_inspired_models/requirements.txt b/physics_inspired_models/requirements.txt new file mode 100644 index 0000000..72c25d4 --- /dev/null +++ b/physics_inspired_models/requirements.txt @@ -0,0 +1,10 @@ +git+https://github.com/deepmind/dm_hamiltonian_dynamics_suite@main#egg=dm_hamiltonian_dynamics_suite +absl-py==0.12.0 +numpy>=1.16.4 +scikit-learn>=1.0 +typing>=3.7.4.3 +jax==0.2.20 +jaxline==0.0.3 +distrax==0.0.2 +optax==0.0.6 +dm-haiku==0.0.3 diff --git a/physics_inspired_models/setup.py b/physics_inspired_models/setup.py new file mode 100644 index 0000000..0c4a65b --- /dev/null +++ b/physics_inspired_models/setup.py @@ -0,0 +1,54 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Setup for pip package.""" +from setuptools import setup + +REQUIRED_PACKAGES = ( + "dm_hamiltonian_dynamics_suite@git+https://github.com/deepmind/dm_hamiltonian_dynamics_suite", # pylint: disable=line-too-long. + "absl-py>=0.12.0", + "numpy>=1.16.4", + "scikit-learn>=1.0", + "typing>=3.7.4.3", + "jax==0.2.20", + "jaxline==0.0.3", + "distrax==0.0.2", + "optax==0.0.6", + "dm-haiku==0.0.3", +) + +LONG_DESCRIPTION = "\n".join([ + "A codebase containing the implementation of the following models:", + "Hamiltonian Generative Network (HGN)", + "Lagrangian Generative Network (LGN)", + "Neural ODE", + "Recurrent Generative Network (RGN)", + "and RNN, LSTM and GRU.", + "This is code accompanying the publication of:" +]) + + +setup( + name="physics_inspired_models", + version="0.0.1", + description="Implementation of multiple physically inspired models.", + long_description=LONG_DESCRIPTION, + url="https://github.com/deepmind/deepmind-research/physics_inspired_models", + author="DeepMind", + package_dir={"physics_inspired_models": "."}, + packages=["physics_inspired_models", "physics_inspired_models.models"], + install_requires=REQUIRED_PACKAGES, + platforms=["any"], + license="Apache License, Version 2.0", +) diff --git a/physics_inspired_models/utils.py b/physics_inspired_models/utils.py new file mode 100644 index 0000000..bde4af4 --- /dev/null +++ b/physics_inspired_models/utils.py @@ -0,0 +1,274 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities functions for Jax.""" +import collections +import functools +from typing import Any, Callable, Dict, Mapping, Union + +import distrax +import jax +from jax import core +from jax import lax +from jax import nn +import jax.numpy as jnp +from jax.tree_util import register_pytree_node +from jaxline import utils +import numpy as np + +HaikuParams = Mapping[str, Mapping[str, jnp.ndarray]] +Params = Union[Mapping[str, jnp.ndarray], HaikuParams, jnp.ndarray] +_Activation = Callable[[jnp.ndarray], jnp.ndarray] + +tf_leaky_relu = functools.partial(nn.leaky_relu, negative_slope=0.2) + + +def filter_only_scalar_stats(stats): + return {k: v for k, v in stats.items() if v.size == 1} + + +def to_numpy(obj): + return jax.tree_map(np.array, obj) + + +@jax.custom_gradient +def geco_lagrange_product(lagrange_multiplier, constraint_ema, constraint_t): + """Modifies the gradients so that they work as described in GECO. + + The evaluation gives: + lagrange * C_ema + The gradient w.r.t lagrange: + - g * C_t + The gradient w.r.t constraint_ema: + 0.0 + The gradient w.r.t constraint_t: + g * lagrange + + Note that if you pass the same value for `constraint_ema` and `constraint_t` + this would only flip the gradient for the lagrange multiplier. + + Args: + lagrange_multiplier: The lagrange multiplier + constraint_ema: The moving average of the constraint + constraint_t: The current constraint + + Returns: + + """ + def grad(gradient): + return (- gradient * constraint_t, + jnp.zeros_like(constraint_ema), + gradient * lagrange_multiplier) + return lagrange_multiplier * constraint_ema, grad + + +def bcast_if(x, t, n): + return [x] * n if isinstance(x, t) else x + + +def stack_time_into_channels( + images: jnp.ndarray, + data_format: str +) -> jnp.ndarray: + axis = data_format.index("C") + list_of_time = [jnp.squeeze(v, axis=1) for v in + jnp.split(images, images.shape[1], axis=1)] + return jnp.concatenate(list_of_time, axis) + + +def stack_device_dim_into_batch(obj): + return jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:]), obj) + + +def nearest_neighbour_upsampling(x, scale, data_format="NHWC"): + """Performs nearest-neighbour upsampling.""" + + if data_format == "NCHW": + b, c, h, w = x.shape + x = jnp.reshape(x, [b, c, h, 1, w, 1]) + ones = jnp.ones([1, 1, 1, scale, 1, scale], dtype=x.dtype) + return jnp.reshape(x * ones, [b, c, scale * h, scale * w]) + elif data_format == "NHWC": + b, h, w, c = x.shape + x = jnp.reshape(x, [b, h, 1, w, 1, c]) + ones = jnp.ones([1, 1, scale, 1, scale, 1], dtype=x.dtype) + return jnp.reshape(x * ones, [b, scale * h, scale * w, c]) + else: + raise ValueError(f"Unrecognized data_format={data_format}.") + + +def get_activation(arg: Union[_Activation, str]) -> _Activation: + """Returns an activation from provided string.""" + if isinstance(arg, str): + # Try fetch in order - [this module, jax.nn, jax.numpy] + if arg in globals(): + return globals()[arg] + if hasattr(nn, arg): + return getattr(nn, arg) + elif hasattr(jnp, arg): + return getattr(jnp, arg) + else: + raise ValueError(f"Unrecognized activation with name {arg}.") + if not callable(arg): + raise ValueError(f"Expected a callable, but got {type(arg)}") + return arg + + +def merge_first_dims(x: jnp.ndarray, num_dims_to_merge: int = 2) -> jnp.ndarray: + return x.reshape((-1,) + x.shape[num_dims_to_merge:]) + + +def extract_image( + inputs: Union[jnp.ndarray, Mapping[str, jnp.ndarray]] +) -> jnp.ndarray: + """Extracts a tensor with key `image` or `x_image` if it is a dict, otherwise returns the inputs.""" + if isinstance(inputs, dict): + if "image" in inputs: + return inputs["image"] + else: + return inputs["x_image"] + elif isinstance(inputs, jnp.ndarray): + return inputs + raise NotImplementedError(f"Not implemented of inputs of type" + f" {type(inputs)}.") + + +def extract_gt_state(inputs: Any) -> jnp.ndarray: + if isinstance(inputs, dict): + return inputs["x"] + elif not isinstance(inputs, jnp.ndarray): + raise NotImplementedError(f"Not implemented of inputs of type" + f" {type(inputs)}.") + return inputs + + +def reshape_latents_conv_to_flat(conv_latents, axis_n_to_keep=1): + q, p = jnp.split(conv_latents, 2, axis=-1) + q = jax.tree_map(lambda x: x.reshape(x.shape[:axis_n_to_keep] + (-1,)), q) + p = jax.tree_map(lambda x: x.reshape(x.shape[:axis_n_to_keep] + (-1,)), p) + flat_latents = jnp.concatenate([q, p], axis=-1) + + return flat_latents + + +def triu_matrix_from_v(x, ndim): + assert x.shape[-1] == (ndim * (ndim + 1)) // 2 + matrix = jnp.zeros(x.shape[:-1] + (ndim, ndim)) + idx = jnp.triu_indices(ndim) + index_update = lambda x, idx, y: x.at[idx].set(y) + for _ in range(x.ndim - 1): + index_update = jax.vmap(index_update, in_axes=(0, None, 0)) + return index_update(matrix, idx, x) + + +def flatten_dict(d, parent_key: str = "", sep: str = "_") -> Dict[str, Any]: + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, collections.MutableMapping): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def convert_to_pytype(target, reference): + """Makes target the same pytype as reference, by jax.tree_flatten.""" + _, pytree = jax.tree_flatten(reference) + leaves, _ = jax.tree_flatten(target) + return jax.tree_unflatten(pytree, leaves) + + +def func_if_not_scalar(func): + """Makes a function that uses func only on non-scalar values.""" + @functools.wraps(func) + def wrapped(array, axis=0): + if array.ndim == 0: + return array + return func(array, axis=axis) + return wrapped + + +mean_if_not_scalar = func_if_not_scalar(jnp.mean) + + +class MultiBatchAccumulator(object): + """Class for abstracting statistics accumulation over multiple batches.""" + + def __init__(self): + self._obj = None + self._obj_max = None + self._obj_min = None + self._num_samples = None + + def add(self, averaged_values, num_samples): + """Adds an element to the moving average and the max.""" + if self._obj is None: + self._obj_max = jax.tree_map(lambda y: y * 1.0, averaged_values) + self._obj_min = jax.tree_map(lambda y: y * 1.0, averaged_values) + self._obj = jax.tree_map(lambda y: y * num_samples, averaged_values) + self._num_samples = num_samples + else: + self._obj_max = jax.tree_multimap(jnp.maximum, self._obj_max, + averaged_values) + self._obj_min = jax.tree_multimap(jnp.minimum, self._obj_min, + averaged_values) + self._obj = jax.tree_multimap(lambda x, y: x + y * num_samples, self._obj, + averaged_values) + self._num_samples += num_samples + + def value(self): + return jax.tree_map(lambda x: x / self._num_samples, self._obj) + + def max(self): + return jax.tree_map(float, self._obj_max) + + def min(self): + return jax.tree_map(float, self._obj_min) + + def sum(self): + return self._obj + + +register_pytree_node( + distrax.Normal, + lambda instance: ([instance.loc, instance.scale], None), + lambda _, args: distrax.Normal(*args) +) + + +def inner_product(x: Any, y: Any) -> jnp.ndarray: + products = jax.tree_multimap(lambda x_, y_: jnp.sum(x_ * y_), x, y) + return sum(jax.tree_leaves(products)) + + +get_first = utils.get_first +bcast_local_devices = utils.bcast_local_devices +py_prefetch = utils.py_prefetch +p_split = jax.pmap(lambda x, num: list(jax.random.split(x, num)), + static_broadcasted_argnums=1) + + +def wrap_if_pmap(p_func): + def p_func_if_pmap(obj, axis_name): + try: + core.axis_frame(axis_name) + return p_func(obj, axis_name) + except NameError: + return obj + return p_func_if_pmap + + +pmean_if_pmap = wrap_if_pmap(lax.pmean) +psum_if_pmap = wrap_if_pmap(lax.psum)