diff --git a/kfac_ferminet_alpha/README.md b/kfac_ferminet_alpha/README.md new file mode 100644 index 0000000..87557b0 --- /dev/null +++ b/kfac_ferminet_alpha/README.md @@ -0,0 +1,38 @@ +# Accompanying code for Better, Faster Fermionic Neural Networks + +All package requirements are listed in `requirements.txt`. + +## Contributing + +This is purely research code, provided with no further intentions of support or +any guarantees of backward compatibility. + +## Installation + +```shell +git clone git@github.com:deepmind/deepmind-research.git +pip install deepmind_research/kfac_ferminet_alpha/ +``` + +## Usage + +You can find examples of how to use the codebase through the [FermiNet project]. + +We also provide an [example training script]. + +## Reference + +**Better, Faster Fermionic Neural Networks** + +James S. Spencer, David Pfau, Aleksandar Botev, and W. M. C. Foulkes. + +URL: https://arxiv.org/abs/2011.07125. + +**Optimizing Neural Networks with Kronecker-factored Approximate Curvature** + +James Martens, Roger Grosse + +URL: https://arxiv.org/abs/1503.05671 + +[FermiNet Project]: https://github.com/deepmind/ferminet/ +[example training script]: https://github.com/deepmind/deepmind-research/kfac_ferminet_alpha/example.py diff --git a/kfac_ferminet_alpha/__init__.py b/kfac_ferminet_alpha/__init__.py new file mode 100644 index 0000000..67a6ca4 --- /dev/null +++ b/kfac_ferminet_alpha/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for anything that an end user would use.""" + +from kfac_ferminet_alpha.loss_functions import register_normal_predictive_distribution +from kfac_ferminet_alpha.loss_functions import register_squared_error_loss +from kfac_ferminet_alpha.optimizer import Optimizer diff --git a/kfac_ferminet_alpha/curvature_blocks.py b/kfac_ferminet_alpha/curvature_blocks.py new file mode 100644 index 0000000..4c0b647 --- /dev/null +++ b/kfac_ferminet_alpha/curvature_blocks.py @@ -0,0 +1,496 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for all of the different curvature blocks.""" +import abc +from typing import Any, Callable, Dict, Mapping, MutableMapping, Optional, Sequence, Union +import jax +from jax import core +import jax.numpy as jnp + +from kfac_ferminet_alpha import tag_graph_matcher as tgm +from kfac_ferminet_alpha import utils + +_Arrays = Sequence[jnp.ndarray] +_BlockInfo = Mapping[str, Any] + + +class CurvatureBlock(utils.Stateful, abc.ABC): + """Top level class.""" + + def __init__(self, layer_tag_eq: tgm.jax_core.JaxprEqn): + super(CurvatureBlock, self).__init__() + self._layer_tag_eq = layer_tag_eq + + @property + def layer_tag_primitive(self) -> tgm.tags.LayerTag: + assert isinstance(self._layer_tag_eq.primitive, tgm.tags.LayerTag) + return self._layer_tag_eq.primitive + + @property + def outputs_shapes(self) -> Sequence[Sequence[int]]: + output_vars = self.layer_tag_primitive.split_all_inputs( + self._layer_tag_eq.invars)[0] + return jax.tree_map(lambda x: x.aval.shape, output_vars) + + @property + def inputs_shapes(self) -> Sequence[Sequence[int]]: + input_vars = self.layer_tag_primitive.split_all_inputs( + self._layer_tag_eq.invars)[1] + return jax.tree_map(lambda x: x.aval.shape, input_vars) + + @property + def params_shapes(self) -> Sequence[Sequence[int]]: + params_vars = self.layer_tag_primitive.split_all_inputs( + self._layer_tag_eq.invars)[2] + return jax.tree_map(lambda x: x.aval.shape, params_vars) + + @abc.abstractmethod + def init(self, rng: jnp.ndarray) -> MutableMapping[str, Any]: + """This initializes/creates all of the arrays for the state of the block. + + Usually this would include the arrays used for storing the curvature + approximation, as well as the arrays for storing any approximate + inverses/powers of the curvature block. + + Args: + rng: The Jax PRNG key to use if any of the state is supposed to be + initialized randomly. + Returns: + A mutable mapping of the state. + """ + + @abc.abstractmethod + def update_curvature_matrix_estimate( + self, + info: _BlockInfo, + batch_size: int, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + pass + + @abc.abstractmethod + def update_curvature_inverse_estimate( + self, + diagonal_weight: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + pass + + @abc.abstractmethod + def multiply_matpower( + self, + vec: _Arrays, + exp: Union[float, int], + diagonal_weight: Union[float, jnp.ndarray] + ) -> _Arrays: + pass + + +CurvatureBlockCtor = Callable[[core.JaxprEqn], CurvatureBlock] + + +@utils.Stateful.infer_class_state +class NaiveDiagonal(CurvatureBlock): + """The naively estimated diagonal block.""" + diagonal_factor: utils.WeightedMovingAverage + + def init(self, rng: jnp.ndarray) -> Dict[str, Any]: + del rng + return dict( + diagonal_factor=utils.WeightedMovingAverage.zero( + self.outputs_shapes[0]) + ) + + def update_curvature_matrix_estimate( + self, + info: _BlockInfo, + batch_size: int, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + dw, = info["outputs_tangent"] + diagonal_update = dw * dw / batch_size + self.diagonal_factor.update(diagonal_update, ema_old, ema_new) + self.diagonal_factor.sync(pmap_axis_name) + + def update_curvature_inverse_estimate( + self, + diagonal_weight: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + pass + + def multiply_matpower( + self, + vec: _Arrays, + exp: Union[float, int], + diagonal_weight: Union[float, jnp.ndarray] + ) -> _Arrays: + w, = vec + if exp == 1: + return w * (self.diagonal_factor.value + diagonal_weight), + elif exp == -1: + return w / (self.diagonal_factor.value + diagonal_weight), + else: + raise NotImplementedError() + + +@utils.Stateful.infer_class_state +class TwoKroneckerFactored(CurvatureBlock, abc.ABC): + """A factor that is the Kronecker product of two matrices.""" + inputs_factor: utils.WeightedMovingAverage + inputs_factor_inverse: jnp.ndarray + outputs_factor: utils.WeightedMovingAverage + outputs_factor_inverse: jnp.ndarray + extra_scale: Optional[Union[int, float, jnp.ndarray]] + + @property + def has_bias(self) -> bool: + return len(self._layer_tag_eq.invars) == 4 + + @abc.abstractmethod + def input_size(self) -> int: + pass + + @abc.abstractmethod + def output_size(self) -> int: + pass + + def compute_extra_scale(self) -> Optional[Union[int, float, jnp.ndarray]]: + return 1 + + def init(self, rng: jnp.ndarray) -> Dict[str, Any]: + # The extra scale is technically a constant, but in general it could be + # useful for anyone examining the state to know it explicitly, + # hence we actually keep it as part of the state. + d_in = self.input_size() + d_out = self.output_size() + return dict( + inputs_factor=utils.WeightedMovingAverage.zero([d_in, d_in]), + inputs_factor_inverse=jnp.zeros([d_in, d_in]), + outputs_factor=utils.WeightedMovingAverage.zero([d_out, d_out]), + outputs_factor_inverse=jnp.zeros([d_out, d_out]), + extra_scale=self.compute_extra_scale() + ) + + def update_curvature_inverse_estimate( + self, + diagonal_weight: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + self.inputs_factor.sync(pmap_axis_name) + self.outputs_factor.sync(pmap_axis_name) + + # This computes the approximate inverse factor using the pi-adjusted + # inversion from the original KFAC paper. + # Note that the damping is divided by extra_scale since: + # (s * A kron B + lambda I)^-1 = s^-1 (A kron B + s^-1 * lambda I)^-1 + # And the extra division by the scale is included in `multiply_matpower`. + (self.inputs_factor_inverse, + self.outputs_factor_inverse) = utils.pi_adjusted_inverse( + factor_0=self.inputs_factor.value, + factor_1=self.outputs_factor.value, + damping=diagonal_weight / self.extra_scale, + pmap_axis_name=pmap_axis_name) + + def multiply_matpower( + self, + vec: _Arrays, + exp: Union[float, int], + diagonal_weight: Union[float, jnp.ndarray] + ) -> _Arrays: + if self.has_bias: + w, b = vec + vec = jnp.concatenate([w.reshape([-1, w.shape[-1]]), b[None]], axis=0) + else: + w, = vec + vec = w.reshape([-1, w.shape[-1]]) + if exp == 1: + inputs_factor, outputs_factor = (self.inputs_factor.value, + self.outputs_factor.value) + scale = self.extra_scale + elif exp == -1: + inputs_factor, outputs_factor = (self.inputs_factor_inverse, + self.outputs_factor_inverse) + scale = 1.0 / self.extra_scale + diagonal_weight = 0 + else: + raise NotImplementedError() + + result = jnp.matmul(inputs_factor, vec) + result = jnp.matmul(result, outputs_factor) + result = result * scale + diagonal_weight * vec + + if self.has_bias: + w_new, b_new = result[:-1], result[-1] + return w_new.reshape(w.shape), b_new + else: + return result.reshape(w.shape), + + +class DenseTwoKroneckerFactored(TwoKroneckerFactored): + """Factor for a standard dense layer.""" + + def input_size(self) -> int: + if self.has_bias: + return self.params_shapes[0][0] + 1 + else: + return self.params_shapes[0][0] + + def output_size(self) -> int: + return self.params_shapes[0][1] + + def update_curvature_matrix_estimate( + self, + info: _BlockInfo, + batch_size: int, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + del pmap_axis_name + (x,), (dy,) = info["inputs"], info["outputs_tangent"] + utils.check_first_dim_is_batch_size(batch_size, x, dy) + + if self.has_bias: + x_one = jnp.ones_like(x[:, :1]) + x = jnp.concatenate([x, x_one], axis=1) + input_stats = jnp.matmul(x.T, x) / batch_size + output_stats = jnp.matmul(dy.T, dy) / batch_size + self.inputs_factor.update(input_stats, ema_old, ema_new) + self.outputs_factor.update(output_stats, ema_old, ema_new) + + +@utils.Stateful.infer_class_state +class ScaleAndShiftDiagonal(CurvatureBlock): + """A scale and shift block with a diagonal approximation to the curvature.""" + scale_factor: Optional[utils.WeightedMovingAverage] + shift_factor: Optional[utils.WeightedMovingAverage] + + @property + def has_scale(self) -> bool: + return self._layer_tag_eq.params["has_scale"] + + @property + def has_shift(self) -> bool: + return self._layer_tag_eq.params["has_shift"] + + def init(self, rng: jnp.ndarray) -> Dict[str, Any]: + del rng + if self.has_scale and self.has_shift: + return dict( + scale_factor=utils.WeightedMovingAverage.zero( + self.params_shapes[0] + ), + shift_factor=utils.WeightedMovingAverage.zero( + self.params_shapes[1] + ) + ) + elif self.has_scale: + return dict( + scale_factor=utils.WeightedMovingAverage.zero( + self.params_shapes[0] + ), + shift_factor=None + ) + elif self.has_shift: + return dict( + scale_factor=None, + shift_factor=utils.WeightedMovingAverage.zero( + self.params_shapes[0] + ), + ) + else: + raise ValueError("Neither `has_scale` nor `has_shift`.") + + def update_curvature_matrix_estimate( + self, + info: _BlockInfo, + batch_size: int, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + (x,), (dy,) = info["inputs"], info["outputs_tangent"] + utils.check_first_dim_is_batch_size(batch_size, x, dy) + + if self.has_scale: + assert self.scale_factor is not None + scale_shape = info["params"][0].shape + full_scale_shape = (1,) * (len(x.shape) - len(scale_shape)) + scale_shape + axis = [i for i, s in enumerate(full_scale_shape) if s == 1 and i != 0] + d_scale = jnp.sum(x * dy, axis=axis) + scale_diag_update = jnp.sum(d_scale * d_scale, axis=0) / batch_size + self.scale_factor.update(scale_diag_update, ema_old, ema_new) + self.scale_factor.sync(pmap_axis_name) + + if self.has_shift: + assert self.shift_factor is not None + shift_shape = info["params"][1].shape + full_shift_shape = (1,) * (len(x.shape) - len(shift_shape)) + shift_shape + axis = [i for i, s in enumerate(full_shift_shape) if s == 1 and i != 0] + d_shift = jnp.sum(dy, axis=axis) + shift_diag_update = jnp.sum(d_shift * d_shift, axis=0) / batch_size + self.shift_factor.update(shift_diag_update, ema_old, ema_new) + self.shift_factor.sync(pmap_axis_name) + + def update_curvature_inverse_estimate( + self, + diagonal_weight: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + pass + + def multiply_matpower( + self, + vec: _Arrays, + exp: Union[float, int], + diagonal_weight: Union[float, jnp.ndarray] + ) -> _Arrays: + if self.has_scale and self.has_shift: + factors = (self.scale_factor.value, self.shift_factor.value) + elif self.has_scale: + factors = (self.scale_factor.value,) + elif self.has_shift: + factors = (self.shift_factor.value,) + else: + raise ValueError("Neither `has_scale` nor `has_shift`.") + factors = jax.tree_map(lambda x: x + diagonal_weight, factors) + if exp == 1: + return jax.tree_multimap(jnp.multiply, vec, factors) + elif exp == -1: + return jax.tree_multimap(jnp.divide, vec, factors) + else: + raise NotImplementedError() + + +@utils.Stateful.infer_class_state +class ScaleAndShiftFull(CurvatureBlock): + """A scale and shift block with full approximation to the curvature.""" + factor: utils.WeightedMovingAverage + inverse_factor: jnp.ndarray + + @property + def _has_scale(self) -> bool: + return self._layer_tag_eq.params["has_scale"] + + @property + def _has_shift(self) -> bool: + return self._layer_tag_eq.params["has_shift"] + + def init(self, rng: jnp.ndarray) -> Dict[str, Any]: + del rng + dims = sum(utils.product(shape) for shape in self.params_shapes) + return dict( + factor=utils.WeightedMovingAverage.zero([dims, dims]), + inverse_factor=jnp.zeros([dims, dims]) + ) + + def update_curvature_matrix_estimate( + self, + info: _BlockInfo, + batch_size: int, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + del pmap_axis_name + (x,), (dy,) = info["inputs"], info["outputs_tangent"] + utils.check_first_dim_is_batch_size(batch_size, x, dy) + + grads = list() + if self._has_scale: + # Scale gradients + scale_shape = info["params"][0].shape + full_scale_shape = (1,) * (len(x.shape) - len(scale_shape)) + scale_shape + axis = [i for i, s in enumerate(full_scale_shape) if s == 1 and i != 0] + d_scale = jnp.sum(x * dy, axis=axis) + d_scale = d_scale.reshape([batch_size, -1]) + grads.append(d_scale) + + if self._has_shift: + # Shift gradients + shift_shape = info["params"][1].shape + full_shift_shape = (1,) * (len(x.shape) - len(shift_shape)) + shift_shape + axis = [i for i, s in enumerate(full_shift_shape) if s == 1 and i != 0] + d_shift = jnp.sum(dy, axis=axis) + d_shift = d_shift.reshape([batch_size, -1]) + grads.append(d_shift) + + grads = jnp.concatenate(grads, axis=1) + factor_update = jnp.matmul(grads.T, grads) / batch_size + self.factor.update(factor_update, ema_old, ema_new) + + def update_curvature_inverse_estimate( + self, + diagonal_weight: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + self.factor.sync(pmap_axis_name) + self.inverse_factor = utils.psd_inv_cholesky(self.factor.value, + diagonal_weight) + + def multiply_matpower( + self, + vec: _Arrays, + exp: Union[float, int], + diagonal_weight: Union[float, jnp.ndarray] + ) -> _Arrays: + # Remember the vector is a tuple of all parameters + if self._has_scale and self._has_shift: + flat_vec = jnp.concatenate([v.flatten() for v in vec]) + else: + flat_vec = vec[0].flatten() + + if exp == 1: + flat_result = ( + jnp.matmul(self.factor.value, flat_vec) + diagonal_weight * flat_vec) + elif exp == -1: + flat_result = jnp.matmul(self.inverse_factor, flat_vec) + else: + raise NotImplementedError() + + if self._has_scale and self._has_shift: + scale_dims = int(vec[0].size) + scale_result = flat_result[:scale_dims].reshape(vec[0].shape) + shift_result = flat_result[scale_dims:].reshape(vec[1].shape) + return scale_result, shift_result + else: + return flat_vec.reshape(vec[0].shape), + + +_default_tag_to_block: MutableMapping[str, CurvatureBlockCtor] = dict( + dense_tag=DenseTwoKroneckerFactored, + generic_tag=NaiveDiagonal, + scale_and_shift_tag=ScaleAndShiftDiagonal, +) + + +def copy_default_tag_to_block() -> MutableMapping[str, CurvatureBlockCtor]: + return dict(_default_tag_to_block) + + +def get_default_tag_to_block(tag_name: str) -> CurvatureBlockCtor: + return _default_tag_to_block[tag_name] + + +def set_default_tag_to_block( + tag_name: str, + block_class: CurvatureBlockCtor, +) -> None: + _default_tag_to_block[tag_name] = block_class diff --git a/kfac_ferminet_alpha/distributions.py b/kfac_ferminet_alpha/distributions.py new file mode 100644 index 0000000..34ddbe8 --- /dev/null +++ b/kfac_ferminet_alpha/distributions.py @@ -0,0 +1,75 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for all distribution implementations needed for the loss functions.""" +import math +import jax +import jax.numpy as jnp + + +class MultivariateNormalDiag: + """Multivariate normal distribution on `R^k`.""" + + def __init__( + self, + loc: jnp.ndarray, + scale_diag: jnp.ndarray): + """Initializes a MultivariateNormalDiag distribution. + + Args: + loc: Mean vector of the distribution. Can also be a batch of vectors. + scale_diag: Vector of standard deviations. + """ + super().__init__() + self._loc = loc + self._scale_diag = scale_diag + + @property + def loc(self) -> jnp.ndarray: + """Mean of the distribution.""" + return self._loc + + @property + def scale_diag(self) -> jnp.ndarray: + """Scale of the distribution.""" + return self._scale_diag + + def _num_dims(self) -> int: + """Dimensionality of the events.""" + return self._scale_diag.shape[-1] + + def _standardize(self, value: jnp.ndarray) -> jnp.ndarray: + return (value - self._loc) / self._scale_diag + + def log_prob(self, value: jnp.ndarray) -> jnp.ndarray: + """See `Distribution.log_prob`.""" + log_unnormalized = -0.5 * jnp.square(self._standardize(value)) + log_normalization = 0.5 * math.log(2 * math.pi) + jnp.log(self._scale_diag) + return jnp.sum(log_unnormalized - log_normalization, axis=-1) + + def mean(self) -> jnp.ndarray: + """Calculates the mean.""" + return self.loc + + def sample(self, seed: jnp.ndarray) -> jnp.ndarray: + """Samples an event. + + Args: + seed: PRNG key or integer seed. + + Returns: + A sample. + """ + eps = jax.random.normal(seed, self.loc.shape) + return self.loc + eps * self.scale_diag diff --git a/kfac_ferminet_alpha/estimator.py b/kfac_ferminet_alpha/estimator.py new file mode 100644 index 0000000..f9a3841 --- /dev/null +++ b/kfac_ferminet_alpha/estimator.py @@ -0,0 +1,340 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the high-level Fisher estimator class.""" +import collections +from typing import Any, Callable, Mapping, Optional, Sequence, Union, TypeVar + +import jax +import jax.numpy as jnp +import jax.random as jnr +import numpy as np + +from kfac_ferminet_alpha import curvature_blocks +from kfac_ferminet_alpha import tracer +from kfac_ferminet_alpha import utils + +_CurvatureBlock = curvature_blocks.CurvatureBlock +TagMapping = Mapping[str, curvature_blocks.CurvatureBlockCtor] +BlockVector = Sequence[jnp.ndarray] + +_StructureT = TypeVar("_StructureT") +_OptionalStateT = TypeVar("_OptionalStateT", bound=Optional[Mapping[str, Any]]) + + +@utils.Stateful.infer_class_state +class CurvatureEstimator(utils.Stateful): + """Curvature estimator class supporting various curvature approximations.""" + blocks: "collections.OrderedDict[str, _CurvatureBlock]" + damping: Optional[jnp.ndarray] + + def __init__(self, + tagged_func: Callable[[Any], jnp.ndarray], + func_args: Sequence[Any], + l2_reg: Union[float, jnp.ndarray], + estimation_mode: str = "fisher_gradients", + params_index: int = 0, + layer_tag_to_block_cls: Optional[TagMapping] = None): + """Create a FisherEstimator object. + + Args: + tagged_func: The function which evaluates the model, in which layer and + loss tags has already been registered. + func_args: Arguments to trace the function for layer and loss tags. + l2_reg: Scalar. The L2 regularization coefficient, which represents + the following regularization function: `coefficient/2 ||theta||^2`. + estimation_mode: The type of curvature estimator to use. One of: * + 'fisher_gradients' - the basic estimation approach from the original + K-FAC paper. (Default) * 'fisher_curvature_prop' - method which + estimates the Fisher using self-products of random 1/-1 vectors times + "half-factors" of the + Fisher, as described here: https://arxiv.org/abs/1206.6464 * + 'fisher_exact' - is the obvious generalization of Curvature + Propagation to compute the exact Fisher (modulo any additional + diagonal or Kronecker approximations) by looping over one-hot + vectors for each coordinate of the output instead of using 1/-1 + vectors. It is more expensive to compute than the other three + options by a factor equal to the output dimension, roughly + speaking. * 'fisher_empirical' - computes the 'empirical' Fisher + information matrix (which uses the data's distribution for the + targets, as opposed to the true Fisher which uses the model's + distribution) and requires that each registered loss have + specified targets. * 'ggn_curvature_prop' - Analogous to + fisher_curvature_prop, but estimates the Generalized + Gauss-Newton matrix (GGN). * 'ggn_exact'- Analogous to + fisher_exact, but estimates the Generalized Gauss-Newton matrix + (GGN). + params_index: The index of the arguments accepted by `func` which + correspond to parameters. + layer_tag_to_block_cls: An optional dict mapping tags to specific classes + of block approximations, which to override the default ones. + """ + if estimation_mode not in ("fisher_gradients", "fisher_empirical", + "fisher_exact", "fisher_curvature_prop", + "ggn_exact", "ggn_curvature_prop"): + raise ValueError(f"Unrecognised estimation_mode={estimation_mode}.") + super().__init__() + self.tagged_func = tagged_func + self.l2_reg = l2_reg + self.estimation_mode = estimation_mode + self.params_index = params_index + self.vjp = tracer.trace_estimator_vjp(self.tagged_func) + + # Figure out the mapping from layer + self.layer_tag_to_block_cls = curvature_blocks.copy_default_tag_to_block() + if layer_tag_to_block_cls is None: + layer_tag_to_block_cls = dict() + layer_tag_to_block_cls = dict(**layer_tag_to_block_cls) + self.layer_tag_to_block_cls.update(layer_tag_to_block_cls) + + # Create the blocks + self._in_tree = jax.tree_structure(func_args) + self._jaxpr = jax.make_jaxpr(self.tagged_func)(*func_args).jaxpr + self._layer_tags, self._loss_tags = tracer.extract_tags(self._jaxpr) + self.blocks = collections.OrderedDict() + counters = dict() + for eqn in self._layer_tags: + cls = self.layer_tag_to_block_cls[eqn.primitive.name] + c = counters.get(cls.__name__, 0) + self.blocks[cls.__name__ + "_" + str(c)] = cls(eqn) + counters[cls.__name__] = c + 1 + + @property + def diagonal_weight(self) -> jnp.ndarray: + return self.l2_reg + self.damping + + def vectors_to_blocks( + self, + parameter_structured_vector: Any, + ) -> Sequence[BlockVector]: + """Splits the parameters to values for the corresponding blocks.""" + in_vars = jax.tree_unflatten(self._in_tree, self._jaxpr.invars) + params_vars = in_vars[self.params_index] + params_vars_flat = jax.tree_flatten(params_vars)[0] + params_values_flat = jax.tree_flatten(parameter_structured_vector)[0] + assert len(params_vars_flat) == len(params_values_flat) + params_dict = dict(zip(params_vars_flat, params_values_flat)) + per_block_vectors = [] + for eqn in self._layer_tags: + if eqn.primitive.name == "generic_tag": + block_vars = eqn.invars + else: + block_vars = eqn.primitive.split_all_inputs(eqn.invars)[2] + per_block_vectors.append(tuple(params_dict.pop(v) for v in block_vars)) + if params_dict: + raise ValueError(f"From the parameters the following structure is not " + f"assigned to any block: {params_dict}. Most likely " + f"this part of the parameters is not part of the graph " + f"reaching the losses.") + return tuple(per_block_vectors) + + def blocks_to_vectors(self, per_block_vectors: Sequence[BlockVector]) -> Any: + """Reverses the function self.vectors_to_blocks.""" + in_vars = jax.tree_unflatten(self._in_tree, self._jaxpr.invars) + params_vars = in_vars[self.params_index] + assigned_dict = dict() + for eqn, block_values in zip(self._layer_tags, per_block_vectors): + if eqn.primitive.name == "generic_tag": + block_params = eqn.invars + else: + block_params = eqn.primitive.split_all_inputs(eqn.invars)[2] + assigned_dict.update(zip(block_params, block_values)) + params_vars_flat, params_tree = jax.tree_flatten(params_vars) + params_values_flat = [assigned_dict[v] for v in params_vars_flat] + assert len(params_vars_flat) == len(params_values_flat) + return jax.tree_unflatten(params_tree, params_values_flat) + + def init( + self, + rng: jnp.ndarray, + init_damping: Optional[jnp.ndarray], + ) -> Mapping[str, Any]: + """Returns an initialized variables for the curvature approximations and the inverses..""" + return dict( + blocks=collections.OrderedDict( + (name, block.init(block_rng)) # + for (name, block), block_rng # + in zip(self.blocks.items(), jnr.split(rng, len(self.blocks)))), + damping=init_damping) + + @property + def mat_type(self) -> str: + return self.estimation_mode.split("_")[0] + + def vec_block_apply( + self, + func: Callable[[_CurvatureBlock, BlockVector], BlockVector], + parameter_structured_vector: Any, + ) -> Any: + """Executes func for each approximation block on vectors.""" + per_block_vectors = self.vectors_to_blocks(parameter_structured_vector) + assert len(per_block_vectors) == len(self.blocks) + results = jax.tree_multimap(func, tuple(self.blocks.values()), + per_block_vectors) + parameter_structured_result = self.blocks_to_vectors(results) + utils.check_structure_shapes_and_dtype(parameter_structured_vector, + parameter_structured_result) + return parameter_structured_result + + def multiply_inverse(self, parameter_structured_vector: Any) -> Any: + """Multiplies the vectors by the corresponding (damped) inverses of the blocks. + + Args: + parameter_structured_vector: Structure equivalent to the parameters of the + model. + + Returns: + A structured identical to `vectors` containing the product. + """ + return self.multiply_matpower(parameter_structured_vector, -1) + + def multiply(self, parameter_structured_vector: Any) -> Any: + """Multiplies the vectors by the corresponding (damped) blocks. + + Args: + parameter_structured_vector: A vector in the same structure as the + parameters of the model. + + Returns: + A structured identical to `vectors` containing the product. + """ + return self.multiply_matpower(parameter_structured_vector, 1) + + def multiply_matpower( + self, + parameter_structured_vector: _StructureT, + exp: int, + ) -> _StructureT: + """Multiplies the vectors by the corresponding matrix powers of the blocks. + + Args: + parameter_structured_vector: A vector in the same structure as the + parameters of the model. + exp: A float representing the power to raise the blocks by before + multiplying it by the vector. + + Returns: + A structured identical to `vectors` containing the product. + """ + + def func(block: _CurvatureBlock, vec: BlockVector) -> BlockVector: + return block.multiply_matpower(vec, exp, self.diagonal_weight) + + return self.vec_block_apply(func, parameter_structured_vector) + + def update_curvature_matrix_estimate( + self, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + batch_size: int, + rng: jnp.ndarray, + func_args: Sequence[Any], + pmap_axis_name: str, + ) -> None: + """Updates the curvature estimate.""" + + # Compute the losses and the VJP function from the function inputs + losses, losses_vjp = self.vjp(func_args) + + # Helper function that updates the blocks given a vjp vector + def _update_blocks(vjp_vec_, ema_old_, ema_new_): + blocks_info_ = losses_vjp(vjp_vec_) + for block_, block_info_ in zip(self.blocks.values(), blocks_info_): + block_.update_curvature_matrix_estimate( + info=block_info_, + batch_size=batch_size, + ema_old=ema_old_, + ema_new=ema_new_, + pmap_axis_name=pmap_axis_name) + + if self.estimation_mode == "fisher_gradients": + keys = jnr.split(rng, len(losses)) if len(losses) > 1 else [rng] + vjp_vec = tuple( + loss.grad_of_evaluate_on_sample(key, coefficient_mode="sqrt") + for loss, key in zip(losses, keys)) + _update_blocks(vjp_vec, ema_old, ema_new) + + elif self.estimation_mode in ("fisher_curvature_prop", + "ggn_curvature_prop"): + keys = jnr.split(rng, len(losses)) if len(losses) > 1 else [rng] + vjp_vec = [] + for loss, key in zip(losses, keys): + if self.estimation_mode == "fisher_curvature_prop": + random_b = jnr.bernoulli(key, shape=loss.fisher_factor_inner_shape()) + vjp_vec.append(loss.multiply_fisher_factor(random_b * 2.0 - 1.0)) + else: + random_b = jnr.bernoulli(key, shape=loss.ggn_factor_inner_shape()) + vjp_vec.append(loss.multiply_ggn_factor(random_b * 2.0 - 1.0)) + _update_blocks(tuple(vjp_vec), ema_old, ema_new) + + elif self.estimation_mode in ("fisher_exact", "ggn_exact"): + # We use the following trick to simulate summation. The equation is: + # estimate = ema_old * estimate + ema_new * (sum_i estimate_index_i) + # weight = ema_old * weight + ema_new + # Instead we update the estimate n times with the following updates: + # for k = 1 + # estimate_k = ema_old * estimate + (ema_new/n) * (n*estimate_index_k) + # weight_k = ema_old * weight + (ema_new/n) + # for k > 1: + # estimate_k = 1.0 * estimate_k-1 + (ema_new/n) * (n*estimate_index_k) + # weight_k = 1.0 * weight_k-1 + (ema_new/n) + # Which is mathematically equivalent to the original version. + zero_tangents = jax.tree_map(jnp.zeros_like, + list(loss.inputs for loss in losses)) + if self.estimation_mode == "fisher_exact": + num_indices = [ + (l, int(np.prod(l.fisher_factor_inner_shape[1:]))) for l in losses + ] + else: + num_indices = [ + (l, int(np.prod(l.ggn_factor_inner_shape()))) for l in losses + ] + total_num_indices = sum(n for _, n in num_indices) + for i, (loss, loss_num_indices) in enumerate(num_indices): + for index in range(loss_num_indices): + vjp_vec = zero_tangents.copy() + if self.estimation_mode == "fisher_exact": + vjp_vec[i] = loss.multiply_fisher_factor_replicated_one_hot([index]) + else: + vjp_vec[i] = loss.multiply_ggn_factor_replicated_one_hot([index]) + if isinstance(vjp_vec[i], jnp.ndarray): + # In the special case of only one parameter, it still needs to be a + # tuple for the tangents. + vjp_vec[i] = (vjp_vec[i],) + vjp_vec[i] = jax.tree_map(lambda x: x * total_num_indices, vjp_vec[i]) + _update_blocks(tuple(vjp_vec), ema_old, ema_new / total_num_indices) + ema_old = 1.0 + + elif self.estimation_mode == "fisher_empirical": + raise NotImplementedError() + else: + raise ValueError(f"Unrecognised estimation_mode={self.estimation_mode}") + + def update_curvature_estimate_inverse( + self, + pmap_axis_name: str, + state: _OptionalStateT, + ) -> _OptionalStateT: + if state is not None: + old_state = self.get_state() + self.set_state(state) + for block in self.blocks.values(): + block.update_curvature_inverse_estimate(self.diagonal_weight, + pmap_axis_name) + if state is None: + return None + else: + state = self.pop_state() + self.set_state(old_state) + return state diff --git a/kfac_ferminet_alpha/example.py b/kfac_ferminet_alpha/example.py new file mode 100644 index 0000000..1aed84d --- /dev/null +++ b/kfac_ferminet_alpha/example.py @@ -0,0 +1,171 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Example of running KFAC.""" +from absl import app +from absl import flags +import jax +import jax.numpy as jnp + +import numpy as np +import kfac_ferminet_alpha as kfac_ferminet_alpha +from kfac_ferminet_alpha import utils + + +TRAINING_STEPS = flags.DEFINE_integer( + name="training_steps", + default=100, + help="Number of training steps to perform") +BATCH_SIZE = flags.DEFINE_integer( + name="batch_size", default=128, help="Batch size") +LEARNING_RATE = flags.DEFINE_float( + name="learning_rate", default=1e-3, help="Learning rate") +L2_REG = flags.DEFINE_float( + name="l2_reg", default=1e-3, help="L2 regularization coefficient") +MOMENTUM = flags.DEFINE_float( + name="momentum", default=0.8, help="Momentum coefficient") +DAMPING = flags.DEFINE_float( + name="damping", default=1e-2, help="Damping coefficient") +MULTI_DEVICE = flags.DEFINE_bool( + name="multi_device", + default=False, + help="Whether the computation should be replicated across multiple devices") +SEED = flags.DEFINE_integer(name="seed", default=12412321, help="JAX RNG seed") + + +def glorot_uniform(shape, key): + dim_in = np.prod(shape[:-1]) + dim_out = shape[-1] + c = jnp.sqrt(6 / (dim_in + dim_out)) + return jax.random.uniform(key, shape=shape, minval=-c, maxval=c) + + +def fully_connected_layer(params, x): + w, b = params + return jnp.matmul(x, w) + b[None] + + +def model_init(rng_key, batch, encoder_sizes=(1000, 500, 250, 30)): + """Initialize the standard autoencoder.""" + x_size = batch.shape[-1] + decoder_sizes = encoder_sizes[len(encoder_sizes) - 2::-1] + sizes = (x_size,) + encoder_sizes + decoder_sizes + (x_size,) + keys = jax.random.split(rng_key, len(sizes) - 1) + params = [] + for rng_key, dim_in, dim_out in zip(keys, sizes, sizes[1:]): + # Glorot uniform initialization + w = glorot_uniform((dim_in, dim_out), rng_key) + b = jnp.zeros([dim_out]) + params.append((w, b)) + return params, None + + +def model_loss(params, inputs, l2_reg): + """Evaluate the standard autoencoder.""" + h = inputs.reshape([inputs.shape[0], -1]) + for i, layer_params in enumerate(params): + h = fully_connected_layer(layer_params, h) + # Last layer does not have a nonlinearity + if i % 4 != 3: + h = jnp.tanh(h) + l2_value = 0.5 * sum(jnp.square(p).sum() for p in jax.tree_leaves(params)) + error = jax.nn.sigmoid(h) - inputs.reshape([inputs.shape[0], -1]) + mean_squared_error = jnp.mean(jnp.sum(error * error, axis=1), axis=0) + regularized_loss = mean_squared_error + l2_reg * l2_value + + return regularized_loss, dict(mean_squared_error=mean_squared_error) + + +def random_data(multi_device, batch_shape, rng): + if multi_device: + shape = (multi_device,) + tuple(batch_shape) + else: + shape = tuple(batch_shape) + while True: + rng, key = jax.random.split(rng) + yield jax.random.normal(key, shape) + + +def main(argv): + del argv # Unused. + + learning_rate = jnp.asarray([LEARNING_RATE.value]) + momentum = jnp.asarray([MOMENTUM.value]) + damping = jnp.asarray([DAMPING.value]) + + # RNG keys + global_step = jnp.zeros([]) + rng = jax.random.PRNGKey(SEED.value) + params_key, opt_key, step_key, data_key = jax.random.split(rng, 4) + dataset = random_data(MULTI_DEVICE.value, (BATCH_SIZE.value, 20), data_key) + example_batch = next(dataset) + + if MULTI_DEVICE.value: + global_step = utils.replicate_all_local_devices(global_step) + learning_rate = utils.replicate_all_local_devices(learning_rate) + momentum = utils.replicate_all_local_devices(momentum) + damping = utils.replicate_all_local_devices(damping) + params_key, opt_key = utils.replicate_all_local_devices( + (params_key, opt_key)) + step_key = utils.make_different_rng_key_on_all_devices(step_key) + split_key = jax.pmap(lambda x: tuple(jax.random.split(x))) + jit_init_parameters_func = jax.pmap(model_init) + else: + split_key = jax.random.split + jit_init_parameters_func = jax.jit(model_init) + + # Initialize or load parameters + params, func_state = jit_init_parameters_func(params_key, example_batch) + + # Make optimizer + optim = kfac_ferminet_alpha.Optimizer( + value_and_grad_func=jax.value_and_grad( + lambda p, x: model_loss(p, x, L2_REG.value), has_aux=True), + l2_reg=L2_REG.value, + value_func_has_aux=True, + value_func_has_state=False, + value_func_has_rng=False, + learning_rate_schedule=None, + momentum_schedule=None, + damping_schedule=None, + norm_constraint=1.0, + num_burnin_steps=10, + ) + + # Initialize optimizer + opt_state = optim.init(params, opt_key, example_batch, func_state) + + for t in range(TRAINING_STEPS.value): + step_key, key_t = split_key(step_key) + params, opt_state, stats = optim.step( + params, + opt_state, + key_t, + dataset, + learning_rate=learning_rate, + momentum=momentum, + damping=damping) + global_step = global_step + 1 + + # Log any of the statistics + print(f"iteration: {t}") + print(f"mini-batch loss = {stats['loss']}") + if "aux" in stats: + for k, v in stats["aux"].items(): + print(f"{k} = {v}") + print("----") + + +if __name__ == "__main__": + app.run(main) diff --git a/kfac_ferminet_alpha/layers_and_loss_tags.py b/kfac_ferminet_alpha/layers_and_loss_tags.py new file mode 100644 index 0000000..92fcfdb --- /dev/null +++ b/kfac_ferminet_alpha/layers_and_loss_tags.py @@ -0,0 +1,354 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A module for registering already known functions for tagging patterns.""" +import functools + +from typing import Sequence, Tuple, TypeVar + +import jax +from jax import core as jax_core +from jax import lax +from jax import lib as jax_lib +from jax.interpreters import batching as jax_batching +import jax.numpy as jnp + +_T = TypeVar("_T") + + +class LossTag(jax_core.Primitive): + """A tagging primitive specifically for losses.""" + multiple_results = True + + def __init__(self, cls, num_inputs: int, num_targets: int = 1): + super().__init__(cls.__name__ + "_tag") + self._cls = cls + self._num_inputs = num_inputs + self._num_targets = num_targets + jax.xla.translations[self] = self.xla_translation + jax.ad.primitive_jvps[self] = self.jvp + # This line defines how does the tag behave under vmap. It is required for + # any primitive that can be used inside a vmap. The reason why we want to + # allow this is two fold - one to not break user code when the tags are not + # used at all, and two - to be able to define a network with code for a + # single example which is the vmap-ed for a batch. + jax_batching.primitive_batchers[self] = self.batching + + @property + def num_inputs(self) -> int: + return self._num_inputs + + @property + def num_targets(self) -> int: + return self._num_targets + + def loss(self, *args, weight: float = 1.0, **kwargs): + return self._cls(*args, weight=weight, **kwargs) + + def loss_evaluate(self, *args, weight: float = 1.0, **kwargs): + return self.loss(*args, weight=weight, **kwargs).evaluate() + + def get_outputs(self, *args, weight: float, return_loss: bool, **kwargs): + if len(args) < self.num_inputs: + raise ValueError("Inputs to the tag are not enough.") + if len(args) < self.num_inputs + self.num_targets: + if len(args) != self.num_inputs: + raise ValueError("Inputs to the tag are not quite enough.") + if return_loss: + raise ValueError("Can not have return_loss=True when there are no " + "targets.") + return args + if len(args) > self.num_inputs + self.num_targets: + raise ValueError("Inputs to the tag are too many.") + if return_loss: + return self.loss(*args, weight=weight, **kwargs).evaluate() + else: + return args + + def impl(self, *args, weight: float, return_loss: bool, **kwargs): + return self.get_outputs(*args, weight=weight, return_loss=return_loss) + + def abstract_eval(self, *args, weight: float, return_loss: bool, **kwargs): + return self.get_outputs(*args, weight=weight, return_loss=return_loss) + + def xla_translation( + self, + c, + *args, + weight: float = 1.0, + return_loss: bool = False, + **kwargs, + ): + outputs = self.get_outputs( + *args, weight=weight, return_loss=return_loss, **kwargs) + if isinstance(outputs, tuple): + return jax_lib.xla_client.ops.Tuple(c, outputs) + return outputs + + def jvp( + self, + arg_values, + arg_tangents, + weight: float, + return_loss: bool, + **kwargs, + ): + if len(arg_values) != len(arg_tangents): + raise ValueError("Values and tangents are not the same length.") + primal_output = self.bind( + *arg_values, weight=weight, return_loss=return_loss, **kwargs) + if len(arg_values) == self.num_inputs: + tangents_out = self.get_outputs( + *arg_tangents, weight=weight, return_loss=return_loss, **kwargs) + elif return_loss: + tangents_out = jax.jvp( + functools.partial(self.loss_evaluate, weight=weight, **kwargs), + arg_tangents, arg_tangents)[1] + else: + tangents_out = arg_tangents + return primal_output, tangents_out + + def batching(self, batched_args, batched_dims, **kwargs): + return self.bind(*batched_args, **kwargs), batched_dims[0] + + +class LayerTag(jax_core.Primitive): + """A tagging primitive that is used to mark/tag computation.""" + + def __init__(self, name: str, num_inputs: int, num_outputs: int): + super().__init__(name) + if num_outputs > 1: + raise NotImplementedError( + f"Only single outputs are supported, got: num_outputs={num_outputs}") + self._num_outputs = num_outputs + self._num_inputs = num_inputs + jax.xla.translations[self] = self.xla_translation + jax.ad.deflinear(self, self.transpose) + jax.ad.primitive_transposes[self] = self.transpose + # This line defines how does the tag behave under vmap. It is required for + # any primitive that can be used inside a vmap. The reason why we want to + # allow this is two fold - one to not break user code when the tags are not + # used at all, and two - to be able to define a network with code for a + # single example which is the vmap-ed for a batch. + jax_batching.primitive_batchers[self] = self.batching + + @property + def num_outputs(self) -> int: + return self._num_outputs + + @property + def num_inputs(self) -> int: + return self._num_inputs + + def split_all_inputs( + self, + all_inputs: Sequence[_T], + ) -> Tuple[Sequence[_T], Sequence[_T], Sequence[_T]]: + outputs = tuple(all_inputs[:self.num_outputs]) + inputs = tuple(all_inputs[self.num_outputs:self.num_outputs + + self.num_inputs]) + params = tuple(all_inputs[self.num_outputs + self.num_inputs:]) + return outputs, inputs, params + + def get_outputs(self, *operands: _T, **kwargs) -> _T: + assert self.num_outputs == 1 + return operands[0] + + def xla_translation(self, c, *operands: _T, **kwargs) -> _T: + return self.get_outputs(*operands, **kwargs) + + @staticmethod + def transpose(cotangent, *operands, **kwargs): + return (cotangent,) + (None,) * (len(operands) - 1) + + def impl(self, *operands, **kwargs): + return self.get_outputs(*operands, **kwargs) + + def abstract_eval(self, *abstract_operands, **kwargs): + return self.get_outputs(*abstract_operands, **kwargs) + + def batching(self, batched_operands, batched_dims, **kwargs): + return self.bind(*batched_operands, **kwargs), batched_dims[0] + + +# _____ _ +# / ____| (_) +# | | __ ___ _ __ ___ _ __ _ ___ +# | | |_ |/ _ \ '_ \ / _ \ '__| |/ __| +# | |__| | __/ | | | __/ | | | (__ +# \_____|\___|_| |_|\___|_| |_|\___| +# +# + +generic_tag = LayerTag(name="generic_tag", num_inputs=0, num_outputs=1) + + +def register_generic(parameter: _T) -> _T: + return generic_tag.bind(parameter) + + +# _____ +# | __ \ +# | | | | ___ _ __ ___ ___ +# | | | |/ _ \ '_ \/ __|/ _ \ +# | |__| | __/ | | \__ \ __/ +# |_____/ \___|_| |_|___/\___| +# + +dense_tag = LayerTag(name="dense_tag", num_inputs=1, num_outputs=1) + + +def register_dense(y, x, w, b=None): + if b is None: + return dense_tag.bind(y, x, w) + return dense_tag.bind(y, x, w, b) + + +def dense_func(x, params): + """Example of a dense layer function.""" + w = params[0] + y = jnp.matmul(x, w) + if len(params) == 1: + # No bias + return y + # Add bias + return y + params[1] + + +def dense_tagging(jaxpr, inverse_map, values_map): + """Correctly registers a dense layer pattern.""" + del inverse_map + in_values = [values_map[v] for v in jaxpr.invars] + out_values = [values_map[v] for v in jaxpr.outvars] + return register_dense(out_values[0], *in_values) + + +# ___ _____ _____ _ _ _ +# |__ \| __ \ / ____| | | | | (_) +# ) | | | | | | ___ _ ____ _____ | |_ _| |_ _ ___ _ __ +# / /| | | | | | / _ \| '_ \ \ / / _ \| | | | | __| |/ _ \| "_ \ +# / /_| |__| | | |___| (_) | | | \ V / (_) | | |_| | |_| | (_) | | | | +# |____|_____/ \_____\___/|_| |_|\_/ \___/|_|\__,_|\__|_|\___/|_| |_| +# + +conv2d_tag = LayerTag(name="conv2d_tag", num_inputs=1, num_outputs=1) + + +def register_conv2d(y, x, w, b=None, **kwargs): + if b is None: + return conv2d_tag.bind(y, x, w, **kwargs) + return conv2d_tag.bind(y, x, w, b, **kwargs) + + +def conv2d_func(x, params): + """Example of a conv2d layer function.""" + w = params[0] + y = lax.conv_general_dilated( + x, + w, + window_strides=(2, 2), + padding="SAME", + dimension_numbers=("NHWC", "HWIO", "NHWC")) + if len(params) == 1: + # No bias + return y + # Add bias + return y + params[1][None, None, None] + + +def conv2d_tagging(jaxpr, inverse_map, values_map): + """Correctly registers a conv2d layer pattern.""" + in_values = [values_map[v] for v in jaxpr.invars] + out_values = [values_map[v] for v in jaxpr.outvars] + keys = [k for k in inverse_map.keys() if isinstance(k, str)] + keys = [k for k in keys if k.startswith("conv_general_dilated")] + if len(keys) != 1: + raise ValueError("Did not find any conv_general_dilated!") + kwargs = inverse_map[keys[0]].params + return register_conv2d(out_values[0], *in_values, **kwargs) + + +# _____ _ _ _____ _ _ __ _ +# / ____| | | | | / ____| | (_)/ _| | +# | (___ ___ __ _| | ___ __ _ _ __ __| | | (___ | |__ _| |_| |_ +# \___ \ / __/ _` | |/ _ \ / _` | '_ \ / _` | \___ \| '_ \| | _| __| +# ____) | (_| (_| | | __/ | (_| | | | | (_| | ____) | | | | | | | |_ +# |_____/ \___\__,_|_|\___| \__,_|_| |_|\__,_| |_____/|_| |_|_|_| \__| +# + +scale_and_shift_tag = LayerTag( + name="scale_and_shift_tag", num_inputs=1, num_outputs=1) + + +def register_scale_and_shift(y, args, has_scale: bool, has_shift: bool): + assert has_scale or has_shift + x, args = args[0], args[1:] + return scale_and_shift_tag.bind( + y, x, *args, has_scale=has_scale, has_shift=has_shift) + + +def scale_and_shift_func(x, params, has_scale: bool, has_shift: bool): + """Example of a scale and shift function.""" + if has_scale and has_shift: + scale, shift = params + return x * scale + shift + elif has_scale: + return x * params[0] + elif has_shift: + return x + params[0] + else: + raise ValueError() + + +def scale_and_shift_tagging( + jaxpr, + inverse_map, + values_map, + has_scale: bool, + has_shift: bool, +): + """Correctly registers a scale and shift layer pattern.""" + del inverse_map + in_values = [values_map[v] for v in jaxpr.invars] + out_values = [values_map[v] for v in jaxpr.outvars] + return register_scale_and_shift(out_values[0], in_values, has_scale, + has_shift) + + +def batch_norm_func( + inputs: Tuple[jnp.ndarray, jnp.ndarray], + params: Tuple[jnp.ndarray, jnp.ndarray], +) -> jnp.ndarray: + """Example of batch norm as is defined in Haiku.""" + x, y = inputs + scale, shift = params + inv = scale * y + return x * inv + shift + + +def batch_norm_tagging_func( + jaxpr, + inverse_map, + values_map, + has_scale: bool, + has_shift: bool, +): + """Correctly registers a batch norm layer pattern as is defined in Haiku.""" + del inverse_map + in_values = [values_map[v] for v in jaxpr.invars] + out_values = [values_map[v] for v in jaxpr.outvars] + # The first two are both multipliers with the scale so we merge them + in_values = [in_values[0] * in_values[1]] + in_values[2:] + return register_scale_and_shift(out_values[0], in_values, has_scale, + has_shift) diff --git a/kfac_ferminet_alpha/loss_functions.py b/kfac_ferminet_alpha/loss_functions.py new file mode 100644 index 0000000..3a08bf2 --- /dev/null +++ b/kfac_ferminet_alpha/loss_functions.py @@ -0,0 +1,653 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Loss functions to be used by LayerCollection.""" +import abc +from typing import Tuple, Optional, Union, Sequence + +import jax +import jax.numpy as jnp + +from kfac_ferminet_alpha import distributions +from kfac_ferminet_alpha import layers_and_loss_tags as tags +from kfac_ferminet_alpha import utils + +ArrayPair = Tuple[jnp.ndarray, jnp.ndarray] +FloatArray = Union[float, jnp.ndarray] +Index = Tuple[int] + + +class LossFunction(abc.ABC): + """Abstract base class for loss functions. + + Note that unlike typical loss functions used in neural networks these are + neither summed nor averaged over the batch and hence the output of evaluate() + will not be a scalar. It is up to the user to then to correctly manipulate + them as needed. + """ + + def __init__(self, weight: FloatArray): + self._weight = weight + + @property + def weight(self) -> FloatArray: + return self._weight + + @property + @abc.abstractmethod + def targets(self) -> Optional[jnp.ndarray]: + """The targets being predicted by the model. + + Returns: + None or Tensor of appropriate shape for calling self._evaluate() on. + """ + pass + + @property + @abc.abstractmethod + def inputs(self) -> Sequence[jnp.ndarray]: + """The inputs to the loss function (excluding the targets).""" + pass + + @abc.abstractmethod + def copy_with_different_inputs(self, inputs: Sequence[jnp.ndarray]): + pass + + def evaluate( + self, + targets: Optional[jnp.ndarray] = None, + coefficient_mode: str = "regular", + ) -> jnp.ndarray: + """Evaluate the loss function on the targets.""" + if targets is None and self.targets is None: + raise ValueError("Cannot evaluate losses with unspecified targets.") + elif targets is None: + targets = self.targets + if coefficient_mode == "regular": + multiplier = self.weight + elif coefficient_mode == "sqrt": + multiplier = jnp.sqrt(self.weight) + elif coefficient_mode == "off": + multiplier = 1.0 + else: + raise ValueError(f"Unrecognized coefficient_mode={coefficient_mode}.") + return self._evaluate(targets) * multiplier + + @abc.abstractmethod + def _evaluate(self, targets: jnp.ndarray) -> jnp.ndarray: + """Evaluates the negative log probability of the targets. + + Args: + targets: Tensor that distribution can calculate log_prob() of. + + Returns: + negative log probability of each target, summed across all targets. + """ + pass + + def grad_of_evaluate( + self, + targets: Optional[jnp.ndarray], + coefficient_mode: str, + ) -> Sequence[jnp.ndarray]: + """Evaluates the gradient of the loss function. + + Note that the targets of the loss must not be `None`. + + Args: + targets: The potential targets on which to evaluate the gradient. + coefficient_mode: The coefficient mode to use for evaluation. + + Returns: + The gradient of the loss evaluation function with respect to the inputs. + """ + def evaluate_sum(inputs: Sequence[jnp.ndarray]) -> jnp.ndarray: + instance = self.copy_with_different_inputs(inputs) + return jnp.sum(instance.evaluate(targets, coefficient_mode)) + return jax.grad(evaluate_sum)(self.inputs) + + def multiply_ggn(self, vector: jnp.ndarray) -> jnp.ndarray: + """Right-multiply a vector by the GGN. + + Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) + of the loss function with respect to its inputs. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the 'inputs' + property. + + Returns: + The vector right-multiplied by the GGN. Will be of the same shape(s) + as the 'inputs' property. + """ + return utils.scalar_mul(self.multiply_ggn_unweighted(vector), self.weight) + + @abc.abstractmethod + def multiply_ggn_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: + """Same as `multiply_ggn`, but without taking into account the weight.""" + pass + + def multiply_ggn_factor(self, vector: jnp.ndarray) -> jnp.ndarray: + """Right-multiply a vector by a factor B of the GGN. + + Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + Note that B can be any matrix satisfying B * B^T = G where G is the GGN, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be of the shape given by the + 'ggn_factor_inner_shape' property. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + return utils.scalar_mul( + self.multiply_ggn_factor_unweighted(vector), jnp.sqrt(self.weight)) + + @abc.abstractmethod + def multiply_ggn_factor_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: + """Same as `multiply_ggn_factor`, but without taking into account the weight.""" + pass + + def multiply_ggn_factor_transpose(self, vector: jnp.ndarray) -> jnp.ndarray: + """Right-multiply a vector by the transpose of a factor B of the GGN. + + Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + Note that B can be any matrix satisfying B * B^T = G where G is the GGN, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the 'inputs' + property. + + Returns: + The vector right-multiplied by B^T. Will be of the shape given by the + 'ggn_factor_inner_shape' property. + """ + return utils.scalar_mul( + self.multiply_ggn_factor_transpose_unweighted(vector), + jnp.sqrt(self.weight)) + + @abc.abstractmethod + def multiply_ggn_factor_transpose_unweighted( + self, + vector: jnp.ndarray + ) -> jnp.ndarray: + """Same as `multiply_ggn_factor_transpose`, but without taking into account the weight.""" + pass + + def multiply_ggn_factor_replicated_one_hot(self, index: Index) -> jnp.ndarray: + """Right-multiply a replicated-one-hot vector by a factor B of the GGN. + + Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + A 'replicated-one-hot' vector means a tensor which, for each slice along the + batch dimension (assumed to be dimension 0), is 1.0 in the entry + corresponding to the given index and 0 elsewhere. + + Note that B can be any matrix satisfying B * B^T = G where G is the GGN, + but will agree with the one used in the other methods of this class. + + Args: + index: A tuple representing in the index of the entry in each slice that + is 1.0. Note that len(index) must be equal to the number of elements of + the 'ggn_factor_inner_shape' tensor minus one. + + Returns: + The vector right-multiplied by B^T. Will be of the same shape(s) as the + 'inputs' property. + """ + return utils.scalar_mul( + self.multiply_ggn_factor_replicated_one_hot_unweighted(index), + jnp.sqrt(self.weight)) + + @abc.abstractmethod + def multiply_ggn_factor_replicated_one_hot_unweighted( + self, + index: Index + ) -> jnp.ndarray: + pass + + @property + @abc.abstractmethod + def ggn_factor_inner_shape(self) -> Sequence[int]: + """The shape of the tensor returned by multiply_ggn_factor.""" + pass + + +class NegativeLogProbLoss(LossFunction): + """Abstract base class for loss functions that are negative log probs.""" + + @property + def inputs(self): + return self.params + + @property + @abc.abstractmethod + def params(self): + """Parameters to the underlying distribution.""" + pass + + def multiply_fisher(self, vector: jnp.ndarray) -> jnp.ndarray: + """Right-multiply a vector by the Fisher. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the 'inputs' + property. + + Returns: + The vector right-multiplied by the Fisher. Will be of the same shape(s) + as the 'inputs' property. + """ + return utils.scalar_mul( + self.multiply_fisher_unweighted(vector), self.weight) + + @abc.abstractmethod + def multiply_fisher_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: + pass + + def multiply_fisher_factor(self, vector: jnp.ndarray) -> jnp.ndarray: + """Right-multiply a vector by a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribution (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be of the shape given by the + 'fisher_factor_inner_shape' property. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + return utils.scalar_mul( + self.multiply_fisher_factor_unweighted(vector), jnp.sqrt(self.weight)) + + @abc.abstractmethod + def multiply_fisher_factor_unweighted( + self, + vector: jnp.ndarray + ) -> jnp.ndarray: + pass + + def multiply_fisher_factor_transpose( + self, + vector: jnp.ndarray + ) -> jnp.ndarray: + """Right-multiply a vector by the transpose of a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribution (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the 'inputs' + property. + + Returns: + The vector right-multiplied by B^T. Will be of the shape given by the + 'fisher_factor_inner_shape' property. + """ + return utils.scalar_mul( + self.multiply_fisher_factor_transpose_unweighted(vector), + jnp.sqrt(self.weight)) + + @abc.abstractmethod + def multiply_fisher_factor_transpose_unweighted( + self, + vector: jnp.ndarray + ) -> jnp.ndarray: + pass + + def multiply_fisher_factor_replicated_one_hot( + self, + index: Index + ) -> jnp.ndarray: + """Right-multiply a replicated-one-hot vector by a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribution (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + A 'replicated-one-hot' vector means a tensor which, for each slice along the + batch dimension (assumed to be dimension 0), is 1.0 in the entry + corresponding to the given index and 0 elsewhere. + + Note that B can be any matrix satisfying B * B^T = H where H is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + index: A tuple representing in the index of the entry in each slice that + is 1.0. Note that len(index) must be equal to the number of elements of + the 'fisher_factor_inner_shape' tensor minus one. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + return utils.scalar_mul( + self.multiply_fisher_factor_replicated_one_hot_unweighted(index), + jnp.sqrt(self.weight)) + + @abc.abstractmethod + def multiply_fisher_factor_replicated_one_hot_unweighted( + self, + index: Index + ) -> jnp.ndarray: + pass + + @property + @abc.abstractmethod + def fisher_factor_inner_shape(self) -> Sequence[int]: + """The shape of the tensor returned by multiply_fisher_factor.""" + pass + + @abc.abstractmethod + def sample(self, rng_key: jnp.ndarray) -> jnp.ndarray: + """Sample 'targets' from the underlying distribution.""" + pass + + def grad_of_evaluate_on_sample( + self, + rng_key: jnp.ndarray, + coefficient_mode: str, + ) -> Sequence[jnp.ndarray]: + """Evaluates the gradient of the log probability on a random sample. + + Args: + rng_key: Jax PRNG key for sampling. + coefficient_mode: The coefficient mode to use for evaluation. + + Returns: + The gradient of the log probability of targets sampled from the + distribution. + """ + return self.grad_of_evaluate(self.sample(rng_key), coefficient_mode) + + +class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss, abc.ABC): + """Base class for neg log prob losses whose inputs are 'natural' parameters. + + We will take the GGN of the loss to be the Fisher associated with the + distribution, which also happens to be equal to the Hessian for this class + of loss functions. See here: https://arxiv.org/abs/1412.1193 + + 'Natural parameters' are defined for exponential-family models. See for + example: https://en.wikipedia.org/wiki/Exponential_family + """ + + def multiply_ggn_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: + return self.multiply_fisher_unweighted(vector) + + def multiply_ggn_factor_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: + return self.multiply_fisher_factor_unweighted(vector) + + def multiply_ggn_factor_transpose_unweighted( + self, + vector: jnp.ndarray + ) -> jnp.ndarray: + return self.multiply_fisher_factor_transpose_unweighted(vector) + + def multiply_ggn_factor_replicated_one_hot_unweighted( + self, + index: Index + ) -> jnp.ndarray: + return self.multiply_fisher_factor_replicated_one_hot_unweighted(index) + + @property + def ggn_factor_inner_shape(self) -> Sequence[int]: + return self.fisher_factor_inner_shape + + +class DistributionNegativeLogProbLoss(NegativeLogProbLoss): + """Base class for neg log prob losses that use the distribution classes.""" + + @property + @abc.abstractmethod + def dist(self): + """The underlying distribution instance.""" + pass + + def _evaluate(self, targets: jnp.ndarray): + return -self.dist.log_prob(targets) + + def sample(self, rng_key: jnp.ndarray): + return self.dist.sample(seed=rng_key) + + @property + def fisher_factor_inner_shape(self) -> Sequence[int]: + return self.dist.mean().shape + + +class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, + NaturalParamsNegativeLogProbLoss): + """Neg log prob loss for a normal distribution parameterized by a mean vector. + + + Note that the covariance is treated as the identity divided by 2. + Also note that the Fisher for such a normal distribution with respect the mean + parameter is given by: + + F = (1 / variance) * I + + See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf. + """ + + def __init__( + self, + mean: jnp.ndarray, + targets: Optional[jnp.ndarray] = None, + variance: float = 0.5, + weight: float = 1.0, + ): + super().__init__(weight=weight) + self._mean = mean + self._targets = targets + self._variance = variance + if not isinstance(variance, float): + raise ValueError("The `variance` argument should be python float.") + + @property + def targets(self) -> Optional[jnp.ndarray]: + return self._targets + + @property + def dist(self): + scale_diag = jnp.full_like(self._mean, jnp.sqrt(self._variance)) + return distributions.MultivariateNormalDiag(self._mean, scale_diag) + + @property + def params(self): + return self._mean, + + def copy_with_different_inputs(self, inputs: Sequence[jnp.ndarray]): + [mean] = inputs + return NormalMeanNegativeLogProbLoss( + mean=mean, + targets=self.targets, + variance=self._variance, + weight=self.weight, + ) + + def multiply_fisher_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: + return vector / self._variance + + def multiply_fisher_factor_unweighted( + self, + vector: jnp.ndarray, + ) -> jnp.ndarray: + return vector / jnp.sqrt(self._variance) + + def multiply_fisher_factor_transpose_unweighted( + self, + vector: jnp.ndarray, + ) -> jnp.ndarray: + return self.multiply_fisher_factor_unweighted(vector) # it's symmetric + + def multiply_fisher_factor_replicated_one_hot_unweighted( + self, + index: Index, + ) -> jnp.ndarray: + assert len(index) == 1, f"Length of index was {len(index)}." + index = index[0] + ones_slice = jnp.ones([self._mean.shape[0]])[..., None] + output_slice = ones_slice / jnp.sqrt(self._variance) + return insert_slice_in_zeros(output_slice, 1, self._mean.shape[1], index) + + +def insert_slice_in_zeros( + slice_to_insert: jnp.ndarray, + dim: int, + dim_size: int, + position: int, +) -> jnp.ndarray: + """Inserts slice into a larger tensor of zeros. + + Forms a new tensor which is the same shape as slice_to_insert, except that + the dimension given by 'dim' is expanded to the size given by 'dim_size'. + 'position' determines the position (index) at which to insert the slice within + that dimension. + + Assumes slice_to_insert.shape[dim] = 1. + + Args: + slice_to_insert: The slice to insert. + dim: The dimension which to expand with zeros. + dim_size: The new size of the 'dim' dimension. + position: The position of 'slice_to_insert' in the new tensor. + + Returns: + The new tensor. + + Raises: + ValueError: If the slice's shape at the given dim is not 1. + """ + slice_shape = slice_to_insert.shape + if slice_shape[dim] != 1: + raise ValueError(f"Expected slice_to_insert.shape to have {dim} dim of 1," + f" but was {slice_to_insert.shape[dim]}.") + + before = [0] * int(len(slice_shape)) + after = before[:] + before[dim] = position + after[dim] = dim_size - position - 1 + return jnp.pad(slice_to_insert, list(zip(before, after))) + + +# _______ _____ _ _ _ _ +# |__ __| | __ \ (_) | | | | (_) +# | | __ _ __ _ | |__) |___ __ _ _ ___| |_ _ __ __ _| |_ _ ___ _ __ +# | |/ _` |/ _` | | _ // _ \/ _` | / __| __| '__/ _` | __| |/ _ \| '_ \ +# | | (_| | (_| | | | \ \ __/ (_| | \__ \ |_| | | (_| | |_| | (_) | | | | +# |_|\__,_|\__, | |_| \_\___|\__, |_|___/\__|_| \__,_|\__|_|\___/|_| |_| +# __/ | __/ | +# |___/ |___/ + + +NormalMeanNegativeLogProbLoss_tag = tags.LossTag( + NormalMeanNegativeLogProbLoss, num_inputs=1) + + +def register_normal_predictive_distribution( + mean: jnp.ndarray, + targets: Optional[jnp.ndarray] = None, + variance: float = 0.5, + weight: float = 1.0, +): + """Registers a normal predictive distribution. + + This corresponds to a squared error loss of the form + weight/(2*var) * ||target - mean||^2 + + Args: + mean: A tensor defining the mean vector of the distribution. The first + dimension must be the batch size. + targets: (OPTIONAL) The targets for the loss function. Only required if one + wants to use the "empirical Fisher" instead of the true Fisher (which is + controlled by the 'estimation_mode' to the optimizer). + (Default: None) + variance: float. The variance of the distribution. Note that the default + value of 0.5 corresponds to a standard squared error loss weight * + ||target - prediction||^2. If you want your squared error loss to be of + the form 0.5*coeff*||target - prediction||^2 you should use + variance=1.0. + (Default: 0.5) + weight: A scalar coefficient to multiply the log prob loss associated with + this distribution. The Fisher will be multiplied by the corresponding + factor. In general this is NOT equivalent to changing the temperature of + the distribution, but in the ase of normal distributions it may be. + (Default: 1.0) + + Returns: + The mean and targets as dependable on the tag. + """ + if targets is None: + targets = jnp.zeros_like(mean) + return NormalMeanNegativeLogProbLoss_tag.bind( + mean, targets, variance=variance, weight=weight, return_loss=False) + + +def register_squared_error_loss( + prediction: jnp.ndarray, + targets: Optional[jnp.ndarray] = None, + weight: float = 1.0, +): + """Registers a squared error loss function. + + This assumes the squared error loss of the form ||target - prediction||^2, + averaged across the mini-batch. If your loss uses a coefficient of 0.5 + you need to set the "weight" argument to reflect this. + + Args: + prediction: The prediction made by the network (i.e. its output). The first + dimension must be the batch size. + targets: (OPTIONAL) The targets for the loss function. Only required if one + wants to use the "empirical Fisher" instead of the true Fisher (which is + controlled by the 'estimation_mode' to the optimizer). + (Default: None) + weight: A float coefficient to multiply the loss function by. + (Default: 1.0) + Returns: + The mean and targets as dependable on the tag. + """ + return register_normal_predictive_distribution( + prediction, targets=targets, variance=0.5, weight=weight) diff --git a/kfac_ferminet_alpha/optimizer.py b/kfac_ferminet_alpha/optimizer.py new file mode 100644 index 0000000..ff2a49f --- /dev/null +++ b/kfac_ferminet_alpha/optimizer.py @@ -0,0 +1,611 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A module for the main curvature optimizer class.""" +from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, Tuple, Union + +import jax +import jax.lax as lax +import jax.numpy as jnp +import jax.random as jnr + +from kfac_ferminet_alpha import estimator +from kfac_ferminet_alpha import tag_graph_matcher as tgm +from kfac_ferminet_alpha import utils + +ScheduleType = Callable[[jnp.ndarray], Optional[jnp.ndarray]] +Parameters = Any +Batch = Any +FuncState = Any +State = Mapping[str, Any] + + +@utils.Stateful.infer_class_state +class Optimizer(utils.Stateful): + """The default optimizer class.""" + velocities: Parameters + estimator: estimator.CurvatureEstimator + step_counter: jnp.ndarray + + def __init__( + self, + value_and_grad_func, + l2_reg: Union[float, jnp.ndarray], + value_func_has_aux: bool = False, + value_func_has_state: bool = False, + value_func_has_rng: bool = False, + learning_rate_schedule: Optional[ScheduleType] = None, + momentum_schedule: Optional[ScheduleType] = None, + damping_schedule: Optional[ScheduleType] = None, + min_damping: Union[float, jnp.ndarray] = 1e-8, + max_damping: Union[float, jnp.ndarray] = jnp.inf, + norm_constraint: Optional[Union[float, jnp.ndarray]] = None, + num_burnin_steps: int = 10, + estimation_mode: str = "fisher_gradients", + curvature_ema: Union[float, jnp.ndarray] = 0.95, + inverse_update_period: int = 5, + register_only_generic: bool = False, + layer_tag_to_block_cls: Optional[estimator.TagMapping] = None, + patterns_to_skip: Sequence[str] = (), + donate_parameters: bool = False, + donate_optimizer_state: bool = False, + donate_batch_inputs: bool = False, + donate_func_state: bool = False, + batch_process_func: Optional[Callable[[Any], Any]] = None, + multi_device: bool = False, + use_jax_cond: bool = True, + debug: bool = False, + pmap_axis_name="kfac_axis", + ): + """Initializes the K-FAC optimizer with the given settings. + + Args: + value_and_grad_func: Python callable. The function should return the value + of the loss to be optimized and its gradients. If the argument + `value_func_has_aux` is `False` then the interface should be: loss, + loss_grads = value_and_grad_func(params, batch) + If `value_func_has_aux` is `True` then the interface should be: (loss, + aux), loss_grads = value_and_grad_func(params, batch) + l2_reg: Scalar. Set this value to tell the optimizer what L2 + regularization coefficient you are using (if any). Note the coefficient + appears in the regularizer as coeff / 2 * sum(param**2). Note that the + user is still responsible for adding regularization to the loss. + value_func_has_aux: Boolean. Specifies whether the provided callable + `value_and_grad_func` returns the loss value only, or also some + auxiliary data. (Default: False) + value_func_has_state: Boolean. Specifies whether the provided callable + `value_and_grad_func` has a persistent state that is inputed and + it also outputs an update version of it. (Default: False) + value_func_has_rng: Boolean. Specifies whether the provided callable + `value_and_grad_func` additionally takes as input an rng key. + (Default: False) + learning_rate_schedule: Callable. A schedule for the learning rate. This + should take as input the current step number and return a single + `jnp.ndarray` that represents the learning rate. (Default: None) + momentum_schedule: Callable. A schedule for the momentum. This should take + as input the current step number and return a single `jnp.ndarray` + that represents the momentum. (Default: None) + damping_schedule: Callable. A schedule for the damping. This should take + as input the current step number and return a single `jnp.ndarray` + that represents the learning rate. (Default: None) + min_damping: Scalar. Minimum value the damping parameter can take. Note + that the default value of 1e-8 is quite arbitrary, and you may have to + adjust this up or down for your particular problem. If you are using a + non-zero value of l2_reg you *may* be able to set this to + zero. (Default: 1e-8) + max_damping: Scalar. Maximum value the damping parameter can take. + (Default: Infinity) + norm_constraint: Scalar. If specified, the update is scaled down so that + its approximate squared Fisher norm `v^T F v` is at most the specified + value.(Note that here `F` is the approximate curvature matrix, not the + exact.) (Default: None) + num_burnin_steps: Int. At the start of optimization, e.g. the first step, + before performing the actual step the optimizer will perform this many + times updates to the curvature approximation without updating the + actual parameters. (Default: 10) + estimation_mode: String. The type of estimator to use for the curvature + matrix. Can be one of: * fisher_empirical * fisher_exact * + fisher_gradients * fisher_curvature_prop * ggn_exact * + ggn_curvature_prop See the doc-string for CurvatureEstimator (in + estimator.py) for a more + detailed description of these options. (Default: 'fisher_gradients'). + curvature_ema: The decay factor used when calculating the covariance + estimate moving averages. (Default: 0.95) + inverse_update_period: Int. The number of steps in between updating the + the computation of the inverse curvature approximation. (Default: 5) + register_only_generic: Boolean. Whether when running the auto-tagger to + register only generic parameters, or allow it to use the graph matcher + to automatically pick up any kind of layer tags. (Default: False) + layer_tag_to_block_cls: Dictionary. A mapping from layer tags to block + classes which to override the default choices of block approximation for + that specific tag. See the doc-string for CurvatureEstimator (in + estimator.py) for a more detailed description of this. + patterns_to_skip: Tuple. A list of any patterns that should be skipped by + the graph matcher when auto-tagging. + donate_parameters: Boolean. Whether to use jax's `donate_argnums` to + donate the parameter values of each call to `step`. Note that this + implies that you will not be able to access the old parameter values' + buffers after calling into `step`. + donate_optimizer_state: Boolean. Whether to use jax's `donate_argnums` to + donate the optimizer state of each call to `step`. Note that this + implies that you will not be able to access the old optimizer state + values' buffers after calling into `step`. + donate_batch_inputs: Boolean. Whether to use jax's `donate_argnums` to + donate the batch values of each call to `step`. Note that this implies + that you will not be able to access the old batch values' buffers after + calling into `step`. + donate_func_state: Boolean. Whether to use jax's `donate_argnums` to + donate the persistent function state of each call to `step`. Note that + this implies that you will not be able to access the old function state + values' buffers after calling into `step`. + batch_process_func: Callable. A function which to be called on each batch + before feeding to the KFAC on device. This could be useful for specific + device input optimizations. + multi_device: Boolean. Whether to use `pmap` and run the optimizer on + multiple devices. (Default: False) + use_jax_cond: Not used for the moment. + debug: Boolean. If non of the step or init functions would be jitted. Note + that this also overrides `multi_device` and prevents using `pmap`. + (Default: False) + pmap_axis_name: String. The name of the `pmap` axis to use when + `multi_device` is set to True. (Default: curvature_axis) + """ + super().__init__() + self.value_and_grad_func = value_and_grad_func + self.value_func_has_aux = value_func_has_aux + self.value_func_has_state = value_func_has_state + self.value_func_has_rng = value_func_has_rng + self.value_func = utils.convert_value_and_grad_to_value_func( + value_and_grad_func, has_aux=value_func_has_aux) + self.l2_reg = l2_reg + self.learning_rate_schedule = learning_rate_schedule + if momentum_schedule is not None: + + def schedule_with_first_step_zero(global_step: jnp.ndarray): + value = momentum_schedule(global_step) + check = jnp.equal(global_step, 0) + return check * jnp.zeros_like(value) + (1 - check) * value + + self.momentum_schedule = schedule_with_first_step_zero + else: + self.momentum_schedule = None + self.damping_schedule = damping_schedule + self.min_damping = min_damping + self.max_damping = max_damping + self.norm_constraint = norm_constraint + self.num_burnin_steps = num_burnin_steps + self.estimation_mode = estimation_mode + self.curvature_ema = curvature_ema + self.inverse_update_period = inverse_update_period + self.register_only_generic = register_only_generic + self.layer_tag_to_block_cls = layer_tag_to_block_cls + self.patterns_to_skip = patterns_to_skip + self.donate_parameters = donate_parameters + self.donate_optimizer_state = donate_optimizer_state + self.donate_batch_inputs = donate_batch_inputs + self.donate_func_state = donate_func_state + self.batch_process_func = batch_process_func or (lambda x: x) + self.multi_device = multi_device + self.use_jax_cond = use_jax_cond + self.debug = debug + self.pmap_axis_name = pmap_axis_name if multi_device else None + self._rng_split = utils.p_split if multi_device else jnr.split + + # Attributes filled in during self.init() + self.finalized = False + self.tagged_func = None + self.flat_params_shapes = None + self.params_treedef = None + # Special attributes related to jitting/pmap + self._jit_init = None + self._jit_burnin = None + self._jit_step = None + + def finalize( + self, + params: Parameters, + rng: jnp.ndarray, + batch: Batch, + func_state: Optional[FuncState] = None, + ) -> None: + """Finalizes the optimizer by tracing the model function with the params and batch.""" + if self.finalized: + raise ValueError("Optimizer has already been finalized.") + if self.multi_device: + # We assume that the parameters and batch are replicated, while tracing + # must happen with parameters for a single device call + params, rng, batch = jax.tree_map(lambda x: x[0], (params, rng, batch)) + if func_state is not None: + func_state = jax.tree_map(lambda x: x[0], func_state) + batch = self.batch_process_func(batch) + # These are all tracing operations and we can run them with abstract values + func_args = utils.make_func_args(params, func_state, rng, batch, + self.value_func_has_state, + self.value_func_has_rng) + # Run all tracing with abstract values so no computation is done + flat_params, self.params_treedef = jax.tree_flatten(params) + self.flat_params_shapes = tuple(p.shape for p in flat_params) + self.tagged_func = tgm.auto_register_tags( + func=self.value_func, + func_args=func_args, + params_index=0, + register_only_generic=self.register_only_generic, + patterns_to_skip=self.patterns_to_skip) + self.estimator = estimator.CurvatureEstimator( + self.tagged_func, + func_args, + self.l2_reg, + self.estimation_mode, + layer_tag_to_block_cls=self.layer_tag_to_block_cls) + # Arguments: params, opt_state, rng, batch, func_state + donate_argnums = [] + if self.donate_parameters: + donate_argnums.append(0) + if self.donate_optimizer_state: + donate_argnums.append(1) + if self.donate_batch_inputs: + donate_argnums.append(3) + if self.donate_func_state and self.value_func_has_state: + donate_argnums.append(4) + donate_argnums = tuple(donate_argnums) + + if self.debug: + self._jit_init = self._init + self._jit_burnin = self._burnin + self._jit_step = self._step + elif self.multi_device: + self._jit_init = jax.pmap( + self._init, axis_name=self.pmap_axis_name, donate_argnums=[0]) + # batch size is static argnum and is at index 5 + self._jit_burnin = jax.pmap( + self._burnin, + axis_name=self.pmap_axis_name, + static_broadcasted_argnums=[5]) + self._jit_step = jax.pmap( + self._step, + axis_name=self.pmap_axis_name, + donate_argnums=donate_argnums, + static_broadcasted_argnums=[5]) + else: + self._jit_init = jax.jit(self._init, donate_argnums=[0]) + # batch size is static argnum and is at index 5 + self._jit_burnin = jax.jit(self._burnin, static_argnums=[5]) + self._jit_step = jax.jit( + self._step, donate_argnums=donate_argnums, static_argnums=[5]) + self.finalized = True + + def _init(self, rng: jnp.ndarray) -> State: + """This is the non-jitted version of initializing the state.""" + flat_velocities = [jnp.zeros(shape) for shape in self.flat_params_shapes] + return dict( + velocities=jax.tree_unflatten(self.params_treedef, flat_velocities), + estimator=self.estimator.init(rng, None), + step_counter=jnp.asarray(0)) + + def verify_args_and_get_step_counter( + self, + params: Parameters, + state: State, + rng: jnp.ndarray, + data_iterator: Iterator[Batch], + func_state: Optional[FuncState] = None, + learning_rate: Optional[jnp.ndarray] = None, + momentum: Optional[jnp.ndarray] = None, + damping: Optional[jnp.ndarray] = None, + global_step_int: Optional[int] = None, + ) -> int: + """Verifies that the arguments passed to `Optimizer.step` are correct.""" + if not self.finalized: + rng, rng_finalize = self._rng_split(rng) + self.finalize(params, rng_finalize, next(data_iterator), func_state) + # Verify correct arguments invocation + if self.learning_rate_schedule is not None and learning_rate is not None: + raise ValueError("When you have passed a `learning_rate_schedule` you " + "should not pass a value to the step function.") + if self.momentum_schedule is not None and momentum is not None: + raise ValueError("When you have passed a `momentum_schedule` you should " + "not pass a value to the step function.") + if self.damping_schedule is not None and damping is not None: + raise ValueError("When you have passed a `damping_schedule` you should " + "not pass a value to the step function.") + # Do a bunrnin on the first iteration + if global_step_int is None: + if self.multi_device: + return int(utils.get_first(state["step_counter"])) + else: + return int(state["step_counter"]) + return global_step_int + + def _burnin( + self, + params: Parameters, + state: State, + rng: jnp.ndarray, + batch: Batch, + func_state: Optional[FuncState], + batch_size: Optional[int], + ) -> Tuple[State, Optional[FuncState]]: + """This is the non-jitted version of a single burnin step.""" + self.set_state(state) + batch = self.batch_process_func(batch) + rng, func_rng = jnr.split(rng) if self.value_func_has_rng else (rng, None) + func_args = utils.make_func_args(params, func_state, func_rng, batch, + self.value_func_has_state, + self.value_func_has_rng) + + # Compute batch size + if batch_size is None: + batch_size = jax.tree_flatten(batch)[0][0].shape[0] + + # Update curvature estimate + ema_old, ema_new = 1.0, 1.0 / self.num_burnin_steps + self.estimator.update_curvature_matrix_estimate(ema_old, ema_new, + batch_size, rng, func_args, + self.pmap_axis_name) + + if func_state is not None: + out, _ = self.value_and_grad_func(*func_args) + _, func_state, _ = utils.extract_func_outputs(out, + self.value_func_has_aux, + self.value_func_has_state) + + return self.pop_state(), func_state + + def _step( + self, + params: Parameters, + state: State, + rng: jnp.ndarray, + batch: Batch, + func_state: Optional[FuncState], + batch_size: Optional[int], + learning_rate: Optional[jnp.ndarray], + momentum: Optional[jnp.ndarray], + damping: Optional[jnp.ndarray], + ) -> Union[Tuple[Parameters, State, FuncState, Mapping[str, jnp.ndarray]], + Tuple[Parameters, State, Mapping[str, jnp.ndarray]]]: + """This is the non-jitted version of a single step.""" + # Unpack and set the state + self.set_state(state) + if damping is not None: + assert self.estimator.damping is None + self.estimator.damping = damping + else: + assert self.estimator.damping is not None + + # Preprocess the batch and construct correctly the function arguments + batch = self.batch_process_func(batch) + rng, func_rng = jnr.split(rng) if self.value_func_has_rng else (rng, None) + func_args = utils.make_func_args(params, func_state, func_rng, batch, + self.value_func_has_state, + self.value_func_has_rng) + + # Compute the batch size + if batch_size is None: + batch_size = jax.tree_flatten(batch)[0][0].shape[0] + + # Compute schedules if applicable + if self.learning_rate_schedule is not None: + assert learning_rate is None + learning_rate = self.learning_rate_schedule(self.step_counter) + else: + assert learning_rate is not None + if self.momentum_schedule is not None: + assert momentum is None + momentum = self.momentum_schedule(self.step_counter) + else: + assert momentum is not None + if self.damping_schedule is not None: + assert damping is None + damping = self.damping_schedule(self.step_counter) + else: + assert damping is not None + + # Compute current loss and gradients + out, grads = self.value_and_grad_func(*func_args) + loss, new_func_state, aux = utils.extract_func_outputs( + out, self.value_func_has_aux, self.value_func_has_state) + # Sync loss and grads + loss, grads = utils.pmean_if_pmap((loss, grads), self.pmap_axis_name) + + # Update curvature estimate + self.estimator.update_curvature_matrix_estimate( + self.curvature_ema, + 1.0, + batch_size, + rng, + func_args, + self.pmap_axis_name, + ) + + # Optionally update the inverse estimate + self.estimator.set_state( + lax.cond( + self.step_counter % self.inverse_update_period == 0, + lambda s: self.estimator.update_curvature_estimate_inverse( # pylint: disable=g-long-lambda + self.pmap_axis_name, s), + lambda s: s, + self.estimator.pop_state())) + + # Compute proposed directions + vectors = self.propose_directions( + grads, + self.velocities, + learning_rate, + momentum, + ) + + # The learning rate is defined as the negative of the coefficient by which + # we multiply the gradients, while the momentum is the coefficient by + # which we multiply the velocities. + neg_learning_rate = -learning_rate + # Compute the coefficients of the update vectors + assert neg_learning_rate is not None and momentum is not None + coefficients = (neg_learning_rate, momentum) + + # Update velocities and compute new delta + self.velocities, delta = self.velocities_and_delta( + self.velocities, + vectors, + coefficients, + ) + + # Update parameters: params = params + delta + params = jax.tree_multimap(jnp.add, params, delta) + + # Optionally compute the reduction ratio and update the damping + self.estimator.damping = None + rho = jnp.nan + + # Statistics with useful information + stats = dict() + stats["step"] = self.step_counter + stats["loss"] = loss + stats["learning_rate"] = -coefficients[0] + stats["momentum"] = coefficients[1] + stats["damping"] = damping + stats["rho"] = rho + if self.value_func_has_aux: + stats["aux"] = aux + self.step_counter = self.step_counter + 1 + + if self.value_func_has_state: + return params, self.pop_state(), new_func_state, stats + else: + assert new_func_state is None + return params, self.pop_state(), stats + + def init( + self, + params: Parameters, + rng: jnp.ndarray, + batch: Batch, + func_state: Optional[FuncState] = None, + ) -> State: + """Initializes the optimizer and returns the appropriate optimizer state.""" + if not self.finalized: + self.finalize(params, rng, batch, func_state) + return self._jit_init(rng) + + def step( + self, + params: Parameters, + state: Mapping[str, Any], + rng: jnp.ndarray, + data_iterator: Iterator[Any], + func_state: Any = None, + learning_rate: Optional[jnp.ndarray] = None, + momentum: Optional[jnp.ndarray] = None, + damping: Optional[jnp.ndarray] = None, + batch_size: Optional[int] = None, + global_step_int: Optional[int] = None, + ) -> Union[Tuple[Parameters, State, FuncState, Mapping[str, jnp.ndarray]], + Tuple[Parameters, State, Mapping[str, jnp.ndarray]]]: + """Performs a single update step using the optimizer. + + Args: + params: The parameters of the model. + state: The state of the optimizer. + rng: A Jax PRNG key. + data_iterator: An iterator that returns a batch of data. + func_state: Any function state that gets passed in and returned. + learning_rate: This must be provided when + `use_adaptive_learning_rate=False` and `learning_rate_schedule=None`. + momentum: This must be provided when + `use_adaptive_momentum=False` and `momentum_schedule=None`. + damping: This must be provided when + `use_adaptive_damping=False` and `damping_schedule=None`. + batch_size: The batch size to use for KFAC. The default behaviour when it + is None is to use the leading dimension of the first data array. + global_step_int: The global step as a python int. Note that this must + match the step inte rnal to the optimizer that is part of its state. + + Returns: + (params, state, stats) + where: + params: The updated model parameters. + state: The updated optimizer state. + stats: A dictionary of key statistics provided to be logged. + """ + step_counter_int = self.verify_args_and_get_step_counter( + params=params, + state=state, + rng=rng, + data_iterator=data_iterator, + func_state=func_state, + learning_rate=learning_rate, + momentum=momentum, + damping=damping, + global_step_int=global_step_int) + + if step_counter_int == 0: + for _ in range(self.num_burnin_steps): + rng, rng_burn = self._rng_split(rng) + batch = next(data_iterator) + state, func_state = self._jit_burnin(params, state, rng_burn, batch, + func_state, batch_size) + + # On the first step we always treat the momentum as 0.0 + if self.momentum_schedule is None: + momentum = jnp.zeros([]) + if self.multi_device: + momentum = utils.replicate_all_local_devices(momentum) + + batch = next(data_iterator) + return self._jit_step(params, state, rng, batch, func_state, batch_size, + learning_rate, momentum, damping) + + def propose_directions( + self, + grads: Parameters, + velocities: Parameters, + learning_rate: Optional[jnp.ndarray], + momentum: Optional[jnp.ndarray], + ) -> Tuple[Parameters, Parameters]: + """Computes the vector proposals for the next step.""" + del momentum # not used in this, but could be used in subclasses + preconditioned_grads = self.estimator.multiply_matpower(grads, -1) + + if self.norm_constraint is not None: + assert learning_rate is not None + sq_norm_grads = utils.inner_product(preconditioned_grads, grads) + sq_norm_scaled_grads = sq_norm_grads * learning_rate**2 + + # We need to sync the norms here, because reduction can be + # non-deterministic. They specifically are on GPUs by default for better + # performance. Hence although grads and preconditioned_grads are synced, + # the inner_product operation can still produce different answers on + # different devices. + sq_norm_scaled_grads = utils.pmean_if_pmap(sq_norm_scaled_grads, + self.pmap_axis_name) + + max_coefficient = jnp.sqrt(self.norm_constraint / sq_norm_scaled_grads) + coefficient = jnp.minimum(max_coefficient, 1) + preconditioned_grads = utils.scalar_mul(preconditioned_grads, coefficient) + + return preconditioned_grads, velocities + + def velocities_and_delta( + self, + velocities: Parameters, + vectors: Sequence[Parameters], + coefficients: Sequence[jnp.ndarray], + ) -> Sequence[Parameters]: + """Computes the new velocities and delta (update to parameters).""" + del velocities + assert len(vectors) == len(coefficients) + delta = utils.scalar_mul(vectors[0], coefficients[0]) + for vi, wi in zip(vectors[1:], coefficients[1:]): + delta = jax.tree_multimap(jnp.add, delta, utils.scalar_mul(vi, wi)) + return delta, delta diff --git a/kfac_ferminet_alpha/requirements.txt b/kfac_ferminet_alpha/requirements.txt new file mode 100644 index 0000000..e93b2a6 --- /dev/null +++ b/kfac_ferminet_alpha/requirements.txt @@ -0,0 +1,7 @@ +jax>=0.2.10 +dataclasses>=0.6 +networkx>=2.1 +numpy>=1.16.4 +typing>=3.7.4.3 +ordered-set>=4.0.2 +absl-py>=0.12.0 diff --git a/kfac_ferminet_alpha/run.sh b/kfac_ferminet_alpha/run.sh new file mode 100755 index 0000000..e35f79f --- /dev/null +++ b/kfac_ferminet_alpha/run.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script installs kfac_ferminet_alpha in a clean virtualenv and runs an +# example training loop. It is designed to be run from the parent directory, +# e.g.: +# +# git clone git@github.com:deepmind/deepmind-research.git +# cd deepmind_research +# kfac_ferminet_alpha/run.sh + +python3 -m venv /tmp/kfac_ferminet_alpha_example +source /tmp/kfac_ferminet_alpha_example/bin/activate +pip3 install -U pip +pip3 install -r kfac_ferminet_alpha/requirements.txt +# For a GPU you have to do: +# pip3 install --upgrade jax jaxlib==0.1.64+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html +pip3 install jaxlib +pip3 install kfac_ferminet_alpha/ +python3 kfac_ferminet_alpha/example.py diff --git a/kfac_ferminet_alpha/setup.py b/kfac_ferminet_alpha/setup.py new file mode 100644 index 0000000..ac77710 --- /dev/null +++ b/kfac_ferminet_alpha/setup.py @@ -0,0 +1,52 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Setup for pip package.""" + +from setuptools import setup + + +REQUIRED_PACKAGES = ( + "absl-py", + "dataclasses", + "jax", + "networkx", + "numpy", + "ordered-set", + "typing", +) + +LONG_DESCRIPTION = "\n".join([ + "Kronecker-Factored Approximate Curvature (K-FAC) optimizer implemented in " + "JAX.", + "", + "Accompanying code for 'Better, Faster Fermionic Neural Networks'", + "James S. Spencer, David Pfau, Aleksandar Botev, and W. M. C. Foulkes.", + "https://arxiv.org/abs/2011.07125.", +]) + + +setup( + name="kfac_ferminet_alpha", + version="0.0.1", + description="A K-FAC optimizer implemented in JAX", + long_description=LONG_DESCRIPTION, + url="https://github.com/deepmind/deepmind-research/kfac_ferminet_alpha", + author="DeepMind", + package_dir={"kfac_ferminet_alpha": "."}, + packages=["kfac_ferminet_alpha"], + install_requires=REQUIRED_PACKAGES, + platforms=["any"], + license="Apache License, Version 2.0", +) diff --git a/kfac_ferminet_alpha/tag_graph_matcher.py b/kfac_ferminet_alpha/tag_graph_matcher.py new file mode 100644 index 0000000..6a16efa --- /dev/null +++ b/kfac_ferminet_alpha/tag_graph_matcher.py @@ -0,0 +1,752 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A module for tagging and graph manipulation.""" +import collections +import functools +import itertools +from typing import Any, NamedTuple, Sequence + +from absl import logging +import jax +from jax import core as jax_core +from jax import lax +from jax import util as jax_util +from jax.interpreters import partial_eval as pe +import jax.numpy as jnp +import networkx as nx +from networkx.algorithms import isomorphism +import numpy as np +import ordered_set + +from kfac_ferminet_alpha import layers_and_loss_tags as tags + +USE_NETWORKX = False + + +def match_nodes(g1, g2, mapping, node1, node2): + """Matching nodes when doing graph search.""" + + if not kfac_node_match(g1.nodes[node1], g2.nodes[node2]): + return False + # Check predecessors + p1 = set(n for n in g1.predecessors(node1) if n in mapping.keys()) + p2 = set(n for n in g2.predecessors(node2) if n in mapping.values()) + if len(p1) != len(p2): + return False + for p1_i in p1: + if mapping[p1_i] not in p2: + return False + # Check successors + s1 = set(n for n in g1.successors(node1) if n in mapping.keys()) + s2 = set(n for n in g2.successors(node2) if n in mapping.values()) + if len(s1) != len(s2): + return False + for s1_i in s1: + if mapping[s1_i] not in s2: + return False + return True + + +def generate_candidates(g1, g2, mapping, node1, node2): + """Generates the initial candidates for graph search.""" + # Check predecessors + p1 = set(n for n in g1.predecessors(node1) if n not in mapping.keys()) + p2 = set(n for n in g2.predecessors(node2) if n not in mapping.values()) + candidates = ordered_set.OrderedSet(itertools.product(p1, p2)) + s1 = set(n for n in g1.successors(node1) if n not in mapping.keys()) + s2 = set(n for n in g2.successors(node2) if n not in mapping.values()) + candidates.update(list(itertools.product(s1, s2))) + return candidates + + +def find_mappings(pattern, graph, mapping, terminals): + """Finds all mappings from graph search of the pattern.""" + if len(mapping) == len(pattern): + for k, v in terminals.items(): + v.add(mapping[k]) + return [frozenset(mapping.items())] + mappings = set() + nodes_list = list(mapping.keys()) + for node1 in reversed(nodes_list): + for s1 in pattern.successors(node1): + if s1 not in mapping.keys(): + for s2 in graph.successors(mapping[node1]): + if s2 not in mapping.values(): + if s1 not in terminals or s2 not in terminals[s1]: + if match_nodes(pattern, graph, mapping, s1, s2): + mapping[s1] = s2 + mappings.update( + find_mappings(pattern, graph, mapping, terminals)) + mapping.pop(s1) + for p1 in pattern.predecessors(node1): + if p1 not in mapping.keys(): + for p2 in graph.predecessors(mapping[node1]): + if p2 not in mapping.values(): + if p1 not in terminals or p2 not in terminals[p1]: + if match_nodes(pattern, graph, mapping, p1, p2): + mapping[p1] = p2 + mappings.update( + find_mappings(pattern, graph, mapping, terminals)) + mapping.pop(p1) + return mappings + + +def match_pattern(pattern, graph): + """Given a pattern returns all matches inside the graph.""" + if USE_NETWORKX: + matcher = isomorphism.GraphMatcher( + graph, pattern, node_match=kfac_node_match) + mappings = list( + dict((k, v) + for v, k in mapping.items()) + for mapping in matcher.subgraph_isomorphisms_iter()) + else: + mapping = collections.OrderedDict() + params1 = [n for n in pattern.nodes if pattern.nodes[n]["op"] == "param"] + params2 = [n for n in graph.nodes if graph.nodes[n]["op"] == "param"] + terminals = { + n: set() for n in pattern.nodes if not list(pattern.successors(n)) + } + + mappings = set() + for node1, node2 in itertools.product(params1, params2): + mapping[node1] = node2 + mappings.update(find_mappings(pattern, graph, mapping, terminals)) + mapping.pop(node1) + for v in terminals.values(): + v.clear() + mappings = list(dict(mapping) for mapping in mappings) + + var_mappings = [] + for mapping in mappings: + var_mappings.append(dict()) + for k, v in mapping.items(): + cond = pattern.nodes[k]["op"] in ("param", "array") + source = pattern.nodes[k]["var"] if cond else k + target = graph.nodes[v]["var"] if cond else graph.nodes[v]["eqn"] + var_mappings[-1][source] = target + + return var_mappings + + +def read_env(env, var): + # Literals are values baked into the Jaxpr + if isinstance(var, jax.core.Literal): + return var.val + return env[var] + + +def write_env(env, var, val): + env[var] = val + + +def abstract_single_value(value): + if isinstance(value, jnp.ndarray): + value = jax.ShapedArray(np.shape(value), np.result_type(value)) + return pe.PartialVal.unknown(value) + else: + return value + + +def abstract_args(args): + return jax.tree_map(abstract_single_value, args) + + +def evaluate_eqn(eqn, in_values, write_func): + """Evaluate a single Jax equation and writes the outputs.""" + in_values = list(in_values) + # This is logic specifically to handle `xla_call` + call_jaxpr, params = jax.core.extract_call_jaxpr(eqn.primitive, eqn.params) + if call_jaxpr: + subfuns = [ + jax.core.lu.wrap_init( + functools.partial(jax.core.eval_jaxpr, call_jaxpr, ())) + ] + else: + subfuns = [] + ans = eqn.primitive.bind(*(subfuns + in_values), **params) + if eqn.primitive.multiple_results: + jax_util.safe_map(write_func, eqn.outvars, ans) + else: + write_func(eqn.outvars[0], ans) + return ans + + +def clean_jaxpr_eqns(jaxpr, preserve_tags=True): + """Performs dead code elimination on the jaxpr, preserving loss and layer tags.""" + eqns = [] + dependants = set(jaxpr.outvars) + for eqn in reversed(jaxpr.eqns): + check = False + for v in eqn.outvars: + if v in dependants: + dependants.remove(v) + check = True + if isinstance(eqn.primitive, (tags.LossTag, tags.LayerTag)): + check = check or preserve_tags + if check: + eqns.append(eqn) + new_dependants = set( + v for v in eqn.invars if not isinstance(v, jax_core.Literal)) + dependants = dependants.union(new_dependants) + # Dependants should only be invars + dependants = dependants - set(jaxpr.invars + jaxpr.constvars) + + if dependants: + raise ValueError("Something went wrong with the dead code elimination.") + return reversed(eqns) + + +def broadcast_merger(f): + """Transforms `f` into a function where all consecutive broadcasts are merged.""" + + def merged_func(*func_args): + typed_jaxpr, out_avals = jax.make_jaxpr(f, return_shape=True)(*func_args) + out_tree = jax.tree_structure(out_avals) + jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals + + # Mapping from variable -> value + env = dict() + read = functools.partial(read_env, env) + write = functools.partial(write_env, env) + + # Bind args and consts to environment + flat_args = jax.tree_flatten(func_args)[0] + write(jax.core.unitvar, jax.core.unit) + jax_util.safe_map(write, jaxpr.invars, flat_args) + jax_util.safe_map(write, jaxpr.constvars, consts) + + # Bind args and consts to environment + write(jax.core.unitvar, jax.core.unit) + jax_util.safe_map(write, jaxpr.invars, flat_args) + jax_util.safe_map(write, jaxpr.constvars, consts) + + # Loop through equations and evaluate primitives using `bind` + broadcasts_outputs = dict() + for eqn in clean_jaxpr_eqns(jaxpr): + # We ignore broadcasting of constants + if (eqn.primitive.name == "broadcast_in_dim" and + not all(isinstance(v, jax_core.Literal) for v in eqn.invars)): + if eqn.invars[0] in broadcasts_outputs: + x, dims = broadcasts_outputs[eqn.invars[0]] + kept_dims = eqn.params["broadcast_dimensions"] + kept_dims = [kept_dims[d] for d in dims] + y = lax.broadcast_in_dim(x, eqn.params["shape"], kept_dims) + jax_util.safe_map(write, eqn.outvars, [y]) + broadcasts_outputs[eqn.outvars[0]] = (x, kept_dims) + else: + inputs = jax_util.safe_map(read, eqn.invars) + evaluate_eqn(eqn, inputs, write) + broadcasts_outputs[eqn.outvars[0]] = ( + inputs[0], eqn.params["broadcast_dimensions"]) + else: + evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) + return jax.tree_unflatten(out_tree, jax_util.safe_map(read, jaxpr.outvars)) + + return merged_func + + +class JaxGraph(NamedTuple): + jaxpr: Any + consts: Any + params: Any + params_tree: Any + in_tree: Any + out_tree: Any + digraph: nx.DiGraph + tagging_func: Any + + +SPECIAL_OP_COMPARE_RULES = dict() + + +def default_compare(node1, node2): + if node1["op"] != node2["op"]: + return False + params1, params2 = node1["eqn"].params, node2["eqn"].params + if set(params1.keys()) != set(params2.keys()): + return False + for k in params1.keys(): + if params1[k] != params2[k]: + return False + return True + + +def reshape_compare(node1, node2): + """Compares two reshape nodes.""" + assert node1["op"] == node2["op"] == "reshape" + params1, params2 = node1["eqn"].params, node2["eqn"].params + if params1["dimensions"] != params2["dimensions"]: + return False + return True + + +def broadcast_in_dim_compare(node1, node2): + """Compares two reshape nodes.""" + assert node1["op"] == node2["op"] == "broadcast_in_dim" + return True + + +def conv_compare(node1, node2): + """Compares two conv_general_dialted nodes.""" + assert node1["op"] == node2["op"] == "conv_general_dilated" + params1, params2 = node1["eqn"].params, node2["eqn"].params + for k in ("window_strides", "padding", "lhs_dilation", "rhs_dilation", + "lhs_shape", "rhs_shape"): + if len(params1[k]) != len(params2[k]): + return False + if (len(params1["dimension_numbers"].lhs_spec) != # + len(params2["dimension_numbers"].lhs_spec)): + return False + if (len(params1["dimension_numbers"].rhs_spec) != # + len(params2["dimension_numbers"].rhs_spec)): + return False + if (len(params1["dimension_numbers"].out_spec) != # + len(params2["dimension_numbers"].out_spec)): + return False + if ((params1["feature_group_count"] > 1) != # + (params2["feature_group_count"] > 1)): + return False + if ((params1["batch_group_count"] > 1) != # + (params2["batch_group_count"] > 1)): + return False + return True + + +SPECIAL_OP_COMPARE_RULES["reshape"] = reshape_compare +SPECIAL_OP_COMPARE_RULES["broadcast_in_dim"] = broadcast_in_dim_compare +SPECIAL_OP_COMPARE_RULES["conv_general_dilated"] = conv_compare + + +def kfac_node_match(node1, node2): + """Checks if two nodes are equivalent.""" + # Parameters match with each other and nothing else + if node1["op"] == "param" and node2["op"] == "param": + return True + # return node1["rank"] == node2["rank"] + if node1["op"] == "param" or node2["op"] == "param": + return False + # Arrays always match each other and nothing else + if node1["op"] == "array" and node2["op"] == "array": + return True + if node1["op"] == "array" or node2["op"] == "array": + return False + # Operators match first on name + if node1["op"] != node2["op"]: + return False + compare = SPECIAL_OP_COMPARE_RULES.get(node1["op"], default_compare) + return compare(node1, node2) + + +def var_to_str(var): + """Returns a string representation of the variable of a Jax expression.""" + if isinstance(var, jax.core.Literal): + return str(var) + elif isinstance(var, jax.core.UnitVar): + return "*" + elif not isinstance(var, jax.core.Var): + raise ValueError(f"Idk what to do with this {type(var)}?") + c = int(var.count) + if c == -1: + return "_" + str_rep = "" + while c > 25: + str_rep += chr(c % 26 + ord("a")) + c = c // 26 + str_rep += chr(c + ord("a")) + return str_rep[::-1] + + +def extract_param_vars_flat(jaxpr, in_tree, params_index): + if params_index is None: + params_index = [] + elif isinstance(params_index, int): + params_index = [params_index] + in_vars = jax.tree_unflatten(in_tree, jaxpr.invars) + return jax.tree_flatten([in_vars[i] for i in params_index]) + + +def fill_jaxpr_to_graph(graph, jaxpr, in_vars=None, out_vars=None): + """Fills the graph with the jaxpr.""" + in_vars = in_vars or [var_to_str(v) for v in jaxpr.invars + jaxpr.constvars] + in_map = dict(zip(jaxpr.invars + jaxpr.constvars, in_vars)) + out_vars = out_vars or [var_to_str(v) for v in jaxpr.outvars] + out_map = dict(zip(jaxpr.outvars, out_vars)) + + for eqn in jaxpr.eqns: + in_vars = [] + for v in eqn.invars: + if isinstance(v, (jax.core.Literal, jax.core.UnitVar)): + in_vars.append(var_to_str(v)) + else: + in_vars.append(in_map.get(v, var_to_str(v))) + out_vars = [out_map.get(v, var_to_str(v)) for v in eqn.outvars] + in_str = ",".join(in_vars) + out_str = ",".join(out_vars) + if isinstance(eqn.primitive, tags.LossTag): + func_name = "__loss_tag" + elif isinstance(eqn.primitive, tags.LayerTag): + func_name = "__layer_tag" + else: + func_name = eqn.primitive.name + node_c = f"{func_name}({in_str})->{out_str}" + graph.add_node(node_c, op=eqn.primitive.name, eqn=eqn) + + # Create incoming edges + for v, name in zip(eqn.invars, in_vars): + if (not isinstance(v, jax.core.Literal) and + not isinstance(v, jax.core.UnitVar)): + graph.add_edge(name, node_c) + + # Create output nodes and edges + for v, name in zip(eqn.outvars, out_vars): + graph.add_node(name, op="array", var=v) + graph.add_edge(node_c, name) + + +def create_digraph(jaxpr, params): + """Creates a directed graph from the given jaxpr and parameters.""" + graph = nx.DiGraph() + # Create input nodes + for v in jaxpr.invars + jaxpr.constvars: + if v in params: + graph.add_node(var_to_str(v), op="param", var=v) + else: + graph.add_node(var_to_str(v), op="array", var=v) + fill_jaxpr_to_graph(graph, jaxpr) + + return graph + + +def function_to_jax_graph(func, args, params_index, tagging_func=None): + """Creates a `JaxGraph` instance from the provided function.""" + in_tree = jax.tree_structure(args) + typed_jaxpr = jax.make_jaxpr(func)(*args) + jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals + params, params_tree = extract_param_vars_flat(jaxpr, in_tree, params_index) + + digraph = create_digraph(jaxpr, params) + if tagging_func is not None: + tagging_func = functools.partial(tagging_func, jaxpr) + return JaxGraph( + jaxpr=jaxpr, + consts=consts, + params=params, + params_tree=params_tree, + in_tree=in_tree, + out_tree=None, + digraph=digraph, + tagging_func=tagging_func) + + +def print_nice_jaxpr(jaxpr): + for eqn in jaxpr.eqns: + print(tuple(eqn.invars), "->", eqn.primitive.name, tuple(eqn.outvars)) + + +def auto_register_tags(func, + func_args, + params_index: int = 0, + register_only_generic: bool = False, + compute_only_loss_tags: bool = True, + patterns_to_skip: Sequence[str] = ()): + """Transform the function to one that is populated with tags.""" + func = broadcast_merger(func) + graph = function_to_jax_graph(func, func_args, params_index=params_index) + matches = dict() + + # Extract the tagged losses variables and all their ancestors + loss_output_vars = [] + num_losses = 0 + loss_ancestors = set() + for node in graph.digraph.nodes: + if node.startswith("__loss_tag"): + num_losses += 1 + ancestors = nx.ancestors(graph.digraph, node) + ancestors.add(node) + for output_node in node.split("->")[-1].split(","): + ancestors.add(output_node) + loss_output_vars.append(graph.digraph.nodes[output_node]["var"]) + loss_ancestors = loss_ancestors.union(ancestors) + loss_output_vars = tuple(loss_output_vars) + + # Extract the sub-graph that leads to losses + sub_graph = nx.induced_subgraph(graph.digraph, loss_ancestors) + + # First collect all parameters that are already part of a layer tag + tagged_params = dict() + pattern_counters = dict() + for tag_node in ( + node for node in sub_graph.nodes if node.startswith("__layer_tag")): + inputs = graph.digraph.nodes[tag_node]["eqn"].invars + tag_instance = graph.digraph.nodes[tag_node]["eqn"].primitive + if tag_instance.name == "generic_tag": + tag_params = tag_instance.split_all_inputs(inputs)[0] + else: + tag_params = tag_instance.split_all_inputs(inputs)[2] + pattern_number = pattern_counters.get(tag_instance.name, 0) + for param in tag_params: + if param not in graph.params: + raise ValueError(f"You have registered a layer tag with parameter " + f"that is not part of the parameters at index " + f"{params_index}.") + if param in tagged_params: + raise ValueError(f"You have registered twice the parameter {param}.") + tagged_params[param] = f"Manual[{tag_instance.name}_{pattern_number}]" + if tag_instance.name not in pattern_counters: + pattern_counters[tag_instance.name] = 1 + else: + pattern_counters[tag_instance.name] += 1 + + if not register_only_generic: + for pattern_name, patterns in get_graph_patterns(): + if pattern_name in patterns_to_skip: + logging.info("Skipping graph pattern %s", pattern_name) + continue + logging.info("Matching graph pattern %s", pattern_name) + for pattern in patterns: + for match_map in match_pattern(pattern.digraph, sub_graph): + if len(pattern.jaxpr.outvars) > 1: + raise NotImplementedError() + output = pattern.jaxpr.outvars[0] + if matches.get(match_map[output]) is not None: + raise ValueError(f"Found more than one match for equation " + f"{match_map[output]}. Examine the jaxpr:\n " + f"{graph.jaxpr}") + # Mark the parameters as already tagged + match_params = set() + match_params_already_tagged = False + for param in match_map.values(): + if param in graph.params: + match_params.add(param) + if param in tagged_params.keys(): + match_params_already_tagged = True + # Register the match only if no parameters are already registered + if not match_params_already_tagged: + matches[match_map[output]] = (match_map, pattern.tagging_func) + pattern_number = pattern_counters.get(pattern_name, 0) + for param in match_params: + tagged_params[param] = f"Auto[{pattern_name}_{pattern_number}]" + if pattern_name not in pattern_counters: + pattern_counters[pattern_name] = 1 + else: + pattern_counters[pattern_name] += 1 + + # Mark remaining parameters as orphans + orphan_params = sorted( + set(graph.params) - set(tagged_params.keys()), key=lambda v: v.count) + params_regs = [tagged_params.get(p, "Orphan") for p in graph.params] + params_regs = jax.tree_unflatten(graph.params_tree, params_regs) + logging.info("=" * 50) + logging.info("Graph parameter registrations:") + logging.info(params_regs) + logging.info("=" * 50) + + # Construct a function with all of the extra tag registrations + @functools.wraps(func) + def wrapped_auto_registered(*args): + flat_args, _ = jax.tree_flatten(args) + # Mapping from variable -> value + env = {} + + read = functools.partial(read_env, env) + write = functools.partial(write_env, env) + + def tag(var): + if matches.get(var) is not None: + inv_map, tagging_func = matches[var] + var_map = {k: v for k, v in inv_map.items() if not isinstance(k, str)} + val_map = jax.tree_map(read, var_map) + val = tagging_func(inv_map, val_map) + env[var] = val + + # Bind args and consts to environment + write(jax.core.unitvar, jax.core.unit) + jax_util.safe_map(write, graph.jaxpr.invars, flat_args) + jax_util.safe_map(write, graph.jaxpr.constvars, graph.consts) + + # Register any orphan parameters as generic + for param_var in orphan_params: + write(param_var, tags.register_generic(read(param_var))) + + # Set the correct output variables + if compute_only_loss_tags: + output_vars = loss_output_vars + out_tree = jax.tree_structure(loss_output_vars) + else: + output_vars = graph.jaxpr.outvars + out_tree = graph.out_tree + + # Loop through equations and evaluate primitives using `bind` + losses_evaluated = 0 + for eqn in graph.jaxpr.eqns: + evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) + jax_util.safe_map(tag, eqn.outvars) + + # If we want to output only tagged losses + if isinstance(eqn.primitive, tags.LossTag): + losses_evaluated += 1 + if compute_only_loss_tags and num_losses == losses_evaluated: + break + + outputs = jax_util.safe_map(read, output_vars) + return jax.tree_unflatten(out_tree, outputs) + + return wrapped_auto_registered + + +# Registered graphs +NAME_TO_JAX_GRAPH = dict() +DEFERRED_REGISTRATIONS = [] + + +def register_function(name, func, tagging_func, example_args, params_index, + precedence): + """Registers a function as a pattern in the graph matcher registry. + + The graph matcher needs to trace at least once the full function, which means + you need to provide it with dummy arguments. The shapes of the arguments do + not matter, as the graph matcher ignores their values, however the rank does. + Especially if there is some broadcasting happening you should register with + every possible broadcast pattern. As a general advice avoid using a shape to + be 1, unless you want the pattern to specifically match that, as some + operations, like squeeze for example, can have special behaviour then. + + Args: + name: The name of the pattern that is being registered to. + func: The function that performs the computation. + tagging_func: Function that correctly creates the tag. + example_args: Example arguments that can be inputted into `func`. + params_index: Specifies at which index of the `example_args` are considered + a parameter. + precedence: This specifies what precedence the graph matcher is going to + assign to the provided pattern. The graph matcher will go from lowest to + highest precedence, randomly breaking ties, when matching. Note that the + pattern that matches a parameter with the lowest precedence will get + registered and no other will. Specifically useful when there is a pattern + for a layer with and without bias, in which case the with bias + registration always should go with lower precedence. + """ + + # This is required because we can not use Jax before InitGoogle() runs + def register(): + jnp_args = jax.tree_map(jnp.asarray, example_args) + graph = function_to_jax_graph( + func, jnp_args, params_index=params_index, tagging_func=tagging_func) + if NAME_TO_JAX_GRAPH.get(name) is None: + NAME_TO_JAX_GRAPH[name] = (precedence, []) + assert precedence == NAME_TO_JAX_GRAPH[name][0] + NAME_TO_JAX_GRAPH[name][1].append(graph) + + DEFERRED_REGISTRATIONS.append(register) + + +def get_graph_patterns(): + """Returns all graph patterns sorted by their precedence.""" + while DEFERRED_REGISTRATIONS: + DEFERRED_REGISTRATIONS.pop()() + return [(name, pattern) for name, (_, pattern) in sorted( + NAME_TO_JAX_GRAPH.items(), key=lambda pair: pair[1][0])] + + +# Dense with bias +register_function( + "dense_with_bias", + tags.dense_func, + tags.dense_tagging, + [np.zeros([11, 13]), [np.zeros([13, 7]), np.zeros([7])]], + params_index=1, + precedence=0) + +# Dense without bias +register_function( + "dense_no_bias", + tags.dense_func, + tags.dense_tagging, [np.zeros([11, 13]), [np.zeros([13, 7])]], + params_index=1, + precedence=1) + +# Conv2d with bias +register_function( + "conv2d_with_bias", + tags.conv2d_func, + tags.conv2d_tagging, + [np.zeros([2, 8, 8, 5]), [np.zeros([3, 3, 5, 4]), + np.zeros([4])]], + params_index=1, + precedence=0) + +# Conv2d without bias +register_function( + "conv2d_no_bias", + tags.conv2d_func, + tags.conv2d_tagging, [np.zeros([2, 8, 8, 5]), [np.zeros([3, 3, 5, 4])]], + params_index=1, + precedence=1) + +# Standard scale and shift with both scale and shift +register_function( + "scale_and_shift", + functools.partial( + tags.scale_and_shift_func, has_scale=True, has_shift=True), + functools.partial( + tags.scale_and_shift_tagging, has_scale=True, has_shift=True), + [np.zeros([2, 13]), [np.zeros([13]), np.zeros([13])]], + params_index=1, + precedence=0) + +# Same but no broadcasting +register_function( + "scale_and_shift", + functools.partial( + tags.scale_and_shift_func, has_scale=True, has_shift=True), + functools.partial( + tags.scale_and_shift_tagging, has_scale=True, has_shift=True), + [np.zeros([13]), [np.zeros([13]), np.zeros([13])]], + params_index=1, + precedence=0) + +# Scale and shift as implemented in batch norm layers in Haiku +register_function( + "scale_and_shift", + tags.batch_norm_func, + functools.partial( + tags.batch_norm_tagging_func, has_scale=True, has_shift=True), + [[np.zeros([2, 13]), np.zeros([13])], [np.zeros([13]), + np.zeros([13])]], + params_index=1, + precedence=0) + +# Same but no broadcasting +register_function( + "scale_and_shift", + tags.batch_norm_func, + functools.partial( + tags.batch_norm_tagging_func, has_scale=True, has_shift=True), + [[np.zeros([13]), np.zeros([13])], [np.zeros([13]), + np.zeros([13])]], + params_index=1, + precedence=0) + +# Only scale +register_function( + "scale_only", + functools.partial( + tags.scale_and_shift_func, has_scale=True, has_shift=False), + functools.partial( + tags.scale_and_shift_tagging, has_scale=True, has_shift=False), + [np.zeros([2, 13]), [np.zeros([13])]], + params_index=1, + precedence=1) diff --git a/kfac_ferminet_alpha/tests/common.py b/kfac_ferminet_alpha/tests/common.py new file mode 100644 index 0000000..0b40492 --- /dev/null +++ b/kfac_ferminet_alpha/tests/common.py @@ -0,0 +1,76 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common functions used across more than one test.""" +import jax +import jax.numpy as jnp +import jax.random as jnr + +from kfac_ferminet_alpha import loss_functions + + +def fully_connected_layer(params, x): + w, b = params + return jnp.matmul(x, w) + b[None] + + +def init_autoencoder(key, data_shape): + """Initialize the standard autoencoder.""" + assert len(data_shape) == 1 + x_size = data_shape[0] + sizes = [x_size, 1000, 500, 250, 30, 250, 500, 1000, x_size] + keys = jnr.split(key, len(sizes) - 1) + params = [] + for key, dim_in, dim_out in zip(keys, sizes, sizes[1:]): + # Glorot uniform initialization + c = jnp.sqrt(6 / (dim_in + dim_out)) + w = jax.random.uniform(key, shape=(dim_in, dim_out), minval=-c, maxval=c) + b = jnp.zeros([dim_out]) + params.append((w, b)) + return params + + +def autoencoder(all_params, x_in): + """Evaluate the standard autoencoder. + + Note that the objective of this autoencoder is not standard, bur rather a sum + of the standard sigmoid crossentropy and squared loss. The reason for this is + to test on handling multiple losses. + + Args: + all_params: All parameter values. + x_in: Inputs to the network. + + Returns: + The value of the two losses and intermediate layer values. + """ + h_in = x_in + layers_values = [] + for i, params in enumerate(all_params): + h_out = fully_connected_layer(params, h_in) + layers_values.append((h_out, h_in)) + # Last layer does not have a nonlinearity + if i % 4 != 3: + # h_in = nn.leaky_relu(h_out) + h_in = jnp.tanh(h_out) + else: + h_in = h_out + h1, _ = loss_functions.register_normal_predictive_distribution(h_in, x_in) + h2, _ = loss_functions.register_normal_predictive_distribution( + h_in, targets=x_in, weight=0.1) + l1 = (h1 - x_in)**2 + jnp.log(jnp.pi) / 2 + l1 = jnp.sum(l1, axis=-1) + l2 = (h2 - x_in)**2 + jnp.log(jnp.pi) / 2 + l2 = jnp.sum(l2, axis=-1) + return [l1, l2 * 0.1], layers_values diff --git a/kfac_ferminet_alpha/tests/graph_matcher_test.py b/kfac_ferminet_alpha/tests/graph_matcher_test.py new file mode 100644 index 0000000..18a7c60 --- /dev/null +++ b/kfac_ferminet_alpha/tests/graph_matcher_test.py @@ -0,0 +1,85 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +import jax +import jax.numpy as jnp +import jax.random as jnr +import jax.test_util as jtu + +from kfac_ferminet_alpha import layers_and_loss_tags +from kfac_ferminet_alpha import loss_functions +from kfac_ferminet_alpha import tag_graph_matcher +from kfac_ferminet_alpha.tests import common + + +def tagged_autoencoder(all_params, x_in): + h_in = x_in + layers_values = [] + for i, params in enumerate(all_params): + h_out = common.fully_connected_layer(params, h_in) + h_out = layers_and_loss_tags.register_dense(h_out, h_in, params[0], + params[1],) + layers_values.append((h_out, h_in)) + # Last layer does not have a nonlinearity + if i % 4 != 3: + h_in = jnp.tanh(h_out) + else: + h_in = h_out + h1, _ = loss_functions.register_normal_predictive_distribution( + h_in, targets=x_in, weight=1.0) + h2, t2 = loss_functions.register_normal_predictive_distribution( + h_in, targets=x_in, weight=0.1) + return [[h1, t2], [h2, t2]] + + +class TestGraphMatcher(jtu.JaxTestCase): + """Class for running all of the tests for integrating the systems.""" + + def _test_jaxpr(self, init_func, model_func, tagged_model, data_shape): + data_shape = tuple(data_shape) + rng_key = jnr.PRNGKey(12345) + init_key, data_key = jnr.split(rng_key) + params = init_func(init_key, data_shape) + data = jnr.normal(data_key, (11,) + data_shape) + func = tag_graph_matcher.auto_register_tags(model_func, (params, data)) + jaxpr = jax.make_jaxpr(func)(params, data).jaxpr + tagged_jaxpr = jax.make_jaxpr(tagged_model)(params, data).jaxpr + self.assertEqual(len(jaxpr.invars), len(tagged_jaxpr.invars)) + self.assertEqual(len(jaxpr.constvars), len(tagged_jaxpr.constvars)) + self.assertEqual(len(jaxpr.outvars), len(tagged_jaxpr.outvars)) + for eq, tagged_eq in zip(jaxpr.eqns, tagged_jaxpr.eqns): + eq_in_vars = [v for v in eq.invars if not isinstance(v, jax.core.UnitVar)] + tagged_in_vars = [ + v for v in tagged_eq.invars if not isinstance(v, jax.core.UnitVar) + ] + self.assertEqual(len(eq_in_vars), len(tagged_in_vars)) + self.assertEqual(len(eq.outvars), len(tagged_eq.outvars)) + self.assertEqual(eq.primitive, tagged_eq.primitive) + for variable, t_variable in zip(eq_in_vars + eq.outvars, + tagged_in_vars + tagged_eq.outvars): + if isinstance(variable, jax.core.Literal): + self.assertEqual(variable.aval, t_variable.aval) + else: + if variable.count != t_variable.count: + print("0") + self.assertEqual(variable.count, t_variable.count) + + def test_autoencoder(self): + self._test_jaxpr(common.init_autoencoder, common.autoencoder, + tagged_autoencoder, [784]) + + +if __name__ == "__main__": + absltest.main() diff --git a/kfac_ferminet_alpha/tests/tracer_test.py b/kfac_ferminet_alpha/tests/tracer_test.py new file mode 100644 index 0000000..8236d4d --- /dev/null +++ b/kfac_ferminet_alpha/tests/tracer_test.py @@ -0,0 +1,198 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +import jax +import jax.numpy as jnp +import jax.random as jnr +import jax.test_util as jtu + +from kfac_ferminet_alpha import loss_functions +from kfac_ferminet_alpha import tag_graph_matcher as tgm +from kfac_ferminet_alpha import tracer +from kfac_ferminet_alpha import utils +from kfac_ferminet_alpha.tests import common + + +def autoencoder_aux(all_aux, all_params, x_in): + h_in = x_in + layers_values = [] + for i, (params, aux) in enumerate(zip(all_params, all_aux)): + h_out = common.fully_connected_layer(params, h_in + aux[1]) + aux[0] + layers_values.append((h_out, h_in)) + # Last layer does not have a nonlinearity + if i % 4 != 3: + h_in = jnp.tanh(h_out) + else: + h_in = h_out + h1, _ = loss_functions.register_normal_predictive_distribution(h_in, x_in) + h2, _ = loss_functions.register_normal_predictive_distribution( + h_in, targets=x_in, weight=0.1) + l1 = (h1 - x_in)**2 + jnp.log(jnp.pi) / 2 + l2 = (h2 - x_in)**2 + jnp.log(jnp.pi) / 2 + return [l1, l2 * 0.1], layers_values + + +class TestTracer(jtu.JaxTestCase): + """Class for running all of the tests for integrating the systems.""" + + @staticmethod + def generate_data(init_func, func, data_shape, rng_key): + n = 3 + + rng_key, key = jnr.split(rng_key) + params = init_func(key, data_shape) + rng_key, key = jnr.split(rng_key) + p_tangents = init_func(key, data_shape) + rng_key, key = jnr.split(rng_key) + data = jnr.normal(key, [n] + data_shape) + + loss_vals, layer_vals = func(params, data) + h = layer_vals[-1][0] + keys = jnr.split(key, len(loss_vals)) + h_tangents = tuple(jnr.normal(key, shape=h.shape) for key in keys) + + return params, data, p_tangents, h_tangents + + def assertStructureAllClose(self, x, y, **kwargs): + x_v, x_tree = jax.tree_flatten(x) + y_v, y_tree = jax.tree_flatten(y) + self.assertEqual(x_tree, y_tree) + for xi, yi in zip(x_v, y_v): + self.assertEqual(xi.shape, yi.shape) + self.assertAllClose(xi, yi, check_dtypes=True, **kwargs) + + def test_tacer_jvp(self): + init_func = common.init_autoencoder + func = common.autoencoder + data_shape = [784] + rng_key = jnr.PRNGKey(12345) + params, data, p_tangents, _ = self.generate_data(init_func, func, + data_shape, rng_key) + + def no_data_func(args): + outputs = func(args, data) + return outputs[0], outputs[1][-1][0] + + # True computation + (primals_out, tangents_out) = jax.jvp(no_data_func, [params], [p_tangents]) + loss_vals, _ = primals_out + _, h_tangents = tangents_out + loss_tangents = ((h_tangents,),) * len(loss_vals) + # Tracer computation + tracer_jvp = tracer.trace_losses_matrix_vector_jvp(func) + tracer_losses, tracer_loss_tangents = tracer_jvp((params, data), p_tangents) + tracer_losses = [loss.evaluate(None) for loss in tracer_losses] + + self.assertStructureAllClose(loss_vals, tracer_losses) + self.assertStructureAllClose(loss_tangents, tracer_loss_tangents) + + def test_tracer_vjp(self): + init_func = common.init_autoencoder + func = common.autoencoder + data_shape = [784] + rng_key = jnr.PRNGKey(12345) + params, data, _, h_tangents = self.generate_data(init_func, func, + data_shape, rng_key) + + def no_data_func(args): + outputs = func(args, data) + return outputs[0], outputs[1][-1][0] + + # True computation + (loss_vals, _), vjp_func = jax.vjp(no_data_func, params) + loss_tangents = jax.tree_map(jnp.zeros_like, loss_vals) + summed_h_tangents = sum(jax.tree_flatten(h_tangents)[0]) + p_tangents = vjp_func((loss_tangents, summed_h_tangents)) + # Tracer computation + trace_vjp = tracer.trace_losses_matrix_vector_vjp(func) + tracer_losses, tracer_vjp_func = trace_vjp(params, data) + tracer_losses = [loss.evaluate(None) for loss in tracer_losses] + tracer_p_tangents = tracer_vjp_func(h_tangents) + + self.assertStructureAllClose(loss_vals, tracer_losses) + self.assertStructureAllClose(p_tangents, tracer_p_tangents) + + def test_tracer_hvp(self): + init_func = common.init_autoencoder + func = common.autoencoder + data_shape = [784] + rng_key = jnr.PRNGKey(12345) + params, data, p_tangents, _ = self.generate_data(init_func, func, + data_shape, rng_key) + + def no_data_func(args): + outputs = func(args, data) + return sum(jax.tree_map(jnp.sum, outputs[0])) + + # True computation + grad_func = jax.grad(no_data_func) + + def grad_time_tangents(args): + return utils.inner_product(grad_func(args), p_tangents) + + hvp = jax.grad(grad_time_tangents) + hvp_vectors = hvp(params) + # Tracer computation + tracer_hvp = tracer.trace_losses_matrix_vector_hvp(func) + tracer_hvp_vectors = tracer_hvp((params, data), p_tangents) + + self.assertStructureAllClose(hvp_vectors, tracer_hvp_vectors, atol=1e-4) + + def test_trace_estimator(self): + init_func = common.init_autoencoder + func = common.autoencoder + aux_func = autoencoder_aux + data_shape = [784] + rng_key = jnr.PRNGKey(12345) + params, data, _, h_tangents = self.generate_data(init_func, func, + data_shape, rng_key) + + def aux_last_layer(aux, args): + outs = aux_func(aux, args, data) + return outs[1][-1][0] + + # True computation + loss_vals, layer_vals = func(params, data) + aux_vals = jax.tree_map(jnp.zeros_like, layer_vals) + _, vjp = jax.vjp(aux_last_layer, aux_vals, params) + summed_h_tangents = sum(jax.tree_flatten(h_tangents)[0]) + aux_tangents, p_tangents = vjp(summed_h_tangents) + layers_info = [] + for aux_p, p_p in zip(layer_vals, params): + info = dict() + info["outputs"] = (aux_p[0],) + info["inputs"] = (aux_p[1],) + info["params"] = (p_p[0], p_p[1]) + layers_info.append(info) + for i, (aux_t, p_t) in enumerate(zip(aux_tangents, p_tangents)): + info = dict() + info["outputs_tangent"] = (aux_t[0],) + info["inputs_tangent"] = (aux_t[1],) + info["params_tangent"] = (p_t[0], p_t[1]) + layers_info[i].update(info) + layers_info = tuple(layers_info) + + func = tgm.auto_register_tags(func, (params, data)) + tracer_vjp = tracer.trace_estimator_vjp(func) + tracer_losses, tracer_vjp_func = tracer_vjp((params, data)) + tracer_losses = [loss.evaluate(None) for loss in tracer_losses] + tracer_outputs = tracer_vjp_func((h_tangents[:1], h_tangents[1:])) + + self.assertStructureAllClose(loss_vals, tracer_losses) + self.assertStructureAllClose(tracer_outputs, layers_info) + + +if __name__ == "__main__": + absltest.main() diff --git a/kfac_ferminet_alpha/tracer.py b/kfac_ferminet_alpha/tracer.py new file mode 100644 index 0000000..4882b20 --- /dev/null +++ b/kfac_ferminet_alpha/tracer.py @@ -0,0 +1,327 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for the Jax tracer functionality for tags.""" +import functools +from typing import Any, Callable, Sequence, Tuple + +import jax +from jax import core +from jax import util as jax_util +import jax.numpy as jnp + +from kfac_ferminet_alpha import layers_and_loss_tags as tags +from kfac_ferminet_alpha import tag_graph_matcher as tgm +from kfac_ferminet_alpha import utils + +_Function = Callable[[Any], Any] +_Loss = tags.LossTag + + +def extract_tags( + jaxpr: core.Jaxpr +) -> Tuple[Sequence[core.JaxprEqn], Sequence[core.JaxprEqn]]: + """Extracts all of the tag equations.""" + # Loop through equations and evaluate primitives using `bind` + layer_tags = [] + loss_tags = [] + for eqn in jaxpr.eqns: + if isinstance(eqn.primitive, tags.LossTag): + loss_tags.append(eqn) + elif isinstance(eqn.primitive, tags.LayerTag): + layer_tags.append(eqn) + return tuple(layer_tags), tuple(loss_tags) + + +def construct_compute_losses_inputs( + jaxpr: core.Jaxpr, + consts: Tuple[Any], + num_losses: int, + primals: Any, + params_index: int) -> Callable[[Any], Sequence[Sequence[jnp.ndarray]]]: + """Constructs a function that computes all of the inputs to all losses.""" + primals_ = list(primals) + + def forward_compute_losses( + params_primals: Any, + ) -> Sequence[Sequence[jnp.ndarray]]: + primals_[params_index] = params_primals + flat_args = jax.tree_flatten(primals_)[0] + # Mapping from variable -> value + env = dict() + read = functools.partial(tgm.read_env, env) + write = functools.partial(tgm.write_env, env) + + # Bind args and consts to environment + write(jax.core.unitvar, jax.core.unit) + jax_util.safe_map(write, jaxpr.invars, flat_args) + jax_util.safe_map(write, jaxpr.constvars, consts) + + # Loop through equations and evaluate primitives using `bind` + losses_so_far = 0 + loss_tags = [] + for eqn in jaxpr.eqns: + tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) + if isinstance(eqn.primitive, tags.LossTag): + loss_tags.append(eqn) + losses_so_far += 1 + if num_losses is not None and losses_so_far == num_losses: + break + return tuple(tuple(read(v) for v in tag.invars) for tag in loss_tags) + # return tuple(jax_util.safe_map(read, tag.invars) for tag in loss_tags) + return forward_compute_losses + + +# We know when `.primitive` will be either a `LossTag` or a `LayerTag`, however +# pytype cannot infer its subclass, so we need to unbox it. + + +def _unbox_loss_tag(jaxpr_eqn: core.JaxprEqn) -> tags.LossTag: + assert isinstance(jaxpr_eqn.primitive, tags.LossTag) + return jaxpr_eqn.primitive + + +def _unbox_layer_tag(jaxpr_eqn: core.JaxprEqn) -> tags.LayerTag: + assert isinstance(jaxpr_eqn.primitive, tags.LayerTag) + return jaxpr_eqn.primitive + + +def trace_losses_matrix_vector_vjp(tagged_func: _Function, + params_index: int = 0): + """Returns the Jacobian-transposed vector product (backward mode) function in equivalent form to jax.vjp.""" + def vjp(*primals): + typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals) + jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals + _, loss_jaxpr_eqns = extract_tags(jaxpr) + n = len(loss_jaxpr_eqns) + losses_func = construct_compute_losses_inputs( + jaxpr, consts, n, primals, params_index) + losses_inputs, full_vjp_func = jax.vjp(losses_func, primals[params_index]) + losses = [] + for jaxpr_eqn, inputs in zip(loss_jaxpr_eqns, losses_inputs): + loss_tag = _unbox_loss_tag(jaxpr_eqn) + losses.append(loss_tag.loss(*inputs, weight=jaxpr_eqn.params["weight"])) + losses = tuple(losses) + + def vjp_func(tangents): + flat_tangents = jax.tree_flatten(tangents)[0] + loss_invars = [] + loss_targets = [] + for jaxpr_eqn, inputs in zip(loss_jaxpr_eqns, losses_inputs): + num_inputs = _unbox_loss_tag(jaxpr_eqn).num_inputs + loss_invars.append(tuple(jaxpr_eqn.invars[:num_inputs])) + loss_targets.append(inputs[num_inputs:]) + treedef = jax.tree_structure(loss_invars) + tangents = jax.tree_unflatten(treedef, flat_tangents) + # Since the losses could also take and targets as inputs and we don't want + # this function to computes vjp w.r.t to those (e.g. the user should not + # be providing tangent vectors for the targets, only for inputs) we have + # to manually fill in these "extra" tangents with zeros. + targets_tangents = jax.tree_map(jnp.zeros_like, loss_targets) + tangents = tuple(ti + tti for ti, tti in zip(tangents, targets_tangents)) + input_tangents = full_vjp_func(tangents)[0] + return input_tangents, + return losses, vjp_func + return vjp + + +def trace_losses_matrix_vector_jvp( + tagged_func: _Function, + params_index: int = 0): + """Returns the Jacobian vector product (forward mode) function in equivalent form to jax.jvp.""" + def jvp(primals, params_tangents): + typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals) + jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals + _, loss_tags = extract_tags(jaxpr) + n = len(loss_tags) + losses_func = construct_compute_losses_inputs(jaxpr, consts, n, + primals, params_index) + primals = (primals[params_index],) + tangents = (params_tangents,) + (primals_out, tangents_out) = jax.jvp(losses_func, primals, tangents) + tangents_out = tuple(tuple(t[:tag.primitive.num_inputs]) + for t, tag in zip(tangents_out, loss_tags)) + losses = tuple(tag.primitive.loss(*inputs, weight=tag.params["weight"]) + for tag, inputs in zip(loss_tags, primals_out)) + return losses, tangents_out + return jvp + + +def trace_losses_matrix_vector_hvp(tagged_func, params_index=0): + """Returns the Hessian vector product function of **the tagged losses**, rather than the output value of `tagged_func`.""" + # The function uses backward-over-forward mode. + + def hvp(primals, params_tangents): + typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals) + jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals + _, loss_tags = extract_tags(jaxpr) + n = len(loss_tags) + losses_func = construct_compute_losses_inputs( + jaxpr, consts, n, primals, params_index) + + def losses_sum(param_primals): + loss_inputs = losses_func(param_primals) + losses = [ + _unbox_loss_tag(jaxpr_eqn).loss( + *inputs, weight=jaxpr_eqn.params["weight"]) + for jaxpr_eqn, inputs in zip(loss_tags, loss_inputs) + ] + # This computes the sum of losses evaluated. Makes it easier as we can + # now use jax.grad rather than jax.vjp for taking derivatives. + return sum(jnp.sum(loss.evaluate(None)) for loss in losses) + + def grads_times_tangents(params_primals): + grads = jax.grad(losses_sum)(params_primals) + return utils.inner_product(grads, params_tangents) + + return jax.grad(grads_times_tangents)(primals[params_index]) + return hvp + + +def trace_estimator_vjp(tagged_func: _Function) -> _Function: + """Creates the function needed for an estimator of curvature matrices. + + Args: + tagged_func: An function that has been annotated with tags both for layers + and losses. + + Returns: + A function with the same signatures as `tagged_func`, which when provided + with inputs returns two things: + 1. The instances of all losses objected that are tagged. + 2. A second function, which when provide with tangent vectors for each + of the loss instances' parameters, returns for every tagged layer a + dictionary containing the following elements: + inputs - The primal values of the inputs to the layer. + outputs - The primal values of the outputs to the layer. + params - The primal values of the layer. + inputs_tangent - The tangent value of layer, given the provided + tangents of the losses. + inputs_tangent - The tangent value of layer, given the provided + tangents of the losses. + inputs_tangent - The tangent value of layer, given the provided + tangents of the losses. + """ + def full_vjp_func(func_args): + # Trace the tagged function + typed_jaxpr = jax.make_jaxpr(tagged_func)(*func_args) + jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals + layer_tags, loss_tags = extract_tags(jaxpr) + + layer_vars_flat = jax.tree_flatten([tag.invars for tag in layer_tags])[0] + layer_input_vars = tuple(set(layer_vars_flat)) + + def forward(): + own_func_args = func_args + # Mapping from variable -> value + env = dict() + read = functools.partial(tgm.read_env, env) + write = functools.partial(tgm.write_env, env) + + # Bind args and consts to environment + write(jax.core.unitvar, jax.core.unit) + jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0]) + jax_util.safe_map(write, jaxpr.constvars, consts) + + # Loop through equations and evaluate primitives using `bind` + num_losses_passed = 0 + for eqn in jaxpr.eqns: + tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) + if isinstance(eqn.primitive, tags.LossTag): + num_losses_passed += 1 + if num_losses_passed == len(loss_tags): + break + if num_losses_passed != len(loss_tags): + raise ValueError("This should be unreachable.") + + return jax_util.safe_map(read, layer_input_vars) + + def forward_aux(aux): + own_func_args = func_args + # Mapping from variable -> value + env = dict() + read = functools.partial(tgm.read_env, env) + def write(var, val): + if not isinstance(var, (jax.core.Literal, jax.core.UnitVar)): + val = val + aux[var] if var in aux else val + env[var] = val + + # Bind args and consts to environment + write(jax.core.unitvar, jax.core.unit) + jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0]) + jax_util.safe_map(write, jaxpr.constvars, consts) + + # Loop through equations and evaluate primitives using `bind` + num_losses_passed = 0 + losses_inputs_values = [] + losses_kwargs_values = [] + for eqn in jaxpr.eqns: + input_values = jax_util.safe_map(read, eqn.invars) + tgm.evaluate_eqn(eqn, input_values, write) + if isinstance(eqn.primitive, tags.LossTag): + loss = eqn.primitive.loss(*input_values, weight=eqn.params["weight"]) + losses_inputs_values.append(loss.inputs) + losses_kwargs_values.append(dict( + targets=loss.targets, + weight=eqn.params["weight"] + )) + num_losses_passed += 1 + if num_losses_passed == len(loss_tags): + break + if num_losses_passed != len(loss_tags): + raise ValueError("This should be unreachable.") + # Read the inputs to the loss functions, but also return the target values + return tuple(losses_inputs_values), tuple(losses_kwargs_values) + + layer_input_values = forward() + primals_dict = dict(zip(layer_input_vars, layer_input_values)) + primals_dict.update(zip(jaxpr.invars, jax.tree_flatten(func_args)[0])) + aux_values = jax.tree_map(jnp.zeros_like, layer_input_values) + aux_dict = dict(zip(layer_input_vars, aux_values)) + + losses_args, aux_vjp, losses_kwargs = jax.vjp(forward_aux, aux_dict, + has_aux=True) + losses = tuple(tag.primitive.loss(*inputs, **kwargs) + for tag, inputs, kwargs in + zip(loss_tags, losses_args, losses_kwargs)) + + def vjp_func(tangents): + all_tangents = aux_vjp(tangents) + tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:] + inputs_tangents = jax.tree_flatten(inputs_tangents)[0] + tangents_dict.update(zip(jaxpr.invars, inputs_tangents)) + + read_primals = functools.partial(tgm.read_env, primals_dict) + read_tangents = functools.partial(tgm.read_env, tangents_dict) + layers_info = [] + for jaxpr_eqn in layer_tags: + layer_tag = _unbox_layer_tag(jaxpr_eqn) + info = dict() + primals = jax_util.safe_map(read_primals, tuple(jaxpr_eqn.invars)) + ( + info["outputs"], + info["inputs"], + info["params"], + ) = layer_tag.split_all_inputs(primals) + tangents = jax_util.safe_map(read_tangents, tuple(jaxpr_eqn.invars)) + ( + info["outputs_tangent"], + info["inputs_tangent"], + info["params_tangent"], + ) = layer_tag.split_all_inputs(tangents) + layers_info.append(info) + return tuple(layers_info) + + return losses, vjp_func + return full_vjp_func diff --git a/kfac_ferminet_alpha/utils.py b/kfac_ferminet_alpha/utils.py new file mode 100644 index 0000000..b7b772e --- /dev/null +++ b/kfac_ferminet_alpha/utils.py @@ -0,0 +1,455 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities related to multi-device operations.""" +import collections +from typing import Any, Mapping, Optional, Sequence, Tuple, TypeVar, Union +import dataclasses +import jax +from jax import core +from jax import lax +import jax.numpy as jnp +from jax.scipy import linalg +import jax.tree_util as tree_util + +T = TypeVar("T") + + +def wrap_if_pmap(p_func): + + def p_func_if_pmap(obj, axis_name): + try: + core.axis_frame(axis_name) + return p_func(obj, axis_name) + except NameError: + return obj + + return p_func_if_pmap + + +pmean_if_pmap = wrap_if_pmap(lax.pmean) +psum_if_pmap = wrap_if_pmap(lax.psum) +compute_mean = jax.pmap(lambda x: lax.pmean(x, "i"), axis_name="i") +compute_sum = jax.pmap(lambda x: lax.psum(x, "i"), axis_name="i") + + +def get_first(obj: T) -> T: + return jax.tree_map(lambda x: x[0], obj) + + +def get_mean(obj: T) -> T: + return get_first(compute_mean(obj)) + + +def get_sum(obj: T) -> T: + return get_first(compute_sum(obj)) + + +broadcast_all_local_devices = jax.pmap(lambda x: x) + + +def replicate_all_local_devices(obj: T) -> T: + n = jax.local_device_count() + obj_stacked = jax.tree_map(lambda x: jnp.stack([x] * n, axis=0), obj) + return broadcast_all_local_devices(obj_stacked) + + +def make_different_rng_key_on_all_devices(rng: jnp.ndarray) -> jnp.ndarray: + rng = jax.random.fold_in(rng, jax.host_id()) + rng = jax.random.split(rng, jax.local_device_count()) + return broadcast_all_local_devices(rng) + + +p_split = jax.pmap(lambda key: tuple(jax.random.split(key))) + + +def scalar_mul(obj: T, scalar: Union[float, jnp.ndarray]) -> T: + return jax.tree_map(lambda x: x * scalar, obj) + + +def scalar_div(obj: T, scalar: Union[float, jnp.ndarray]) -> T: + return jax.tree_map(lambda x: x / scalar, obj) + + +def make_func_args(params, func_state, rng, batch, has_state: bool, + has_rng: bool): + """Correctly puts all arguments to the function together.""" + func_args = (params,) + if has_state: + if func_state is None: + raise ValueError("The `func_state` is None, but the argument `has_state` " + "is True.") + func_args += (func_state,) + if has_rng: + if rng is None: + raise ValueError("The `rng` is None, but the argument `has_rng` is True.") + func_args += (rng,) + func_args += (batch,) + return func_args + + +def extract_func_outputs( + raw_outputs: Any, + has_aux: bool, + has_state: bool, +) -> Tuple[jnp.ndarray, Any, Any]: + """Given the function output returns separately the loss, func_state, aux.""" + if not has_aux and not has_state: + return raw_outputs, None, None + loss, other = raw_outputs + if has_aux and has_state: + func_state, aux = other + elif has_aux: + func_state, aux = None, other + else: + func_state, aux = other, None + return loss, func_state, aux + + +def inner_product(obj1: T, obj2: T) -> jnp.ndarray: + if jax.tree_structure(obj1) != jax.tree_structure(obj2): + raise ValueError("The two structures are not identical.") + elements_product = jax.tree_multimap(lambda x, y: jnp.sum(x * y), obj1, obj2) + return sum(jax.tree_flatten(elements_product)[0]) + + +def psd_inv_cholesky(matrix: jnp.ndarray, damping: jnp.ndarray) -> jnp.ndarray: + assert matrix.ndim == 2 + identity = jnp.eye(matrix.shape[0]) + matrix = matrix + damping * identity + return linalg.solve(matrix, identity, sym_pos=True) + + +def solve_maybe_small(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + """Computes a^-1 b more efficiently for small matrices.""" + assert a.shape[-1] == a.shape[-2] == b.shape[-1] + d = a.shape[-1] + if d == 0: + return a + elif d == 1: + return b / a[..., 0] + elif d == 2: + det = a[..., 0, 0] * a[..., 1, 1] - a[..., 0, 1] * a[..., 1, 0] + b_0 = a[..., 1, 1] * b[..., 0] - a[..., 0, 1] * b[..., 1] + b_1 = a[..., 0, 0] * b[..., 1] - a[..., 1, 0] * b[..., 0] + return jnp.stack([b_0, b_1], axis=-1) / det + elif d == 3: + raise NotImplementedError() + return jnp.linalg.solve(a, b) + + +def pi_adjusted_inverse( + factor_0: jnp.ndarray, + factor_1: jnp.ndarray, + damping: jnp.ndarray, + pmap_axis_name: str, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Performs inversion with pi-adjusted damping.""" + # Compute the norms of each factor + norm_0 = jnp.trace(factor_0) + norm_1 = jnp.trace(factor_1) + + # We need to sync the norms here, because reduction can be non-deterministic. + # They specifically are on GPUs by default for better performance. + # Hence although factor_0 and factor_1 are synced, the trace operation above + # can still produce different answers on different devices. + norm_0, norm_1 = pmean_if_pmap((norm_0, norm_1), axis_name=pmap_axis_name) + + # Compute the overall scale + scale = norm_0 * norm_1 + + def regular_inverse( + operand: Sequence[jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]: + factor0, factor1, norm0, norm1, s, d = operand + # Special cases with one or two scalar factors + if factor0.size == 1 and factor1.size == 1: + value = jnp.ones_like(factor0) / jnp.sqrt(s) + return value, value + if factor0.size == 1: + factor1_normed = factor1 / norm1 + damping1 = d / norm1 + factor1_inv = psd_inv_cholesky(factor1_normed, damping1) + return jnp.full((1, 1), s), factor1_inv + if factor1.size == 1: + factor0_normed = factor0 / norm0 + damping0 = d / norm0 + factor0_inv = psd_inv_cholesky(factor0_normed, damping0) + return factor0_inv, jnp.full((1, 1), s) + + # Invert first factor + factor0_normed = factor0 / norm0 + damping0 = jnp.sqrt(d * factor1.shape[0] / (s * factor0.shape[0])) + factor0_inv = psd_inv_cholesky(factor0_normed, damping0) / jnp.sqrt(s) + + # Invert second factor + factor1_normed = factor1 / norm1 + damping1 = jnp.sqrt(d * factor0.shape[0] / (s * factor1.shape[0])) + factor1_inv = psd_inv_cholesky(factor1_normed, damping1) / jnp.sqrt(s) + return factor0_inv, factor1_inv + + def zero_inverse( + operand: Sequence[jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]: + return (jnp.eye(factor_0.shape[0]) / jnp.sqrt(operand[-1]), + jnp.eye(factor_1.shape[0]) / jnp.sqrt(operand[-1])) + + # In the special case where for some reason one of the factors is zero, then + # the correct inverse of `(0 kron A + lambda I)` is + # `(I/sqrt(lambda) kron (I/sqrt(lambda)`. However, because one of the norms is + # zero, then `pi` and `1/pi` would be 0 and infinity leading to NaN values. + # Hence, we need to make this check explicitly. + return lax.cond( + jnp.greater(scale, 0.0), + regular_inverse, + zero_inverse, + operand=(factor_0, factor_1, norm_0, norm_1, scale, damping)) + + +def convert_value_and_grad_to_value_func( + value_and_grad_func, + has_aux: bool = False, +): + """Converts a value_and_grad function to value_func only.""" + + def value_func(*args, **kwargs): + out, _ = value_and_grad_func(*args, **kwargs) + if has_aux: + return out[0] + else: + return out + + return value_func + + +def check_structure_shapes_and_dtype(obj1: T, obj2: T) -> None: + """Verifies that the two objects have the same pytree structure.""" + assert jax.tree_structure(obj1) == jax.tree_structure(obj2) + for v1, v2 in zip(jax.tree_flatten(obj1)[0], jax.tree_flatten(obj2)[0]): + assert v1.shape == v2.shape + assert v1.dtype == v2.dtype + + +def check_first_dim_is_batch_size(batch_size: int, *args: jnp.ndarray) -> None: + for i, arg in enumerate(args): + if arg.shape[0] != batch_size: + raise ValueError(f"Expecting first dimension of arg[{i}] with shape " + f"{arg.shape} to be equal to the batch size " + f"{batch_size}.") + + +def py_tree_registered_dataclass(cls, *args, **kwargs): + """Creates a new dataclass type and registers it as a pytree node.""" + dcls = dataclasses.dataclass(cls, *args, **kwargs) + tree_util.register_pytree_node( + dcls, + lambda instance: ( # pylint: disable=g-long-lambda + [getattr(instance, f.name) + for f in dataclasses.fields(instance)], None), + lambda _, instance_args: dcls(*instance_args)) + return dcls + + +class WeightedMovingAverage: + """A wrapped class for a variable for which we keep exponential moving average.""" + + def __init__(self, weight: jnp.ndarray, array: jnp.ndarray): + self._weight = weight + self._array = array + + @staticmethod + def zero(shape: Sequence[int]) -> "WeightedMovingAverage": + return WeightedMovingAverage(weight=jnp.zeros([]), array=jnp.zeros(shape)) + + @property + def weight(self) -> jnp.ndarray: + return self._weight + + @property + def value(self) -> jnp.ndarray: + return self._array / self._weight + + @property + def raw_value(self) -> jnp.ndarray: + return self._array + + def update(self, value: jnp.ndarray, old_weight_multiplier: float, + new_weight: float) -> None: + self._weight = old_weight_multiplier * self._weight + new_weight + self._array = old_weight_multiplier * self._array + new_weight * value + + def sync(self, pmap_axis_name: str) -> None: + self._array = pmean_if_pmap(self._array, pmap_axis_name) + + def __str__(self) -> str: + return (f"ExponentialMovingAverage(weight={self._weight}, " + f"array={self._array})") + + def __repr__(self) -> str: + return self.__str__() + + +tree_util.register_pytree_node( + WeightedMovingAverage, + lambda instance: ((instance.weight, instance.raw_value), None), + lambda _, instance_args: WeightedMovingAverage(*instance_args), +) + + +class Stateful: + """A class for stateful objects.""" + + def __init__(self, stateful_fields_names: Optional[Sequence[str]] = ()): + self.__stateful_fields_names = stateful_fields_names + + def _add_stateful_fields_names(self, value: Sequence[str]) -> None: + self.__stateful_fields_names += tuple(value) + + def get_state(self) -> Mapping[str, Any]: + """Returns the state of the object.""" + state = dict() + for name in self.__stateful_fields_names: + state[name] = Stateful._get_state_from_instance(getattr(self, name)) + return state + + def set_state(self, value): + """Sets the state of the object with the provided value and returns the object.""" + assert isinstance(value, dict) + for name in self.__stateful_fields_names: + setattr(self, name, + Stateful._set_state_to_instance(getattr(self, name), value[name])) + return self + + def clear_state(self) -> None: + """Clears the state of the object.""" + for name in self.__stateful_fields_names: + setattr(self, name, + Stateful._clear_state_from_instance(getattr(self, name))) + + def pop_state(self) -> Mapping[str, Any]: + """Returns the current state of the object, while simultaneously clearing it.""" + state = self.get_state() + self.clear_state() + return state + + @staticmethod + def _get_state_from_instance(obj): + """Recursively gets the state of the object and returns it.""" + if isinstance(obj, Stateful): + return obj.get_state() + if isinstance(obj, list): + return [Stateful._get_state_from_instance(i) for i in obj] + if isinstance(obj, tuple): + return tuple(Stateful._get_state_from_instance(i) for i in obj) + if isinstance(obj, collections.OrderedDict): + return collections.OrderedDict( + (k, Stateful._get_state_from_instance(v)) for k, v in obj.items()) + if isinstance(obj, dict): + return dict( + (k, Stateful._get_state_from_instance(v)) for k, v in obj.items()) + return obj + + @staticmethod + def _set_state_to_instance(obj, value): + """Recursively sets the state of the object and returns it.""" + if isinstance(obj, Stateful): + obj.set_state(value) + return obj + if isinstance(value, list): + if obj is None: + obj = [None] * len(value) + return [ + Stateful._set_state_to_instance(obj_i, value_i) + for obj_i, value_i in zip(obj, value) + ] + if isinstance(value, tuple): + if obj is None: + obj = [None] * len(value) + return tuple( + Stateful._set_state_to_instance(obj_i, value_i) + for obj_i, value_i in zip(obj, value)) + if isinstance(value, collections.OrderedDict): + if obj is None: + obj = dict((k, None) for k in value) + return collections.OrderedDict( + (k, Stateful._set_state_to_instance(obj[k], value[k])) for k in obj) + if isinstance(value, dict): + obj = dict((k, None) for k in value) + return dict( + (k, Stateful._set_state_to_instance(obj[k], value[k])) for k in obj) + return value + + @staticmethod + def _clear_state_from_instance(obj): + """Recursively clears the state of the object and returns it.""" + if isinstance(obj, Stateful): + obj.clear_state() + return obj + if isinstance(obj, list): + return [Stateful._clear_state_from_instance(obj_i) for obj_i in obj] + if isinstance(obj, tuple): + return tuple(Stateful._clear_state_from_instance(obj_i) for obj_i in obj) + if isinstance(obj, collections.OrderedDict): + return collections.OrderedDict( + (k, Stateful._clear_state_from_instance(obj[k])) for k in obj) + if isinstance(obj, dict): + return dict((k, Stateful._clear_state_from_instance(obj[k])) for k in obj) + return None + + @staticmethod + def infer_class_state(class_type): + """Infers a stateful class state attributes from class annotations.""" + if not issubclass(class_type, Stateful): + raise ValueError( + f"In order to annotate a class as stateful it must inherit " + f"{Stateful!r}") + + class_type = dataclasses.dataclass( + class_type, init=False, repr=False, eq=False) # pytype: disable=wrong-keyword-args + fields_names = tuple(field.name for field in dataclasses.fields(class_type)) + original_init = getattr(class_type, "__init__", None) + if original_init is None: + + def injected_init(self, *args, **kwargs): + super(self.__class__, self).__init__(*args, **kwargs) # pylint: disable=bad-super-call + Stateful._add_stateful_fields_names(self, fields_names) + for field_name in fields_names: + if getattr(self, field_name, None) is None: + setattr(self, field_name, None) + + setattr(class_type, "__init__", injected_init) + else: + + def injected_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + Stateful._add_stateful_fields_names(self, fields_names) + for field_name in fields_names: + if getattr(self, field_name, None) is None: + setattr(self, field_name, None) + + setattr(class_type, "__init__", injected_init) + return class_type + + +def compute_sq_norm_relative_abs_diff(obj, pmap_axis_name): + sq_norm = inner_product(obj, obj) + synced_sq_norm = psum_if_pmap(sq_norm, pmap_axis_name) + synced_sq_norm = (synced_sq_norm - sq_norm) / (jax.device_count() - 1.0) + sq_norm_abs_diff = jnp.abs(sq_norm - synced_sq_norm) + return sq_norm_abs_diff / sq_norm + + +def product(iterable_object): + x = 1 + for element in iterable_object: + x *= element + return x diff --git a/object_attention_for_reasoning/README.md b/object_attention_for_reasoning/README.md new file mode 100644 index 0000000..b62c2f5 --- /dev/null +++ b/object_attention_for_reasoning/README.md @@ -0,0 +1,53 @@ +Implementation of the object-based transformer model from +["Object-based attention for spatio-temporal reasoning"](https://arxiv.org/abs/2012.08508) +[1]. + +This package includes source code for the transformer model, +pre-trained model parameters for the CLEVRER task, +and MONet [2] latent variables for all videos in the training +and validation sets. It does not include the model training code. +See Section 2 of [1] for details. + +[1] David Ding, Felix Hill, Adam Santoro, Matt Botvinick. *Object-based +attention for spatio-temporal reasoning: Outperforming neuro-symbolic models +with flexible distributed architectures*. +arXiv preprint arXiv:2012.08508, 2020. + +[2] Chris P. Burgess, Loic Matthey, Nick Watters, Rishabh Kabra, Irina Higgins, +Matt Botvinick, and Alexander Lerchner +*MONet: Unsupervised scene decomposition and representation*. +arXiv preprint arXiv:1901.11390, 2019. + + +# Instructions + +Note: This code depends on Tensorflow 1 and Sonnet 1. Tensorflow 1 is only +available on PYPI for Python 3.7 and earlier. + +To run this code, execute the following commands from the `deepmind_research/` +directory: + +```shell +# Download checkpoints and MONet latents +wget https://storage.googleapis.com/object-attention-for-reasoning/checkpoints_and_latents.zip +unzip checkpoints_and_latents.zip +python3.7 -m venv object_based_attention_venv +source object_based_attention_venv/bin/activate +pip install --upgrade setuptools wheel +pip install -r requirements.txt +python -m object_attention_for_reasoning.run_model +``` +If the code runs correctly, you should see the model's predicted answer to two +CLEVRER questions (a descriptive one and a multiple choice one), and both +answers should be correct. + +If you find the provided code useful, please cite this paper: +``` +@article{objectattention2020, + title={Object-based attention for spatio-temporal reasoning: Outperforming + neuro-symbolic models with flexible distributed architectures}, + author={David Ding and Felix Hill and Adam Santoro and Matt Botvinick}, + journal={arXiv preprint arXiv:2012.08508}, + year={2020} +} +``` diff --git a/object_attention_for_reasoning/model.py b/object_attention_for_reasoning/model.py new file mode 100644 index 0000000..98c1435 --- /dev/null +++ b/object_attention_for_reasoning/model.py @@ -0,0 +1,185 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model code. Provided settings are identical to what was used in the paper.""" + +import sonnet as snt +import tensorflow.compat.v1 as tf + +from object_attention_for_reasoning import transformer + + +QUESTION_VOCAB_SIZE = 82 +ANSWER_VOCAB_SIZE = 22 + +MAX_QUESTION_LENGTH = 20 +MAX_CHOICE_LENGTH = 12 + +NUM_CHOICES = 4 +EMBED_DIM = 16 + +PRETRAINED_MODEL_CONFIG = dict( + use_relative_positions=True, + shuffle_objects=True, + transformer_layers=28, + head_size=128, + num_heads=10, + embed_dim=EMBED_DIM, +) + + +def append_ids(tensor, id_vector, axis): + id_vector = tf.constant(id_vector, tf.float32) + for a in range(len(tensor.shape)): + if a != axis: + id_vector = tf.expand_dims(id_vector, axis=a) + tiling_vector = [s if i != axis else 1 for i, s in enumerate(tensor.shape)] + id_tensor = tf.tile(id_vector, tiling_vector) + return tf.concat([tensor, id_tensor], axis=axis) + + +class ClevrerTransformerModel(object): + """Model from Ding et al. 2020 (https://arxiv.org/abs/2012.08508).""" + + def __init__(self, use_relative_positions, shuffle_objects, + transformer_layers, num_heads, head_size, embed_dim): + """Instantiate Sonnet modules.""" + self._embed_dim = embed_dim + self._embed = snt.Embed(QUESTION_VOCAB_SIZE, embed_dim - 2) + self._shuffle_objects = shuffle_objects + self._memory_transformer = transformer.TransformerTower( + value_size=embed_dim + 2, + num_heads=num_heads, + num_layers=transformer_layers, + use_relative_positions=use_relative_positions, + causal=False) + + self._final_layer_mc = snt.Sequential( + [snt.Linear(head_size), tf.nn.relu, snt.Linear(1)]) + self._final_layer_descriptive = snt.Sequential( + [snt.Linear(head_size), tf.nn.relu, + snt.Linear(ANSWER_VOCAB_SIZE)]) + + self._dummy = tf.get_variable("dummy", [embed_dim + 2], tf.float32, + initializer=tf.zeros_initializer) + self._infill_linear = snt.Linear(embed_dim + 2) + self._mask_embedding = tf.get_variable( + "mask", [embed_dim + 2], tf.float32, initializer=tf.zeros_initializer) + + def _apply_transformers(self, lang_embedding, vision_embedding): + """Applies transformer to language and vision input. + + Args: + lang_embedding: tensor, + vision_embedding: tensor, "validation", or "test". + + Returns: + tuple, output at dummy token, all output embeddings, infill loss + """ + def _unroll(tensor): + """Unroll the time dimension into the object dimension.""" + return tf.reshape( + tensor, [tensor.shape[0], -1, tensor.shape[3]]) + + words = append_ids(lang_embedding, [1, 0], axis=2) + dummy_word = tf.tile(self._dummy[None, None, :], [tf.shape(words)[0], 1, 1]) + vision_embedding = append_ids(vision_embedding, [0, 1], axis=3) + vision_over_time = _unroll(vision_embedding) + transformer_input = tf.concat([dummy_word, words, vision_over_time], axis=1) + + output, _ = self._memory_transformer(transformer_input, + is_training=False) + return output[:, 0, :] + + def apply_model_descriptive(self, inputs): + """Applies model to CLEVRER descriptive questions. + + Args: + inputs: dict of form: { + "question": tf.int32 tensor of shape [batch, MAX_QUESTION_LENGTH], + "monet_latents": tf.float32 tensor of shape [batch, frames, 8, 16], + } + Returns: + Tensor of shape [batch, ANSWER_VOCAB_SIZE], representing logits for each + possible answer word. + """ + question = inputs["question"] + + # Shape: [batch, question_len, embed_dim-2] + question_embedding = self._embed(question) + # Shape: [batch, question_len, embed_dim] + question_embedding = append_ids(question_embedding, [0, 1], 2) + choices_embedding = self._embed( + tf.zeros([question.shape[0], MAX_CHOICE_LENGTH], tf.int64)) + choices_embedding = append_ids(choices_embedding, [0, 1], 2) + # Shape: [batch, choices, question_len + choice_len, embed_dim] + lang_embedding = tf.concat([question_embedding, choices_embedding], axis=1) + + # Shape: [batch, frames, num_objects, embed_dim] + vision_embedding = inputs["monet_latents"] + + if self._shuffle_objects: + vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3]) + vision_embedding = tf.random.shuffle(vision_embedding) + vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3]) + + output = self._apply_transformers(lang_embedding, vision_embedding) + output = self._final_layer_descriptive(output) + return output + + def apply_model_mc(self, inputs): + """Applies model to CLEVRER multiple-choice questions. + + Args: + inputs: dict of form: { + "question": tf.int32 tensor of shape [batch, MAX_QUESTION_LENGTH], + "choices": tf.int32 tensor of shape [batch, 4, MAX_CHOICE_LENGTH], + "monet_latents": tf.float32 tensor of shape [batch, frames, 8, 16], + } + Returns: + Tensor of shape [batch, 4], representing logits for each choice + """ + question = inputs["question"] + choices = inputs["choices"] + + # Shape: [batch, question_len, embed_dim-2] + question_embedding = self._embed(question) + # Shape: [batch, question_len, embed_dim] + question_embedding = append_ids(question_embedding, [1, 0], 2) + # Shape: [batch, choices, choice_len, embed_dim-2] + choices_embedding = snt.BatchApply(self._embed)(choices) + # Shape: [batch, choices, choice_len, embed_dim] + choices_embedding = append_ids(choices_embedding, [0, 1], 3) + # Shape: [batch, choices, question_len + choice_len, embed_dim] + lang_embedding = tf.concat([ + tf.tile(question_embedding[:, None], + [1, choices_embedding.shape[1], 1, 1]), + choices_embedding], axis=2) + + # Shape: [batch, frames, num_objects, embed_dim] + vision_embedding = inputs["monet_latents"] + + if self._shuffle_objects: + vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3]) + vision_embedding = tf.random.shuffle(vision_embedding) + vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3]) + + output_per_choice = [] + for c in range(NUM_CHOICES): + output = self._apply_transformers( + lang_embedding[:, c, :, :], vision_embedding) + output_per_choice.append(output) + + output = tf.stack(output_per_choice, axis=1) + output = tf.squeeze(snt.BatchApply(self._final_layer_mc)(output), axis=2) + return output diff --git a/object_attention_for_reasoning/requirements.txt b/object_attention_for_reasoning/requirements.txt new file mode 100644 index 0000000..54d5e90 --- /dev/null +++ b/object_attention_for_reasoning/requirements.txt @@ -0,0 +1,4 @@ +absl-py==0.11.0 +dm-sonnet==1.36 +numpy==1.20.1 +tensorflow==1.15.0 diff --git a/object_attention_for_reasoning/run_model.py b/object_attention_for_reasoning/run_model.py new file mode 100644 index 0000000..97ca284 --- /dev/null +++ b/object_attention_for_reasoning/run_model.py @@ -0,0 +1,163 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Example code for running model on CLEVRER.""" +import json + +from absl import app +from absl import flags +import numpy as np +import tensorflow.compat.v1 as tf + +from object_attention_for_reasoning import model as modellib + + +BATCH_SIZE = 1 +NUM_FRAMES = 25 +NUM_OBJECTS = 8 + +_BASE_DIR = flags.DEFINE_string( + "base_dir", "./clevrer_monet_latents", + "Directory containing checkpoints and MONet latents.") +_SCENE_IDX = flags.DEFINE_string( + "scene_idx", 1000, "Scene index of CLEVRER video.") + + +def load_monet_latents(base_dir, scene_index): + filename = f"{base_dir}/train/{scene_index}.npz" + with open(filename, "rb") as f: + return np.load(f) + + +def _split_string(s): + """Splits string to words and standardize alphabet.""" + return s.lower().replace("?", "").split() + + +def _pad(array, length): + """Pad an array to desired length.""" + return np.pad(array, [(0, length - array.shape[0])], mode="constant") + + +def encode_sentence(token_map, sentence, pad_length): + """Encode CLEVRER question/choice sentences as sequence of token ids.""" + ret = np.array( + [token_map["question_vocab"][w] for w in _split_string(sentence)], + np.int32) + return _pad(ret, pad_length) + + +def encode_choices(token_map, choices): + """Encode CLEVRER choices.""" + arrays = [encode_sentence(token_map, choice["choice"], + modellib.MAX_CHOICE_LENGTH) + for choice in choices] + return _pad(np.stack(arrays, axis=0), modellib.NUM_CHOICES) + + +def main(unused_argv): + base_dir = _BASE_DIR.value + with open(f"{base_dir}/vocab.json", "rb") as f: + token_map = json.load(f) + + reverse_answer_lookup = {v: k for k, v in token_map["answer_vocab"].items()} + + with open(f"{base_dir}/train.json", "rb") as f: + questions_data = json.load(f) + + tf.reset_default_graph() + model = modellib.ClevrerTransformerModel(**modellib.PRETRAINED_MODEL_CONFIG) + + inputs_descriptive = { + "monet_latents": tf.placeholder( + tf.float32, + [BATCH_SIZE, NUM_FRAMES, NUM_OBJECTS, modellib.EMBED_DIM]), + "question": tf.placeholder( + tf.int32, [BATCH_SIZE, modellib.MAX_QUESTION_LENGTH]), + } + + inputs_mc = { + "monet_latents": tf.placeholder( + tf.float32, + [BATCH_SIZE, NUM_FRAMES, NUM_OBJECTS, modellib.EMBED_DIM]), + "question": tf.placeholder(tf.int32, + [BATCH_SIZE, modellib.MAX_QUESTION_LENGTH]), + "choices": tf.placeholder( + tf.int32, [BATCH_SIZE, modellib.NUM_CHOICES, + modellib.MAX_CHOICE_LENGTH]), + } + + output_descriptive = model.apply_model_descriptive(inputs_descriptive) + output_mc = model.apply_model_mc(inputs_mc) + + # Restore from checkpoint + saver = tf.train.Saver() + checkpoint_dir = f"{base_dir}/checkpoints/" + sess = tf.train.SingularMonitoredSession(checkpoint_dir=checkpoint_dir) + ckpt = tf.train.get_checkpoint_state(checkpoint_dir) + saver.restore(sess, ckpt.model_checkpoint_path) + + def eval_descriptive(monet_latents, question_json): + # CLEVRER provides videos with 128 frames. In our model, we subsample 25 + # frames (as was done in Yi et al (2020)). + # For training, we randomize the choice of 25 frames, and for evaluation, we + # sample the 25 frames as evenly as possible. + # We do that by doing strided sampling of the frames. + stride, rem = divmod(monet_latents.shape[0], NUM_FRAMES) + monet_latents = monet_latents[None, :-rem:stride] + assert monet_latents.shape[1] == NUM_FRAMES + question = encode_sentence(token_map, question_json["question"], + modellib.MAX_QUESTION_LENGTH) + batched_question = np.expand_dims(question, axis=0) + logits = sess.run(output_descriptive, feed_dict={ + inputs_descriptive["monet_latents"]: monet_latents, + inputs_descriptive["question"]: batched_question, + }) + descriptive_answer = np.argmax(logits) + return reverse_answer_lookup[descriptive_answer] + + def eval_mc(monet_latents, question_json): + stride, rem = divmod(monet_latents.shape[0], NUM_FRAMES) + monet_latents = monet_latents[None, :-rem:stride] + assert monet_latents.shape[1] == NUM_FRAMES + question = encode_sentence( + token_map, question_json["question"], modellib.MAX_QUESTION_LENGTH) + choices = encode_choices( + token_map, question_json["choices"]) + mc_answer = sess.run(output_mc, feed_dict={ + inputs_mc["monet_latents"]: monet_latents, + inputs_mc["question"]: np.expand_dims(question, axis=0), + inputs_mc["choices"]: np.expand_dims(choices, axis=0), + }) + return mc_answer >= 0 + + sample_scene_idx = _SCENE_IDX.value + question_json = questions_data[sample_scene_idx]["questions"][0] + print("Descriptive Question: ", question_json["question"]) + print("Model Answer: ", + eval_descriptive(load_monet_latents(base_dir, sample_scene_idx), + question_json)) + print("True Answer: ", question_json["answer"]) + + question_json = questions_data[sample_scene_idx]["questions"][-1] + print("Multiple-Choice Question: ", question_json["question"]) + for i, choice_json in enumerate(question_json["choices"]): + print(f"{i+1}) {choice_json['choice']}") + print("Model Answer: ", + eval_mc(load_monet_latents(base_dir, sample_scene_idx), question_json)) + print("True Answer: ", + [choice_json["answer"] for choice_json in question_json["choices"]]) + + +if __name__ == "__main__": + app.run(main) diff --git a/object_attention_for_reasoning/transformer.py b/object_attention_for_reasoning/transformer.py new file mode 100644 index 0000000..c1b5474 --- /dev/null +++ b/object_attention_for_reasoning/transformer.py @@ -0,0 +1,667 @@ +# Fork of Sonnet transformer model with small modifications +# +# Copyright 2017 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Implementation of Transformer networks. + +Size glossary: + * Batch size (B). + * Sequence length (N). + * Memory size (M). The size of the optional memory, passed in via `state`. + * Number of heads (H): the number of attention heads. + * Value size (V): the size of each value embedding per head. + * Key size (K): the size of each key embedding per head. Equally, the size + of each query embedding per head. Typically K <= V. + * Embedding size (HV). The size of the activation or embedding relating to + each input between layers. Equal to value_size * num_heads. + * All attention size (F). The size of all attention activations over every + head. + * QKV size (F / H): The size of the query, key and value per head. Equal to + 2K + V or equivalently F / H. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import numpy as np +from sonnet.python.modules import base +from sonnet.python.modules import basic +from sonnet.python.modules import layer_norm as snt_ln +from sonnet.python.modules import util +from sonnet.python.modules.nets import mlp as snt_mlp +import tensorflow.compat.v1 as tf + +AttentionState = collections.namedtuple('AttentionState', + ('queries', 'keys', 'values', 'logits', + 'weights', 'embeddings', 'read_words')) + +CompressedMemoryState = collections.namedtuple( + 'CompressedMemoryState', ('episodic_memory', 'compressed_memory', 'index')) + + +def rel_shift(position_logits): + """Shifting of logits for relative attention. + + Args: + position_logits: A tensor of shape [B, H, N, N + M]. + + Returns: + The shifted logits. Example, for input (H=1, B=1): + [5, 4, 3, 2, 1] + [5, 4, 3, 2, 1] + [5, 4, 3, 2, 1] + [5, 4, 3, 2, 1] + [5, 4, 3, 2, 1] + + the function outputs: + [1, 0, 5, 4, 3] + [2, 1, 0, 5, 4] + [3, 2, 1, 0, 5] + [4, 3, 2, 1, 0] + [5, 4, 3, 2, 1] + + Raises: + ValueError if position_logits is not 4D. + + Note: this is not an exact shift as the upper triangle is non-zero. This + works as intended in the causally-masked case. If this is used with un-masked + attention, we'd want these to also be zero. + """ + if position_logits.get_shape().ndims != 4: + raise ValueError('Expected 4D position logits.') + + input_shape = position_logits.shape + batch_size = input_shape[0] + num_heads = input_shape[1] + t1 = input_shape[2] + t2 = input_shape[3] + # We prepend zeros on the final timescale dimension. + to_pad = tf.zeros([batch_size, num_heads, t1, 1]) + position_logits = tf.concat([to_pad, position_logits], -1) + # Reshape trick to shift input. + position_logits = tf.reshape(position_logits, + [batch_size, num_heads, t2 + 1, t1]) + # Remove extra time dimension and re-shape. + position_logits = position_logits[:, :, 1:] + position_logits = tf.reshape(position_logits, input_shape) + return position_logits + + +def _layer_norm(inputs): + if inputs.get_shape().ndims > 2: + return basic.BatchApply(snt_ln.LayerNorm())(inputs) + else: + return snt_ln.LayerNorm()(inputs) + + +def _concat_and_slice(prev_memory, new_memory): + original_memory_size = prev_memory.get_shape().as_list()[1] + concat_memory = tf.concat([prev_memory, new_memory], 1) + memory = concat_memory[:, -original_memory_size:] + return memory, concat_memory + + +def simple_attention(queries, keys, values): + logits = tf.matmul(queries, keys, transpose_b=True) + weights = tf.nn.softmax(logits) + return tf.matmul(weights, values) + + +class ResidualDropoutWrapper(base.AbstractModule): + """Wrapper class that applies residual connections, dropout and layer norm. + + By default applies a relu to the module output before the other operations. + """ + + def __init__(self, + layer, + dropout_rate, + layer_norm='input', + name='residual_dropout_wrapper'): + self._module = layer + self._dropout_rate = dropout_rate + self._layer_norm = layer_norm + super(ResidualDropoutWrapper, self).__init__(name=name) + + def _build(self, inputs, *args, **kwargs): + if self._layer_norm in ('both', 'input'): + normed_inputs = _layer_norm(inputs) + else: + normed_inputs = inputs + module_output = self._module(normed_inputs, *args, **kwargs) + module_state = None + # If module outputs multiple items, assumes (output, state) tuple. + if isinstance(module_output, tuple): + module_output, module_state = module_output + if kwargs['is_training']: # kwargs must contain is_training. + module_output = tf.nn.dropout(module_output, rate=self._dropout_rate) + output = inputs + module_output + if self._layer_norm in ('both', 'output'): + output = _layer_norm(output) + if module_state is None: + return output + else: + return output, module_state + + +def future_mask(chunk_size, dtype): + """Creates attention mask to ensure an element i cannot attend to j > i.""" + square = tf.ones([chunk_size, chunk_size], dtype=dtype) + # Create upper diagonal matrix and remove diagonal entries (allow self-attn). + mask = tf.matrix_band_part(square, 0, -1) - tf.matrix_band_part(square, 0, 0) + # Multiply by -1e6 and expand to broadcast with [B, H, N, N] logits. + mask = -1e6 * tf.reshape(mask, [1, 1, chunk_size, chunk_size]) + return mask + + +def _memory_size(state): + if isinstance(state, CompressedMemoryState): + return (state.episodic_memory.get_shape().as_list()[1] + + state.compressed_memory.get_shape().as_list()[1]) + else: + return state.get_shape().as_list()[1] + + +def create_mask(inputs, state, equal_window): + """Creates mask for future sequence positions. + + Args: + inputs: inputs tensor of shape [B, N, D] + state: optional tensor of shape [B, M, D], CompressedMemoryState or a list + where the ith entry corresponds to the ith layer's state. + equal_window: if True, then each activation has an equally-sized attention + window of length 'M'. This only makes sense if a state is given. + + Returns: + Float tensor of shape [1, 1, N, N + M], to be summed with logits. + """ + chunk_size = inputs.get_shape().as_list()[1] + dtype = inputs.dtype + mask = future_mask(chunk_size, dtype) + if state is not None: + if isinstance(state, (tuple, list)): + largest_memory_layer = np.argmax([_memory_size(s) for s in state]) + state = state[largest_memory_layer] + mem_size = _memory_size(state) + mask = tf.concat( + [tf.zeros([1, 1, chunk_size, mem_size], dtype=dtype), mask], 3) + + if equal_window: + attn_mask = tf.ones([chunk_size, chunk_size], dtype=dtype) + mask_dia = tf.cast(tf.matrix_band_part(attn_mask, 0, 0), dtype=dtype) + mask_l = tf.cast(tf.matrix_band_part(attn_mask, -1, 0), dtype=dtype) + start_mask = tf.reshape(mask_l - mask_dia, + [1, 1, chunk_size, chunk_size]) * -1e6 + mask = tf.concat( + [mask[:, :, :, :chunk_size] + start_mask, mask[:, :, :, chunk_size:]], + 3) + return mask + + +def default_mlp(hidden_sizes, activate_final=False, init_std=2., **kwargs): + """Standard batch-applied MLP for transformer modules.""" + init = {'w': tf.variance_scaling_initializer(init_std, distribution='normal')} + mlp = snt_mlp.MLP( + hidden_sizes, + activate_final=activate_final, + use_dropout=True, + initializers=init, + **kwargs) + return basic.BatchApply(mlp) + + +def get_position_encodings(sequence_length, + hidden_size, + clamp_value, + max_timescale=10000., + min_timescale=2.0): + """Creates sinusoidal encodings of shape [1, N + M, D].""" + # NOTE: when not using relative position encodings, min_timescale must be 2.0 + # and hidden_size must be an even number. Otherwise, the dimensions do not + # match. + pos_seq = tf.range(sequence_length - 1, -1, -1.0) + if clamp_value > 0: + pos_seq = tf.minimum(pos_seq, clamp_value) + freqs = tf.range(0, hidden_size, min_timescale) + inv_freq = 1 / (max_timescale**(freqs / hidden_size)) + sinusoid_inp = tf.einsum('i,j->ij', pos_seq, inv_freq) + pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) + pos_emb = tf.expand_dims(pos_emb, 0) + + output_dim = pos_emb.get_shape().as_list()[-1] + if output_dim != hidden_size: + raise ValueError( + 'position embedding dimension ({}) does not match that of the input ({}).' + .format(output_dim, hidden_size)) + return pos_emb + + +class MultiheadAttention(base.AbstractModule): + """Implements multi-head attention with optional state context.""" + + def __init__(self, + value_size, + key_size, + num_heads, + mask=None, + scaling=True, + positional_encodings=None, + use_relative_positions=False, + init_std=2., + name='multihead_attention'): + """Creates a MultiheadAttention module. + + Args: + value_size: V parameter. See size glossary in class docstring. + key_size: K parameter. See size glossary in class docstring. + num_heads: The number of independent queries per timestep. + mask: Optional mask to attention logits. This can prevent attending to + future positions or unused memory slots. + scaling: Whether to scale the attention logits. + positional_encodings: Either None (none given), or an iterable of + `(key_positional_encodings, query_positional_encodings)` tuples, where + the first encodings in the list indicate the oldest entries in memory + and the final encodings indicate the newest entries in memory and the + sequence. + use_relative_positions: If True then relative positions are incorporated, + vs absolute, into the attention logits. This is done exactly as + described in the TransformerXL, Dai et al. 2019. + init_std: scaling of standard deviation for weight matrices init. + name: Name of module. + """ + + super(MultiheadAttention, self).__init__(name=name) + self._value_size = value_size + self._key_size = key_size + self._sizes = { + 'value': self._value_size, + 'key': self._key_size, + 'query': self._key_size, + 'relative_keys': self._key_size, + 'relative_keys_0': self._key_size, + } + self._num_heads = num_heads + self._mask = mask + self._scaling = scaling + self._positional_encodings = positional_encodings + self._use_relative_positions = use_relative_positions + self._init = {'w': tf.variance_scaling_initializer(init_std)} + + @util.reuse_variables + def multihead_linear(self, inputs, name): + with tf.variable_scope(name, reuse=tf.AUTO_REUSE): + hidden_size = self._sizes[name] + input_size = inputs.shape[-1].value + w = tf.get_variable( + 'linear/w', + shape=[input_size, self._num_heads * hidden_size], + initializer=self._init['w']) + w = tf.reshape(w, [input_size, self._num_heads, hidden_size]) + out = tf.einsum('bij,jhk->bhik', inputs, w) + return out + + def _build(self, + inputs, + query_inputs=None, + state=None, + is_training=False, + dropout_keep_prob=0.5, + key_value_inputs=None): + """Calculates multi-layer self attention. + + Args: + inputs: Tensor of shape [batch_size, num_steps, output_dim_size]. Inputs + used as the query, key, and value to the attention layer. + query_inputs: optional Tensor of shape [batch_size, num_steps, + output_dim_size]. Query inputs to the attention layer. Set when + query_inputs is different from the inputs argument. + state: optional CompressedMemoryState or a Tensor of shape [batch_size, + memory_size, dim_size] concatenated to the inputs. Set when attend to + the memory from previous steps. + is_training: if currently training. + dropout_keep_prob: dropout rate applied to attention weights. + key_value_inputs: optional Tensor of shape [batch_size, num_steps, + output_dim_size]. It is used as the key and value of the multihead + attention. Set when the key and value are different from the inputs + argument. + + Returns: + output: the result Tensor of shape + [batch_size, num_steps, output_dim_size]. + attention_state: named tuple of AttentionState. + """ + if key_value_inputs is not None and state is not None: + raise ValueError('Only one of the key_value_input and state is needed.') + embedding_size = self._value_size * self._num_heads + + q_inputs = inputs if query_inputs is None else query_inputs + # Denoted by L. If query_inputs is None, L = N. + _, query_size = q_inputs.get_shape().as_list()[:2] + + if key_value_inputs is not None: + k_inputs = key_value_inputs + v_inputs = k_inputs + elif state is not None: + if isinstance(state, CompressedMemoryState): + state_memory_list = [state.compressed_memory, state.episodic_memory] + else: + state_memory_list = [state] + + k_inputs = tf.concat(state_memory_list + [inputs], 1) + v_inputs = k_inputs + else: + k_inputs = inputs + v_inputs = inputs + + # Batch size denoted by B + batch_size = tf.shape(inputs)[0] + # Chunk_size denoted by N + chunk_size = inputs.get_shape().as_list()[1] + # Denoted by N + M + att_size = k_inputs.get_shape().as_list()[1] + + if self._positional_encodings and not self._use_relative_positions: + if len(self._positional_encodings) != 1: + raise ValueError( + 'Absolute positional encodings only supported for 1 memory. ' + 'Found %i.' % len(self._positional_encodings)) + key_positions, query_positions = self._positional_encodings[0] + k_inputs += key_positions + q_inputs += query_positions + + # [B, H, L, K] + q = self.multihead_linear(q_inputs, 'query') + # [B, H, N + M, K] + k = self.multihead_linear(k_inputs, 'key') + # [B, H, N + M, V] + v = self.multihead_linear(v_inputs, 'value') + + # Scaling the dot-product + if self._scaling: + q *= self._key_size**-0.5 + + # [B, H, L, N + M] + if self._use_relative_positions: + r_w_bias = tf.get_variable( + 'r_w_bias', [1, self._num_heads, 1, self._key_size], + dtype=inputs.dtype) + content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True) + all_relative_logits = [] + # Loop over multiple positional encodings, for the case of multiple + # memory types. + for i, positional_encodings in enumerate(self._positional_encodings): + key_positions, query_positions = positional_encodings + if key_positions.get_shape().as_list()[-1] != att_size: + key_positions = key_positions[:, -att_size:] # Crop to layer mem size + is_final = i == len(self._positional_encodings) - 1 + suffix = '' if is_final else '_%d' % i + relative_keys = self.multihead_linear( + key_positions, name='relative_keys' + suffix) + # [B, H, N, D] + r_r_bias = tf.get_variable( + 'r_r_bias' + suffix, [1, self._num_heads, 1, self._key_size], + dtype=inputs.dtype) + relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1]) + relative_logits = tf.matmul( + q + r_r_bias, relative_keys, transpose_b=True) + relative_logits = rel_shift(relative_logits) + if not is_final: # Include relative positions for input sequence. + relative_logits = relative_logits[:, :, :, :-chunk_size] + all_relative_logits.append(relative_logits) + all_relative_logits = tf.concat(all_relative_logits, 3) + logits = content_logits + all_relative_logits + else: + # [B, H, N, N + M] + logits = tf.matmul(q, k, transpose_b=True) + content_logits = logits + + if self._mask is not None: + if self._mask.get_shape().as_list()[-1] != att_size: + mask = self._mask[:, :, :, -att_size:] + else: + mask = self._mask + logits += mask + + weights = tf.nn.softmax(logits) + if is_training: + weights = tf.nn.dropout(weights, dropout_keep_prob) + # [B, L, H, V], where V is value_size + output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v) + + # [B, L, H, V] -> [B, L, HV] + attended_inputs = basic.BatchReshape([query_size, embedding_size])( + output_transpose) + # Apply final mlp to mix information between heads. + output = basic.BatchApply(basic.Linear(embedding_size))(attended_inputs) + + attention_state = AttentionState( + queries=q, + keys=k, + values=v, + weights=weights, + logits=content_logits, + embeddings=inputs, + read_words=output) + return output, attention_state + + +class TransformerTower(base.AbstractModule): + """Transformer tower. + + Deep residual network using blocks of attention and MLPs, specified in + Vaswani et al. 2017. + """ + + def __init__(self, + value_size, + num_heads, + num_layers, + causal=True, + key_size=None, + shared_attention=False, + output_size=None, + mlp_hidden_sizes=tuple([1024]), + dropout_rate=0.1, + use_relative_positions=True, + clamp_time_range=0, + same_attention_length=False, + layer_norm='input', + name='transformer_tower'): + """Initializes TransformerTower. + + Args: + value_size: dimensionality of values per-head. + num_heads: number of attention heads. + num_layers: number of transformer blocks, where each block contains a + multi-head attention layer and an MLP. + causal: if True, applies a causal mask. + key_size: optional dimensionality of key size. If unspecified then it is + set to `value_size`. + shared_attention: if True, attention params are shared across all layers. + output_size: if set, the desired output dimensionality. By default the + output size is `value_size` x `num_heads`. + mlp_hidden_sizes: tuple containing dimensionality of mlp layer(s). If + multiple values are specified, the mlp contains multiple layers for each + transformer block. + dropout_rate: dropout rate applied to hidden activations, attention, and + positional encodings. + use_relative_positions: if False, applies absolute positional encodings. + If true, uses relative positional encodings from Dai et al. 2019. + clamp_time_range: clamps max temporal positional encoding if specified. + same_attention_length: if True, attention is masked to ensure each + position in the sequence contains the same length of attention. + layer_norm: Where to apply layer-norm in Transformer block. Can be one of + 'input' (Vaswani et al. 2017), 'output', or 'both'. + name: name of variable scope. + """ + super(TransformerTower, self).__init__(name=name) + self._causal = causal + self._mask = None + + if key_size is None: + key_size = value_size + self._key_size = key_size + self._value_size = value_size + self._shared_attention = shared_attention + self._num_heads = num_heads + self._num_layers = num_layers + self._output_size = output_size + self._embedding_size = self._value_size * self._num_heads + self._mlp_hidden_sizes = list(mlp_hidden_sizes) + [self._embedding_size] + self._multihead_attention = None + self._object_embeddings = None + self._dropout_rate = dropout_rate + self._positional_encodings = None + self._use_relative_positions = use_relative_positions + self._clamp_time_range = clamp_time_range + self._same_attention_length = same_attention_length + self._layer_norm = layer_norm + self._attention_modules = [] + self._object_mlps = [] + + def get_sublayers(self, is_training): + if self._multihead_attention is None or not self._shared_attention: + attention_module = MultiheadAttention( + value_size=self._value_size, + key_size=self._key_size, + num_heads=self._num_heads, + mask=self._mask, + positional_encodings=self._positional_encodings, + use_relative_positions=self._use_relative_positions, + init_std=2. / np.sqrt(self._num_layers), + ) + self._multihead_attention = ResidualDropoutWrapper( + attention_module, self._dropout_rate, layer_norm=self._layer_norm) + mlp = default_mlp( + self._mlp_hidden_sizes, init_std=2. / np.sqrt(self._num_layers)) + object_mlp = ResidualDropoutWrapper( + mlp, self._dropout_rate, layer_norm=self._layer_norm) + + self._attention_modules.append(attention_module) + self._object_mlps.append(mlp) + return self._multihead_attention, object_mlp + + def _build(self, + inputs, + state=None, + condition=None, + is_training=True, + final_layer_key_value_inputs=None): + """Calculates multi-layer self attention and mlp transformation. + + Args: + inputs: Tensor of shape [batch_size, num_steps, dim_size]. + state: optional list of length num_layers of tensors of shape + [batch_size, memory_size, dim_size]. + condition: optional tensor to condition on. The shape is shape + [batch_size, dim_size]. + is_training: If true, dropout is applied. + final_layer_key_value_inputs: optional Tensor to be used as the key and + value for the final multi-head attention layer of shape + [batch_size, num_steps, dim_size]. Useful when the tower is a Seq2Seq + decoder and it can attend to encoder outputs. + + Returns: + output: tensor of shape [batch_size, num_steps, output_dim_size]. + state: list of length `num_layers` containing AttentionState tuples. + """ + # inputs: [B, N, F] + if final_layer_key_value_inputs is not None and state is not None and len( + state) == (self._num_layers - 1): + raise ValueError('When the final_layer_key_value_input is set, exclude' + 'the state of the last layer.') + + if condition is not None: + condition_tile = tf.tile( + tf.expand_dims(condition, 1), [1, tf.shape(inputs)[1], 1]) + inputs = tf.concat([inputs, condition_tile], -1) + + # Map inputs to be of `embedding_size` dimension. + if inputs.get_shape().as_list()[-1] != self._embedding_size: + inputs = default_mlp([self._embedding_size], activate_final=True)( + inputs, + is_training=is_training, + dropout_keep_prob=1 - self._dropout_rate) + + if state is None: + memory_sizes = [0] + elif isinstance(state[0], CompressedMemoryState): + cm_mem_size = max(_memory_size(s.compressed_memory) for s in state) + em_mem_size = max(_memory_size(s.episodic_memory) for s in state) + memory_sizes = [cm_mem_size, em_mem_size] + else: + memory_sizes = [max([_memory_size(s) for s in state])] + chunk_size = inputs.get_shape().as_list()[1] + self._positional_encodings = [] + # Creates positional encodings for different memory types. + for i, memory_size in enumerate(memory_sizes): + seq_len = chunk_size + memory_size + key_positions = get_position_encodings( + sequence_length=seq_len, + hidden_size=inputs.get_shape().as_list()[2], + clamp_value=self._clamp_time_range, + ) + if is_training: + key_positions = tf.nn.dropout(key_positions, rate=self._dropout_rate) + key_positions = tf.cast(key_positions, dtype=inputs.dtype) + query_positions = key_positions[:, -chunk_size:, :] + self._positional_encodings.append((key_positions, query_positions)) + + if self._causal: + self._mask = create_mask(inputs, state, self._same_attention_length) + + layer_i_inputs = inputs + attention_states = [] + key_value_inputs = None + + for i in range(self._num_layers): + with tf.variable_scope('layer_%d' % i, reuse=tf.AUTO_REUSE): + multihead_attention, object_mlp = self.get_sublayers(is_training) + # Multihead attention with residuals. + state_i = None if state is None else state[i] + if i == (self._num_layers - + 1) and final_layer_key_value_inputs is not None: + # When the final_layer_key_value_inputs is set, the finaly layer + # of attention will use it as the key & value, thus no need for state. + key_value_inputs = final_layer_key_value_inputs + state_i = None + + attention_outputs, attention_state = multihead_attention( + layer_i_inputs, + state=state_i, + is_training=is_training, + dropout_keep_prob=1. - self._dropout_rate, + key_value_inputs=key_value_inputs) + attention_states.append(attention_state) + # Feed-forward with residuals. + output = object_mlp( + attention_outputs, + is_training=is_training, + dropout_keep_prob=1 - self._dropout_rate) + layer_i_inputs = output + + if self._output_size is not None: + output = basic.BatchApply( + basic.Linear(self._output_size, use_bias=False))( + output) + + return output, attention_states + + def attention_module(self, i): + """Returns the i-th layer attention module.""" + return self._attention_modules[i]