Ensure pip is up to date in kfac_ferminet_alpha/run.sh

Also create the `venv` in `/tmp/` rather than messing with the source tree.

PiperOrigin-RevId: 368225759
This commit is contained in:
Alistair Muldal
2021-04-13 17:02:33 +01:00
committed by Diego de Las Casas
parent ce4db84f12
commit d029e06aa2
23 changed files with 5814 additions and 0 deletions

View File

@@ -0,0 +1,38 @@
# Accompanying code for Better, Faster Fermionic Neural Networks
All package requirements are listed in `requirements.txt`.
## Contributing
This is purely research code, provided with no further intentions of support or
any guarantees of backward compatibility.
## Installation
```shell
git clone git@github.com:deepmind/deepmind-research.git
pip install deepmind_research/kfac_ferminet_alpha/
```
## Usage
You can find examples of how to use the codebase through the [FermiNet project].
We also provide an [example training script].
## Reference
**Better, Faster Fermionic Neural Networks**
James S. Spencer, David Pfau, Aleksandar Botev, and W. M. C. Foulkes.
URL: https://arxiv.org/abs/2011.07125.
**Optimizing Neural Networks with Kronecker-factored Approximate Curvature**
James Martens, Roger Grosse
URL: https://arxiv.org/abs/1503.05671
[FermiNet Project]: https://github.com/deepmind/ferminet/
[example training script]: https://github.com/deepmind/deepmind-research/kfac_ferminet_alpha/example.py

View File

@@ -0,0 +1,19 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for anything that an end user would use."""
from kfac_ferminet_alpha.loss_functions import register_normal_predictive_distribution
from kfac_ferminet_alpha.loss_functions import register_squared_error_loss
from kfac_ferminet_alpha.optimizer import Optimizer

View File

@@ -0,0 +1,496 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for all of the different curvature blocks."""
import abc
from typing import Any, Callable, Dict, Mapping, MutableMapping, Optional, Sequence, Union
import jax
from jax import core
import jax.numpy as jnp
from kfac_ferminet_alpha import tag_graph_matcher as tgm
from kfac_ferminet_alpha import utils
_Arrays = Sequence[jnp.ndarray]
_BlockInfo = Mapping[str, Any]
class CurvatureBlock(utils.Stateful, abc.ABC):
"""Top level class."""
def __init__(self, layer_tag_eq: tgm.jax_core.JaxprEqn):
super(CurvatureBlock, self).__init__()
self._layer_tag_eq = layer_tag_eq
@property
def layer_tag_primitive(self) -> tgm.tags.LayerTag:
assert isinstance(self._layer_tag_eq.primitive, tgm.tags.LayerTag)
return self._layer_tag_eq.primitive
@property
def outputs_shapes(self) -> Sequence[Sequence[int]]:
output_vars = self.layer_tag_primitive.split_all_inputs(
self._layer_tag_eq.invars)[0]
return jax.tree_map(lambda x: x.aval.shape, output_vars)
@property
def inputs_shapes(self) -> Sequence[Sequence[int]]:
input_vars = self.layer_tag_primitive.split_all_inputs(
self._layer_tag_eq.invars)[1]
return jax.tree_map(lambda x: x.aval.shape, input_vars)
@property
def params_shapes(self) -> Sequence[Sequence[int]]:
params_vars = self.layer_tag_primitive.split_all_inputs(
self._layer_tag_eq.invars)[2]
return jax.tree_map(lambda x: x.aval.shape, params_vars)
@abc.abstractmethod
def init(self, rng: jnp.ndarray) -> MutableMapping[str, Any]:
"""This initializes/creates all of the arrays for the state of the block.
Usually this would include the arrays used for storing the curvature
approximation, as well as the arrays for storing any approximate
inverses/powers of the curvature block.
Args:
rng: The Jax PRNG key to use if any of the state is supposed to be
initialized randomly.
Returns:
A mutable mapping of the state.
"""
@abc.abstractmethod
def update_curvature_matrix_estimate(
self,
info: _BlockInfo,
batch_size: int,
ema_old: Union[float, jnp.ndarray],
ema_new: Union[float, jnp.ndarray],
pmap_axis_name: str
) -> None:
pass
@abc.abstractmethod
def update_curvature_inverse_estimate(
self,
diagonal_weight: Union[float, jnp.ndarray],
pmap_axis_name: str
) -> None:
pass
@abc.abstractmethod
def multiply_matpower(
self,
vec: _Arrays,
exp: Union[float, int],
diagonal_weight: Union[float, jnp.ndarray]
) -> _Arrays:
pass
CurvatureBlockCtor = Callable[[core.JaxprEqn], CurvatureBlock]
@utils.Stateful.infer_class_state
class NaiveDiagonal(CurvatureBlock):
"""The naively estimated diagonal block."""
diagonal_factor: utils.WeightedMovingAverage
def init(self, rng: jnp.ndarray) -> Dict[str, Any]:
del rng
return dict(
diagonal_factor=utils.WeightedMovingAverage.zero(
self.outputs_shapes[0])
)
def update_curvature_matrix_estimate(
self,
info: _BlockInfo,
batch_size: int,
ema_old: Union[float, jnp.ndarray],
ema_new: Union[float, jnp.ndarray],
pmap_axis_name: str
) -> None:
dw, = info["outputs_tangent"]
diagonal_update = dw * dw / batch_size
self.diagonal_factor.update(diagonal_update, ema_old, ema_new)
self.diagonal_factor.sync(pmap_axis_name)
def update_curvature_inverse_estimate(
self,
diagonal_weight: Union[float, jnp.ndarray],
pmap_axis_name: str
) -> None:
pass
def multiply_matpower(
self,
vec: _Arrays,
exp: Union[float, int],
diagonal_weight: Union[float, jnp.ndarray]
) -> _Arrays:
w, = vec
if exp == 1:
return w * (self.diagonal_factor.value + diagonal_weight),
elif exp == -1:
return w / (self.diagonal_factor.value + diagonal_weight),
else:
raise NotImplementedError()
@utils.Stateful.infer_class_state
class TwoKroneckerFactored(CurvatureBlock, abc.ABC):
"""A factor that is the Kronecker product of two matrices."""
inputs_factor: utils.WeightedMovingAverage
inputs_factor_inverse: jnp.ndarray
outputs_factor: utils.WeightedMovingAverage
outputs_factor_inverse: jnp.ndarray
extra_scale: Optional[Union[int, float, jnp.ndarray]]
@property
def has_bias(self) -> bool:
return len(self._layer_tag_eq.invars) == 4
@abc.abstractmethod
def input_size(self) -> int:
pass
@abc.abstractmethod
def output_size(self) -> int:
pass
def compute_extra_scale(self) -> Optional[Union[int, float, jnp.ndarray]]:
return 1
def init(self, rng: jnp.ndarray) -> Dict[str, Any]:
# The extra scale is technically a constant, but in general it could be
# useful for anyone examining the state to know it explicitly,
# hence we actually keep it as part of the state.
d_in = self.input_size()
d_out = self.output_size()
return dict(
inputs_factor=utils.WeightedMovingAverage.zero([d_in, d_in]),
inputs_factor_inverse=jnp.zeros([d_in, d_in]),
outputs_factor=utils.WeightedMovingAverage.zero([d_out, d_out]),
outputs_factor_inverse=jnp.zeros([d_out, d_out]),
extra_scale=self.compute_extra_scale()
)
def update_curvature_inverse_estimate(
self,
diagonal_weight: Union[float, jnp.ndarray],
pmap_axis_name: str
) -> None:
self.inputs_factor.sync(pmap_axis_name)
self.outputs_factor.sync(pmap_axis_name)
# This computes the approximate inverse factor using the pi-adjusted
# inversion from the original KFAC paper.
# Note that the damping is divided by extra_scale since:
# (s * A kron B + lambda I)^-1 = s^-1 (A kron B + s^-1 * lambda I)^-1
# And the extra division by the scale is included in `multiply_matpower`.
(self.inputs_factor_inverse,
self.outputs_factor_inverse) = utils.pi_adjusted_inverse(
factor_0=self.inputs_factor.value,
factor_1=self.outputs_factor.value,
damping=diagonal_weight / self.extra_scale,
pmap_axis_name=pmap_axis_name)
def multiply_matpower(
self,
vec: _Arrays,
exp: Union[float, int],
diagonal_weight: Union[float, jnp.ndarray]
) -> _Arrays:
if self.has_bias:
w, b = vec
vec = jnp.concatenate([w.reshape([-1, w.shape[-1]]), b[None]], axis=0)
else:
w, = vec
vec = w.reshape([-1, w.shape[-1]])
if exp == 1:
inputs_factor, outputs_factor = (self.inputs_factor.value,
self.outputs_factor.value)
scale = self.extra_scale
elif exp == -1:
inputs_factor, outputs_factor = (self.inputs_factor_inverse,
self.outputs_factor_inverse)
scale = 1.0 / self.extra_scale
diagonal_weight = 0
else:
raise NotImplementedError()
result = jnp.matmul(inputs_factor, vec)
result = jnp.matmul(result, outputs_factor)
result = result * scale + diagonal_weight * vec
if self.has_bias:
w_new, b_new = result[:-1], result[-1]
return w_new.reshape(w.shape), b_new
else:
return result.reshape(w.shape),
class DenseTwoKroneckerFactored(TwoKroneckerFactored):
"""Factor for a standard dense layer."""
def input_size(self) -> int:
if self.has_bias:
return self.params_shapes[0][0] + 1
else:
return self.params_shapes[0][0]
def output_size(self) -> int:
return self.params_shapes[0][1]
def update_curvature_matrix_estimate(
self,
info: _BlockInfo,
batch_size: int,
ema_old: Union[float, jnp.ndarray],
ema_new: Union[float, jnp.ndarray],
pmap_axis_name: str
) -> None:
del pmap_axis_name
(x,), (dy,) = info["inputs"], info["outputs_tangent"]
utils.check_first_dim_is_batch_size(batch_size, x, dy)
if self.has_bias:
x_one = jnp.ones_like(x[:, :1])
x = jnp.concatenate([x, x_one], axis=1)
input_stats = jnp.matmul(x.T, x) / batch_size
output_stats = jnp.matmul(dy.T, dy) / batch_size
self.inputs_factor.update(input_stats, ema_old, ema_new)
self.outputs_factor.update(output_stats, ema_old, ema_new)
@utils.Stateful.infer_class_state
class ScaleAndShiftDiagonal(CurvatureBlock):
"""A scale and shift block with a diagonal approximation to the curvature."""
scale_factor: Optional[utils.WeightedMovingAverage]
shift_factor: Optional[utils.WeightedMovingAverage]
@property
def has_scale(self) -> bool:
return self._layer_tag_eq.params["has_scale"]
@property
def has_shift(self) -> bool:
return self._layer_tag_eq.params["has_shift"]
def init(self, rng: jnp.ndarray) -> Dict[str, Any]:
del rng
if self.has_scale and self.has_shift:
return dict(
scale_factor=utils.WeightedMovingAverage.zero(
self.params_shapes[0]
),
shift_factor=utils.WeightedMovingAverage.zero(
self.params_shapes[1]
)
)
elif self.has_scale:
return dict(
scale_factor=utils.WeightedMovingAverage.zero(
self.params_shapes[0]
),
shift_factor=None
)
elif self.has_shift:
return dict(
scale_factor=None,
shift_factor=utils.WeightedMovingAverage.zero(
self.params_shapes[0]
),
)
else:
raise ValueError("Neither `has_scale` nor `has_shift`.")
def update_curvature_matrix_estimate(
self,
info: _BlockInfo,
batch_size: int,
ema_old: Union[float, jnp.ndarray],
ema_new: Union[float, jnp.ndarray],
pmap_axis_name: str
) -> None:
(x,), (dy,) = info["inputs"], info["outputs_tangent"]
utils.check_first_dim_is_batch_size(batch_size, x, dy)
if self.has_scale:
assert self.scale_factor is not None
scale_shape = info["params"][0].shape
full_scale_shape = (1,) * (len(x.shape) - len(scale_shape)) + scale_shape
axis = [i for i, s in enumerate(full_scale_shape) if s == 1 and i != 0]
d_scale = jnp.sum(x * dy, axis=axis)
scale_diag_update = jnp.sum(d_scale * d_scale, axis=0) / batch_size
self.scale_factor.update(scale_diag_update, ema_old, ema_new)
self.scale_factor.sync(pmap_axis_name)
if self.has_shift:
assert self.shift_factor is not None
shift_shape = info["params"][1].shape
full_shift_shape = (1,) * (len(x.shape) - len(shift_shape)) + shift_shape
axis = [i for i, s in enumerate(full_shift_shape) if s == 1 and i != 0]
d_shift = jnp.sum(dy, axis=axis)
shift_diag_update = jnp.sum(d_shift * d_shift, axis=0) / batch_size
self.shift_factor.update(shift_diag_update, ema_old, ema_new)
self.shift_factor.sync(pmap_axis_name)
def update_curvature_inverse_estimate(
self,
diagonal_weight: Union[float, jnp.ndarray],
pmap_axis_name: str
) -> None:
pass
def multiply_matpower(
self,
vec: _Arrays,
exp: Union[float, int],
diagonal_weight: Union[float, jnp.ndarray]
) -> _Arrays:
if self.has_scale and self.has_shift:
factors = (self.scale_factor.value, self.shift_factor.value)
elif self.has_scale:
factors = (self.scale_factor.value,)
elif self.has_shift:
factors = (self.shift_factor.value,)
else:
raise ValueError("Neither `has_scale` nor `has_shift`.")
factors = jax.tree_map(lambda x: x + diagonal_weight, factors)
if exp == 1:
return jax.tree_multimap(jnp.multiply, vec, factors)
elif exp == -1:
return jax.tree_multimap(jnp.divide, vec, factors)
else:
raise NotImplementedError()
@utils.Stateful.infer_class_state
class ScaleAndShiftFull(CurvatureBlock):
"""A scale and shift block with full approximation to the curvature."""
factor: utils.WeightedMovingAverage
inverse_factor: jnp.ndarray
@property
def _has_scale(self) -> bool:
return self._layer_tag_eq.params["has_scale"]
@property
def _has_shift(self) -> bool:
return self._layer_tag_eq.params["has_shift"]
def init(self, rng: jnp.ndarray) -> Dict[str, Any]:
del rng
dims = sum(utils.product(shape) for shape in self.params_shapes)
return dict(
factor=utils.WeightedMovingAverage.zero([dims, dims]),
inverse_factor=jnp.zeros([dims, dims])
)
def update_curvature_matrix_estimate(
self,
info: _BlockInfo,
batch_size: int,
ema_old: Union[float, jnp.ndarray],
ema_new: Union[float, jnp.ndarray],
pmap_axis_name: str
) -> None:
del pmap_axis_name
(x,), (dy,) = info["inputs"], info["outputs_tangent"]
utils.check_first_dim_is_batch_size(batch_size, x, dy)
grads = list()
if self._has_scale:
# Scale gradients
scale_shape = info["params"][0].shape
full_scale_shape = (1,) * (len(x.shape) - len(scale_shape)) + scale_shape
axis = [i for i, s in enumerate(full_scale_shape) if s == 1 and i != 0]
d_scale = jnp.sum(x * dy, axis=axis)
d_scale = d_scale.reshape([batch_size, -1])
grads.append(d_scale)
if self._has_shift:
# Shift gradients
shift_shape = info["params"][1].shape
full_shift_shape = (1,) * (len(x.shape) - len(shift_shape)) + shift_shape
axis = [i for i, s in enumerate(full_shift_shape) if s == 1 and i != 0]
d_shift = jnp.sum(dy, axis=axis)
d_shift = d_shift.reshape([batch_size, -1])
grads.append(d_shift)
grads = jnp.concatenate(grads, axis=1)
factor_update = jnp.matmul(grads.T, grads) / batch_size
self.factor.update(factor_update, ema_old, ema_new)
def update_curvature_inverse_estimate(
self,
diagonal_weight: Union[float, jnp.ndarray],
pmap_axis_name: str
) -> None:
self.factor.sync(pmap_axis_name)
self.inverse_factor = utils.psd_inv_cholesky(self.factor.value,
diagonal_weight)
def multiply_matpower(
self,
vec: _Arrays,
exp: Union[float, int],
diagonal_weight: Union[float, jnp.ndarray]
) -> _Arrays:
# Remember the vector is a tuple of all parameters
if self._has_scale and self._has_shift:
flat_vec = jnp.concatenate([v.flatten() for v in vec])
else:
flat_vec = vec[0].flatten()
if exp == 1:
flat_result = (
jnp.matmul(self.factor.value, flat_vec) + diagonal_weight * flat_vec)
elif exp == -1:
flat_result = jnp.matmul(self.inverse_factor, flat_vec)
else:
raise NotImplementedError()
if self._has_scale and self._has_shift:
scale_dims = int(vec[0].size)
scale_result = flat_result[:scale_dims].reshape(vec[0].shape)
shift_result = flat_result[scale_dims:].reshape(vec[1].shape)
return scale_result, shift_result
else:
return flat_vec.reshape(vec[0].shape),
_default_tag_to_block: MutableMapping[str, CurvatureBlockCtor] = dict(
dense_tag=DenseTwoKroneckerFactored,
generic_tag=NaiveDiagonal,
scale_and_shift_tag=ScaleAndShiftDiagonal,
)
def copy_default_tag_to_block() -> MutableMapping[str, CurvatureBlockCtor]:
return dict(_default_tag_to_block)
def get_default_tag_to_block(tag_name: str) -> CurvatureBlockCtor:
return _default_tag_to_block[tag_name]
def set_default_tag_to_block(
tag_name: str,
block_class: CurvatureBlockCtor,
) -> None:
_default_tag_to_block[tag_name] = block_class

