Files
deepmind-research/physics_inspired_models/integrators.py
T
Alex Botev 2c7c401024 Open sourcing the physics inspired models code.
PiperOrigin-RevId: 408640606
2021-11-09 22:56:08 +00:00

1042 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing the implementations of the various numerical integrators.
Higher order methods mostly taken from [1].
References:
[1] Leimkuhler, Benedict and Sebastian Reich. Simulating hamiltonian dynamics.
Vol. 14. Cambridge university press, 2004.
[2] Forest, Etienne and Ronald D. Ruth. Fourth-order symplectic integration.
Physica D: Nonlinear Phenomena 43.1 (1990): 105-117.
[3] Blanes, Sergio and Per Christian Moan. Practical symplectic partitioned
RungeKutta and RungeKuttaNyström methods. Journal of Computational and
Applied Mathematics 142.2 (2002): 313-330.
[4] McLachlan, Robert I. On the numerical integration of ordinary differential
equations by symmetric composition methods. SIAM Journal on Scientific
Computing 16.1 (1995): 151-168.
[5] Yoshida, Haruo. Construction of higher order symplectic integrators.
Physics letters A 150.5-7 (1990): 262-268.
[6] Süli, Endre; Mayers, David (2003), An Introduction to Numerical Analysis,
Cambridge University Press, ISBN 0-521-00794-1.
[7] Hairer, Ernst; Nørsett, Syvert Paul; Wanner, Gerhard (1993), Solving
ordinary differential equations I: Nonstiff problems, Berlin, New York:
Springer-Verlag, ISBN 978-3-540-56670-0.
"""
from typing import Callable, Dict, Optional, Sequence, Tuple, TypeVar, Union
from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space
import jax
from jax import lax
from jax.experimental import ode
import jax.numpy as jnp
import numpy as np
M = TypeVar("M")
TM = TypeVar("TM")
TimeInterval = Union[jnp.ndarray, Tuple[float, float]]
# _____ _
# / ____| | |
# | | __ ___ _ __ ___ _ __ __ _| |
# | | |_ |/ _ \ '_ \ / _ \ '__/ _` | |
# | |__| | __/ | | | __/ | | (_| | |
# \_____|\___|_| |_|\___|_| \__,_|_|
# _____ _ _ _
# |_ _| | | | | (_)
# | | _ __ | |_ ___ __ _ _ __ __ _| |_ _ ___ _ __
# | | | '_ \| __/ _ \/ _` | '__/ _` | __| |/ _ \| '_ \
# _| |_| | | | || __/ (_| | | | (_| | |_| | (_) | | | |
# |_____|_| |_|\__\___|\__, |_| \__,_|\__|_|\___/|_| |_|
# __/ |
# |___/
GeneralTangentFunction = Callable[
[
Optional[Union[float, jnp.ndarray]], # t
M # y
],
TM # dy_dt
]
GeneralIntegrator = Callable[
[
GeneralTangentFunction,
Optional[Union[float, jnp.ndarray]], # t
M, # y
jnp.ndarray, # dt
],
M # y_next
]
def solve_ivp_dt(
fun: GeneralTangentFunction,
y0: M,
t0: Union[float, jnp.ndarray],
dt: Union[float, jnp.ndarray],
method: Union[str, GeneralIntegrator],
num_steps: Optional[int] = None,
steps_per_dt: int = 1,
use_scan: bool = True,
ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None
) -> Tuple[jnp.ndarray, M]:
"""Solve an initial value problem for a system of ODEs using explicit method.
This function numerically integrates a system of ordinary differential
equations given an initial value::
dy / dt = f(t, y)
y(t0) = y0
Here t is a one-dimensional independent variable (time), y(t) is an
n-dimensional vector-valued function (state), and an n-dimensional
vector-valued function f(t, y) determines the differential equations.
The goal is to find y(t) approximately satisfying the differential
equations, given an initial value y(t0)=y0.
All of the solvers supported here are explicit and non-adaptive. This makes
them easy to run with a fixed amount of computation and ensures solutions are
easily differentiable.
Args:
fun: callable
Right-hand side of the system. The calling signature is ``fun(t, y)``.
Here `t` is a scalar representing the time instance. `y` can be any
type `M`, including a flat array, that is registered as a
pytree. In addition, there is a type denoted as `TM` that represents
the tangent space to `M`. It is assumed that any element of `TM` can be
multiplied by arrays and scalars, can be added to other `TM` instances
as well as they can be right added to an element of `M`, that is
add(M, TM) exists. The function should return an element of `TM` that
defines the time derivative of `y`.
y0: an instance of `M`
Initial state at `t_span[0]`.
t0: float or array.
The initial time point of integration.
dt: array
Array containing all consecutive increments in time, at which the integral
to be evaluated. The size of this array along axis 0 defines the number of
steps that the integrator would do.
method: string or `GeneralIntegrator`
The integrator method to use. Possible values for string are:
* general_euler - see `GeneralEuler`
* rk2 - see `RungaKutta2`
* rk4 - see `RungaKutta4`
* rk38 - see `RungaKutta38`
num_steps: Optional int.
If provided the `dt` will be treated as the same per step time interval,
applied for this many steps. In other words setting this argument is
equivalent to replicating `dt` num_steps times and stacking over axis=0.
steps_per_dt: int
This determines the overall step size. Between any two values of t_eval
the step size is `dt = (t_eval[i+1] - t_eval[i]) / steps_per_dt.
use_scan: bool
Whether for the loop to use `lax.scan` or a python loop
ode_int_kwargs: dict
Extra arguments to be passed to `ode.odeint` when method="adaptive"
Returns:
t: array
Time points at which the solution is evaluated.
y : an instance of M
Values of the solution at `t`.
"""
if method == "adaptive":
ndim = y0.q.ndim if isinstance(y0, phase_space.PhaseSpace) else y0.ndim
signs = jnp.asarray(jnp.sign(dt))
signs = signs.reshape([-1] + [1] * (ndim - 1))
if isinstance(dt, float) or dt.ndim == 0:
true_t_eval = t0 + dt * np.arange(1, num_steps + 1)
else:
true_t_eval = t0 + dt[None] * np.arange(1, num_steps + 1)[:, None]
if isinstance(dt, float):
dt = np.asarray(dt)
if isinstance(dt, np.ndarray) and dt.ndim > 0:
if np.all(np.abs(dt) != np.abs(dt[0])):
raise ValueError("Not all values of `dt` where the same.")
elif isinstance(dt, jnp.ndarray) and dt.ndim > 0:
raise ValueError("The code here works only when `dy_dt` is time "
"independent and `np.abs(dt)` is the same. For this we "
"allow calling this only with numpy (not jax.numpy) "
"arrays.")
dt: jnp.ndarray = jnp.abs(jnp.asarray(dt))
dt = dt.reshape([-1])[0]
t_eval = t0 + dt * np.arange(num_steps + 1)
outputs = ode.odeint(
func=lambda y_, t_: fun(None, y_) * signs,
y0=y0,
t=jnp.abs(t_eval - t0),
**(ode_int_kwargs or dict())
)
# Note that we do not return the initial point
return true_t_eval, jax.tree_map(lambda x: x[1:], outputs)
method = get_integrator(method)
if num_steps is not None:
dt = jnp.repeat(jnp.asarray(dt)[None], repeats=num_steps, axis=0)
t_eval = t0 + jnp.cumsum(dt, axis=0)
t0 = jnp.ones_like(t_eval[..., :1]) * t0
t = jnp.concatenate([t0, t_eval[..., :-1]], axis=-1)
def loop_body(y_: M, t_dt: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[M, M]:
t_, dt_ = t_dt
dt_: jnp.ndarray = dt_ / steps_per_dt
for _ in range(steps_per_dt):
y_ = method(fun, t_, y_, dt_)
t_ = t_ + dt_
return y_, y_
if use_scan:
return t_eval, lax.scan(loop_body, init=y0, xs=(t, dt))[1]
else:
y = [y0]
for t_and_dt_i in zip(t, dt):
y.append(loop_body(y[-1], t_and_dt_i)[0])
# Note that we do not return the initial point
return t_eval, jax.tree_multimap(lambda *args: jnp.stack(args, axis=0),
*y[1:])
def solve_ivp_dt_two_directions(
fun: GeneralTangentFunction,
y0: M,
t0: Union[float, jnp.ndarray],
dt: Union[float, jnp.ndarray],
method: Union[str, GeneralIntegrator],
num_steps_forward: int,
num_steps_backward: int,
include_y0: bool = True,
steps_per_dt: int = 1,
use_scan: bool = True,
ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None
) -> M:
"""Equivalent to `solve_ivp_dt` but you can specify unrolling the problem for a fixed number of steps in both time directions."""
yt = []
if num_steps_backward > 0:
yt_bck = solve_ivp_dt(
fun=fun,
y0=y0,
t0=t0,
dt=- dt,
method=method,
num_steps=num_steps_backward,
steps_per_dt=steps_per_dt,
use_scan=use_scan,
ode_int_kwargs=ode_int_kwargs
)[1]
yt.append(jax.tree_map(lambda x: jnp.flip(x, axis=0), yt_bck))
if include_y0:
yt.append(jax.tree_map(lambda x: x[None], y0))
if num_steps_forward > 0:
yt_fwd = solve_ivp_dt(
fun=fun,
y0=y0,
t0=t0,
dt=dt,
method=method,
num_steps=num_steps_forward,
steps_per_dt=steps_per_dt,
use_scan=use_scan,
ode_int_kwargs=ode_int_kwargs
)[1]
yt.append(yt_fwd)
if len(yt) > 1:
return jax.tree_multimap(lambda *a: jnp.concatenate(a, axis=0), *yt)
else:
return yt[0]
def solve_ivp_t_eval(
fun: GeneralTangentFunction,
t_span: TimeInterval,
y0: M,
method: Union[str, GeneralIntegrator],
t_eval: Optional[jnp.ndarray] = None,
steps_per_dt: int = 1,
use_scan: bool = True,
ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None
) -> Tuple[jnp.ndarray, M]:
"""Solve an initial value problem for a system of ODEs using an explicit method.
This function numerically integrates a system of ordinary differential
equations given an initial value::
dy / dt = f(t, y)
y(t0) = y0
Here t is a one-dimensional independent variable (time), y(t) is an
n-dimensional vector-valued function (state), and an n-dimensional
vector-valued function f(t, y) determines the differential equations.
The goal is to find y(t) approximately satisfying the differential
equations, given an initial value y(t0)=y0.
All of the solvers supported here are explicit and non-adaptive. This in
terms makes them easy to run with fixed amount of computation and
the solutions to be easily differentiable.
Args:
fun: callable
Right-hand side of the system. The calling signature is ``fun(t, y)``.
Here `t` is a scalar representing the time instance. `y` can be any
type `M`, including a flat array, that is registered as a
pytree. In addition, there is a type denoted as `TM` that represents
the tangent space to `M`. It is assumed that any element of `TM` can be
multiplied by arrays and scalars, can be added to other `TM` instances
as well as they can be right added to an element of `M`, that is
add(M, TM) exists. The function should return an element of `TM` that
defines the time derivative of `y`.
t_span: 2-tuple of floats
Interval of integration (t0, tf). The solver starts with t=t0 and
integrates until it reaches t=tf.
y0: an instance of `M`
Initial state at `t_span[0]`.
method: string or `GeneralIntegrator`
The integrator method to use. Possible values for string are:
* general_euler - see `GeneralEuler`
* rk2 - see `RungaKutta2`
* rk4 - see `RungaKutta4`
* rk38 - see `RungaKutta38`
t_eval: array or None.
Times at which to store the computed solution. Must be sorted and lie
within `t_span`. If None then t_eval = [t_span[-1]]
steps_per_dt: int
This determines the overall step size. Between any two values of t_eval
the step size is `dt = (t_eval[i+1] - t_eval[i]) / steps_per_dt.
use_scan: bool
Whether for the loop to use `lax.scan` or a python loop
ode_int_kwargs: dict
Extra arguments to be passed to `ode.odeint` when method="adaptive"
Returns:
t: array
Time points at which the solution is evaluated.
y : an instance of M
Values of the solution at `t`.
"""
# Check for t_eval
if t_eval is None:
t_eval = np.asarray([t_span[-1]])
if isinstance(t_span[0], float) and isinstance(t_span[1], float):
t_span = np.asarray(t_span)
elif isinstance(t_span[0], float) and isinstance(t_span[1], jnp.ndarray):
t_span = (np.full_like(t_span[1], t_span[0]), t_span[1])
t_span = np.stack(t_span, axis=0)
elif isinstance(t_span[1], float) and isinstance(t_span[0], jnp.ndarray):
t_span = (t_span[0], jnp.full_like(t_span[0], t_span[1]))
t_span = np.stack(t_span, axis=0)
else:
t_span = np.stack(t_span, axis=0)
def check_span(span, ts):
# Verify t_span and t_eval
if span[0] < span[1]:
# Forward in time
if not np.all(np.logical_and(span[0] <= ts, ts <= span[1])):
raise ValueError("Values in `t_eval` are not within `t_span`.")
if not np.all(ts[:-1] < ts[1:]):
raise ValueError("Values in `t_eval` are not properly sorted.")
else:
# Backward in time
if not np.all(np.logical_and(span[0] >= ts, ts >= span[1])):
raise ValueError("Values in `t_eval` are not within `t_span`.")
if not np.all(ts[:-1] > ts[1:]):
raise ValueError("Values in `t_eval` are not properly sorted.")
if t_span.ndim == 1:
check_span(t_span, t_eval)
elif t_span.ndim == 2:
if t_eval.ndim != 2:
raise ValueError("t_eval should have rank 2.")
for i in range(t_span.shape[1]):
check_span(t_span[:, i], t_eval[:, i])
t = np.concatenate([t_span[:1], t_eval[:-1]], axis=0)
return solve_ivp_dt(
fun=fun,
y0=y0,
t0=t_span[0],
dt=t_eval - t,
method=method,
steps_per_dt=steps_per_dt,
use_scan=use_scan,
ode_int_kwargs=ode_int_kwargs
)
class RungaKutta(GeneralIntegrator):
"""A general Runga-Kutta integrator defined using a Butcher tableau."""
def __init__(
self,
a_tableau: Sequence[Sequence[float]],
b_tableau: Sequence[float],
c_tableau: Sequence[float],
order: int):
if len(b_tableau) != len(c_tableau) + 1:
raise ValueError("The length of b_tableau should be exactly one more than"
" the length of c_tableau.")
if len(b_tableau) != len(a_tableau) + 1:
raise ValueError("The length of b_tableau should be exactly one more than"
" the length of a_tableau.")
self.a_tableau = a_tableau
self.b_tableau = b_tableau
self.c_tableau = c_tableau
self.order = order
def __call__(
self,
tangent_func: GeneralTangentFunction,
t: jnp.ndarray,
y: M,
dt: jnp.ndarray
) -> M: # pytype: disable=invalid-annotation
k = [tangent_func(t, y)]
zero = jax.tree_map(jnp.zeros_like, k[0])
# We always broadcast opposite to numpy (e.g. leading dims (batch) count)
if dt.ndim > 0:
dt = dt.reshape(dt.shape + (1,) * (y.ndim - dt.ndim))
if t.ndim > 0:
t = t.reshape(t.shape + (1,) * (y.ndim - t.ndim))
for c_n, a_n_row in zip(self.c_tableau, self.a_tableau):
t_n = t + dt * c_n
products = [a_i * k_i for a_i, k_i in zip(a_n_row, k) if a_i != 0.0]
delta_n = sum(products, zero)
y_n = y + dt * delta_n
k.append(tangent_func(t_n, y_n))
products = [b_i * k_i for b_i, k_i in zip(self.b_tableau, k) if b_i != 0.0]
delta = sum(products, zero)
return y + dt * delta
class GeneralEuler(RungaKutta):
"""The standard Euler method (for general ODE problems)."""
def __init__(self):
super().__init__(
a_tableau=[],
b_tableau=[1.0],
c_tableau=[],
order=1
)
class RungaKutta2(RungaKutta):
"""The second order Runga-Kutta method corresponding to the mid-point rule."""
def __init__(self):
super().__init__(
a_tableau=[[1.0 / 2.0]],
b_tableau=[0.0, 1.0],
c_tableau=[1.0 / 2.0],
order=2
)
class RungaKutta4(RungaKutta):
"""The fourth order Runga-Kutta method from [6]."""
def __init__(self):
super().__init__(
a_tableau=[[1.0 / 2.0],
[0.0, 1.0 / 2.0],
[0.0, 0.0, 1.0]],
b_tableau=[1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0],
c_tableau=[1.0 / 2.0, 1.0 / 2.0, 1.0],
order=4
)
class RungaKutta38(RungaKutta):
"""The fourth order 3/8 rule Runga-Kutta method from [7]."""
def __init__(self):
super().__init__(
a_tableau=[[1.0 / 3.0],
[-1.0 / 3.0, 1.0],
[1.0, -1.0, 1.0]],
b_tableau=[1.0 / 8.0, 3.0 / 8.0, 3.0 / 8.0, 1.0 / 8.0],
c_tableau=[1.0 / 3.0, 2.0 / 3.0, 1.0],
order=4
)
# _____ _ _ _
# / ____| | | | | (_)
# | (___ _ _ _ __ ___ _ __ | | ___ ___| |_ _ ___
# \___ \| | | | '_ ` _ \| '_ \| |/ _ \/ __| __| |/ __|
# ____) | |_| | | | | | | |_) | | __/ (__| |_| | (__
# |_____/ \__, |_| |_| |_| .__/|_|\___|\___|\__|_|\___|
# __/ | | |
# |___/ |_|
# _____ _ _ _
# |_ _| | | | | (_)
# | | _ __ | |_ ___ __ _ _ __ __ _| |_ _ ___ _ __
# | | | '_ \| __/ _ \/ _` | '__/ _` | __| |/ _ \| '_ \
# _| |_| | | | || __/ (_| | | | (_| | |_| | (_) | | | |
# |_____|_| |_|\__\___|\__, |_| \__,_|\__|_|\___/|_| |_|
# __/ |
# |___/
SymplecticIntegrator = Callable[
[
phase_space.SymplecticTangentFunction,
jnp.ndarray, # t
phase_space.PhaseSpace, # (q, p)
jnp.ndarray, # dt
],
phase_space.PhaseSpace # (q_next, p_next)
]
def solve_hamiltonian_ivp_dt(
hamiltonian: phase_space.HamiltonianFunction,
y0: phase_space.PhaseSpace,
t0: Union[float, jnp.ndarray],
dt: Union[float, jnp.ndarray],
method: Union[str, SymplecticIntegrator],
num_steps: Optional[int] = None,
steps_per_dt: int = 1,
use_scan: bool = True,
ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None
) -> Tuple[jnp.ndarray, phase_space.PhaseSpace]:
"""Solve an initial value problem for a Hamiltonian system.
This function numerically integrates a Hamiltonian system given an
initial value::
dq / dt = dH / dp
dp / dt = - dH / dq
q(t0), p(t0) = y0.q, y0.p
Here t is a one-dimensional independent variable (time), y(t) is an
n-dimensional vector-valued function (state), and an n-dimensional
vector-valued function H(t, q, p) determines the value of the Hamiltonian.
The goal is to find q(t) and p(t) approximately satisfying the differential
equations, given an initial values q(t0), p(t0) = y0.q, y0.p
All of the solvers supported here are explicit and non-adaptive. This in
terms makes them easy to run with fixed amount of computation and
the solutions to be easily differentiable.
Args:
hamiltonian: callable
The Hamiltonian function. The calling signature is ``h(t, s)``, where
`s` is an instance of `PhaseSpace`.
y0: an instance of `M`
Initial state at t=t0.
t0: float or array.
The initial time point of integration.
dt: array
Array containing all consecutive increments in time, at which the integral
to be evaluated. The size of this array along axis 0 defines the number of
steps that the integrator would do.
method: string or `GeneralIntegrator`
The integrator method to use. Possible values for string are:
* symp_euler - see `SymplecticEuler`
* symp_euler_q - a `SymplecticEuler` with position_first=True
* symp_euler_p - a `SymplecticEuler` with position_first=False
* leap_frog - see `LeapFrog`
* leap_frog_q - a `LeapFrog` with position_first=True
* leap_frog_p - a `LeapFrog` with position_first=False
* stormer_verlet - same as leap_frog
* stormer_verlet_q - same as leap_frog_q
* stormer_verlet_p - same as leap_frog_p
* ruth4 - see `Ruth4`,
* sym4 - see `Symmetric4`
* sym6 - see `Symmetric6`
* so4 - see `SymmetricSo4`
* so4_q - a `SymmetricSo4` with position_first=True
* so4_p - a `SymmetricSo4` with position_first=False
* so6 - see `SymmetricSo6`
* so6_q - a `SymmetricSo6` with position_first=True
* so6_p - a `SymmetricSo6` with position_first=False
* so8 - see `SymmetricSo8`
* so8_q - a `SymmetricSo8` with position_first=True
* so8_p - a `SymmetricSo8` with position_first=False
num_steps: Optional int.
If provided the `dt` will be treated as the same per step time interval,
applied for this many steps. In other words setting this argument is
equivalent to replicating `dt` num_steps times and stacking over axis=0.
steps_per_dt: int
This determines the overall step size. Between any two values of t_eval
the step size is `dt = (t_eval[i+1] - t_eval[i]) / steps_per_dt.
use_scan: bool
Whether for the loop to use `lax.scan` or a python loop
ode_int_kwargs: dict
Extra arguments to be passed to `ode.odeint` when method="adaptive"
Returns:
t: array
Time points at which the solution is evaluated.
y : an instance of M
Values of the solution at `t`.
"""
if not isinstance(y0, phase_space.PhaseSpace):
raise ValueError("The initial state must be an instance of `PhaseSpace`.")
dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
return solve_ivp_dt(
fun=dy_dt,
y0=y0,
t0=t0,
dt=dt,
method=method,
num_steps=num_steps,
steps_per_dt=steps_per_dt,
use_scan=use_scan,
ode_int_kwargs=ode_int_kwargs
)
def solve_hamiltonian_ivp_t_eval(
hamiltonian: phase_space.HamiltonianFunction,
t_span: TimeInterval,
y0: phase_space.PhaseSpace,
method: Union[str, SymplecticIntegrator],
t_eval: Optional[jnp.ndarray] = None,
steps_per_dt: int = 1,
use_scan: bool = True,
ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None
) -> Tuple[jnp.ndarray, phase_space.PhaseSpace]:
"""Solve an initial value problem for a Hamiltonian system.
This function numerically integrates a Hamiltonian system given an
initial value::
dq / dt = dH / dp
dp / dt = - dH / dq
q(t0), p(t0) = y0.q, y0.p
Here t is a one-dimensional independent variable (time), y(t) is an
n-dimensional vector-valued function (state), and an n-dimensional
vector-valued function H(t, q, p) determines the value of the Hamiltonian.
The goal is to find q(t) and p(t) approximately satisfying the differential
equations, given an initial values q(t0), p(t0) = y0.q, y0.p
All of the solvers supported here are explicit and non-adaptive. This in
terms makes them easy to run with fixed amount of computation and
the solutions to be easily differentiable.
Args:
hamiltonian: callable
The Hamiltonian function. The calling signature is ``h(t, s)``, where
`s` is an instance of `PhaseSpace`.
t_span: 2-tuple of floats
Interval of integration (t0, tf). The solver starts with t=t0 and
integrates until it reaches t=tf.
y0: an instance of `M`
Initial state at `t_span[0]`.
method: string or `GeneralIntegrator`
The integrator method to use. Possible values for string are:
* symp_euler - see `SymplecticEuler`
* symp_euler_q - a `SymplecticEuler` with position_first=True
* symp_euler_p - a `SymplecticEuler` with position_first=False
* leap_frog - see `LeapFrog`
* leap_frog_q - a `LeapFrog` with position_first=True
* leap_frog_p - a `LeapFrog` with position_first=False
* stormer_verlet - same as leap_frog
* stormer_verlet_q - same as leap_frog_q
* stormer_verlet_p - same as leap_frog_p
* ruth4 - see `Ruth4`,
* sym4 - see `Symmetric4`
* sym6 - see `Symmetric6`
* so4 - see `SymmetricSo4`
* so4_q - a `SymmetricSo4` with position_first=True
* so4_p - a `SymmetricSo4` with position_first=False
* so6 - see `SymmetricSo6`
* so6_q - a `SymmetricSo6` with position_first=True
* so6_p - a `SymmetricSo6` with position_first=False
* so8 - see `SymmetricSo8`
* so8_q - a `SymmetricSo8` with position_first=True
* so8_p - a `SymmetricSo8` with position_first=False
t_eval: array or None.
Times at which to store the computed solution. Must be sorted and lie
within `t_span`. If None then t_eval = [t_span[-1]]
steps_per_dt: int
This determines the overall step size. Between any two values of t_eval
the step size is `dt = (t_eval[i+1] - t_eval[i]) / steps_per_dt.
use_scan: bool
Whether for the loop to use `lax.scan` or a python loop
ode_int_kwargs: dict
Extra argumrnts to be passed to `ode.odeint` when method="adaptive"
Returns:
t: array
Time points at which the solution is evaluated.
y : an instance of M
Values of the solution at `t`.
"""
if not isinstance(y0, phase_space.PhaseSpace):
raise ValueError("The initial state must be an instance of `PhaseSpace`.")
dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
if method == "adaptive":
dy_dt = phase_space.transform_symplectic_tangent_function_using_array(dy_dt)
return solve_ivp_t_eval(
fun=dy_dt,
t_span=t_span,
y0=y0,
method=method,
t_eval=t_eval,
steps_per_dt=steps_per_dt,
use_scan=use_scan,
ode_int_kwargs=ode_int_kwargs
)
class CompositionSymplectic(SymplecticIntegrator):
"""A generalized symplectic integrator based on compositions.
Simulates Hamiltonian dynamics using a composition of symplectic steps:
q_{0} = q_init, p_{0} = p_init
for i in [1, n]:
p_{i+1} = p_{i} - c_{i} * dH/dq(q_{i}) * dt
q_{i+1} = q_{i} + d_{i} * dH/dp(p_{i+1}) * dt
q_next = q_{n}, p_next = p_{n}
This integrator always starts with updating the momentum.
The order argument is used mainly for testing to estimate the error when
integrating various systems.
"""
def __init__(
self,
momentum_coefficients: Sequence[float],
position_coefficients: Sequence[float],
order: int):
if len(position_coefficients) != len(momentum_coefficients):
raise ValueError("The number of momentum_coefficients and "
"position_coefficients must be the same.")
if not np.allclose(sum(position_coefficients), 1.0):
raise ValueError("The sum of the position_coefficients "
"must be equal to 1.")
if not np.allclose(sum(momentum_coefficients), 1.0):
raise ValueError("The sum of the momentum_coefficients "
"must be equal to 1.")
self.momentum_coefficients = momentum_coefficients
self.position_coefficients = position_coefficients
self.order = order
def __call__(
self,
tangent_func: phase_space.SymplecticTangentFunction,
t: jnp.ndarray,
y: phase_space.PhaseSpace,
dt: jnp.ndarray
) -> phase_space.PhaseSpace:
q, p = y.q, y.p
# This is intentional to prevent a bug where one uses y later
del y
# We always broadcast opposite to numpy (e.g. leading dims (batch) count)
if dt.ndim > 0:
dt = dt.reshape(dt.shape + (1,) * (q.ndim - dt.ndim))
if t.ndim > 0:
t = t.reshape(t.shape + (1,) * (q.ndim - t.ndim))
t_q = t
t_p = t
for c, d in zip(self.momentum_coefficients, self.position_coefficients):
# Update momentum
if c != 0.0:
dp_dt = tangent_func(t_p, phase_space.PhaseSpace(q, p)).p
p = p + c * dt * dp_dt
t_p = t_p + c * dt
# Update position
if d != 0.0:
dq_dt = tangent_func(t_q, phase_space.PhaseSpace(q, p)).q
q = q + d * dt * dq_dt
t_q = t_q + d * dt
return phase_space.PhaseSpace(position=q, momentum=p)
class SymplecticEuler(CompositionSymplectic):
"""The symplectic Euler method (for Hamiltonian systems).
If position_first = True:
q_{t+1} = q_{t} + dH/dp(p_{t}) * dt
p_{t+1} = p_{t} - dH/dq(q_{t+1}) * dt
else:
p_{t+1} = p_{t} - dH/dq(q_{t}) * dt
q_{t+1} = q_{t} + dH/dp(p_{t+1}) * dt
"""
def __init__(self, position_first=True):
if position_first:
super().__init__(
momentum_coefficients=[0.0, 1.0],
position_coefficients=[1.0, 0.0],
order=1
)
else:
super().__init__(
momentum_coefficients=[1.0],
position_coefficients=[1.0],
order=1
)
class SymmetricCompositionSymplectic(CompositionSymplectic):
"""A generalized composition integrator that is symmetric.
The integrators produced are always of the form:
[update_q, update_p, ..., update_p, update_q]
or
[update_p, update_q, ..., update_q, update_p]
based on the position_first argument. The method will expect which ever is
updated first to have one more coefficient.
"""
def __init__(
self,
momentum_coefficients: Sequence[float],
position_coefficients: Sequence[float],
position_first: bool,
order: int):
position_coefficients = list(position_coefficients)
momentum_coefficients = list(momentum_coefficients)
if position_first:
if len(position_coefficients) != len(momentum_coefficients) + 1:
raise ValueError("The number of position_coefficients must be one more "
"than momentum_coefficients when position_first=True.")
momentum_coefficients = [0.0] + momentum_coefficients
else:
if len(position_coefficients) + 1 != len(momentum_coefficients):
raise ValueError("The number of momentum_coefficients must be one more "
"than position_coefficients when position_first=True.")
position_coefficients = position_coefficients + [0.0]
super().__init__(
position_coefficients=position_coefficients,
momentum_coefficients=momentum_coefficients,
order=order
)
def symmetrize_coefficients(
coefficients: Sequence[float],
odd_number: bool
) -> Sequence[float]:
"""Symmetrizes the coefficients for an integrator."""
coefficients = list(coefficients)
if odd_number:
final = 1.0 - 2.0 * sum(coefficients)
return coefficients + [final] + coefficients[::-1]
else:
final = 0.5 - sum(coefficients)
return coefficients + [final, final] + coefficients[::-1]
class LeapFrog(SymmetricCompositionSymplectic):
"""The standard Leap-Frog method (also known as Stormer-Verlet).
If position_first = True:
q_half = q_{t} + dH/dp(p_{t}) * dt / 2
p_{t+1} = p_{t} - dH/dq(q_half) * dt
q_{t+1} = q_half + dH/dp(p_{t+1}) * dt / 2
else:
p_half = p_{t} - dH/dq(q_{t}) * dt / 2
q_{t+1} = q_{t} + dH/dp(p_half) * dt
p_{t+1} = p_half - dH/dq(q_{t+1}) * dt / 2
"""
def __init__(self, position_first=False):
if position_first:
super().__init__(
position_coefficients=[0.5, 0.5],
momentum_coefficients=[1.0],
position_first=True,
order=2
)
else:
super().__init__(
position_coefficients=[1.0],
momentum_coefficients=[0.5, 0.5],
position_first=False,
order=2
)
class Ruth4(SymmetricCompositionSymplectic):
"""The Fourth order method from [2]."""
def __init__(self):
cbrt_2 = float(np.cbrt(2.0))
c = [1.0 / (2.0 - cbrt_2)]
# 3: [c1, 1.0 - 2*c1, c1]
c = symmetrize_coefficients(c, odd_number=True)
d = [1.0 / (4.0 - 2.0 * cbrt_2)]
# 4: [d1, 0.5 - d1, 0.5 - d1, d1]
d = symmetrize_coefficients(d, odd_number=False)
super().__init__(
position_coefficients=d,
momentum_coefficients=c,
position_first=True,
order=4
)
class Symmetric4(SymmetricCompositionSymplectic):
"""The fourth order method from Table 6.1 in [1] (originally from [3])."""
def __init__(self):
c = [0.0792036964311957, 0.353172906049774, -0.0420650803577195]
# 7 : [c1, c2, c3, 1.0 - c1 - c2 - c3, c3, c2, c1]
c = symmetrize_coefficients(c, odd_number=True)
d = [0.209515106613362, -0.143851773179818]
# 6: [d1, d2, 0.5 - d1, 0.5 - d1, d2, d1]
d = symmetrize_coefficients(d, odd_number=False)
super().__init__(
position_coefficients=d,
momentum_coefficients=c,
position_first=False,
order=4
)
class Symmetric6(SymmetricCompositionSymplectic):
"""The sixth order method from Table 6.1 in [1] (originally from [3])."""
def __init__(self):
c = [0.0502627644003922, 0.413514300428344, 0.0450798897943977,
-0.188054853819569, 0.541960678450780]
# 11 : [c1, c2, c3, c4, c5, 1.0 - sum(ci), c5, c4, c3, c2, c1]
c = symmetrize_coefficients(c, odd_number=True)
d = [0.148816447901042, -0.132385865767784, 0.067307604692185,
0.432666402578175]
# 10: [d1, d2, d3, d4, 0.5 - sum(di), 0.5 - sum(di), d4, d3, d2, d1]
d = symmetrize_coefficients(d, odd_number=False)
super().__init__(
position_coefficients=d,
momentum_coefficients=c,
position_first=False,
order=4
)
def coefficients_based_on_composing_second_order(
weights: Sequence[float]
) -> Tuple[Sequence[float], Sequence[float]]:
"""Constructs the coefficients for methods based on second-order schemes."""
coefficients_0 = []
coefficients_1 = []
coefficients_0.append(weights[0] / 2.0)
for i in range(len(weights) - 1):
coefficients_1.append(weights[i])
coefficients_0.append((weights[i] + weights[i + 1]) / 2.0)
coefficients_1.append(weights[-1])
coefficients_0.append(weights[-1] / 2.0)
return coefficients_0, coefficients_1
class SymmetricSo4(SymmetricCompositionSymplectic):
"""The fourth order method from Table 6.2 in [1] (originally from [4])."""
def __init__(self, position_first: bool = False):
w = [0.28, 0.62546642846767004501]
# 5
w = symmetrize_coefficients(w, odd_number=True)
c0, c1 = coefficients_based_on_composing_second_order(w)
c_q, c_p = (c0, c1) if position_first else (c1, c0)
super().__init__(
position_coefficients=c_q,
momentum_coefficients=c_p,
position_first=position_first,
order=4
)
class SymmetricSo6(SymmetricCompositionSymplectic):
"""The sixth order method from Table 6.2 in [1] (originally from [5])."""
def __init__(self, position_first: bool = False):
w = [0.78451361047755726382, 0.23557321335935813368,
-1.17767998417887100695]
# 7
w = symmetrize_coefficients(w, odd_number=True)
c0, c1 = coefficients_based_on_composing_second_order(w)
c_q, c_p = (c0, c1) if position_first else (c1, c0)
super().__init__(
position_coefficients=c_q,
momentum_coefficients=c_p,
position_first=position_first,
order=6
)
class SymmetricSo8(SymmetricCompositionSymplectic):
"""The eighth order method from Table 6.2 in [1] (originally from [4])."""
def __init__(self, position_first: bool = False):
w = [0.74167036435061295345, -0.40910082580003159400,
0.19075471029623837995, -0.57386247111608226666,
0.29906418130365592384, 0.33462491824529818378,
0.31529309239676659663]
# 15
w = symmetrize_coefficients(w, odd_number=True)
c0, c1 = coefficients_based_on_composing_second_order(w)
c_q, c_p = (c0, c1) if position_first else (c1, c0)
super().__init__(
position_coefficients=c_q,
momentum_coefficients=c_p,
position_first=position_first,
order=8
)
general_integrators = dict(
general_euler=GeneralEuler(),
rk2=RungaKutta2(),
rk4=RungaKutta4(),
rk38=RungaKutta38()
)
symplectic_integrators = dict(
symp_euler=SymplecticEuler(position_first=True),
symp_euler_q=SymplecticEuler(position_first=True),
symp_euler_p=SymplecticEuler(position_first=False),
leap_frog=LeapFrog(position_first=False),
leap_frog_q=LeapFrog(position_first=True),
leap_frog_p=LeapFrog(position_first=False),
stormer_verlet=LeapFrog(position_first=False),
stormer_verlet_q=LeapFrog(position_first=True),
stormer_verlet_p=LeapFrog(position_first=False),
ruth4=Ruth4(),
sym4=Symmetric4(),
sym6=Symmetric6(),
so4=SymmetricSo4(position_first=False),
so4_q=SymmetricSo4(position_first=True),
so4_p=SymmetricSo4(position_first=False),
so6=SymmetricSo6(position_first=False),
so6_q=SymmetricSo6(position_first=True),
so6_p=SymmetricSo6(position_first=False),
so8=SymmetricSo8(position_first=False),
so8_q=SymmetricSo8(position_first=True),
so8_p=SymmetricSo8(position_first=False),
)
def get_integrator(
name_or_callable: Union[str, GeneralIntegrator]
) -> GeneralIntegrator:
"""Returns any integrator with the provided name or the argument."""
if isinstance(name_or_callable, str):
if name_or_callable in general_integrators:
return general_integrators[name_or_callable]
elif name_or_callable in symplectic_integrators:
return symplectic_integrators[name_or_callable]
else:
raise ValueError(f"Unrecognized integrator with name {name_or_callable}.")
if not callable(name_or_callable):
raise ValueError(f"Expected a callable, but got {type(name_or_callable)}.")
return name_or_callable