View File

@@ -0,0 +1,75 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for all distribution implementations needed for the loss functions."""
import math
import jax
import jax.numpy as jnp
class MultivariateNormalDiag:
"""Multivariate normal distribution on `R^k`."""
def __init__(
self,
loc: jnp.ndarray,
scale_diag: jnp.ndarray):
"""Initializes a MultivariateNormalDiag distribution.
Args:
loc: Mean vector of the distribution. Can also be a batch of vectors.
scale_diag: Vector of standard deviations.
"""
super().__init__()
self._loc = loc
self._scale_diag = scale_diag
@property
def loc(self) -> jnp.ndarray:
"""Mean of the distribution."""
return self._loc
@property
def scale_diag(self) -> jnp.ndarray:
"""Scale of the distribution."""
return self._scale_diag
def _num_dims(self) -> int:
"""Dimensionality of the events."""
return self._scale_diag.shape[-1]
def _standardize(self, value: jnp.ndarray) -> jnp.ndarray:
return (value - self._loc) / self._scale_diag
def log_prob(self, value: jnp.ndarray) -> jnp.ndarray:
"""See `Distribution.log_prob`."""
log_unnormalized = -0.5 * jnp.square(self._standardize(value))
log_normalization = 0.5 * math.log(2 * math.pi) + jnp.log(self._scale_diag)
return jnp.sum(log_unnormalized - log_normalization, axis=-1)
def mean(self) -> jnp.ndarray:
"""Calculates the mean."""
return self.loc
def sample(self, seed: jnp.ndarray) -> jnp.ndarray:
"""Samples an event.
Args:
seed: PRNG key or integer seed.
Returns:
A sample.
"""
eps = jax.random.normal(seed, self.loc.shape)
return self.loc + eps * self.scale_diag

View File

@@ -0,0 +1,340 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the high-level Fisher estimator class."""
import collections
from typing import Any, Callable, Mapping, Optional, Sequence, Union, TypeVar
import jax
import jax.numpy as jnp
import jax.random as jnr
import numpy as np
from kfac_ferminet_alpha import curvature_blocks
from kfac_ferminet_alpha import tracer
from kfac_ferminet_alpha import utils
_CurvatureBlock = curvature_blocks.CurvatureBlock
TagMapping = Mapping[str, curvature_blocks.CurvatureBlockCtor]
BlockVector = Sequence[jnp.ndarray]
_StructureT = TypeVar("_StructureT")
_OptionalStateT = TypeVar("_OptionalStateT", bound=Optional[Mapping[str, Any]])
@utils.Stateful.infer_class_state
class CurvatureEstimator(utils.Stateful):
"""Curvature estimator class supporting various curvature approximations."""
blocks: "collections.OrderedDict[str, _CurvatureBlock]"
damping: Optional[jnp.ndarray]
def __init__(self,
tagged_func: Callable[[Any], jnp.ndarray],
func_args: Sequence[Any],
l2_reg: Union[float, jnp.ndarray],
estimation_mode: str = "fisher_gradients",
params_index: int = 0,
layer_tag_to_block_cls: Optional[TagMapping] = None):
"""Create a FisherEstimator object.
Args:
tagged_func: The function which evaluates the model, in which layer and
loss tags has already been registered.
func_args: Arguments to trace the function for layer and loss tags.
l2_reg: Scalar. The L2 regularization coefficient, which represents
the following regularization function: `coefficient/2 ||theta||^2`.
estimation_mode: The type of curvature estimator to use. One of: *
'fisher_gradients' - the basic estimation approach from the original
K-FAC paper. (Default) * 'fisher_curvature_prop' - method which
estimates the Fisher using self-products of random 1/-1 vectors times
"half-factors" of the
Fisher, as described here: https://arxiv.org/abs/1206.6464 *
'fisher_exact' - is the obvious generalization of Curvature
Propagation to compute the exact Fisher (modulo any additional
diagonal or Kronecker approximations) by looping over one-hot
vectors for each coordinate of the output instead of using 1/-1
vectors. It is more expensive to compute than the other three
options by a factor equal to the output dimension, roughly
speaking. * 'fisher_empirical' - computes the 'empirical' Fisher
information matrix (which uses the data's distribution for the
targets, as opposed to the true Fisher which uses the model's
distribution) and requires that each registered loss have
specified targets. * 'ggn_curvature_prop' - Analogous to
fisher_curvature_prop, but estimates the Generalized
Gauss-Newton matrix (GGN). * 'ggn_exact'- Analogous to
fisher_exact, but estimates the Generalized Gauss-Newton matrix
(GGN).
params_index: The index of the arguments accepted by `func` which
correspond to parameters.
layer_tag_to_block_cls: An optional dict mapping tags to specific classes
of block approximations, which to override the default ones.
"""
if estimation_mode not in ("fisher_gradients", "fisher_empirical",
"fisher_exact", "fisher_curvature_prop",
"ggn_exact", "ggn_curvature_prop"):
raise ValueError(f"Unrecognised estimation_mode={estimation_mode}.")
super().__init__()
self.tagged_func = tagged_func
self.l2_reg = l2_reg
self.estimation_mode = estimation_mode
self.params_index = params_index
self.vjp = tracer.trace_estimator_vjp(self.tagged_func)
# Figure out the mapping from layer
self.layer_tag_to_block_cls = curvature_blocks.copy_default_tag_to_block()
if layer_tag_to_block_cls is None:
layer_tag_to_block_cls = dict()
layer_tag_to_block_cls = dict(**layer_tag_to_block_cls)
self.layer_tag_to_block_cls.update(layer_tag_to_block_cls)
# Create the blocks
self._in_tree = jax.tree_structure(func_args)
self._jaxpr = jax.make_jaxpr(self.tagged_func)(*func_args).jaxpr
self._layer_tags, self._loss_tags = tracer.extract_tags(self._jaxpr)
self.blocks = collections.OrderedDict()
counters = dict()
for eqn in self._layer_tags:
cls = self.layer_tag_to_block_cls[eqn.primitive.name]
c = counters.get(cls.__name__, 0)
self.blocks[cls.__name__ + "_" + str(c)] = cls(eqn)
counters[cls.__name__] = c + 1
@property
def diagonal_weight(self) -> jnp.ndarray:
return self.l2_reg + self.damping
def vectors_to_blocks(
self,
parameter_structured_vector: Any,
) -> Sequence[BlockVector]:
"""Splits the parameters to values for the corresponding blocks."""
in_vars = jax.tree_unflatten(self._in_tree, self._jaxpr.invars)
params_vars = in_vars[self.params_index]
params_vars_flat = jax.tree_flatten(params_vars)[0]
params_values_flat = jax.tree_flatten(parameter_structured_vector)[0]
assert len(params_vars_flat) == len(params_values_flat)
params_dict = dict(zip(params_vars_flat, params_values_flat))
per_block_vectors = []
for eqn in self._layer_tags:
if eqn.primitive.name == "generic_tag":
block_vars = eqn.invars
else:
block_vars = eqn.primitive.split_all_inputs(eqn.invars)[2]
per_block_vectors.append(tuple(params_dict.pop(v) for v in block_vars))
if params_dict:
raise ValueError(f"From the parameters the following structure is not "
f"assigned to any block: {params_dict}. Most likely "
f"this part of the parameters is not part of the graph "
f"reaching the losses.")
return tuple(per_block_vectors)
def blocks_to_vectors(self, per_block_vectors: Sequence[BlockVector]) -> Any:
"""Reverses the function self.vectors_to_blocks."""
in_vars = jax.tree_unflatten(self._in_tree, self._jaxpr.invars)
params_vars = in_vars[self.params_index]
assigned_dict = dict()
for eqn, block_values in zip(self._layer_tags, per_block_vectors):
if eqn.primitive.name == "generic_tag":
block_params = eqn.invars
else:
block_params = eqn.primitive.split_all_inputs(eqn.invars)[2]
assigned_dict.update(zip(block_params, block_values))
params_vars_flat, params_tree = jax.tree_flatten(params_vars)
params_values_flat = [assigned_dict[v] for v in params_vars_flat]
assert len(params_vars_flat) == len(params_values_flat)
return jax.tree_unflatten(params_tree, params_values_flat)
def init(
self,
rng: jnp.ndarray,
init_damping: Optional[jnp.ndarray],
) -> Mapping[str, Any]:
"""Returns an initialized variables for the curvature approximations and the inverses.."""
return dict(
blocks=collections.OrderedDict(
(name, block.init(block_rng)) #
for (name, block), block_rng #
in zip(self.blocks.items(), jnr.split(rng, len(self.blocks)))),
damping=init_damping)
@property
def mat_type(self) -> str:
return self.estimation_mode.split("_")[0]
def vec_block_apply(
self,
func: Callable[[_CurvatureBlock, BlockVector], BlockVector],
parameter_structured_vector: Any,
) -> Any:
"""Executes func for each approximation block on vectors."""
per_block_vectors = self.vectors_to_blocks(parameter_structured_vector)
assert len(per_block_vectors) == len(self.blocks)
results = jax.tree_multimap(func, tuple(self.blocks.values()),
per_block_vectors)
parameter_structured_result = self.blocks_to_vectors(results)
utils.check_structure_shapes_and_dtype(parameter_structured_vector,
parameter_structured_result)
return parameter_structured_result
def multiply_inverse(self, parameter_structured_vector: Any) -> Any:
"""Multiplies the vectors by the corresponding (damped) inverses of the blocks.
Args:
parameter_structured_vector: Structure equivalent to the parameters of the
model.
Returns:
A structured identical to `vectors` containing the product.
"""
return self.multiply_matpower(parameter_structured_vector, -1)
def multiply(self, parameter_structured_vector: Any) -> Any:
"""Multiplies the vectors by the corresponding (damped) blocks.
Args:
parameter_structured_vector: A vector in the same structure as the
parameters of the model.
Returns:
A structured identical to `vectors` containing the product.
"""
return self.multiply_matpower(parameter_structured_vector, 1)
def multiply_matpower(
self,
parameter_structured_vector: _StructureT,
exp: int,
) -> _StructureT:
"""Multiplies the vectors by the corresponding matrix powers of the blocks.
Args:
parameter_structured_vector: A vector in the same structure as the
parameters of the model.
exp: A float representing the power to raise the blocks by before
multiplying it by the vector.
Returns:
A structured identical to `vectors` containing the product.
"""
def func(block: _CurvatureBlock, vec: BlockVector) -> BlockVector:
return block.multiply_matpower(vec, exp, self.diagonal_weight)
return self.vec_block_apply(func, parameter_structured_vector)
def update_curvature_matrix_estimate(
self,
ema_old: Union[float, jnp.ndarray],
ema_new: Union[float, jnp.ndarray],
batch_size: int,
rng: jnp.ndarray,
func_args: Sequence[Any],
pmap_axis_name: str,
) -> None:
"""Updates the curvature estimate."""
# Compute the losses and the VJP function from the function inputs
losses, losses_vjp = self.vjp(func_args)
# Helper function that updates the blocks given a vjp vector
def _update_blocks(vjp_vec_, ema_old_, ema_new_):
blocks_info_ = losses_vjp(vjp_vec_)
for block_, block_info_ in zip(self.blocks.values(), blocks_info_):
block_.update_curvature_matrix_estimate(
info=block_info_,
batch_size=batch_size,
ema_old=ema_old_,
ema_new=ema_new_,
pmap_axis_name=pmap_axis_name)
if self.estimation_mode == "fisher_gradients":
keys = jnr.split(rng, len(losses)) if len(losses) > 1 else [rng]
vjp_vec = tuple(
loss.grad_of_evaluate_on_sample(key, coefficient_mode="sqrt")
for loss, key in zip(losses, keys))
_update_blocks(vjp_vec, ema_old, ema_new)
elif self.estimation_mode in ("fisher_curvature_prop",
"ggn_curvature_prop"):
keys = jnr.split(rng, len(losses)) if len(losses) > 1 else [rng]
vjp_vec = []
for loss, key in zip(losses, keys):
if self.estimation_mode == "fisher_curvature_prop":
random_b = jnr.bernoulli(key, shape=loss.fisher_factor_inner_shape())
vjp_vec.append(loss.multiply_fisher_factor(random_b * 2.0 - 1.0))
else:
random_b = jnr.bernoulli(key, shape=loss.ggn_factor_inner_shape())
vjp_vec.append(loss.multiply_ggn_factor(random_b * 2.0 - 1.0))
_update_blocks(tuple(vjp_vec), ema_old, ema_new)
elif self.estimation_mode in ("fisher_exact", "ggn_exact"):
# We use the following trick to simulate summation. The equation is:
# estimate = ema_old * estimate + ema_new * (sum_i estimate_index_i)
# weight = ema_old * weight + ema_new
# Instead we update the estimate n times with the following updates:
# for k = 1
# estimate_k = ema_old * estimate + (ema_new/n) * (n*estimate_index_k)
# weight_k = ema_old * weight + (ema_new/n)
# for k > 1:
# estimate_k = 1.0 * estimate_k-1 + (ema_new/n) * (n*estimate_index_k)
# weight_k = 1.0 * weight_k-1 + (ema_new/n)
# Which is mathematically equivalent to the original version.
zero_tangents = jax.tree_map(jnp.zeros_like,
list(loss.inputs for loss in losses))
if self.estimation_mode == "fisher_exact":
num_indices = [
(l, int(np.prod(l.fisher_factor_inner_shape[1:]))) for l in losses
]
else:
num_indices = [
(l, int(np.prod(l.ggn_factor_inner_shape()))) for l in losses
]
total_num_indices = sum(n for _, n in num_indices)
for i, (loss, loss_num_indices) in enumerate(num_indices):
for index in range(loss_num_indices):
vjp_vec = zero_tangents.copy()
if self.estimation_mode == "fisher_exact":
vjp_vec[i] = loss.multiply_fisher_factor_replicated_one_hot([index])
else:
vjp_vec[i] = loss.multiply_ggn_factor_replicated_one_hot([index])
if isinstance(vjp_vec[i], jnp.ndarray):
# In the special case of only one parameter, it still needs to be a
# tuple for the tangents.
vjp_vec[i] = (vjp_vec[i],)
vjp_vec[i] = jax.tree_map(lambda x: x * total_num_indices, vjp_vec[i])
_update_blocks(tuple(vjp_vec), ema_old, ema_new / total_num_indices)
ema_old = 1.0
elif self.estimation_mode == "fisher_empirical":
raise NotImplementedError()
else:
raise ValueError(f"Unrecognised estimation_mode={self.estimation_mode}")
def update_curvature_estimate_inverse(
self,
pmap_axis_name: str,
state: _OptionalStateT,
) -> _OptionalStateT:
if state is not None:
old_state = self.get_state()
self.set_state(state)
for block in self.blocks.values():
block.update_curvature_inverse_estimate(self.diagonal_weight,
pmap_axis_name)
if state is None:
return None
else:
state = self.pop_state()
self.set_state(old_state)
return state

View File

@@ -0,0 +1,171 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example of running KFAC."""
from absl import app
from absl import flags
import jax
import jax.numpy as jnp
import numpy as np
import kfac_ferminet_alpha as kfac_ferminet_alpha
from kfac_ferminet_alpha import utils
TRAINING_STEPS = flags.DEFINE_integer(
name="training_steps",
default=100,
help="Number of training steps to perform")
BATCH_SIZE = flags.DEFINE_integer(
name="batch_size", default=128, help="Batch size")
LEARNING_RATE = flags.DEFINE_float(
name="learning_rate", default=1e-3, help="Learning rate")
L2_REG = flags.DEFINE_float(
name="l2_reg", default=1e-3, help="L2 regularization coefficient")
MOMENTUM = flags.DEFINE_float(
name="momentum", default=0.8, help="Momentum coefficient")
DAMPING = flags.DEFINE_float(
name="damping", default=1e-2, help="Damping coefficient")
MULTI_DEVICE = flags.DEFINE_bool(
name="multi_device",
default=False,
help="Whether the computation should be replicated across multiple devices")
SEED = flags.DEFINE_integer(name="seed", default=12412321, help="JAX RNG seed")
def glorot_uniform(shape, key):
dim_in = np.prod(shape[:-1])
dim_out = shape[-1]
c = jnp.sqrt(6 / (dim_in + dim_out))
return jax.random.uniform(key, shape=shape, minval=-c, maxval=c)
def fully_connected_layer(params, x):
w, b = params
return jnp.matmul(x, w) + b[None]
def model_init(rng_key, batch, encoder_sizes=(1000, 500, 250, 30)):
"""Initialize the standard autoencoder."""
x_size = batch.shape[-1]
decoder_sizes = encoder_sizes[len(encoder_sizes) - 2::-1]
sizes = (x_size,) + encoder_sizes + decoder_sizes + (x_size,)
keys = jax.random.split(rng_key, len(sizes) - 1)
params = []
for rng_key, dim_in, dim_out in zip(keys, sizes, sizes[1:]):
# Glorot uniform initialization
w = glorot_uniform((dim_in, dim_out), rng_key)
b = jnp.zeros([dim_out])
params.append((w, b))
return params, None
def model_loss(params, inputs, l2_reg):
"""Evaluate the standard autoencoder."""
h = inputs.reshape([inputs.shape[0], -1])
for i, layer_params in enumerate(params):
h = fully_connected_layer(layer_params, h)
# Last layer does not have a nonlinearity
if i % 4 != 3:
h = jnp.tanh(h)
l2_value = 0.5 * sum(jnp.square(p).sum() for p in jax.tree_leaves(params))
error = jax.nn.sigmoid(h) - inputs.reshape([inputs.shape[0], -1])
mean_squared_error = jnp.mean(jnp.sum(error * error, axis=1), axis=0)
regularized_loss = mean_squared_error + l2_reg * l2_value
return regularized_loss, dict(mean_squared_error=mean_squared_error)
def random_data(multi_device, batch_shape, rng):
if multi_device:
shape = (multi_device,) + tuple(batch_shape)
else:
shape = tuple(batch_shape)
while True:
rng, key = jax.random.split(rng)
yield jax.random.normal(key, shape)
def main(argv):
del argv # Unused.
learning_rate = jnp.asarray([LEARNING_RATE.value])
momentum = jnp.asarray([MOMENTUM.value])
damping = jnp.asarray([DAMPING.value])
# RNG keys
global_step = jnp.zeros([])
rng = jax.random.PRNGKey(SEED.value)
params_key, opt_key, step_key, data_key = jax.random.split(rng, 4)
dataset = random_data(MULTI_DEVICE.value, (BATCH_SIZE.value, 20), data_key)
example_batch = next(dataset)
if MULTI_DEVICE.value:
global_step = utils.replicate_all_local_devices(global_step)
learning_rate = utils.replicate_all_local_devices(learning_rate)
momentum = utils.replicate_all_local_devices(momentum)
damping = utils.replicate_all_local_devices(damping)
params_key, opt_key = utils.replicate_all_local_devices(
(params_key, opt_key))
step_key = utils.make_different_rng_key_on_all_devices(step_key)
split_key = jax.pmap(lambda x: tuple(jax.random.split(x)))
jit_init_parameters_func = jax.pmap(model_init)
else:
split_key = jax.random.split
jit_init_parameters_func = jax.jit(model_init)
# Initialize or load parameters
params, func_state = jit_init_parameters_func(params_key, example_batch)
# Make optimizer
optim = kfac_ferminet_alpha.Optimizer(
value_and_grad_func=jax.value_and_grad(
lambda p, x: model_loss(p, x, L2_REG.value), has_aux=True),
l2_reg=L2_REG.value,
value_func_has_aux=True,
value_func_has_state=False,
value_func_has_rng=False,
learning_rate_schedule=None,
momentum_schedule=None,
damping_schedule=None,
norm_constraint=1.0,
num_burnin_steps=10,
)
# Initialize optimizer
opt_state = optim.init(params, opt_key, example_batch, func_state)
for t in range(TRAINING_STEPS.value):
step_key, key_t = split_key(step_key)
params, opt_state, stats = optim.step(
params,
opt_state,
key_t,
dataset,
learning_rate=learning_rate,
momentum=momentum,
damping=damping)
global_step = global_step + 1
# Log any of the statistics
print(f"iteration: {t}")
print(f"mini-batch loss = {stats['loss']}")
if "aux" in stats:
for k, v in stats["aux"].items():
print(f"{k} = {v}")
print("----")
if __name__ == "__main__":
app.run(main)

View File

@@ -0,0 +1,354 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A module for registering already known functions for tagging patterns."""
import functools
from typing import Sequence, Tuple, TypeVar
import jax
from jax import core as jax_core
from jax import lax
from jax import lib as jax_lib
from jax.interpreters import batching as jax_batching
import jax.numpy as jnp
_T = TypeVar("_T")
class LossTag(jax_core.Primitive):
"""A tagging primitive specifically for losses."""
multiple_results = True
def __init__(self, cls, num_inputs: int, num_targets: int = 1):
super().__init__(cls.__name__ + "_tag")
self._cls = cls
self._num_inputs = num_inputs
self._num_targets = num_targets
jax.xla.translations[self] = self.xla_translation
jax.ad.primitive_jvps[self] = self.jvp
# This line defines how does the tag behave under vmap. It is required for
# any primitive that can be used inside a vmap. The reason why we want to
# allow this is two fold - one to not break user code when the tags are not
# used at all, and two - to be able to define a network with code for a
# single example which is the vmap-ed for a batch.
jax_batching.primitive_batchers[self] = self.batching
@property
def num_inputs(self) -> int:
return self._num_inputs
@property
def num_targets(self) -> int:
return self._num_targets
def loss(self, *args, weight: float = 1.0, **kwargs):
return self._cls(*args, weight=weight, **kwargs)
def loss_evaluate(self, *args, weight: float = 1.0, **kwargs):
return self.loss(*args, weight=weight, **kwargs).evaluate()
def get_outputs(self, *args, weight: float, return_loss: bool, **kwargs):
if len(args) < self.num_inputs:
raise ValueError("Inputs to the tag are not enough.")
if len(args) < self.num_inputs + self.num_targets:
if len(args) != self.num_inputs:
raise ValueError("Inputs to the tag are not quite enough.")
if return_loss:
raise ValueError("Can not have return_loss=True when there are no "
"targets.")
return args
if len(args) > self.num_inputs + self.num_targets:
raise ValueError("Inputs to the tag are too many.")
if return_loss:
return self.loss(*args, weight=weight, **kwargs).evaluate()
else:
return args
def impl(self, *args, weight: float, return_loss: bool, **kwargs):
return self.get_outputs(*args, weight=weight, return_loss=return_loss)
def abstract_eval(self, *args, weight: float, return_loss: bool, **kwargs):
return self.get_outputs(*args, weight=weight, return_loss=return_loss)
def xla_translation(
self,
c,
*args,
weight: float = 1.0,
return_loss: bool = False,
**kwargs,
):
outputs = self.get_outputs(
*args, weight=weight, return_loss=return_loss, **kwargs)
if isinstance(outputs, tuple):
return jax_lib.xla_client.ops.Tuple(c, outputs)
return outputs
def jvp(
self,
arg_values,
arg_tangents,
weight: float,
return_loss: bool,
**kwargs,
):
if len(arg_values) != len(arg_tangents):
raise ValueError("Values and tangents are not the same length.")
primal_output = self.bind(
*arg_values, weight=weight, return_loss=return_loss, **kwargs)
if len(arg_values) == self.num_inputs:
tangents_out = self.get_outputs(
*arg_tangents, weight=weight, return_loss=return_loss, **kwargs)
elif return_loss:
tangents_out = jax.jvp(
functools.partial(self.loss_evaluate, weight=weight, **kwargs),
arg_tangents, arg_tangents)[1]
else:
tangents_out = arg_tangents
return primal_output, tangents_out
def batching(self, batched_args, batched_dims, **kwargs):
return self.bind(*batched_args, **kwargs), batched_dims[0]
class LayerTag(jax_core.Primitive):
"""A tagging primitive that is used to mark/tag computation."""
def __init__(self, name: str, num_inputs: int, num_outputs: int):
super().__init__(name)
if num_outputs > 1:
raise NotImplementedError(
f"Only single outputs are supported, got: num_outputs={num_outputs}")
self._num_outputs = num_outputs
self._num_inputs = num_inputs
jax.xla.translations[self] = self.xla_translation
jax.ad.deflinear(self, self.transpose)
jax.ad.primitive_transposes[self] = self.transpose
# This line defines how does the tag behave under vmap. It is required for
# any primitive that can be used inside a vmap. The reason why we want to
# allow this is two fold - one to not break user code when the tags are not
# used at all, and two - to be able to define a network with code for a
# single example which is the vmap-ed for a batch.
jax_batching.primitive_batchers[self] = self.batching
@property
def num_outputs(self) -> int:
return self._num_outputs
@property
def num_inputs(self) -> int:
return self._num_inputs
def split_all_inputs(
self,
all_inputs: Sequence[_T],
) -> Tuple[Sequence[_T], Sequence[_T], Sequence[_T]]:
outputs = tuple(all_inputs[:self.num_outputs])
inputs = tuple(all_inputs[self.num_outputs:self.num_outputs +
self.num_inputs])
params = tuple(all_inputs[self.num_outputs + self.num_inputs:])
return outputs, inputs, params
def get_outputs(self, *operands: _T, **kwargs) -> _T:
assert self.num_outputs == 1
return operands[0]
def xla_translation(self, c, *operands: _T, **kwargs) -> _T:
return self.get_outputs(*operands, **kwargs)
@staticmethod
def transpose(cotangent, *operands, **kwargs):
return (cotangent,) + (None,) * (len(operands) - 1)
def impl(self, *operands, **kwargs):
return self.get_outputs(*operands, **kwargs)
def abstract_eval(self, *abstract_operands, **kwargs):
return self.get_outputs(*abstract_operands, **kwargs)
def batching(self, batched_operands, batched_dims, **kwargs):
return self.bind(*batched_operands, **kwargs), batched_dims[0]
# _____ _
# / ____| (_)
# | | __ ___ _ __ ___ _ __ _ ___
# | | |_ |/ _ \ '_ \ / _ \ '__| |/ __|
# | |__| | __/ | | | __/ | | | (__
# \_____|\___|_| |_|\___|_| |_|\___|
#
#
generic_tag = LayerTag(name="generic_tag", num_inputs=0, num_outputs=1)
def register_generic(parameter: _T) -> _T:
return generic_tag.bind(parameter)
# _____
# | __ \
# | | | | ___ _ __ ___ ___
# | | | |/ _ \ '_ \/ __|/ _ \
# | |__| | __/ | | \__ \ __/
# |_____/ \___|_| |_|___/\___|
#
dense_tag = LayerTag(name="dense_tag", num_inputs=1, num_outputs=1)
def register_dense(y, x, w, b=None):
if b is None:
return dense_tag.bind(y, x, w)
return dense_tag.bind(y, x, w, b)
def dense_func(x, params):
"""Example of a dense layer function."""
w = params[0]
y = jnp.matmul(x, w)
if len(params) == 1:
# No bias
return y
# Add bias
return y + params[1]
def dense_tagging(jaxpr, inverse_map, values_map):
"""Correctly registers a dense layer pattern."""
del inverse_map
in_values = [values_map[v] for v in jaxpr.invars]
out_values = [values_map[v] for v in jaxpr.outvars]
return register_dense(out_values[0], *in_values)
# ___ _____ _____ _ _ _
# |__ \| __ \ / ____| | | | | (_)
# ) | | | | | | ___ _ ____ _____ | |_ _| |_ _ ___ _ __
# / /| | | | | | / _ \| '_ \ \ / / _ \| | | | | __| |/ _ \| "_ \
# / /_| |__| | | |___| (_) | | | \ V / (_) | | |_| | |_| | (_) | | | |
# |____|_____/ \_____\___/|_| |_|\_/ \___/|_|\__,_|\__|_|\___/|_| |_|
#
conv2d_tag = LayerTag(name="conv2d_tag", num_inputs=1, num_outputs=1)
def register_conv2d(y, x, w, b=None, **kwargs):
if b is None:
return conv2d_tag.bind(y, x, w, **kwargs)
return conv2d_tag.bind(y, x, w, b, **kwargs)
def conv2d_func(x, params):
"""Example of a conv2d layer function."""
w = params[0]
y = lax.conv_general_dilated(
x,
w,
window_strides=(2, 2),
padding="SAME",
dimension_numbers=("NHWC", "HWIO", "NHWC"))
if len(params) == 1:
# No bias
return y
# Add bias
return y + params[1][None, None, None]
def conv2d_tagging(jaxpr, inverse_map, values_map):
"""Correctly registers a conv2d layer pattern."""
in_values = [values_map[v] for v in jaxpr.invars]
out_values = [values_map[v] for v in jaxpr.outvars]
keys = [k for k in inverse_map.keys() if isinstance(k, str)]
keys = [k for k in keys if k.startswith("conv_general_dilated")]
if len(keys) != 1:
raise ValueError("Did not find any conv_general_dilated!")
kwargs = inverse_map[keys[0]].params
return register_conv2d(out_values[0], *in_values, **kwargs)
# _____ _ _ _____ _ _ __ _
# / ____| | | | | / ____| | (_)/ _| |
# | (___ ___ __ _| | ___ __ _ _ __ __| | | (___ | |__ _| |_| |_
# \___ \ / __/ _` | |/ _ \ / _` | '_ \ / _` | \___ \| '_ \| | _| __|
# ____) | (_| (_| | | __/ | (_| | | | | (_| | ____) | | | | | | | |_
# |_____/ \___\__,_|_|\___| \__,_|_| |_|\__,_| |_____/|_| |_|_|_| \__|
#
scale_and_shift_tag = LayerTag(
name="scale_and_shift_tag", num_inputs=1, num_outputs=1)
def register_scale_and_shift(y, args, has_scale: bool, has_shift: bool):
assert has_scale or has_shift
x, args = args[0], args[1:]
return scale_and_shift_tag.bind(
y, x, *args, has_scale=has_scale, has_shift=has_shift)
def scale_and_shift_func(x, params, has_scale: bool, has_shift: bool):
"""Example of a scale and shift function."""
if has_scale and has_shift:
scale, shift = params
return x * scale + shift
elif has_scale:
return x * params[0]
elif has_shift:
return x + params[0]
else:
raise ValueError()
def scale_and_shift_tagging(
jaxpr,
inverse_map,
values_map,
has_scale: bool,
has_shift: bool,
):
"""Correctly registers a scale and shift layer pattern."""
del inverse_map
in_values = [values_map[v] for v in jaxpr.invars]
out_values = [values_map[v] for v in jaxpr.outvars]
return register_scale_and_shift(out_values[0], in_values, has_scale,
has_shift)
def batch_norm_func(
inputs: Tuple[jnp.ndarray, jnp.ndarray],
params: Tuple[jnp.ndarray, jnp.ndarray],
) -> jnp.ndarray:
"""Example of batch norm as is defined in Haiku."""
x, y = inputs
scale, shift = params
inv = scale * y
return x * inv + shift
def batch_norm_tagging_func(
jaxpr,
inverse_map,
values_map,
has_scale: bool,
has_shift: bool,
):
"""Correctly registers a batch norm layer pattern as is defined in Haiku."""
del inverse_map
in_values = [values_map[v] for v in jaxpr.invars]
out_values = [values_map[v] for v in jaxpr.outvars]
# The first two are both multipliers with the scale so we merge them
in_values = [in_values[0] * in_values[1]] + in_values[2:]
return register_scale_and_shift(out_values[0], in_values, has_scale,
has_shift)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
jax>=0.2.10
dataclasses>=0.6
networkx>=2.1
numpy>=1.16.4
typing>=3.7.4.3
ordered-set>=4.0.2
absl-py>=0.12.0

33
kfac_ferminet_alpha/run.sh Executable file
View File

@@ -0,0 +1,33 @@
#!/bin/bash
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script installs kfac_ferminet_alpha in a clean virtualenv and runs an
# example training loop. It is designed to be run from the parent directory,
# e.g.:
#
# git clone git@github.com:deepmind/deepmind-research.git
# cd deepmind_research
# kfac_ferminet_alpha/run.sh
python3 -m venv /tmp/kfac_ferminet_alpha_example
source /tmp/kfac_ferminet_alpha_example/bin/activate
pip3 install -U pip
pip3 install -r kfac_ferminet_alpha/requirements.txt
# For a GPU you have to do:
# pip3 install --upgrade jax jaxlib==0.1.64+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip3 install jaxlib
pip3 install kfac_ferminet_alpha/
python3 kfac_ferminet_alpha/example.py

View File

@@ -0,0 +1,52 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup for pip package."""
from setuptools import setup
REQUIRED_PACKAGES = (
"absl-py",
"dataclasses",
"jax",
"networkx",
"numpy",
"ordered-set",
"typing",
)
LONG_DESCRIPTION = "\n".join([
"Kronecker-Factored Approximate Curvature (K-FAC) optimizer implemented in "
"JAX.",
"",
"Accompanying code for 'Better, Faster Fermionic Neural Networks'",
"James S. Spencer, David Pfau, Aleksandar Botev, and W. M. C. Foulkes.",
"https://arxiv.org/abs/2011.07125.",
])
setup(
name="kfac_ferminet_alpha",
version="0.0.1",
description="A K-FAC optimizer implemented in JAX",
long_description=LONG_DESCRIPTION,
url="https://github.com/deepmind/deepmind-research/kfac_ferminet_alpha",
author="DeepMind",
package_dir={"kfac_ferminet_alpha": "."},
packages=["kfac_ferminet_alpha"],
install_requires=REQUIRED_PACKAGES,
platforms=["any"],
license="Apache License, Version 2.0",
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,76 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common functions used across more than one test."""
import jax
import jax.numpy as jnp
import jax.random as jnr
from kfac_ferminet_alpha import loss_functions
def fully_connected_layer(params, x):
w, b = params
return jnp.matmul(x, w) + b[None]
def init_autoencoder(key, data_shape):
"""Initialize the standard autoencoder."""
assert len(data_shape) == 1
x_size = data_shape[0]
sizes = [x_size, 1000, 500, 250, 30, 250, 500, 1000, x_size]
keys = jnr.split(key, len(sizes) - 1)
params = []
for key, dim_in, dim_out in zip(keys, sizes, sizes[1:]):
# Glorot uniform initialization
c = jnp.sqrt(6 / (dim_in + dim_out))
w = jax.random.uniform(key, shape=(dim_in, dim_out), minval=-c, maxval=c)
b = jnp.zeros([dim_out])
params.append((w, b))
return params
def autoencoder(all_params, x_in):
"""Evaluate the standard autoencoder.
Note that the objective of this autoencoder is not standard, bur rather a sum
of the standard sigmoid crossentropy and squared loss. The reason for this is
to test on handling multiple losses.
Args:
all_params: All parameter values.
x_in: Inputs to the network.
Returns:
The value of the two losses and intermediate layer values.
"""
h_in = x_in
layers_values = []
for i, params in enumerate(all_params):
h_out = fully_connected_layer(params, h_in)
layers_values.append((h_out, h_in))
# Last layer does not have a nonlinearity
if i % 4 != 3:
# h_in = nn.leaky_relu(h_out)
h_in = jnp.tanh(h_out)
else:
h_in = h_out
h1, _ = loss_functions.register_normal_predictive_distribution(h_in, x_in)
h2, _ = loss_functions.register_normal_predictive_distribution(
h_in, targets=x_in, weight=0.1)
l1 = (h1 - x_in)**2 + jnp.log(jnp.pi) / 2
l1 = jnp.sum(l1, axis=-1)
l2 = (h2 - x_in)**2 + jnp.log(jnp.pi) / 2
l2 = jnp.sum(l2, axis=-1)
return [l1, l2 * 0.1], layers_values

View File

@@ -0,0 +1,85 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
import jax.numpy as jnp
import jax.random as jnr
import jax.test_util as jtu
from kfac_ferminet_alpha import layers_and_loss_tags
from kfac_ferminet_alpha import loss_functions
from kfac_ferminet_alpha import tag_graph_matcher
from kfac_ferminet_alpha.tests import common
def tagged_autoencoder(all_params, x_in):
h_in = x_in
layers_values = []
for i, params in enumerate(all_params):
h_out = common.fully_connected_layer(params, h_in)
h_out = layers_and_loss_tags.register_dense(h_out, h_in, params[0],
params[1],)
layers_values.append((h_out, h_in))
# Last layer does not have a nonlinearity
if i % 4 != 3:
h_in = jnp.tanh(h_out)
else:
h_in = h_out
h1, _ = loss_functions.register_normal_predictive_distribution(
h_in, targets=x_in, weight=1.0)
h2, t2 = loss_functions.register_normal_predictive_distribution(
h_in, targets=x_in, weight=0.1)
return [[h1, t2], [h2, t2]]
class TestGraphMatcher(jtu.JaxTestCase):
"""Class for running all of the tests for integrating the systems."""
def _test_jaxpr(self, init_func, model_func, tagged_model, data_shape):
data_shape = tuple(data_shape)
rng_key = jnr.PRNGKey(12345)
init_key, data_key = jnr.split(rng_key)
params = init_func(init_key, data_shape)
data = jnr.normal(data_key, (11,) + data_shape)
func = tag_graph_matcher.auto_register_tags(model_func, (params, data))
jaxpr = jax.make_jaxpr(func)(params, data).jaxpr
tagged_jaxpr = jax.make_jaxpr(tagged_model)(params, data).jaxpr
self.assertEqual(len(jaxpr.invars), len(tagged_jaxpr.invars))
self.assertEqual(len(jaxpr.constvars), len(tagged_jaxpr.constvars))
self.assertEqual(len(jaxpr.outvars), len(tagged_jaxpr.outvars))
for eq, tagged_eq in zip(jaxpr.eqns, tagged_jaxpr.eqns):
eq_in_vars = [v for v in eq.invars if not isinstance(v, jax.core.UnitVar)]
tagged_in_vars = [
v for v in tagged_eq.invars if not isinstance(v, jax.core.UnitVar)
]
self.assertEqual(len(eq_in_vars), len(tagged_in_vars))
self.assertEqual(len(eq.outvars), len(tagged_eq.outvars))
self.assertEqual(eq.primitive, tagged_eq.primitive)
for variable, t_variable in zip(eq_in_vars + eq.outvars,
tagged_in_vars + tagged_eq.outvars):
if isinstance(variable, jax.core.Literal):
self.assertEqual(variable.aval, t_variable.aval)
else:
if variable.count != t_variable.count:
print("0")
self.assertEqual(variable.count, t_variable.count)
def test_autoencoder(self):
self._test_jaxpr(common.init_autoencoder, common.autoencoder,
tagged_autoencoder, [784])
if __name__ == "__main__":
absltest.main()

View File

@@ -0,0 +1,198 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
import jax.numpy as jnp
import jax.random as jnr
import jax.test_util as jtu
from kfac_ferminet_alpha import loss_functions
from kfac_ferminet_alpha import tag_graph_matcher as tgm
from kfac_ferminet_alpha import tracer
from kfac_ferminet_alpha import utils
from kfac_ferminet_alpha.tests import common
def autoencoder_aux(all_aux, all_params, x_in):
h_in = x_in
layers_values = []
for i, (params, aux) in enumerate(zip(all_params, all_aux)):
h_out = common.fully_connected_layer(params, h_in + aux[1]) + aux[0]
layers_values.append((h_out, h_in))
# Last layer does not have a nonlinearity
if i % 4 != 3:
h_in = jnp.tanh(h_out)
else:
h_in = h_out
h1, _ = loss_functions.register_normal_predictive_distribution(h_in, x_in)
h2, _ = loss_functions.register_normal_predictive_distribution(
h_in, targets=x_in, weight=0.1)
l1 = (h1 - x_in)**2 + jnp.log(jnp.pi) / 2
l2 = (h2 - x_in)**2 + jnp.log(jnp.pi) / 2
return [l1, l2 * 0.1], layers_values
class TestTracer(jtu.JaxTestCase):
"""Class for running all of the tests for integrating the systems."""
@staticmethod
def generate_data(init_func, func, data_shape, rng_key):
n = 3
rng_key, key = jnr.split(rng_key)
params = init_func(key, data_shape)
rng_key, key = jnr.split(rng_key)
p_tangents = init_func(key, data_shape)
rng_key, key = jnr.split(rng_key)
data = jnr.normal(key, [n] + data_shape)
loss_vals, layer_vals = func(params, data)
h = layer_vals[-1][0]
keys = jnr.split(key, len(loss_vals))
h_tangents = tuple(jnr.normal(key, shape=h.shape) for key in keys)
return params, data, p_tangents, h_tangents
def assertStructureAllClose(self, x, y, **kwargs):
x_v, x_tree = jax.tree_flatten(x)
y_v, y_tree = jax.tree_flatten(y)
self.assertEqual(x_tree, y_tree)
for xi, yi in zip(x_v, y_v):
self.assertEqual(xi.shape, yi.shape)
self.assertAllClose(xi, yi, check_dtypes=True, **kwargs)
def test_tacer_jvp(self):
init_func = common.init_autoencoder
func = common.autoencoder
data_shape = [784]
rng_key = jnr.PRNGKey(12345)
params, data, p_tangents, _ = self.generate_data(init_func, func,
data_shape, rng_key)
def no_data_func(args):
outputs = func(args, data)
return outputs[0], outputs[1][-1][0]
# True computation
(primals_out, tangents_out) = jax.jvp(no_data_func, [params], [p_tangents])
loss_vals, _ = primals_out
_, h_tangents = tangents_out
loss_tangents = ((h_tangents,),) * len(loss_vals)
# Tracer computation
tracer_jvp = tracer.trace_losses_matrix_vector_jvp(func)
tracer_losses, tracer_loss_tangents = tracer_jvp((params, data), p_tangents)
tracer_losses = [loss.evaluate(None) for loss in tracer_losses]
self.assertStructureAllClose(loss_vals, tracer_losses)
self.assertStructureAllClose(loss_tangents, tracer_loss_tangents)
def test_tracer_vjp(self):
init_func = common.init_autoencoder
func = common.autoencoder
data_shape = [784]
rng_key = jnr.PRNGKey(12345)
params, data, _, h_tangents = self.generate_data(init_func, func,
data_shape, rng_key)
def no_data_func(args):
outputs = func(args, data)
return outputs[0], outputs[1][-1][0]
# True computation
(loss_vals, _), vjp_func = jax.vjp(no_data_func, params)
loss_tangents = jax.tree_map(jnp.zeros_like, loss_vals)
summed_h_tangents = sum(jax.tree_flatten(h_tangents)[0])
p_tangents = vjp_func((loss_tangents, summed_h_tangents))
# Tracer computation
trace_vjp = tracer.trace_losses_matrix_vector_vjp(func)
tracer_losses, tracer_vjp_func = trace_vjp(params, data)
tracer_losses = [loss.evaluate(None) for loss in tracer_losses]
tracer_p_tangents = tracer_vjp_func(h_tangents)
self.assertStructureAllClose(loss_vals, tracer_losses)
self.assertStructureAllClose(p_tangents, tracer_p_tangents)
def test_tracer_hvp(self):
init_func = common.init_autoencoder
func = common.autoencoder
data_shape = [784]
rng_key = jnr.PRNGKey(12345)
params, data, p_tangents, _ = self.generate_data(init_func, func,
data_shape, rng_key)
def no_data_func(args):
outputs = func(args, data)
return sum(jax.tree_map(jnp.sum, outputs[0]))
# True computation
grad_func = jax.grad(no_data_func)
def grad_time_tangents(args):
return utils.inner_product(grad_func(args), p_tangents)
hvp = jax.grad(grad_time_tangents)
hvp_vectors = hvp(params)
# Tracer computation
tracer_hvp = tracer.trace_losses_matrix_vector_hvp(func)
tracer_hvp_vectors = tracer_hvp((params, data), p_tangents)
self.assertStructureAllClose(hvp_vectors, tracer_hvp_vectors, atol=1e-4)
def test_trace_estimator(self):
init_func = common.init_autoencoder
func = common.autoencoder
aux_func = autoencoder_aux
data_shape = [784]
rng_key = jnr.PRNGKey(12345)
params, data, _, h_tangents = self.generate_data(init_func, func,
data_shape, rng_key)
def aux_last_layer(aux, args):
outs = aux_func(aux, args, data)
return outs[1][-1][0]
# True computation
loss_vals, layer_vals = func(params, data)
aux_vals = jax.tree_map(jnp.zeros_like, layer_vals)
_, vjp = jax.vjp(aux_last_layer, aux_vals, params)
summed_h_tangents = sum(jax.tree_flatten(h_tangents)[0])
aux_tangents, p_tangents = vjp(summed_h_tangents)
layers_info = []
for aux_p, p_p in zip(layer_vals, params):
info = dict()
info["outputs"] = (aux_p[0],)
info["inputs"] = (aux_p[1],)
info["params"] = (p_p[0], p_p[1])
layers_info.append(info)
for i, (aux_t, p_t) in enumerate(zip(aux_tangents, p_tangents)):
info = dict()
info["outputs_tangent"] = (aux_t[0],)
info["inputs_tangent"] = (aux_t[1],)
info["params_tangent"] = (p_t[0], p_t[1])
layers_info[i].update(info)
layers_info = tuple(layers_info)
func = tgm.auto_register_tags(func, (params, data))
tracer_vjp = tracer.trace_estimator_vjp(func)
tracer_losses, tracer_vjp_func = tracer_vjp((params, data))
tracer_losses = [loss.evaluate(None) for loss in tracer_losses]
tracer_outputs = tracer_vjp_func((h_tangents[:1], h_tangents[1:]))
self.assertStructureAllClose(loss_vals, tracer_losses)
self.assertStructureAllClose(tracer_outputs, layers_info)
if __name__ == "__main__":
absltest.main()

View File

@@ -0,0 +1,327 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the Jax tracer functionality for tags."""
import functools
from typing import Any, Callable, Sequence, Tuple
import jax
from jax import core
from jax import util as jax_util
import jax.numpy as jnp
from kfac_ferminet_alpha import layers_and_loss_tags as tags
from kfac_ferminet_alpha import tag_graph_matcher as tgm
from kfac_ferminet_alpha import utils
_Function = Callable[[Any], Any]
_Loss = tags.LossTag
def extract_tags(
jaxpr: core.Jaxpr
) -> Tuple[Sequence[core.JaxprEqn], Sequence[core.JaxprEqn]]:
"""Extracts all of the tag equations."""
# Loop through equations and evaluate primitives using `bind`
layer_tags = []
loss_tags = []
for eqn in jaxpr.eqns:
if isinstance(eqn.primitive, tags.LossTag):
loss_tags.append(eqn)
elif isinstance(eqn.primitive, tags.LayerTag):
layer_tags.append(eqn)
return tuple(layer_tags), tuple(loss_tags)
def construct_compute_losses_inputs(
jaxpr: core.Jaxpr,
consts: Tuple[Any],
num_losses: int,
primals: Any,
params_index: int) -> Callable[[Any], Sequence[Sequence[jnp.ndarray]]]:
"""Constructs a function that computes all of the inputs to all losses."""
primals_ = list(primals)
def forward_compute_losses(
params_primals: Any,
) -> Sequence[Sequence[jnp.ndarray]]:
primals_[params_index] = params_primals
flat_args = jax.tree_flatten(primals_)[0]
# Mapping from variable -> value
env = dict()
read = functools.partial(tgm.read_env, env)
write = functools.partial(tgm.write_env, env)
# Bind args and consts to environment
write(jax.core.unitvar, jax.core.unit)
jax_util.safe_map(write, jaxpr.invars, flat_args)
jax_util.safe_map(write, jaxpr.constvars, consts)
# Loop through equations and evaluate primitives using `bind`
losses_so_far = 0
loss_tags = []
for eqn in jaxpr.eqns:
tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
if isinstance(eqn.primitive, tags.LossTag):
loss_tags.append(eqn)
losses_so_far += 1
if num_losses is not None and losses_so_far == num_losses:
break
return tuple(tuple(read(v) for v in tag.invars) for tag in loss_tags)
# return tuple(jax_util.safe_map(read, tag.invars) for tag in loss_tags)
return forward_compute_losses
# We know when `.primitive` will be either a `LossTag` or a `LayerTag`, however
# pytype cannot infer its subclass, so we need to unbox it.
def _unbox_loss_tag(jaxpr_eqn: core.JaxprEqn) -> tags.LossTag:
assert isinstance(jaxpr_eqn.primitive, tags.LossTag)
return jaxpr_eqn.primitive
def _unbox_layer_tag(jaxpr_eqn: core.JaxprEqn) -> tags.LayerTag:
assert isinstance(jaxpr_eqn.primitive, tags.LayerTag)
return jaxpr_eqn.primitive
def trace_losses_matrix_vector_vjp(tagged_func: _Function,
params_index: int = 0):
"""Returns the Jacobian-transposed vector product (backward mode) function in equivalent form to jax.vjp."""
def vjp(*primals):
typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals)
jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals
_, loss_jaxpr_eqns = extract_tags(jaxpr)
n = len(loss_jaxpr_eqns)
losses_func = construct_compute_losses_inputs(
jaxpr, consts, n, primals, params_index)
losses_inputs, full_vjp_func = jax.vjp(losses_func, primals[params_index])
losses = []
for jaxpr_eqn, inputs in zip(loss_jaxpr_eqns, losses_inputs):
loss_tag = _unbox_loss_tag(jaxpr_eqn)
losses.append(loss_tag.loss(*inputs, weight=jaxpr_eqn.params["weight"]))
losses = tuple(losses)
def vjp_func(tangents):
flat_tangents = jax.tree_flatten(tangents)[0]
loss_invars = []
loss_targets = []
for jaxpr_eqn, inputs in zip(loss_jaxpr_eqns, losses_inputs):
num_inputs = _unbox_loss_tag(jaxpr_eqn).num_inputs
loss_invars.append(tuple(jaxpr_eqn.invars[:num_inputs]))
loss_targets.append(inputs[num_inputs:])
treedef = jax.tree_structure(loss_invars)
tangents = jax.tree_unflatten(treedef, flat_tangents)
# Since the losses could also take and targets as inputs and we don't want
# this function to computes vjp w.r.t to those (e.g. the user should not
# be providing tangent vectors for the targets, only for inputs) we have
# to manually fill in these "extra" tangents with zeros.
targets_tangents = jax.tree_map(jnp.zeros_like, loss_targets)
tangents = tuple(ti + tti for ti, tti in zip(tangents, targets_tangents))
input_tangents = full_vjp_func(tangents)[0]
return input_tangents,
return losses, vjp_func
return vjp
def trace_losses_matrix_vector_jvp(
tagged_func: _Function,
params_index: int = 0):
"""Returns the Jacobian vector product (forward mode) function in equivalent form to jax.jvp."""
def jvp(primals, params_tangents):
typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals)
jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals
_, loss_tags = extract_tags(jaxpr)
n = len(loss_tags)
losses_func = construct_compute_losses_inputs(jaxpr, consts, n,
primals, params_index)
primals = (primals[params_index],)
tangents = (params_tangents,)
(primals_out, tangents_out) = jax.jvp(losses_func, primals, tangents)
tangents_out = tuple(tuple(t[:tag.primitive.num_inputs])
for t, tag in zip(tangents_out, loss_tags))
losses = tuple(tag.primitive.loss(*inputs, weight=tag.params["weight"])
for tag, inputs in zip(loss_tags, primals_out))
return losses, tangents_out
return jvp
def trace_losses_matrix_vector_hvp(tagged_func, params_index=0):
"""Returns the Hessian vector product function of **the tagged losses**, rather than the output value of `tagged_func`."""
# The function uses backward-over-forward mode.
def hvp(primals, params_tangents):
typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals)
jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals
_, loss_tags = extract_tags(jaxpr)
n = len(loss_tags)
losses_func = construct_compute_losses_inputs(
jaxpr, consts, n, primals, params_index)
def losses_sum(param_primals):
loss_inputs = losses_func(param_primals)
losses = [
_unbox_loss_tag(jaxpr_eqn).loss(
*inputs, weight=jaxpr_eqn.params["weight"])
for jaxpr_eqn, inputs in zip(loss_tags, loss_inputs)
]
# This computes the sum of losses evaluated. Makes it easier as we can
# now use jax.grad rather than jax.vjp for taking derivatives.
return sum(jnp.sum(loss.evaluate(None)) for loss in losses)
def grads_times_tangents(params_primals):
grads = jax.grad(losses_sum)(params_primals)
return utils.inner_product(grads, params_tangents)
return jax.grad(grads_times_tangents)(primals[params_index])
return hvp
def trace_estimator_vjp(tagged_func: _Function) -> _Function:
"""Creates the function needed for an estimator of curvature matrices.
Args:
tagged_func: An function that has been annotated with tags both for layers
and losses.
Returns:
A function with the same signatures as `tagged_func`, which when provided
with inputs returns two things:
1. The instances of all losses objected that are tagged.
2. A second function, which when provide with tangent vectors for each
of the loss instances' parameters, returns for every tagged layer a
dictionary containing the following elements:
inputs - The primal values of the inputs to the layer.
outputs - The primal values of the outputs to the layer.
params - The primal values of the layer.
inputs_tangent - The tangent value of layer, given the provided
tangents of the losses.
inputs_tangent - The tangent value of layer, given the provided
tangents of the losses.
inputs_tangent - The tangent value of layer, given the provided
tangents of the losses.
"""
def full_vjp_func(func_args):
# Trace the tagged function
typed_jaxpr = jax.make_jaxpr(tagged_func)(*func_args)
jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals
layer_tags, loss_tags = extract_tags(jaxpr)
layer_vars_flat = jax.tree_flatten([tag.invars for tag in layer_tags])[0]
layer_input_vars = tuple(set(layer_vars_flat))
def forward():
own_func_args = func_args
# Mapping from variable -> value
env = dict()
read = functools.partial(tgm.read_env, env)
write = functools.partial(tgm.write_env, env)
# Bind args and consts to environment
write(jax.core.unitvar, jax.core.unit)
jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0])
jax_util.safe_map(write, jaxpr.constvars, consts)
# Loop through equations and evaluate primitives using `bind`
num_losses_passed = 0
for eqn in jaxpr.eqns:
tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
if isinstance(eqn.primitive, tags.LossTag):
num_losses_passed += 1
if num_losses_passed == len(loss_tags):
break
if num_losses_passed != len(loss_tags):
raise ValueError("This should be unreachable.")
return jax_util.safe_map(read, layer_input_vars)
def forward_aux(aux):
own_func_args = func_args
# Mapping from variable -> value
env = dict()
read = functools.partial(tgm.read_env, env)
def write(var, val):
if not isinstance(var, (jax.core.Literal, jax.core.UnitVar)):
val = val + aux[var] if var in aux else val
env[var] = val
# Bind args and consts to environment
write(jax.core.unitvar, jax.core.unit)
jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0])
jax_util.safe_map(write, jaxpr.constvars, consts)
# Loop through equations and evaluate primitives using `bind`
num_losses_passed = 0
losses_inputs_values = []
losses_kwargs_values = []
for eqn in jaxpr.eqns:
input_values = jax_util.safe_map(read, eqn.invars)
tgm.evaluate_eqn(eqn, input_values, write)
if isinstance(eqn.primitive, tags.LossTag):
loss = eqn.primitive.loss(*input_values, weight=eqn.params["weight"])
losses_inputs_values.append(loss.inputs)
losses_kwargs_values.append(dict(
targets=loss.targets,
weight=eqn.params["weight"]
))
num_losses_passed += 1
if num_losses_passed == len(loss_tags):
break
if num_losses_passed != len(loss_tags):
raise ValueError("This should be unreachable.")
# Read the inputs to the loss functions, but also return the target values
return tuple(losses_inputs_values), tuple(losses_kwargs_values)
layer_input_values = forward()
primals_dict = dict(zip(layer_input_vars, layer_input_values))
primals_dict.update(zip(jaxpr.invars, jax.tree_flatten(func_args)[0]))
aux_values = jax.tree_map(jnp.zeros_like, layer_input_values)
aux_dict = dict(zip(layer_input_vars, aux_values))
losses_args, aux_vjp, losses_kwargs = jax.vjp(forward_aux, aux_dict,
has_aux=True)
losses = tuple(tag.primitive.loss(*inputs, **kwargs)
for tag, inputs, kwargs in
zip(loss_tags, losses_args, losses_kwargs))
def vjp_func(tangents):
all_tangents = aux_vjp(tangents)
tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:]
inputs_tangents = jax.tree_flatten(inputs_tangents)[0]
tangents_dict.update(zip(jaxpr.invars, inputs_tangents))
read_primals = functools.partial(tgm.read_env, primals_dict)
read_tangents = functools.partial(tgm.read_env, tangents_dict)
layers_info = []
for jaxpr_eqn in layer_tags:
layer_tag = _unbox_layer_tag(jaxpr_eqn)
info = dict()
primals = jax_util.safe_map(read_primals, tuple(jaxpr_eqn.invars))
(
info["outputs"],
info["inputs"],
info["params"],
) = layer_tag.split_all_inputs(primals)
tangents = jax_util.safe_map(read_tangents, tuple(jaxpr_eqn.invars))
(
info["outputs_tangent"],
info["inputs_tangent"],
info["params_tangent"],
) = layer_tag.split_all_inputs(tangents)
layers_info.append(info)
return tuple(layers_info)
return losses, vjp_func
return full_vjp_func

View File

@@ -0,0 +1,455 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities related to multi-device operations."""
import collections
from typing import Any, Mapping, Optional, Sequence, Tuple, TypeVar, Union
import dataclasses
import jax
from jax import core
from jax import lax
import jax.numpy as jnp
from jax.scipy import linalg
import jax.tree_util as tree_util
T = TypeVar("T")
def wrap_if_pmap(p_func):
def p_func_if_pmap(obj, axis_name):
try:
core.axis_frame(axis_name)
return p_func(obj, axis_name)
except NameError:
return obj
return p_func_if_pmap
pmean_if_pmap = wrap_if_pmap(lax.pmean)
psum_if_pmap = wrap_if_pmap(lax.psum)
compute_mean = jax.pmap(lambda x: lax.pmean(x, "i"), axis_name="i")
compute_sum = jax.pmap(lambda x: lax.psum(x, "i"), axis_name="i")
def get_first(obj: T) -> T:
return jax.tree_map(lambda x: x[0], obj)
def get_mean(obj: T) -> T:
return get_first(compute_mean(obj))
def get_sum(obj: T) -> T:
return get_first(compute_sum(obj))
broadcast_all_local_devices = jax.pmap(lambda x: x)
def replicate_all_local_devices(obj: T) -> T:
n = jax.local_device_count()
obj_stacked = jax.tree_map(lambda x: jnp.stack([x] * n, axis=0), obj)
return broadcast_all_local_devices(obj_stacked)
def make_different_rng_key_on_all_devices(rng: jnp.ndarray) -> jnp.ndarray:
rng = jax.random.fold_in(rng, jax.host_id())
rng = jax.random.split(rng, jax.local_device_count())
return broadcast_all_local_devices(rng)
p_split = jax.pmap(lambda key: tuple(jax.random.split(key)))
def scalar_mul(obj: T, scalar: Union[float, jnp.ndarray]) -> T:
return jax.tree_map(lambda x: x * scalar, obj)
def scalar_div(obj: T, scalar: Union[float, jnp.ndarray]) -> T:
return jax.tree_map(lambda x: x / scalar, obj)
def make_func_args(params, func_state, rng, batch, has_state: bool,
has_rng: bool):
"""Correctly puts all arguments to the function together."""
func_args = (params,)
if has_state:
if func_state is None:
raise ValueError("The `func_state` is None, but the argument `has_state` "
"is True.")
func_args += (func_state,)
if has_rng:
if rng is None:
raise ValueError("The `rng` is None, but the argument `has_rng` is True.")
func_args += (rng,)
func_args += (batch,)
return func_args
def extract_func_outputs(
raw_outputs: Any,
has_aux: bool,
has_state: bool,
) -> Tuple[jnp.ndarray, Any, Any]:
"""Given the function output returns separately the loss, func_state, aux."""
if not has_aux and not has_state:
return raw_outputs, None, None
loss, other = raw_outputs
if has_aux and has_state:
func_state, aux = other
elif has_aux:
func_state, aux = None, other
else:
func_state, aux = other, None
return loss, func_state, aux
def inner_product(obj1: T, obj2: T) -> jnp.ndarray:
if jax.tree_structure(obj1) != jax.tree_structure(obj2):
raise ValueError("The two structures are not identical.")
elements_product = jax.tree_multimap(lambda x, y: jnp.sum(x * y), obj1, obj2)
return sum(jax.tree_flatten(elements_product)[0])
def psd_inv_cholesky(matrix: jnp.ndarray, damping: jnp.ndarray) -> jnp.ndarray:
assert matrix.ndim == 2
identity = jnp.eye(matrix.shape[0])
matrix = matrix + damping * identity
return linalg.solve(matrix, identity, sym_pos=True)
def solve_maybe_small(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
"""Computes a^-1 b more efficiently for small matrices."""
assert a.shape[-1] == a.shape[-2] == b.shape[-1]
d = a.shape[-1]
if d == 0:
return a
elif d == 1:
return b / a[..., 0]
elif d == 2:
det = a[..., 0, 0] * a[..., 1, 1] - a[..., 0, 1] * a[..., 1, 0]
b_0 = a[..., 1, 1] * b[..., 0] - a[..., 0, 1] * b[..., 1]
b_1 = a[..., 0, 0] * b[..., 1] - a[..., 1, 0] * b[..., 0]
return jnp.stack([b_0, b_1], axis=-1) / det
elif d == 3:
raise NotImplementedError()
return jnp.linalg.solve(a, b)
def pi_adjusted_inverse(
factor_0: jnp.ndarray,
factor_1: jnp.ndarray,
damping: jnp.ndarray,
pmap_axis_name: str,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Performs inversion with pi-adjusted damping."""
# Compute the norms of each factor
norm_0 = jnp.trace(factor_0)
norm_1 = jnp.trace(factor_1)
# We need to sync the norms here, because reduction can be non-deterministic.
# They specifically are on GPUs by default for better performance.
# Hence although factor_0 and factor_1 are synced, the trace operation above
# can still produce different answers on different devices.
norm_0, norm_1 = pmean_if_pmap((norm_0, norm_1), axis_name=pmap_axis_name)
# Compute the overall scale
scale = norm_0 * norm_1
def regular_inverse(
operand: Sequence[jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
factor0, factor1, norm0, norm1, s, d = operand
# Special cases with one or two scalar factors
if factor0.size == 1 and factor1.size == 1:
value = jnp.ones_like(factor0) / jnp.sqrt(s)
return value, value
if factor0.size == 1:
factor1_normed = factor1 / norm1
damping1 = d / norm1
factor1_inv = psd_inv_cholesky(factor1_normed, damping1)
return jnp.full((1, 1), s), factor1_inv
if factor1.size == 1:
factor0_normed = factor0 / norm0
damping0 = d / norm0
factor0_inv = psd_inv_cholesky(factor0_normed, damping0)
return factor0_inv, jnp.full((1, 1), s)
# Invert first factor
factor0_normed = factor0 / norm0
damping0 = jnp.sqrt(d * factor1.shape[0] / (s * factor0.shape[0]))
factor0_inv = psd_inv_cholesky(factor0_normed, damping0) / jnp.sqrt(s)
# Invert second factor
factor1_normed = factor1 / norm1
damping1 = jnp.sqrt(d * factor0.shape[0] / (s * factor1.shape[0]))
factor1_inv = psd_inv_cholesky(factor1_normed, damping1) / jnp.sqrt(s)
return factor0_inv, factor1_inv
def zero_inverse(
operand: Sequence[jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
return (jnp.eye(factor_0.shape[0]) / jnp.sqrt(operand[-1]),
jnp.eye(factor_1.shape[0]) / jnp.sqrt(operand[-1]))
# In the special case where for some reason one of the factors is zero, then
# the correct inverse of `(0 kron A + lambda I)` is
# `(I/sqrt(lambda) kron (I/sqrt(lambda)`. However, because one of the norms is
# zero, then `pi` and `1/pi` would be 0 and infinity leading to NaN values.
# Hence, we need to make this check explicitly.
return lax.cond(
jnp.greater(scale, 0.0),
regular_inverse,
zero_inverse,
operand=(factor_0, factor_1, norm_0, norm_1, scale, damping))
def convert_value_and_grad_to_value_func(
value_and_grad_func,
has_aux: bool = False,
):
"""Converts a value_and_grad function to value_func only."""
def value_func(*args, **kwargs):
out, _ = value_and_grad_func(*args, **kwargs)
if has_aux:
return out[0]
else:
return out
return value_func
def check_structure_shapes_and_dtype(obj1: T, obj2: T) -> None:
"""Verifies that the two objects have the same pytree structure."""
assert jax.tree_structure(obj1) == jax.tree_structure(obj2)
for v1, v2 in zip(jax.tree_flatten(obj1)[0], jax.tree_flatten(obj2)[0]):
assert v1.shape == v2.shape
assert v1.dtype == v2.dtype
def check_first_dim_is_batch_size(batch_size: int, *args: jnp.ndarray) -> None:
for i, arg in enumerate(args):
if arg.shape[0] != batch_size:
raise ValueError(f"Expecting first dimension of arg[{i}] with shape "
f"{arg.shape} to be equal to the batch size "
f"{batch_size}.")
def py_tree_registered_dataclass(cls, *args, **kwargs):
"""Creates a new dataclass type and registers it as a pytree node."""
dcls = dataclasses.dataclass(cls, *args, **kwargs)
tree_util.register_pytree_node(
dcls,
lambda instance: ( # pylint: disable=g-long-lambda
[getattr(instance, f.name)
for f in dataclasses.fields(instance)], None),
lambda _, instance_args: dcls(*instance_args))
return dcls
class WeightedMovingAverage:
"""A wrapped class for a variable for which we keep exponential moving average."""
def __init__(self, weight: jnp.ndarray, array: jnp.ndarray):
self._weight = weight
self._array = array
@staticmethod
def zero(shape: Sequence[int]) -> "WeightedMovingAverage":
return WeightedMovingAverage(weight=jnp.zeros([]), array=jnp.zeros(shape))
@property
def weight(self) -> jnp.ndarray:
return self._weight
@property
def value(self) -> jnp.ndarray:
return self._array / self._weight
@property
def raw_value(self) -> jnp.ndarray:
return self._array
def update(self, value: jnp.ndarray, old_weight_multiplier: float,
new_weight: float) -> None:
self._weight = old_weight_multiplier * self._weight + new_weight
self._array = old_weight_multiplier * self._array + new_weight * value
def sync(self, pmap_axis_name: str) -> None:
self._array = pmean_if_pmap(self._array, pmap_axis_name)
def __str__(self) -> str:
return (f"ExponentialMovingAverage(weight={self._weight}, "
f"array={self._array})")
def __repr__(self) -> str:
return self.__str__()
tree_util.register_pytree_node(
WeightedMovingAverage,
lambda instance: ((instance.weight, instance.raw_value), None),
lambda _, instance_args: WeightedMovingAverage(*instance_args),
)
class Stateful:
"""A class for stateful objects."""
def __init__(self, stateful_fields_names: Optional[Sequence[str]] = ()):
self.__stateful_fields_names = stateful_fields_names
def _add_stateful_fields_names(self, value: Sequence[str]) -> None:
self.__stateful_fields_names += tuple(value)
def get_state(self) -> Mapping[str, Any]:
"""Returns the state of the object."""
state = dict()
for name in self.__stateful_fields_names:
state[name] = Stateful._get_state_from_instance(getattr(self, name))
return state
def set_state(self, value):
"""Sets the state of the object with the provided value and returns the object."""
assert isinstance(value, dict)
for name in self.__stateful_fields_names:
setattr(self, name,
Stateful._set_state_to_instance(getattr(self, name), value[name]))
return self
def clear_state(self) -> None:
"""Clears the state of the object."""
for name in self.__stateful_fields_names:
setattr(self, name,
Stateful._clear_state_from_instance(getattr(self, name)))
def pop_state(self) -> Mapping[str, Any]:
"""Returns the current state of the object, while simultaneously clearing it."""
state = self.get_state()
self.clear_state()
return state
@staticmethod
def _get_state_from_instance(obj):
"""Recursively gets the state of the object and returns it."""
if isinstance(obj, Stateful):
return obj.get_state()
if isinstance(obj, list):
return [Stateful._get_state_from_instance(i) for i in obj]
if isinstance(obj, tuple):
return tuple(Stateful._get_state_from_instance(i) for i in obj)
if isinstance(obj, collections.OrderedDict):
return collections.OrderedDict(
(k, Stateful._get_state_from_instance(v)) for k, v in obj.items())
if isinstance(obj, dict):
return dict(
(k, Stateful._get_state_from_instance(v)) for k, v in obj.items())
return obj
@staticmethod
def _set_state_to_instance(obj, value):
"""Recursively sets the state of the object and returns it."""
if isinstance(obj, Stateful):
obj.set_state(value)
return obj
if isinstance(value, list):
if obj is None:
obj = [None] * len(value)
return [
Stateful._set_state_to_instance(obj_i, value_i)
for obj_i, value_i in zip(obj, value)
]
if isinstance(value, tuple):
if obj is None:
obj = [None] * len(value)
return tuple(
Stateful._set_state_to_instance(obj_i, value_i)
for obj_i, value_i in zip(obj, value))
if isinstance(value, collections.OrderedDict):
if obj is None:
obj = dict((k, None) for k in value)
return collections.OrderedDict(
(k, Stateful._set_state_to_instance(obj[k], value[k])) for k in obj)
if isinstance(value, dict):
obj = dict((k, None) for k in value)
return dict(
(k, Stateful._set_state_to_instance(obj[k], value[k])) for k in obj)
return value
@staticmethod
def _clear_state_from_instance(obj):
"""Recursively clears the state of the object and returns it."""
if isinstance(obj, Stateful):
obj.clear_state()
return obj
if isinstance(obj, list):
return [Stateful._clear_state_from_instance(obj_i) for obj_i in obj]
if isinstance(obj, tuple):
return tuple(Stateful._clear_state_from_instance(obj_i) for obj_i in obj)
if isinstance(obj, collections.OrderedDict):
return collections.OrderedDict(
(k, Stateful._clear_state_from_instance(obj[k])) for k in obj)
if isinstance(obj, dict):
return dict((k, Stateful._clear_state_from_instance(obj[k])) for k in obj)
return None
@staticmethod
def infer_class_state(class_type):
"""Infers a stateful class state attributes from class annotations."""
if not issubclass(class_type, Stateful):
raise ValueError(
f"In order to annotate a class as stateful it must inherit "
f"{Stateful!r}")
class_type = dataclasses.dataclass(
class_type, init=False, repr=False, eq=False) # pytype: disable=wrong-keyword-args
fields_names = tuple(field.name for field in dataclasses.fields(class_type))
original_init = getattr(class_type, "__init__", None)
if original_init is None:
def injected_init(self, *args, **kwargs):
super(self.__class__, self).__init__(*args, **kwargs) # pylint: disable=bad-super-call
Stateful._add_stateful_fields_names(self, fields_names)
for field_name in fields_names:
if getattr(self, field_name, None) is None:
setattr(self, field_name, None)
setattr(class_type, "__init__", injected_init)
else:
def injected_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
Stateful._add_stateful_fields_names(self, fields_names)
for field_name in fields_names:
if getattr(self, field_name, None) is None:
setattr(self, field_name, None)
setattr(class_type, "__init__", injected_init)
return class_type
def compute_sq_norm_relative_abs_diff(obj, pmap_axis_name):
sq_norm = inner_product(obj, obj)
synced_sq_norm = psum_if_pmap(sq_norm, pmap_axis_name)
synced_sq_norm = (synced_sq_norm - sq_norm) / (jax.device_count() - 1.0)
sq_norm_abs_diff = jnp.abs(sq_norm - synced_sq_norm)
return sq_norm_abs_diff / sq_norm
def product(iterable_object):
x = 1
for element in iterable_object:
x *= element
return x

View File

@@ -0,0 +1,53 @@
Implementation of the object-based transformer model from
["Object-based attention for spatio-temporal reasoning"](https://arxiv.org/abs/2012.08508)
[1].
This package includes source code for the transformer model,
pre-trained model parameters for the CLEVRER task,
and MONet [2] latent variables for all videos in the training
and validation sets. It does not include the model training code.
See Section 2 of [1] for details.
[1] David Ding, Felix Hill, Adam Santoro, Matt Botvinick. *Object-based
attention for spatio-temporal reasoning: Outperforming neuro-symbolic models
with flexible distributed architectures*.
arXiv preprint arXiv:2012.08508, 2020.
[2] Chris P. Burgess, Loic Matthey, Nick Watters, Rishabh Kabra, Irina Higgins,
Matt Botvinick, and Alexander Lerchner
*MONet: Unsupervised scene decomposition and representation*.
arXiv preprint arXiv:1901.11390, 2019.
# Instructions
Note: This code depends on Tensorflow 1 and Sonnet 1. Tensorflow 1 is only
available on PYPI for Python 3.7 and earlier.
To run this code, execute the following commands from the `deepmind_research/`
directory:
```shell
# Download checkpoints and MONet latents
wget https://storage.googleapis.com/object-attention-for-reasoning/checkpoints_and_latents.zip
unzip checkpoints_and_latents.zip
python3.7 -m venv object_based_attention_venv
source object_based_attention_venv/bin/activate
pip install --upgrade setuptools wheel
pip install -r requirements.txt
python -m object_attention_for_reasoning.run_model
```
If the code runs correctly, you should see the model's predicted answer to two
CLEVRER questions (a descriptive one and a multiple choice one), and both
answers should be correct.
If you find the provided code useful, please cite this paper:
```
@article{objectattention2020,
title={Object-based attention for spatio-temporal reasoning: Outperforming
neuro-symbolic models with flexible distributed architectures},
author={David Ding and Felix Hill and Adam Santoro and Matt Botvinick},
journal={arXiv preprint arXiv:2012.08508},
year={2020}
}
```

View File

@@ -0,0 +1,185 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model code. Provided settings are identical to what was used in the paper."""
import sonnet as snt
import tensorflow.compat.v1 as tf
from object_attention_for_reasoning import transformer
QUESTION_VOCAB_SIZE = 82
ANSWER_VOCAB_SIZE = 22
MAX_QUESTION_LENGTH = 20
MAX_CHOICE_LENGTH = 12
NUM_CHOICES = 4
EMBED_DIM = 16
PRETRAINED_MODEL_CONFIG = dict(
use_relative_positions=True,
shuffle_objects=True,
transformer_layers=28,
head_size=128,
num_heads=10,
embed_dim=EMBED_DIM,
)
def append_ids(tensor, id_vector, axis):
id_vector = tf.constant(id_vector, tf.float32)
for a in range(len(tensor.shape)):
if a != axis:
id_vector = tf.expand_dims(id_vector, axis=a)
tiling_vector = [s if i != axis else 1 for i, s in enumerate(tensor.shape)]
id_tensor = tf.tile(id_vector, tiling_vector)
return tf.concat([tensor, id_tensor], axis=axis)
class ClevrerTransformerModel(object):
"""Model from Ding et al. 2020 (https://arxiv.org/abs/2012.08508)."""
def __init__(self, use_relative_positions, shuffle_objects,
transformer_layers, num_heads, head_size, embed_dim):
"""Instantiate Sonnet modules."""
self._embed_dim = embed_dim
self._embed = snt.Embed(QUESTION_VOCAB_SIZE, embed_dim - 2)
self._shuffle_objects = shuffle_objects
self._memory_transformer = transformer.TransformerTower(
value_size=embed_dim + 2,
num_heads=num_heads,
num_layers=transformer_layers,
use_relative_positions=use_relative_positions,
causal=False)
self._final_layer_mc = snt.Sequential(
[snt.Linear(head_size), tf.nn.relu, snt.Linear(1)])
self._final_layer_descriptive = snt.Sequential(
[snt.Linear(head_size), tf.nn.relu,
snt.Linear(ANSWER_VOCAB_SIZE)])
self._dummy = tf.get_variable("dummy", [embed_dim + 2], tf.float32,
initializer=tf.zeros_initializer)
self._infill_linear = snt.Linear(embed_dim + 2)
self._mask_embedding = tf.get_variable(
"mask", [embed_dim + 2], tf.float32, initializer=tf.zeros_initializer)
def _apply_transformers(self, lang_embedding, vision_embedding):
"""Applies transformer to language and vision input.
Args:
lang_embedding: tensor,
vision_embedding: tensor, "validation", or "test".
Returns:
tuple, output at dummy token, all output embeddings, infill loss
"""
def _unroll(tensor):
"""Unroll the time dimension into the object dimension."""
return tf.reshape(
tensor, [tensor.shape[0], -1, tensor.shape[3]])
words = append_ids(lang_embedding, [1, 0], axis=2)
dummy_word = tf.tile(self._dummy[None, None, :], [tf.shape(words)[0], 1, 1])
vision_embedding = append_ids(vision_embedding, [0, 1], axis=3)
vision_over_time = _unroll(vision_embedding)
transformer_input = tf.concat([dummy_word, words, vision_over_time], axis=1)
output, _ = self._memory_transformer(transformer_input,
is_training=False)
return output[:, 0, :]
def apply_model_descriptive(self, inputs):
"""Applies model to CLEVRER descriptive questions.
Args:
inputs: dict of form: {
"question": tf.int32 tensor of shape [batch, MAX_QUESTION_LENGTH],
"monet_latents": tf.float32 tensor of shape [batch, frames, 8, 16],
}
Returns:
Tensor of shape [batch, ANSWER_VOCAB_SIZE], representing logits for each
possible answer word.
"""
question = inputs["question"]
# Shape: [batch, question_len, embed_dim-2]
question_embedding = self._embed(question)
# Shape: [batch, question_len, embed_dim]
question_embedding = append_ids(question_embedding, [0, 1], 2)
choices_embedding = self._embed(
tf.zeros([question.shape[0], MAX_CHOICE_LENGTH], tf.int64))
choices_embedding = append_ids(choices_embedding, [0, 1], 2)
# Shape: [batch, choices, question_len + choice_len, embed_dim]
lang_embedding = tf.concat([question_embedding, choices_embedding], axis=1)
# Shape: [batch, frames, num_objects, embed_dim]
vision_embedding = inputs["monet_latents"]
if self._shuffle_objects:
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
vision_embedding = tf.random.shuffle(vision_embedding)
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
output = self._apply_transformers(lang_embedding, vision_embedding)
output = self._final_layer_descriptive(output)
return output
def apply_model_mc(self, inputs):
"""Applies model to CLEVRER multiple-choice questions.
Args:
inputs: dict of form: {
"question": tf.int32 tensor of shape [batch, MAX_QUESTION_LENGTH],
"choices": tf.int32 tensor of shape [batch, 4, MAX_CHOICE_LENGTH],
"monet_latents": tf.float32 tensor of shape [batch, frames, 8, 16],
}
Returns:
Tensor of shape [batch, 4], representing logits for each choice
"""
question = inputs["question"]
choices = inputs["choices"]
# Shape: [batch, question_len, embed_dim-2]
question_embedding = self._embed(question)
# Shape: [batch, question_len, embed_dim]
question_embedding = append_ids(question_embedding, [1, 0], 2)
# Shape: [batch, choices, choice_len, embed_dim-2]
choices_embedding = snt.BatchApply(self._embed)(choices)
# Shape: [batch, choices, choice_len, embed_dim]
choices_embedding = append_ids(choices_embedding, [0, 1], 3)
# Shape: [batch, choices, question_len + choice_len, embed_dim]
lang_embedding = tf.concat([
tf.tile(question_embedding[:, None],
[1, choices_embedding.shape[1], 1, 1]),
choices_embedding], axis=2)
# Shape: [batch, frames, num_objects, embed_dim]
vision_embedding = inputs["monet_latents"]
if self._shuffle_objects:
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
vision_embedding = tf.random.shuffle(vision_embedding)
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
output_per_choice = []
for c in range(NUM_CHOICES):
output = self._apply_transformers(
lang_embedding[:, c, :, :], vision_embedding)
output_per_choice.append(output)
output = tf.stack(output_per_choice, axis=1)
output = tf.squeeze(snt.BatchApply(self._final_layer_mc)(output), axis=2)
return output

View File

@@ -0,0 +1,4 @@
absl-py==0.11.0
dm-sonnet==1.36
numpy==1.20.1
tensorflow==1.15.0

View File

@@ -0,0 +1,163 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example code for running model on CLEVRER."""
import json
from absl import app
from absl import flags
import numpy as np
import tensorflow.compat.v1 as tf
from object_attention_for_reasoning import model as modellib
BATCH_SIZE = 1
NUM_FRAMES = 25
NUM_OBJECTS = 8
_BASE_DIR = flags.DEFINE_string(
"base_dir", "./clevrer_monet_latents",
"Directory containing checkpoints and MONet latents.")
_SCENE_IDX = flags.DEFINE_string(
"scene_idx", 1000, "Scene index of CLEVRER video.")
def load_monet_latents(base_dir, scene_index):
filename = f"{base_dir}/train/{scene_index}.npz"
with open(filename, "rb") as f:
return np.load(f)
def _split_string(s):
"""Splits string to words and standardize alphabet."""
return s.lower().replace("?", "").split()
def _pad(array, length):
"""Pad an array to desired length."""
return np.pad(array, [(0, length - array.shape[0])], mode="constant")
def encode_sentence(token_map, sentence, pad_length):
"""Encode CLEVRER question/choice sentences as sequence of token ids."""
ret = np.array(
[token_map["question_vocab"][w] for w in _split_string(sentence)],
np.int32)
return _pad(ret, pad_length)
def encode_choices(token_map, choices):
"""Encode CLEVRER choices."""
arrays = [encode_sentence(token_map, choice["choice"],
modellib.MAX_CHOICE_LENGTH)
for choice in choices]
return _pad(np.stack(arrays, axis=0), modellib.NUM_CHOICES)
def main(unused_argv):
base_dir = _BASE_DIR.value
with open(f"{base_dir}/vocab.json", "rb") as f:
token_map = json.load(f)
reverse_answer_lookup = {v: k for k, v in token_map["answer_vocab"].items()}
with open(f"{base_dir}/train.json", "rb") as f:
questions_data = json.load(f)
tf.reset_default_graph()
model = modellib.ClevrerTransformerModel(**modellib.PRETRAINED_MODEL_CONFIG)
inputs_descriptive = {
"monet_latents": tf.placeholder(
tf.float32,
[BATCH_SIZE, NUM_FRAMES, NUM_OBJECTS, modellib.EMBED_DIM]),
"question": tf.placeholder(
tf.int32, [BATCH_SIZE, modellib.MAX_QUESTION_LENGTH]),
}
inputs_mc = {
"monet_latents": tf.placeholder(
tf.float32,
[BATCH_SIZE, NUM_FRAMES, NUM_OBJECTS, modellib.EMBED_DIM]),
"question": tf.placeholder(tf.int32,
[BATCH_SIZE, modellib.MAX_QUESTION_LENGTH]),
"choices": tf.placeholder(
tf.int32, [BATCH_SIZE, modellib.NUM_CHOICES,
modellib.MAX_CHOICE_LENGTH]),
}
output_descriptive = model.apply_model_descriptive(inputs_descriptive)
output_mc = model.apply_model_mc(inputs_mc)
# Restore from checkpoint
saver = tf.train.Saver()
checkpoint_dir = f"{base_dir}/checkpoints/"
sess = tf.train.SingularMonitoredSession(checkpoint_dir=checkpoint_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
saver.restore(sess, ckpt.model_checkpoint_path)
def eval_descriptive(monet_latents, question_json):
# CLEVRER provides videos with 128 frames. In our model, we subsample 25
# frames (as was done in Yi et al (2020)).
# For training, we randomize the choice of 25 frames, and for evaluation, we
# sample the 25 frames as evenly as possible.
# We do that by doing strided sampling of the frames.
stride, rem = divmod(monet_latents.shape[0], NUM_FRAMES)
monet_latents = monet_latents[None, :-rem:stride]
assert monet_latents.shape[1] == NUM_FRAMES
question = encode_sentence(token_map, question_json["question"],
modellib.MAX_QUESTION_LENGTH)
batched_question = np.expand_dims(question, axis=0)
logits = sess.run(output_descriptive, feed_dict={
inputs_descriptive["monet_latents"]: monet_latents,
inputs_descriptive["question"]: batched_question,
})
descriptive_answer = np.argmax(logits)
return reverse_answer_lookup[descriptive_answer]
def eval_mc(monet_latents, question_json):
stride, rem = divmod(monet_latents.shape[0], NUM_FRAMES)
monet_latents = monet_latents[None, :-rem:stride]
assert monet_latents.shape[1] == NUM_FRAMES
question = encode_sentence(
token_map, question_json["question"], modellib.MAX_QUESTION_LENGTH)
choices = encode_choices(
token_map, question_json["choices"])
mc_answer = sess.run(output_mc, feed_dict={
inputs_mc["monet_latents"]: monet_latents,
inputs_mc["question"]: np.expand_dims(question, axis=0),
inputs_mc["choices"]: np.expand_dims(choices, axis=0),
})
return mc_answer >= 0
sample_scene_idx = _SCENE_IDX.value
question_json = questions_data[sample_scene_idx]["questions"][0]
print("Descriptive Question: ", question_json["question"])
print("Model Answer: ",
eval_descriptive(load_monet_latents(base_dir, sample_scene_idx),
question_json))
print("True Answer: ", question_json["answer"])
question_json = questions_data[sample_scene_idx]["questions"][-1]
print("Multiple-Choice Question: ", question_json["question"])
for i, choice_json in enumerate(question_json["choices"]):
print(f"{i+1}) {choice_json['choice']}")
print("Model Answer: ",
eval_mc(load_monet_latents(base_dir, sample_scene_idx), question_json))
print("True Answer: ",
[choice_json["answer"] for choice_json in question_json["choices"]])
if __name__ == "__main__":
app.run(main)

File diff suppressed because it is too large Load Diff