mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Ensure pip is up to date in kfac_ferminet_alpha/run.sh
Also create the `venv` in `/tmp/` rather than messing with the source tree. PiperOrigin-RevId: 368225759
This commit is contained in:
committed by
Diego de Las Casas
parent
ce4db84f12
commit
d029e06aa2
38
kfac_ferminet_alpha/README.md
Normal file
38
kfac_ferminet_alpha/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Accompanying code for Better, Faster Fermionic Neural Networks
|
||||
|
||||
All package requirements are listed in `requirements.txt`.
|
||||
|
||||
## Contributing
|
||||
|
||||
This is purely research code, provided with no further intentions of support or
|
||||
any guarantees of backward compatibility.
|
||||
|
||||
## Installation
|
||||
|
||||
```shell
|
||||
git clone git@github.com:deepmind/deepmind-research.git
|
||||
pip install deepmind_research/kfac_ferminet_alpha/
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
You can find examples of how to use the codebase through the [FermiNet project].
|
||||
|
||||
We also provide an [example training script].
|
||||
|
||||
## Reference
|
||||
|
||||
**Better, Faster Fermionic Neural Networks**
|
||||
|
||||
James S. Spencer, David Pfau, Aleksandar Botev, and W. M. C. Foulkes.
|
||||
|
||||
URL: https://arxiv.org/abs/2011.07125.
|
||||
|
||||
**Optimizing Neural Networks with Kronecker-factored Approximate Curvature**
|
||||
|
||||
James Martens, Roger Grosse
|
||||
|
||||
URL: https://arxiv.org/abs/1503.05671
|
||||
|
||||
[FermiNet Project]: https://github.com/deepmind/ferminet/
|
||||
[example training script]: https://github.com/deepmind/deepmind-research/kfac_ferminet_alpha/example.py
|
||||
19
kfac_ferminet_alpha/__init__.py
Normal file
19
kfac_ferminet_alpha/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for anything that an end user would use."""
|
||||
|
||||
from kfac_ferminet_alpha.loss_functions import register_normal_predictive_distribution
|
||||
from kfac_ferminet_alpha.loss_functions import register_squared_error_loss
|
||||
from kfac_ferminet_alpha.optimizer import Optimizer
|
||||
496
kfac_ferminet_alpha/curvature_blocks.py
Normal file
496
kfac_ferminet_alpha/curvature_blocks.py
Normal file
@@ -0,0 +1,496 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for all of the different curvature blocks."""
|
||||
import abc
|
||||
from typing import Any, Callable, Dict, Mapping, MutableMapping, Optional, Sequence, Union
|
||||
import jax
|
||||
from jax import core
|
||||
import jax.numpy as jnp
|
||||
|
||||
from kfac_ferminet_alpha import tag_graph_matcher as tgm
|
||||
from kfac_ferminet_alpha import utils
|
||||
|
||||
_Arrays = Sequence[jnp.ndarray]
|
||||
_BlockInfo = Mapping[str, Any]
|
||||
|
||||
|
||||
class CurvatureBlock(utils.Stateful, abc.ABC):
|
||||
"""Top level class."""
|
||||
|
||||
def __init__(self, layer_tag_eq: tgm.jax_core.JaxprEqn):
|
||||
super(CurvatureBlock, self).__init__()
|
||||
self._layer_tag_eq = layer_tag_eq
|
||||
|
||||
@property
|
||||
def layer_tag_primitive(self) -> tgm.tags.LayerTag:
|
||||
assert isinstance(self._layer_tag_eq.primitive, tgm.tags.LayerTag)
|
||||
return self._layer_tag_eq.primitive
|
||||
|
||||
@property
|
||||
def outputs_shapes(self) -> Sequence[Sequence[int]]:
|
||||
output_vars = self.layer_tag_primitive.split_all_inputs(
|
||||
self._layer_tag_eq.invars)[0]
|
||||
return jax.tree_map(lambda x: x.aval.shape, output_vars)
|
||||
|
||||
@property
|
||||
def inputs_shapes(self) -> Sequence[Sequence[int]]:
|
||||
input_vars = self.layer_tag_primitive.split_all_inputs(
|
||||
self._layer_tag_eq.invars)[1]
|
||||
return jax.tree_map(lambda x: x.aval.shape, input_vars)
|
||||
|
||||
@property
|
||||
def params_shapes(self) -> Sequence[Sequence[int]]:
|
||||
params_vars = self.layer_tag_primitive.split_all_inputs(
|
||||
self._layer_tag_eq.invars)[2]
|
||||
return jax.tree_map(lambda x: x.aval.shape, params_vars)
|
||||
|
||||
@abc.abstractmethod
|
||||
def init(self, rng: jnp.ndarray) -> MutableMapping[str, Any]:
|
||||
"""This initializes/creates all of the arrays for the state of the block.
|
||||
|
||||
Usually this would include the arrays used for storing the curvature
|
||||
approximation, as well as the arrays for storing any approximate
|
||||
inverses/powers of the curvature block.
|
||||
|
||||
Args:
|
||||
rng: The Jax PRNG key to use if any of the state is supposed to be
|
||||
initialized randomly.
|
||||
Returns:
|
||||
A mutable mapping of the state.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_curvature_matrix_estimate(
|
||||
self,
|
||||
info: _BlockInfo,
|
||||
batch_size: int,
|
||||
ema_old: Union[float, jnp.ndarray],
|
||||
ema_new: Union[float, jnp.ndarray],
|
||||
pmap_axis_name: str
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_curvature_inverse_estimate(
|
||||
self,
|
||||
diagonal_weight: Union[float, jnp.ndarray],
|
||||
pmap_axis_name: str
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply_matpower(
|
||||
self,
|
||||
vec: _Arrays,
|
||||
exp: Union[float, int],
|
||||
diagonal_weight: Union[float, jnp.ndarray]
|
||||
) -> _Arrays:
|
||||
pass
|
||||
|
||||
|
||||
CurvatureBlockCtor = Callable[[core.JaxprEqn], CurvatureBlock]
|
||||
|
||||
|
||||
@utils.Stateful.infer_class_state
|
||||
class NaiveDiagonal(CurvatureBlock):
|
||||
"""The naively estimated diagonal block."""
|
||||
diagonal_factor: utils.WeightedMovingAverage
|
||||
|
||||
def init(self, rng: jnp.ndarray) -> Dict[str, Any]:
|
||||
del rng
|
||||
return dict(
|
||||
diagonal_factor=utils.WeightedMovingAverage.zero(
|
||||
self.outputs_shapes[0])
|
||||
)
|
||||
|
||||
def update_curvature_matrix_estimate(
|
||||
self,
|
||||
info: _BlockInfo,
|
||||
batch_size: int,
|
||||
ema_old: Union[float, jnp.ndarray],
|
||||
ema_new: Union[float, jnp.ndarray],
|
||||
pmap_axis_name: str
|
||||
) -> None:
|
||||
dw, = info["outputs_tangent"]
|
||||
diagonal_update = dw * dw / batch_size
|
||||
self.diagonal_factor.update(diagonal_update, ema_old, ema_new)
|
||||
self.diagonal_factor.sync(pmap_axis_name)
|
||||
|
||||
def update_curvature_inverse_estimate(
|
||||
self,
|
||||
diagonal_weight: Union[float, jnp.ndarray],
|
||||
pmap_axis_name: str
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def multiply_matpower(
|
||||
self,
|
||||
vec: _Arrays,
|
||||
exp: Union[float, int],
|
||||
diagonal_weight: Union[float, jnp.ndarray]
|
||||
) -> _Arrays:
|
||||
w, = vec
|
||||
if exp == 1:
|
||||
return w * (self.diagonal_factor.value + diagonal_weight),
|
||||
elif exp == -1:
|
||||
return w / (self.diagonal_factor.value + diagonal_weight),
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@utils.Stateful.infer_class_state
|
||||
class TwoKroneckerFactored(CurvatureBlock, abc.ABC):
|
||||
"""A factor that is the Kronecker product of two matrices."""
|
||||
inputs_factor: utils.WeightedMovingAverage
|
||||
inputs_factor_inverse: jnp.ndarray
|
||||
outputs_factor: utils.WeightedMovingAverage
|
||||
outputs_factor_inverse: jnp.ndarray
|
||||
extra_scale: Optional[Union[int, float, jnp.ndarray]]
|
||||
|
||||
@property
|
||||
def has_bias(self) -> bool:
|
||||
return len(self._layer_tag_eq.invars) == 4
|
||||
|
||||
@abc.abstractmethod
|
||||
def input_size(self) -> int:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def output_size(self) -> int:
|
||||
pass
|
||||
|
||||
def compute_extra_scale(self) -> Optional[Union[int, float, jnp.ndarray]]:
|
||||
return 1
|
||||
|
||||
def init(self, rng: jnp.ndarray) -> Dict[str, Any]:
|
||||
# The extra scale is technically a constant, but in general it could be
|
||||
# useful for anyone examining the state to know it explicitly,
|
||||
# hence we actually keep it as part of the state.
|
||||
d_in = self.input_size()
|
||||
d_out = self.output_size()
|
||||
return dict(
|
||||
inputs_factor=utils.WeightedMovingAverage.zero([d_in, d_in]),
|
||||
inputs_factor_inverse=jnp.zeros([d_in, d_in]),
|
||||
outputs_factor=utils.WeightedMovingAverage.zero([d_out, d_out]),
|
||||
outputs_factor_inverse=jnp.zeros([d_out, d_out]),
|
||||
extra_scale=self.compute_extra_scale()
|
||||
)
|
||||
|
||||
def update_curvature_inverse_estimate(
|
||||
self,
|
||||
diagonal_weight: Union[float, jnp.ndarray],
|
||||
pmap_axis_name: str
|
||||
) -> None:
|
||||
self.inputs_factor.sync(pmap_axis_name)
|
||||
self.outputs_factor.sync(pmap_axis_name)
|
||||
|
||||
# This computes the approximate inverse factor using the pi-adjusted
|
||||
# inversion from the original KFAC paper.
|
||||
# Note that the damping is divided by extra_scale since:
|
||||
# (s * A kron B + lambda I)^-1 = s^-1 (A kron B + s^-1 * lambda I)^-1
|
||||
# And the extra division by the scale is included in `multiply_matpower`.
|
||||
(self.inputs_factor_inverse,
|
||||
self.outputs_factor_inverse) = utils.pi_adjusted_inverse(
|
||||
factor_0=self.inputs_factor.value,
|
||||
factor_1=self.outputs_factor.value,
|
||||
damping=diagonal_weight / self.extra_scale,
|
||||
pmap_axis_name=pmap_axis_name)
|
||||
|
||||
def multiply_matpower(
|
||||
self,
|
||||
vec: _Arrays,
|
||||
exp: Union[float, int],
|
||||
diagonal_weight: Union[float, jnp.ndarray]
|
||||
) -> _Arrays:
|
||||
if self.has_bias:
|
||||
w, b = vec
|
||||
vec = jnp.concatenate([w.reshape([-1, w.shape[-1]]), b[None]], axis=0)
|
||||
else:
|
||||
w, = vec
|
||||
vec = w.reshape([-1, w.shape[-1]])
|
||||
if exp == 1:
|
||||
inputs_factor, outputs_factor = (self.inputs_factor.value,
|
||||
self.outputs_factor.value)
|
||||
scale = self.extra_scale
|
||||
elif exp == -1:
|
||||
inputs_factor, outputs_factor = (self.inputs_factor_inverse,
|
||||
self.outputs_factor_inverse)
|
||||
scale = 1.0 / self.extra_scale
|
||||
diagonal_weight = 0
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
result = jnp.matmul(inputs_factor, vec)
|
||||
result = jnp.matmul(result, outputs_factor)
|
||||
result = result * scale + diagonal_weight * vec
|
||||
|
||||
if self.has_bias:
|
||||
w_new, b_new = result[:-1], result[-1]
|
||||
return w_new.reshape(w.shape), b_new
|
||||
else:
|
||||
return result.reshape(w.shape),
|
||||
|
||||
|
||||
class DenseTwoKroneckerFactored(TwoKroneckerFactored):
|
||||
"""Factor for a standard dense layer."""
|
||||
|
||||
def input_size(self) -> int:
|
||||
if self.has_bias:
|
||||
return self.params_shapes[0][0] + 1
|
||||
else:
|
||||
return self.params_shapes[0][0]
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.params_shapes[0][1]
|
||||
|
||||
def update_curvature_matrix_estimate(
|
||||
self,
|
||||
info: _BlockInfo,
|
||||
batch_size: int,
|
||||
ema_old: Union[float, jnp.ndarray],
|
||||
ema_new: Union[float, jnp.ndarray],
|
||||
pmap_axis_name: str
|
||||
) -> None:
|
||||
del pmap_axis_name
|
||||
(x,), (dy,) = info["inputs"], info["outputs_tangent"]
|
||||
utils.check_first_dim_is_batch_size(batch_size, x, dy)
|
||||
|
||||
if self.has_bias:
|
||||
x_one = jnp.ones_like(x[:, :1])
|
||||
x = jnp.concatenate([x, x_one], axis=1)
|
||||
input_stats = jnp.matmul(x.T, x) / batch_size
|
||||
output_stats = jnp.matmul(dy.T, dy) / batch_size
|
||||
self.inputs_factor.update(input_stats, ema_old, ema_new)
|
||||
self.outputs_factor.update(output_stats, ema_old, ema_new)
|
||||
|
||||
|
||||
@utils.Stateful.infer_class_state
|
||||
class ScaleAndShiftDiagonal(CurvatureBlock):
|
||||
"""A scale and shift block with a diagonal approximation to the curvature."""
|
||||
scale_factor: Optional[utils.WeightedMovingAverage]
|
||||
shift_factor: Optional[utils.WeightedMovingAverage]
|
||||
|
||||
@property
|
||||
def has_scale(self) -> bool:
|
||||
return self._layer_tag_eq.params["has_scale"]
|
||||
|
||||
@property
|
||||
def has_shift(self) -> bool:
|
||||
return self._layer_tag_eq.params["has_shift"]
|
||||
|
||||
def init(self, rng: jnp.ndarray) -> Dict[str, Any]:
|
||||
del rng
|
||||
if self.has_scale and self.has_shift:
|
||||
return dict(
|
||||
scale_factor=utils.WeightedMovingAverage.zero(
|
||||
self.params_shapes[0]
|
||||
),
|
||||
shift_factor=utils.WeightedMovingAverage.zero(
|
||||
self.params_shapes[1]
|
||||
)
|
||||
)
|
||||
elif self.has_scale:
|
||||
return dict(
|
||||
scale_factor=utils.WeightedMovingAverage.zero(
|
||||
self.params_shapes[0]
|
||||
),
|
||||
shift_factor=None
|
||||
)
|
||||
elif self.has_shift:
|
||||
return dict(
|
||||
scale_factor=None,
|
||||
shift_factor=utils.WeightedMovingAverage.zero(
|
||||
self.params_shapes[0]
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Neither `has_scale` nor `has_shift`.")
|
||||
|
||||
def update_curvature_matrix_estimate(
|
||||
self,
|
||||
info: _BlockInfo,
|
||||
batch_size: int,
|
||||
ema_old: Union[float, jnp.ndarray],
|
||||
ema_new: Union[float, jnp.ndarray],
|
||||
pmap_axis_name: str
|
||||
) -> None:
|
||||
(x,), (dy,) = info["inputs"], info["outputs_tangent"]
|
||||
utils.check_first_dim_is_batch_size(batch_size, x, dy)
|
||||
|
||||
if self.has_scale:
|
||||
assert self.scale_factor is not None
|
||||
scale_shape = info["params"][0].shape
|
||||
full_scale_shape = (1,) * (len(x.shape) - len(scale_shape)) + scale_shape
|
||||
axis = [i for i, s in enumerate(full_scale_shape) if s == 1 and i != 0]
|
||||
d_scale = jnp.sum(x * dy, axis=axis)
|
||||
scale_diag_update = jnp.sum(d_scale * d_scale, axis=0) / batch_size
|
||||
self.scale_factor.update(scale_diag_update, ema_old, ema_new)
|
||||
self.scale_factor.sync(pmap_axis_name)
|
||||
|
||||
if self.has_shift:
|
||||
assert self.shift_factor is not None
|
||||
shift_shape = info["params"][1].shape
|
||||
full_shift_shape = (1,) * (len(x.shape) - len(shift_shape)) + shift_shape
|
||||
axis = [i for i, s in enumerate(full_shift_shape) if s == 1 and i != 0]
|
||||
d_shift = jnp.sum(dy, axis=axis)
|
||||
shift_diag_update = jnp.sum(d_shift * d_shift, axis=0) / batch_size
|
||||
self.shift_factor.update(shift_diag_update, ema_old, ema_new)
|
||||
self.shift_factor.sync(pmap_axis_name)
|
||||
|
||||
def update_curvature_inverse_estimate(
|
||||
self,
|
||||
diagonal_weight: Union[float, jnp.ndarray],
|
||||
pmap_axis_name: str
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def multiply_matpower(
|
||||
self,
|
||||
vec: _Arrays,
|
||||
exp: Union[float, int],
|
||||
diagonal_weight: Union[float, jnp.ndarray]
|
||||
) -> _Arrays:
|
||||
if self.has_scale and self.has_shift:
|
||||
factors = (self.scale_factor.value, self.shift_factor.value)
|
||||
elif self.has_scale:
|
||||
factors = (self.scale_factor.value,)
|
||||
elif self.has_shift:
|
||||
factors = (self.shift_factor.value,)
|
||||
else:
|
||||
raise ValueError("Neither `has_scale` nor `has_shift`.")
|
||||
factors = jax.tree_map(lambda x: x + diagonal_weight, factors)
|
||||
if exp == 1:
|
||||
return jax.tree_multimap(jnp.multiply, vec, factors)
|
||||
elif exp == -1:
|
||||
return jax.tree_multimap(jnp.divide, vec, factors)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@utils.Stateful.infer_class_state
|
||||
class ScaleAndShiftFull(CurvatureBlock):
|
||||
"""A scale and shift block with full approximation to the curvature."""
|
||||
factor: utils.WeightedMovingAverage
|
||||
inverse_factor: jnp.ndarray
|
||||
|
||||
@property
|
||||
def _has_scale(self) -> bool:
|
||||
return self._layer_tag_eq.params["has_scale"]
|
||||
|
||||
@property
|
||||
def _has_shift(self) -> bool:
|
||||
return self._layer_tag_eq.params["has_shift"]
|
||||
|
||||
def init(self, rng: jnp.ndarray) -> Dict[str, Any]:
|
||||
del rng
|
||||
dims = sum(utils.product(shape) for shape in self.params_shapes)
|
||||
return dict(
|
||||
factor=utils.WeightedMovingAverage.zero([dims, dims]),
|
||||
inverse_factor=jnp.zeros([dims, dims])
|
||||
)
|
||||
|
||||
def update_curvature_matrix_estimate(
|
||||
self,
|
||||
info: _BlockInfo,
|
||||
batch_size: int,
|
||||
ema_old: Union[float, jnp.ndarray],
|
||||
ema_new: Union[float, jnp.ndarray],
|
||||
pmap_axis_name: str
|
||||
) -> None:
|
||||
del pmap_axis_name
|
||||
(x,), (dy,) = info["inputs"], info["outputs_tangent"]
|
||||
utils.check_first_dim_is_batch_size(batch_size, x, dy)
|
||||
|
||||
grads = list()
|
||||
if self._has_scale:
|
||||
# Scale gradients
|
||||
scale_shape = info["params"][0].shape
|
||||
full_scale_shape = (1,) * (len(x.shape) - len(scale_shape)) + scale_shape
|
||||
axis = [i for i, s in enumerate(full_scale_shape) if s == 1 and i != 0]
|
||||
d_scale = jnp.sum(x * dy, axis=axis)
|
||||
d_scale = d_scale.reshape([batch_size, -1])
|
||||
grads.append(d_scale)
|
||||
|
||||
if self._has_shift:
|
||||
# Shift gradients
|
||||
shift_shape = info["params"][1].shape
|
||||
full_shift_shape = (1,) * (len(x.shape) - len(shift_shape)) + shift_shape
|
||||
axis = [i for i, s in enumerate(full_shift_shape) if s == 1 and i != 0]
|
||||
d_shift = jnp.sum(dy, axis=axis)
|
||||
d_shift = d_shift.reshape([batch_size, -1])
|
||||
grads.append(d_shift)
|
||||
|
||||
grads = jnp.concatenate(grads, axis=1)
|
||||
factor_update = jnp.matmul(grads.T, grads) / batch_size
|
||||
self.factor.update(factor_update, ema_old, ema_new)
|
||||
|
||||
def update_curvature_inverse_estimate(
|
||||
self,
|
||||
diagonal_weight: Union[float, jnp.ndarray],
|
||||
pmap_axis_name: str
|
||||
) -> None:
|
||||
self.factor.sync(pmap_axis_name)
|
||||
self.inverse_factor = utils.psd_inv_cholesky(self.factor.value,
|
||||
diagonal_weight)
|
||||
|
||||
def multiply_matpower(
|
||||
self,
|
||||
vec: _Arrays,
|
||||
exp: Union[float, int],
|
||||
diagonal_weight: Union[float, jnp.ndarray]
|
||||
) -> _Arrays:
|
||||
# Remember the vector is a tuple of all parameters
|
||||
if self._has_scale and self._has_shift:
|
||||
flat_vec = jnp.concatenate([v.flatten() for v in vec])
|
||||
else:
|
||||
flat_vec = vec[0].flatten()
|
||||
|
||||
if exp == 1:
|
||||
flat_result = (
|
||||
jnp.matmul(self.factor.value, flat_vec) + diagonal_weight * flat_vec)
|
||||
elif exp == -1:
|
||||
flat_result = jnp.matmul(self.inverse_factor, flat_vec)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if self._has_scale and self._has_shift:
|
||||
scale_dims = int(vec[0].size)
|
||||
scale_result = flat_result[:scale_dims].reshape(vec[0].shape)
|
||||
shift_result = flat_result[scale_dims:].reshape(vec[1].shape)
|
||||
return scale_result, shift_result
|
||||
else:
|
||||
return flat_vec.reshape(vec[0].shape),
|
||||
|
||||
|
||||
_default_tag_to_block: MutableMapping[str, CurvatureBlockCtor] = dict(
|
||||
dense_tag=DenseTwoKroneckerFactored,
|
||||
generic_tag=NaiveDiagonal,
|
||||
scale_and_shift_tag=ScaleAndShiftDiagonal,
|
||||
)
|
||||
|
||||
|
||||
def copy_default_tag_to_block() -> MutableMapping[str, CurvatureBlockCtor]:
|
||||
return dict(_default_tag_to_block)
|
||||
|
||||
|
||||
def get_default_tag_to_block(tag_name: str) -> CurvatureBlockCtor:
|
||||
return _default_tag_to_block[tag_name]
|
||||
|
||||
|
||||
def set_default_tag_to_block(
|
||||
tag_name: str,
|
||||
block_class: CurvatureBlockCtor,
|
||||
) -> None:
|
||||
_default_tag_to_block[tag_name] = block_class
|
||||
75
kfac_ferminet_alpha/distributions.py
Normal file
75
kfac_ferminet_alpha/distributions.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for all distribution implementations needed for the loss functions."""
|
||||
import math
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class MultivariateNormalDiag:
|
||||
"""Multivariate normal distribution on `R^k`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loc: jnp.ndarray,
|
||||
scale_diag: jnp.ndarray):
|
||||
"""Initializes a MultivariateNormalDiag distribution.
|
||||
|
||||
Args:
|
||||
loc: Mean vector of the distribution. Can also be a batch of vectors.
|
||||
scale_diag: Vector of standard deviations.
|
||||
"""
|
||||
super().__init__()
|
||||
self._loc = loc
|
||||
self._scale_diag = scale_diag
|
||||
|
||||
@property
|
||||
def loc(self) -> jnp.ndarray:
|
||||
"""Mean of the distribution."""
|
||||
return self._loc
|
||||
|
||||
@property
|
||||
def scale_diag(self) -> jnp.ndarray:
|
||||
"""Scale of the distribution."""
|
||||
return self._scale_diag
|
||||
|
||||
def _num_dims(self) -> int:
|
||||
"""Dimensionality of the events."""
|
||||
return self._scale_diag.shape[-1]
|
||||
|
||||
def _standardize(self, value: jnp.ndarray) -> jnp.ndarray:
|
||||
return (value - self._loc) / self._scale_diag
|
||||
|
||||
def log_prob(self, value: jnp.ndarray) -> jnp.ndarray:
|
||||
"""See `Distribution.log_prob`."""
|
||||
log_unnormalized = -0.5 * jnp.square(self._standardize(value))
|
||||
log_normalization = 0.5 * math.log(2 * math.pi) + jnp.log(self._scale_diag)
|
||||
return jnp.sum(log_unnormalized - log_normalization, axis=-1)
|
||||
|
||||
def mean(self) -> jnp.ndarray:
|
||||
"""Calculates the mean."""
|
||||
return self.loc
|
||||
|
||||
def sample(self, seed: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Samples an event.
|
||||
|
||||
Args:
|
||||
seed: PRNG key or integer seed.
|
||||
|
||||
Returns:
|
||||
A sample.
|
||||
"""
|
||||
eps = jax.random.normal(seed, self.loc.shape)
|
||||
return self.loc + eps * self.scale_diag
|
||||
340
kfac_ferminet_alpha/estimator.py
Normal file
340
kfac_ferminet_alpha/estimator.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Defines the high-level Fisher estimator class."""
|
||||
import collections
|
||||
from typing import Any, Callable, Mapping, Optional, Sequence, Union, TypeVar
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jnr
|
||||
import numpy as np
|
||||
|
||||
from kfac_ferminet_alpha import curvature_blocks
|
||||
from kfac_ferminet_alpha import tracer
|
||||
from kfac_ferminet_alpha import utils
|
||||
|
||||
_CurvatureBlock = curvature_blocks.CurvatureBlock
|
||||
TagMapping = Mapping[str, curvature_blocks.CurvatureBlockCtor]
|
||||
BlockVector = Sequence[jnp.ndarray]
|
||||
|
||||
_StructureT = TypeVar("_StructureT")
|
||||
_OptionalStateT = TypeVar("_OptionalStateT", bound=Optional[Mapping[str, Any]])
|
||||
|
||||
|
||||
@utils.Stateful.infer_class_state
|
||||
class CurvatureEstimator(utils.Stateful):
|
||||
"""Curvature estimator class supporting various curvature approximations."""
|
||||
blocks: "collections.OrderedDict[str, _CurvatureBlock]"
|
||||
damping: Optional[jnp.ndarray]
|
||||
|
||||
def __init__(self,
|
||||
tagged_func: Callable[[Any], jnp.ndarray],
|
||||
func_args: Sequence[Any],
|
||||
l2_reg: Union[float, jnp.ndarray],
|
||||
estimation_mode: str = "fisher_gradients",
|
||||
params_index: int = 0,
|
||||
layer_tag_to_block_cls: Optional[TagMapping] = None):
|
||||
"""Create a FisherEstimator object.
|
||||
|
||||
Args:
|
||||
tagged_func: The function which evaluates the model, in which layer and
|
||||
loss tags has already been registered.
|
||||
func_args: Arguments to trace the function for layer and loss tags.
|
||||
l2_reg: Scalar. The L2 regularization coefficient, which represents
|
||||
the following regularization function: `coefficient/2 ||theta||^2`.
|
||||
estimation_mode: The type of curvature estimator to use. One of: *
|
||||
'fisher_gradients' - the basic estimation approach from the original
|
||||
K-FAC paper. (Default) * 'fisher_curvature_prop' - method which
|
||||
estimates the Fisher using self-products of random 1/-1 vectors times
|
||||
"half-factors" of the
|
||||
Fisher, as described here: https://arxiv.org/abs/1206.6464 *
|
||||
'fisher_exact' - is the obvious generalization of Curvature
|
||||
Propagation to compute the exact Fisher (modulo any additional
|
||||
diagonal or Kronecker approximations) by looping over one-hot
|
||||
vectors for each coordinate of the output instead of using 1/-1
|
||||
vectors. It is more expensive to compute than the other three
|
||||
options by a factor equal to the output dimension, roughly
|
||||
speaking. * 'fisher_empirical' - computes the 'empirical' Fisher
|
||||
information matrix (which uses the data's distribution for the
|
||||
targets, as opposed to the true Fisher which uses the model's
|
||||
distribution) and requires that each registered loss have
|
||||
specified targets. * 'ggn_curvature_prop' - Analogous to
|
||||
fisher_curvature_prop, but estimates the Generalized
|
||||
Gauss-Newton matrix (GGN). * 'ggn_exact'- Analogous to
|
||||
fisher_exact, but estimates the Generalized Gauss-Newton matrix
|
||||
(GGN).
|
||||
params_index: The index of the arguments accepted by `func` which
|
||||
correspond to parameters.
|
||||
layer_tag_to_block_cls: An optional dict mapping tags to specific classes
|
||||
of block approximations, which to override the default ones.
|
||||
"""
|
||||
if estimation_mode not in ("fisher_gradients", "fisher_empirical",
|
||||
"fisher_exact", "fisher_curvature_prop",
|
||||
"ggn_exact", "ggn_curvature_prop"):
|
||||
raise ValueError(f"Unrecognised estimation_mode={estimation_mode}.")
|
||||
super().__init__()
|
||||
self.tagged_func = tagged_func
|
||||
self.l2_reg = l2_reg
|
||||
self.estimation_mode = estimation_mode
|
||||
self.params_index = params_index
|
||||
self.vjp = tracer.trace_estimator_vjp(self.tagged_func)
|
||||
|
||||
# Figure out the mapping from layer
|
||||
self.layer_tag_to_block_cls = curvature_blocks.copy_default_tag_to_block()
|
||||
if layer_tag_to_block_cls is None:
|
||||
layer_tag_to_block_cls = dict()
|
||||
layer_tag_to_block_cls = dict(**layer_tag_to_block_cls)
|
||||
self.layer_tag_to_block_cls.update(layer_tag_to_block_cls)
|
||||
|
||||
# Create the blocks
|
||||
self._in_tree = jax.tree_structure(func_args)
|
||||
self._jaxpr = jax.make_jaxpr(self.tagged_func)(*func_args).jaxpr
|
||||
self._layer_tags, self._loss_tags = tracer.extract_tags(self._jaxpr)
|
||||
self.blocks = collections.OrderedDict()
|
||||
counters = dict()
|
||||
for eqn in self._layer_tags:
|
||||
cls = self.layer_tag_to_block_cls[eqn.primitive.name]
|
||||
c = counters.get(cls.__name__, 0)
|
||||
self.blocks[cls.__name__ + "_" + str(c)] = cls(eqn)
|
||||
counters[cls.__name__] = c + 1
|
||||
|
||||
@property
|
||||
def diagonal_weight(self) -> jnp.ndarray:
|
||||
return self.l2_reg + self.damping
|
||||
|
||||
def vectors_to_blocks(
|
||||
self,
|
||||
parameter_structured_vector: Any,
|
||||
) -> Sequence[BlockVector]:
|
||||
"""Splits the parameters to values for the corresponding blocks."""
|
||||
in_vars = jax.tree_unflatten(self._in_tree, self._jaxpr.invars)
|
||||
params_vars = in_vars[self.params_index]
|
||||
params_vars_flat = jax.tree_flatten(params_vars)[0]
|
||||
params_values_flat = jax.tree_flatten(parameter_structured_vector)[0]
|
||||
assert len(params_vars_flat) == len(params_values_flat)
|
||||
params_dict = dict(zip(params_vars_flat, params_values_flat))
|
||||
per_block_vectors = []
|
||||
for eqn in self._layer_tags:
|
||||
if eqn.primitive.name == "generic_tag":
|
||||
block_vars = eqn.invars
|
||||
else:
|
||||
block_vars = eqn.primitive.split_all_inputs(eqn.invars)[2]
|
||||
per_block_vectors.append(tuple(params_dict.pop(v) for v in block_vars))
|
||||
if params_dict:
|
||||
raise ValueError(f"From the parameters the following structure is not "
|
||||
f"assigned to any block: {params_dict}. Most likely "
|
||||
f"this part of the parameters is not part of the graph "
|
||||
f"reaching the losses.")
|
||||
return tuple(per_block_vectors)
|
||||
|
||||
def blocks_to_vectors(self, per_block_vectors: Sequence[BlockVector]) -> Any:
|
||||
"""Reverses the function self.vectors_to_blocks."""
|
||||
in_vars = jax.tree_unflatten(self._in_tree, self._jaxpr.invars)
|
||||
params_vars = in_vars[self.params_index]
|
||||
assigned_dict = dict()
|
||||
for eqn, block_values in zip(self._layer_tags, per_block_vectors):
|
||||
if eqn.primitive.name == "generic_tag":
|
||||
block_params = eqn.invars
|
||||
else:
|
||||
block_params = eqn.primitive.split_all_inputs(eqn.invars)[2]
|
||||
assigned_dict.update(zip(block_params, block_values))
|
||||
params_vars_flat, params_tree = jax.tree_flatten(params_vars)
|
||||
params_values_flat = [assigned_dict[v] for v in params_vars_flat]
|
||||
assert len(params_vars_flat) == len(params_values_flat)
|
||||
return jax.tree_unflatten(params_tree, params_values_flat)
|
||||
|
||||
def init(
|
||||
self,
|
||||
rng: jnp.ndarray,
|
||||
init_damping: Optional[jnp.ndarray],
|
||||
) -> Mapping[str, Any]:
|
||||
"""Returns an initialized variables for the curvature approximations and the inverses.."""
|
||||
return dict(
|
||||
blocks=collections.OrderedDict(
|
||||
(name, block.init(block_rng)) #
|
||||
for (name, block), block_rng #
|
||||
in zip(self.blocks.items(), jnr.split(rng, len(self.blocks)))),
|
||||
damping=init_damping)
|
||||
|
||||
@property
|
||||
def mat_type(self) -> str:
|
||||
return self.estimation_mode.split("_")[0]
|
||||
|
||||
def vec_block_apply(
|
||||
self,
|
||||
func: Callable[[_CurvatureBlock, BlockVector], BlockVector],
|
||||
parameter_structured_vector: Any,
|
||||
) -> Any:
|
||||
"""Executes func for each approximation block on vectors."""
|
||||
per_block_vectors = self.vectors_to_blocks(parameter_structured_vector)
|
||||
assert len(per_block_vectors) == len(self.blocks)
|
||||
results = jax.tree_multimap(func, tuple(self.blocks.values()),
|
||||
per_block_vectors)
|
||||
parameter_structured_result = self.blocks_to_vectors(results)
|
||||
utils.check_structure_shapes_and_dtype(parameter_structured_vector,
|
||||
parameter_structured_result)
|
||||
return parameter_structured_result
|
||||
|
||||
def multiply_inverse(self, parameter_structured_vector: Any) -> Any:
|
||||
"""Multiplies the vectors by the corresponding (damped) inverses of the blocks.
|
||||
|
||||
Args:
|
||||
parameter_structured_vector: Structure equivalent to the parameters of the
|
||||
model.
|
||||
|
||||
Returns:
|
||||
A structured identical to `vectors` containing the product.
|
||||
"""
|
||||
return self.multiply_matpower(parameter_structured_vector, -1)
|
||||
|
||||
def multiply(self, parameter_structured_vector: Any) -> Any:
|
||||
"""Multiplies the vectors by the corresponding (damped) blocks.
|
||||
|
||||
Args:
|
||||
parameter_structured_vector: A vector in the same structure as the
|
||||
parameters of the model.
|
||||
|
||||
Returns:
|
||||
A structured identical to `vectors` containing the product.
|
||||
"""
|
||||
return self.multiply_matpower(parameter_structured_vector, 1)
|
||||
|
||||
def multiply_matpower(
|
||||
self,
|
||||
parameter_structured_vector: _StructureT,
|
||||
exp: int,
|
||||
) -> _StructureT:
|
||||
"""Multiplies the vectors by the corresponding matrix powers of the blocks.
|
||||
|
||||
Args:
|
||||
parameter_structured_vector: A vector in the same structure as the
|
||||
parameters of the model.
|
||||
exp: A float representing the power to raise the blocks by before
|
||||
multiplying it by the vector.
|
||||
|
||||
Returns:
|
||||
A structured identical to `vectors` containing the product.
|
||||
"""
|
||||
|
||||
def func(block: _CurvatureBlock, vec: BlockVector) -> BlockVector:
|
||||
return block.multiply_matpower(vec, exp, self.diagonal_weight)
|
||||
|
||||
return self.vec_block_apply(func, parameter_structured_vector)
|
||||
|
||||
def update_curvature_matrix_estimate(
|
||||
self,
|
||||
ema_old: Union[float, jnp.ndarray],
|
||||
ema_new: Union[float, jnp.ndarray],
|
||||
batch_size: int,
|
||||
rng: jnp.ndarray,
|
||||
func_args: Sequence[Any],
|
||||
pmap_axis_name: str,
|
||||
) -> None:
|
||||
"""Updates the curvature estimate."""
|
||||
|
||||
# Compute the losses and the VJP function from the function inputs
|
||||
losses, losses_vjp = self.vjp(func_args)
|
||||
|
||||
# Helper function that updates the blocks given a vjp vector
|
||||
def _update_blocks(vjp_vec_, ema_old_, ema_new_):
|
||||
blocks_info_ = losses_vjp(vjp_vec_)
|
||||
for block_, block_info_ in zip(self.blocks.values(), blocks_info_):
|
||||
block_.update_curvature_matrix_estimate(
|
||||
info=block_info_,
|
||||
batch_size=batch_size,
|
||||
ema_old=ema_old_,
|
||||
ema_new=ema_new_,
|
||||
pmap_axis_name=pmap_axis_name)
|
||||
|
||||
if self.estimation_mode == "fisher_gradients":
|
||||
keys = jnr.split(rng, len(losses)) if len(losses) > 1 else [rng]
|
||||
vjp_vec = tuple(
|
||||
loss.grad_of_evaluate_on_sample(key, coefficient_mode="sqrt")
|
||||
for loss, key in zip(losses, keys))
|
||||
_update_blocks(vjp_vec, ema_old, ema_new)
|
||||
|
||||
elif self.estimation_mode in ("fisher_curvature_prop",
|
||||
"ggn_curvature_prop"):
|
||||
keys = jnr.split(rng, len(losses)) if len(losses) > 1 else [rng]
|
||||
vjp_vec = []
|
||||
for loss, key in zip(losses, keys):
|
||||
if self.estimation_mode == "fisher_curvature_prop":
|
||||
random_b = jnr.bernoulli(key, shape=loss.fisher_factor_inner_shape())
|
||||
vjp_vec.append(loss.multiply_fisher_factor(random_b * 2.0 - 1.0))
|
||||
else:
|
||||
random_b = jnr.bernoulli(key, shape=loss.ggn_factor_inner_shape())
|
||||
vjp_vec.append(loss.multiply_ggn_factor(random_b * 2.0 - 1.0))
|
||||
_update_blocks(tuple(vjp_vec), ema_old, ema_new)
|
||||
|
||||
elif self.estimation_mode in ("fisher_exact", "ggn_exact"):
|
||||
# We use the following trick to simulate summation. The equation is:
|
||||
# estimate = ema_old * estimate + ema_new * (sum_i estimate_index_i)
|
||||
# weight = ema_old * weight + ema_new
|
||||
# Instead we update the estimate n times with the following updates:
|
||||
# for k = 1
|
||||
# estimate_k = ema_old * estimate + (ema_new/n) * (n*estimate_index_k)
|
||||
# weight_k = ema_old * weight + (ema_new/n)
|
||||
# for k > 1:
|
||||
# estimate_k = 1.0 * estimate_k-1 + (ema_new/n) * (n*estimate_index_k)
|
||||
# weight_k = 1.0 * weight_k-1 + (ema_new/n)
|
||||
# Which is mathematically equivalent to the original version.
|
||||
zero_tangents = jax.tree_map(jnp.zeros_like,
|
||||
list(loss.inputs for loss in losses))
|
||||
if self.estimation_mode == "fisher_exact":
|
||||
num_indices = [
|
||||
(l, int(np.prod(l.fisher_factor_inner_shape[1:]))) for l in losses
|
||||
]
|
||||
else:
|
||||
num_indices = [
|
||||
(l, int(np.prod(l.ggn_factor_inner_shape()))) for l in losses
|
||||
]
|
||||
total_num_indices = sum(n for _, n in num_indices)
|
||||
for i, (loss, loss_num_indices) in enumerate(num_indices):
|
||||
for index in range(loss_num_indices):
|
||||
vjp_vec = zero_tangents.copy()
|
||||
if self.estimation_mode == "fisher_exact":
|
||||
vjp_vec[i] = loss.multiply_fisher_factor_replicated_one_hot([index])
|
||||
else:
|
||||
vjp_vec[i] = loss.multiply_ggn_factor_replicated_one_hot([index])
|
||||
if isinstance(vjp_vec[i], jnp.ndarray):
|
||||
# In the special case of only one parameter, it still needs to be a
|
||||
# tuple for the tangents.
|
||||
vjp_vec[i] = (vjp_vec[i],)
|
||||
vjp_vec[i] = jax.tree_map(lambda x: x * total_num_indices, vjp_vec[i])
|
||||
_update_blocks(tuple(vjp_vec), ema_old, ema_new / total_num_indices)
|
||||
ema_old = 1.0
|
||||
|
||||
elif self.estimation_mode == "fisher_empirical":
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise ValueError(f"Unrecognised estimation_mode={self.estimation_mode}")
|
||||
|
||||
def update_curvature_estimate_inverse(
|
||||
self,
|
||||
pmap_axis_name: str,
|
||||
state: _OptionalStateT,
|
||||
) -> _OptionalStateT:
|
||||
if state is not None:
|
||||
old_state = self.get_state()
|
||||
self.set_state(state)
|
||||
for block in self.blocks.values():
|
||||
block.update_curvature_inverse_estimate(self.diagonal_weight,
|
||||
pmap_axis_name)
|
||||
if state is None:
|
||||
return None
|
||||
else:
|
||||
state = self.pop_state()
|
||||
self.set_state(old_state)
|
||||
return state
|
||||
171
kfac_ferminet_alpha/example.py
Normal file
171
kfac_ferminet_alpha/example.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Example of running KFAC."""
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import numpy as np
|
||||
import kfac_ferminet_alpha as kfac_ferminet_alpha
|
||||
from kfac_ferminet_alpha import utils
|
||||
|
||||
|
||||
TRAINING_STEPS = flags.DEFINE_integer(
|
||||
name="training_steps",
|
||||
default=100,
|
||||
help="Number of training steps to perform")
|
||||
BATCH_SIZE = flags.DEFINE_integer(
|
||||
name="batch_size", default=128, help="Batch size")
|
||||
LEARNING_RATE = flags.DEFINE_float(
|
||||
name="learning_rate", default=1e-3, help="Learning rate")
|
||||
L2_REG = flags.DEFINE_float(
|
||||
name="l2_reg", default=1e-3, help="L2 regularization coefficient")
|
||||
MOMENTUM = flags.DEFINE_float(
|
||||
name="momentum", default=0.8, help="Momentum coefficient")
|
||||
DAMPING = flags.DEFINE_float(
|
||||
name="damping", default=1e-2, help="Damping coefficient")
|
||||
MULTI_DEVICE = flags.DEFINE_bool(
|
||||
name="multi_device",
|
||||
default=False,
|
||||
help="Whether the computation should be replicated across multiple devices")
|
||||
SEED = flags.DEFINE_integer(name="seed", default=12412321, help="JAX RNG seed")
|
||||
|
||||
|
||||
def glorot_uniform(shape, key):
|
||||
dim_in = np.prod(shape[:-1])
|
||||
dim_out = shape[-1]
|
||||
c = jnp.sqrt(6 / (dim_in + dim_out))
|
||||
return jax.random.uniform(key, shape=shape, minval=-c, maxval=c)
|
||||
|
||||
|
||||
def fully_connected_layer(params, x):
|
||||
w, b = params
|
||||
return jnp.matmul(x, w) + b[None]
|
||||
|
||||
|
||||
def model_init(rng_key, batch, encoder_sizes=(1000, 500, 250, 30)):
|
||||
"""Initialize the standard autoencoder."""
|
||||
x_size = batch.shape[-1]
|
||||
decoder_sizes = encoder_sizes[len(encoder_sizes) - 2::-1]
|
||||
sizes = (x_size,) + encoder_sizes + decoder_sizes + (x_size,)
|
||||
keys = jax.random.split(rng_key, len(sizes) - 1)
|
||||
params = []
|
||||
for rng_key, dim_in, dim_out in zip(keys, sizes, sizes[1:]):
|
||||
# Glorot uniform initialization
|
||||
w = glorot_uniform((dim_in, dim_out), rng_key)
|
||||
b = jnp.zeros([dim_out])
|
||||
params.append((w, b))
|
||||
return params, None
|
||||
|
||||
|
||||
def model_loss(params, inputs, l2_reg):
|
||||
"""Evaluate the standard autoencoder."""
|
||||
h = inputs.reshape([inputs.shape[0], -1])
|
||||
for i, layer_params in enumerate(params):
|
||||
h = fully_connected_layer(layer_params, h)
|
||||
# Last layer does not have a nonlinearity
|
||||
if i % 4 != 3:
|
||||
h = jnp.tanh(h)
|
||||
l2_value = 0.5 * sum(jnp.square(p).sum() for p in jax.tree_leaves(params))
|
||||
error = jax.nn.sigmoid(h) - inputs.reshape([inputs.shape[0], -1])
|
||||
mean_squared_error = jnp.mean(jnp.sum(error * error, axis=1), axis=0)
|
||||
regularized_loss = mean_squared_error + l2_reg * l2_value
|
||||
|
||||
return regularized_loss, dict(mean_squared_error=mean_squared_error)
|
||||
|
||||
|
||||
def random_data(multi_device, batch_shape, rng):
|
||||
if multi_device:
|
||||
shape = (multi_device,) + tuple(batch_shape)
|
||||
else:
|
||||
shape = tuple(batch_shape)
|
||||
while True:
|
||||
rng, key = jax.random.split(rng)
|
||||
yield jax.random.normal(key, shape)
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv # Unused.
|
||||
|
||||
learning_rate = jnp.asarray([LEARNING_RATE.value])
|
||||
momentum = jnp.asarray([MOMENTUM.value])
|
||||
damping = jnp.asarray([DAMPING.value])
|
||||
|
||||
# RNG keys
|
||||
global_step = jnp.zeros([])
|
||||
rng = jax.random.PRNGKey(SEED.value)
|
||||
params_key, opt_key, step_key, data_key = jax.random.split(rng, 4)
|
||||
dataset = random_data(MULTI_DEVICE.value, (BATCH_SIZE.value, 20), data_key)
|
||||
example_batch = next(dataset)
|
||||
|
||||
if MULTI_DEVICE.value:
|
||||
global_step = utils.replicate_all_local_devices(global_step)
|
||||
learning_rate = utils.replicate_all_local_devices(learning_rate)
|
||||
momentum = utils.replicate_all_local_devices(momentum)
|
||||
damping = utils.replicate_all_local_devices(damping)
|
||||
params_key, opt_key = utils.replicate_all_local_devices(
|
||||
(params_key, opt_key))
|
||||
step_key = utils.make_different_rng_key_on_all_devices(step_key)
|
||||
split_key = jax.pmap(lambda x: tuple(jax.random.split(x)))
|
||||
jit_init_parameters_func = jax.pmap(model_init)
|
||||
else:
|
||||
split_key = jax.random.split
|
||||
jit_init_parameters_func = jax.jit(model_init)
|
||||
|
||||
# Initialize or load parameters
|
||||
params, func_state = jit_init_parameters_func(params_key, example_batch)
|
||||
|
||||
# Make optimizer
|
||||
optim = kfac_ferminet_alpha.Optimizer(
|
||||
value_and_grad_func=jax.value_and_grad(
|
||||
lambda p, x: model_loss(p, x, L2_REG.value), has_aux=True),
|
||||
l2_reg=L2_REG.value,
|
||||
value_func_has_aux=True,
|
||||
value_func_has_state=False,
|
||||
value_func_has_rng=False,
|
||||
learning_rate_schedule=None,
|
||||
momentum_schedule=None,
|
||||
damping_schedule=None,
|
||||
norm_constraint=1.0,
|
||||
num_burnin_steps=10,
|
||||
)
|
||||
|
||||
# Initialize optimizer
|
||||
opt_state = optim.init(params, opt_key, example_batch, func_state)
|
||||
|
||||
for t in range(TRAINING_STEPS.value):
|
||||
step_key, key_t = split_key(step_key)
|
||||
params, opt_state, stats = optim.step(
|
||||
params,
|
||||
opt_state,
|
||||
key_t,
|
||||
dataset,
|
||||
learning_rate=learning_rate,
|
||||
momentum=momentum,
|
||||
damping=damping)
|
||||
global_step = global_step + 1
|
||||
|
||||
# Log any of the statistics
|
||||
print(f"iteration: {t}")
|
||||
print(f"mini-batch loss = {stats['loss']}")
|
||||
if "aux" in stats:
|
||||
for k, v in stats["aux"].items():
|
||||
print(f"{k} = {v}")
|
||||
print("----")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
354
kfac_ferminet_alpha/layers_and_loss_tags.py
Normal file
354
kfac_ferminet_alpha/layers_and_loss_tags.py
Normal file
@@ -0,0 +1,354 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A module for registering already known functions for tagging patterns."""
|
||||
import functools
|
||||
|
||||
from typing import Sequence, Tuple, TypeVar
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax import lax
|
||||
from jax import lib as jax_lib
|
||||
from jax.interpreters import batching as jax_batching
|
||||
import jax.numpy as jnp
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class LossTag(jax_core.Primitive):
|
||||
"""A tagging primitive specifically for losses."""
|
||||
multiple_results = True
|
||||
|
||||
def __init__(self, cls, num_inputs: int, num_targets: int = 1):
|
||||
super().__init__(cls.__name__ + "_tag")
|
||||
self._cls = cls
|
||||
self._num_inputs = num_inputs
|
||||
self._num_targets = num_targets
|
||||
jax.xla.translations[self] = self.xla_translation
|
||||
jax.ad.primitive_jvps[self] = self.jvp
|
||||
# This line defines how does the tag behave under vmap. It is required for
|
||||
# any primitive that can be used inside a vmap. The reason why we want to
|
||||
# allow this is two fold - one to not break user code when the tags are not
|
||||
# used at all, and two - to be able to define a network with code for a
|
||||
# single example which is the vmap-ed for a batch.
|
||||
jax_batching.primitive_batchers[self] = self.batching
|
||||
|
||||
@property
|
||||
def num_inputs(self) -> int:
|
||||
return self._num_inputs
|
||||
|
||||
@property
|
||||
def num_targets(self) -> int:
|
||||
return self._num_targets
|
||||
|
||||
def loss(self, *args, weight: float = 1.0, **kwargs):
|
||||
return self._cls(*args, weight=weight, **kwargs)
|
||||
|
||||
def loss_evaluate(self, *args, weight: float = 1.0, **kwargs):
|
||||
return self.loss(*args, weight=weight, **kwargs).evaluate()
|
||||
|
||||
def get_outputs(self, *args, weight: float, return_loss: bool, **kwargs):
|
||||
if len(args) < self.num_inputs:
|
||||
raise ValueError("Inputs to the tag are not enough.")
|
||||
if len(args) < self.num_inputs + self.num_targets:
|
||||
if len(args) != self.num_inputs:
|
||||
raise ValueError("Inputs to the tag are not quite enough.")
|
||||
if return_loss:
|
||||
raise ValueError("Can not have return_loss=True when there are no "
|
||||
"targets.")
|
||||
return args
|
||||
if len(args) > self.num_inputs + self.num_targets:
|
||||
raise ValueError("Inputs to the tag are too many.")
|
||||
if return_loss:
|
||||
return self.loss(*args, weight=weight, **kwargs).evaluate()
|
||||
else:
|
||||
return args
|
||||
|
||||
def impl(self, *args, weight: float, return_loss: bool, **kwargs):
|
||||
return self.get_outputs(*args, weight=weight, return_loss=return_loss)
|
||||
|
||||
def abstract_eval(self, *args, weight: float, return_loss: bool, **kwargs):
|
||||
return self.get_outputs(*args, weight=weight, return_loss=return_loss)
|
||||
|
||||
def xla_translation(
|
||||
self,
|
||||
c,
|
||||
*args,
|
||||
weight: float = 1.0,
|
||||
return_loss: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
outputs = self.get_outputs(
|
||||
*args, weight=weight, return_loss=return_loss, **kwargs)
|
||||
if isinstance(outputs, tuple):
|
||||
return jax_lib.xla_client.ops.Tuple(c, outputs)
|
||||
return outputs
|
||||
|
||||
def jvp(
|
||||
self,
|
||||
arg_values,
|
||||
arg_tangents,
|
||||
weight: float,
|
||||
return_loss: bool,
|
||||
**kwargs,
|
||||
):
|
||||
if len(arg_values) != len(arg_tangents):
|
||||
raise ValueError("Values and tangents are not the same length.")
|
||||
primal_output = self.bind(
|
||||
*arg_values, weight=weight, return_loss=return_loss, **kwargs)
|
||||
if len(arg_values) == self.num_inputs:
|
||||
tangents_out = self.get_outputs(
|
||||
*arg_tangents, weight=weight, return_loss=return_loss, **kwargs)
|
||||
elif return_loss:
|
||||
tangents_out = jax.jvp(
|
||||
functools.partial(self.loss_evaluate, weight=weight, **kwargs),
|
||||
arg_tangents, arg_tangents)[1]
|
||||
else:
|
||||
tangents_out = arg_tangents
|
||||
return primal_output, tangents_out
|
||||
|
||||
def batching(self, batched_args, batched_dims, **kwargs):
|
||||
return self.bind(*batched_args, **kwargs), batched_dims[0]
|
||||
|
||||
|
||||
class LayerTag(jax_core.Primitive):
|
||||
"""A tagging primitive that is used to mark/tag computation."""
|
||||
|
||||
def __init__(self, name: str, num_inputs: int, num_outputs: int):
|
||||
super().__init__(name)
|
||||
if num_outputs > 1:
|
||||
raise NotImplementedError(
|
||||
f"Only single outputs are supported, got: num_outputs={num_outputs}")
|
||||
self._num_outputs = num_outputs
|
||||
self._num_inputs = num_inputs
|
||||
jax.xla.translations[self] = self.xla_translation
|
||||
jax.ad.deflinear(self, self.transpose)
|
||||
jax.ad.primitive_transposes[self] = self.transpose
|
||||
# This line defines how does the tag behave under vmap. It is required for
|
||||
# any primitive that can be used inside a vmap. The reason why we want to
|
||||
# allow this is two fold - one to not break user code when the tags are not
|
||||
# used at all, and two - to be able to define a network with code for a
|
||||
# single example which is the vmap-ed for a batch.
|
||||
jax_batching.primitive_batchers[self] = self.batching
|
||||
|
||||
@property
|
||||
def num_outputs(self) -> int:
|
||||
return self._num_outputs
|
||||
|
||||
@property
|
||||
def num_inputs(self) -> int:
|
||||
return self._num_inputs
|
||||
|
||||
def split_all_inputs(
|
||||
self,
|
||||
all_inputs: Sequence[_T],
|
||||
) -> Tuple[Sequence[_T], Sequence[_T], Sequence[_T]]:
|
||||
outputs = tuple(all_inputs[:self.num_outputs])
|
||||
inputs = tuple(all_inputs[self.num_outputs:self.num_outputs +
|
||||
self.num_inputs])
|
||||
params = tuple(all_inputs[self.num_outputs + self.num_inputs:])
|
||||
return outputs, inputs, params
|
||||
|
||||
def get_outputs(self, *operands: _T, **kwargs) -> _T:
|
||||
assert self.num_outputs == 1
|
||||
return operands[0]
|
||||
|
||||
def xla_translation(self, c, *operands: _T, **kwargs) -> _T:
|
||||
return self.get_outputs(*operands, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def transpose(cotangent, *operands, **kwargs):
|
||||
return (cotangent,) + (None,) * (len(operands) - 1)
|
||||
|
||||
def impl(self, *operands, **kwargs):
|
||||
return self.get_outputs(*operands, **kwargs)
|
||||
|
||||
def abstract_eval(self, *abstract_operands, **kwargs):
|
||||
return self.get_outputs(*abstract_operands, **kwargs)
|
||||
|
||||
def batching(self, batched_operands, batched_dims, **kwargs):
|
||||
return self.bind(*batched_operands, **kwargs), batched_dims[0]
|
||||
|
||||
|
||||
# _____ _
|
||||
# / ____| (_)
|
||||
# | | __ ___ _ __ ___ _ __ _ ___
|
||||
# | | |_ |/ _ \ '_ \ / _ \ '__| |/ __|
|
||||
# | |__| | __/ | | | __/ | | | (__
|
||||
# \_____|\___|_| |_|\___|_| |_|\___|
|
||||
#
|
||||
#
|
||||
|
||||
generic_tag = LayerTag(name="generic_tag", num_inputs=0, num_outputs=1)
|
||||
|
||||
|
||||
def register_generic(parameter: _T) -> _T:
|
||||
return generic_tag.bind(parameter)
|
||||
|
||||
|
||||
# _____
|
||||
# | __ \
|
||||
# | | | | ___ _ __ ___ ___
|
||||
# | | | |/ _ \ '_ \/ __|/ _ \
|
||||
# | |__| | __/ | | \__ \ __/
|
||||
# |_____/ \___|_| |_|___/\___|
|
||||
#
|
||||
|
||||
dense_tag = LayerTag(name="dense_tag", num_inputs=1, num_outputs=1)
|
||||
|
||||
|
||||
def register_dense(y, x, w, b=None):
|
||||
if b is None:
|
||||
return dense_tag.bind(y, x, w)
|
||||
return dense_tag.bind(y, x, w, b)
|
||||
|
||||
|
||||
def dense_func(x, params):
|
||||
"""Example of a dense layer function."""
|
||||
w = params[0]
|
||||
y = jnp.matmul(x, w)
|
||||
if len(params) == 1:
|
||||
# No bias
|
||||
return y
|
||||
# Add bias
|
||||
return y + params[1]
|
||||
|
||||
|
||||
def dense_tagging(jaxpr, inverse_map, values_map):
|
||||
"""Correctly registers a dense layer pattern."""
|
||||
del inverse_map
|
||||
in_values = [values_map[v] for v in jaxpr.invars]
|
||||
out_values = [values_map[v] for v in jaxpr.outvars]
|
||||
return register_dense(out_values[0], *in_values)
|
||||
|
||||
|
||||
# ___ _____ _____ _ _ _
|
||||
# |__ \| __ \ / ____| | | | | (_)
|
||||
# ) | | | | | | ___ _ ____ _____ | |_ _| |_ _ ___ _ __
|
||||
# / /| | | | | | / _ \| '_ \ \ / / _ \| | | | | __| |/ _ \| "_ \
|
||||
# / /_| |__| | | |___| (_) | | | \ V / (_) | | |_| | |_| | (_) | | | |
|
||||
# |____|_____/ \_____\___/|_| |_|\_/ \___/|_|\__,_|\__|_|\___/|_| |_|
|
||||
#
|
||||
|
||||
conv2d_tag = LayerTag(name="conv2d_tag", num_inputs=1, num_outputs=1)
|
||||
|
||||
|
||||
def register_conv2d(y, x, w, b=None, **kwargs):
|
||||
if b is None:
|
||||
return conv2d_tag.bind(y, x, w, **kwargs)
|
||||
return conv2d_tag.bind(y, x, w, b, **kwargs)
|
||||
|
||||
|
||||
def conv2d_func(x, params):
|
||||
"""Example of a conv2d layer function."""
|
||||
w = params[0]
|
||||
y = lax.conv_general_dilated(
|
||||
x,
|
||||
w,
|
||||
window_strides=(2, 2),
|
||||
padding="SAME",
|
||||
dimension_numbers=("NHWC", "HWIO", "NHWC"))
|
||||
if len(params) == 1:
|
||||
# No bias
|
||||
return y
|
||||
# Add bias
|
||||
return y + params[1][None, None, None]
|
||||
|
||||
|
||||
def conv2d_tagging(jaxpr, inverse_map, values_map):
|
||||
"""Correctly registers a conv2d layer pattern."""
|
||||
in_values = [values_map[v] for v in jaxpr.invars]
|
||||
out_values = [values_map[v] for v in jaxpr.outvars]
|
||||
keys = [k for k in inverse_map.keys() if isinstance(k, str)]
|
||||
keys = [k for k in keys if k.startswith("conv_general_dilated")]
|
||||
if len(keys) != 1:
|
||||
raise ValueError("Did not find any conv_general_dilated!")
|
||||
kwargs = inverse_map[keys[0]].params
|
||||
return register_conv2d(out_values[0], *in_values, **kwargs)
|
||||
|
||||
|
||||
# _____ _ _ _____ _ _ __ _
|
||||
# / ____| | | | | / ____| | (_)/ _| |
|
||||
# | (___ ___ __ _| | ___ __ _ _ __ __| | | (___ | |__ _| |_| |_
|
||||
# \___ \ / __/ _` | |/ _ \ / _` | '_ \ / _` | \___ \| '_ \| | _| __|
|
||||
# ____) | (_| (_| | | __/ | (_| | | | | (_| | ____) | | | | | | | |_
|
||||
# |_____/ \___\__,_|_|\___| \__,_|_| |_|\__,_| |_____/|_| |_|_|_| \__|
|
||||
#
|
||||
|
||||
scale_and_shift_tag = LayerTag(
|
||||
name="scale_and_shift_tag", num_inputs=1, num_outputs=1)
|
||||
|
||||
|
||||
def register_scale_and_shift(y, args, has_scale: bool, has_shift: bool):
|
||||
assert has_scale or has_shift
|
||||
x, args = args[0], args[1:]
|
||||
return scale_and_shift_tag.bind(
|
||||
y, x, *args, has_scale=has_scale, has_shift=has_shift)
|
||||
|
||||
|
||||
def scale_and_shift_func(x, params, has_scale: bool, has_shift: bool):
|
||||
"""Example of a scale and shift function."""
|
||||
if has_scale and has_shift:
|
||||
scale, shift = params
|
||||
return x * scale + shift
|
||||
elif has_scale:
|
||||
return x * params[0]
|
||||
elif has_shift:
|
||||
return x + params[0]
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
|
||||
def scale_and_shift_tagging(
|
||||
jaxpr,
|
||||
inverse_map,
|
||||
values_map,
|
||||
has_scale: bool,
|
||||
has_shift: bool,
|
||||
):
|
||||
"""Correctly registers a scale and shift layer pattern."""
|
||||
del inverse_map
|
||||
in_values = [values_map[v] for v in jaxpr.invars]
|
||||
out_values = [values_map[v] for v in jaxpr.outvars]
|
||||
return register_scale_and_shift(out_values[0], in_values, has_scale,
|
||||
has_shift)
|
||||
|
||||
|
||||
def batch_norm_func(
|
||||
inputs: Tuple[jnp.ndarray, jnp.ndarray],
|
||||
params: Tuple[jnp.ndarray, jnp.ndarray],
|
||||
) -> jnp.ndarray:
|
||||
"""Example of batch norm as is defined in Haiku."""
|
||||
x, y = inputs
|
||||
scale, shift = params
|
||||
inv = scale * y
|
||||
return x * inv + shift
|
||||
|
||||
|
||||
def batch_norm_tagging_func(
|
||||
jaxpr,
|
||||
inverse_map,
|
||||
values_map,
|
||||
has_scale: bool,
|
||||
has_shift: bool,
|
||||
):
|
||||
"""Correctly registers a batch norm layer pattern as is defined in Haiku."""
|
||||
del inverse_map
|
||||
in_values = [values_map[v] for v in jaxpr.invars]
|
||||
out_values = [values_map[v] for v in jaxpr.outvars]
|
||||
# The first two are both multipliers with the scale so we merge them
|
||||
in_values = [in_values[0] * in_values[1]] + in_values[2:]
|
||||
return register_scale_and_shift(out_values[0], in_values, has_scale,
|
||||
has_shift)
|
||||
653
kfac_ferminet_alpha/loss_functions.py
Normal file
653
kfac_ferminet_alpha/loss_functions.py
Normal file
File diff suppressed because it is too large
Load Diff
611
kfac_ferminet_alpha/optimizer.py
Normal file
611
kfac_ferminet_alpha/optimizer.py
Normal file
File diff suppressed because it is too large
Load Diff
7
kfac_ferminet_alpha/requirements.txt
Normal file
7
kfac_ferminet_alpha/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
jax>=0.2.10
|
||||
dataclasses>=0.6
|
||||
networkx>=2.1
|
||||
numpy>=1.16.4
|
||||
typing>=3.7.4.3
|
||||
ordered-set>=4.0.2
|
||||
absl-py>=0.12.0
|
||||
33
kfac_ferminet_alpha/run.sh
Executable file
33
kfac_ferminet_alpha/run.sh
Executable file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script installs kfac_ferminet_alpha in a clean virtualenv and runs an
|
||||
# example training loop. It is designed to be run from the parent directory,
|
||||
# e.g.:
|
||||
#
|
||||
# git clone git@github.com:deepmind/deepmind-research.git
|
||||
# cd deepmind_research
|
||||
# kfac_ferminet_alpha/run.sh
|
||||
|
||||
python3 -m venv /tmp/kfac_ferminet_alpha_example
|
||||
source /tmp/kfac_ferminet_alpha_example/bin/activate
|
||||
pip3 install -U pip
|
||||
pip3 install -r kfac_ferminet_alpha/requirements.txt
|
||||
# For a GPU you have to do:
|
||||
# pip3 install --upgrade jax jaxlib==0.1.64+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
pip3 install jaxlib
|
||||
pip3 install kfac_ferminet_alpha/
|
||||
python3 kfac_ferminet_alpha/example.py
|
||||
52
kfac_ferminet_alpha/setup.py
Normal file
52
kfac_ferminet_alpha/setup.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Setup for pip package."""
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
|
||||
REQUIRED_PACKAGES = (
|
||||
"absl-py",
|
||||
"dataclasses",
|
||||
"jax",
|
||||
"networkx",
|
||||
"numpy",
|
||||
"ordered-set",
|
||||
"typing",
|
||||
)
|
||||
|
||||
LONG_DESCRIPTION = "\n".join([
|
||||
"Kronecker-Factored Approximate Curvature (K-FAC) optimizer implemented in "
|
||||
"JAX.",
|
||||
"",
|
||||
"Accompanying code for 'Better, Faster Fermionic Neural Networks'",
|
||||
"James S. Spencer, David Pfau, Aleksandar Botev, and W. M. C. Foulkes.",
|
||||
"https://arxiv.org/abs/2011.07125.",
|
||||
])
|
||||
|
||||
|
||||
setup(
|
||||
name="kfac_ferminet_alpha",
|
||||
version="0.0.1",
|
||||
description="A K-FAC optimizer implemented in JAX",
|
||||
long_description=LONG_DESCRIPTION,
|
||||
url="https://github.com/deepmind/deepmind-research/kfac_ferminet_alpha",
|
||||
author="DeepMind",
|
||||
package_dir={"kfac_ferminet_alpha": "."},
|
||||
packages=["kfac_ferminet_alpha"],
|
||||
install_requires=REQUIRED_PACKAGES,
|
||||
platforms=["any"],
|
||||
license="Apache License, Version 2.0",
|
||||
)
|
||||
752
kfac_ferminet_alpha/tag_graph_matcher.py
Normal file
752
kfac_ferminet_alpha/tag_graph_matcher.py
Normal file
File diff suppressed because it is too large
Load Diff
76
kfac_ferminet_alpha/tests/common.py
Normal file
76
kfac_ferminet_alpha/tests/common.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common functions used across more than one test."""
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jnr
|
||||
|
||||
from kfac_ferminet_alpha import loss_functions
|
||||
|
||||
|
||||
def fully_connected_layer(params, x):
|
||||
w, b = params
|
||||
return jnp.matmul(x, w) + b[None]
|
||||
|
||||
|
||||
def init_autoencoder(key, data_shape):
|
||||
"""Initialize the standard autoencoder."""
|
||||
assert len(data_shape) == 1
|
||||
x_size = data_shape[0]
|
||||
sizes = [x_size, 1000, 500, 250, 30, 250, 500, 1000, x_size]
|
||||
keys = jnr.split(key, len(sizes) - 1)
|
||||
params = []
|
||||
for key, dim_in, dim_out in zip(keys, sizes, sizes[1:]):
|
||||
# Glorot uniform initialization
|
||||
c = jnp.sqrt(6 / (dim_in + dim_out))
|
||||
w = jax.random.uniform(key, shape=(dim_in, dim_out), minval=-c, maxval=c)
|
||||
b = jnp.zeros([dim_out])
|
||||
params.append((w, b))
|
||||
return params
|
||||
|
||||
|
||||
def autoencoder(all_params, x_in):
|
||||
"""Evaluate the standard autoencoder.
|
||||
|
||||
Note that the objective of this autoencoder is not standard, bur rather a sum
|
||||
of the standard sigmoid crossentropy and squared loss. The reason for this is
|
||||
to test on handling multiple losses.
|
||||
|
||||
Args:
|
||||
all_params: All parameter values.
|
||||
x_in: Inputs to the network.
|
||||
|
||||
Returns:
|
||||
The value of the two losses and intermediate layer values.
|
||||
"""
|
||||
h_in = x_in
|
||||
layers_values = []
|
||||
for i, params in enumerate(all_params):
|
||||
h_out = fully_connected_layer(params, h_in)
|
||||
layers_values.append((h_out, h_in))
|
||||
# Last layer does not have a nonlinearity
|
||||
if i % 4 != 3:
|
||||
# h_in = nn.leaky_relu(h_out)
|
||||
h_in = jnp.tanh(h_out)
|
||||
else:
|
||||
h_in = h_out
|
||||
h1, _ = loss_functions.register_normal_predictive_distribution(h_in, x_in)
|
||||
h2, _ = loss_functions.register_normal_predictive_distribution(
|
||||
h_in, targets=x_in, weight=0.1)
|
||||
l1 = (h1 - x_in)**2 + jnp.log(jnp.pi) / 2
|
||||
l1 = jnp.sum(l1, axis=-1)
|
||||
l2 = (h2 - x_in)**2 + jnp.log(jnp.pi) / 2
|
||||
l2 = jnp.sum(l2, axis=-1)
|
||||
return [l1, l2 * 0.1], layers_values
|
||||
85
kfac_ferminet_alpha/tests/graph_matcher_test.py
Normal file
85
kfac_ferminet_alpha/tests/graph_matcher_test.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jnr
|
||||
import jax.test_util as jtu
|
||||
|
||||
from kfac_ferminet_alpha import layers_and_loss_tags
|
||||
from kfac_ferminet_alpha import loss_functions
|
||||
from kfac_ferminet_alpha import tag_graph_matcher
|
||||
from kfac_ferminet_alpha.tests import common
|
||||
|
||||
|
||||
def tagged_autoencoder(all_params, x_in):
|
||||
h_in = x_in
|
||||
layers_values = []
|
||||
for i, params in enumerate(all_params):
|
||||
h_out = common.fully_connected_layer(params, h_in)
|
||||
h_out = layers_and_loss_tags.register_dense(h_out, h_in, params[0],
|
||||
params[1],)
|
||||
layers_values.append((h_out, h_in))
|
||||
# Last layer does not have a nonlinearity
|
||||
if i % 4 != 3:
|
||||
h_in = jnp.tanh(h_out)
|
||||
else:
|
||||
h_in = h_out
|
||||
h1, _ = loss_functions.register_normal_predictive_distribution(
|
||||
h_in, targets=x_in, weight=1.0)
|
||||
h2, t2 = loss_functions.register_normal_predictive_distribution(
|
||||
h_in, targets=x_in, weight=0.1)
|
||||
return [[h1, t2], [h2, t2]]
|
||||
|
||||
|
||||
class TestGraphMatcher(jtu.JaxTestCase):
|
||||
"""Class for running all of the tests for integrating the systems."""
|
||||
|
||||
def _test_jaxpr(self, init_func, model_func, tagged_model, data_shape):
|
||||
data_shape = tuple(data_shape)
|
||||
rng_key = jnr.PRNGKey(12345)
|
||||
init_key, data_key = jnr.split(rng_key)
|
||||
params = init_func(init_key, data_shape)
|
||||
data = jnr.normal(data_key, (11,) + data_shape)
|
||||
func = tag_graph_matcher.auto_register_tags(model_func, (params, data))
|
||||
jaxpr = jax.make_jaxpr(func)(params, data).jaxpr
|
||||
tagged_jaxpr = jax.make_jaxpr(tagged_model)(params, data).jaxpr
|
||||
self.assertEqual(len(jaxpr.invars), len(tagged_jaxpr.invars))
|
||||
self.assertEqual(len(jaxpr.constvars), len(tagged_jaxpr.constvars))
|
||||
self.assertEqual(len(jaxpr.outvars), len(tagged_jaxpr.outvars))
|
||||
for eq, tagged_eq in zip(jaxpr.eqns, tagged_jaxpr.eqns):
|
||||
eq_in_vars = [v for v in eq.invars if not isinstance(v, jax.core.UnitVar)]
|
||||
tagged_in_vars = [
|
||||
v for v in tagged_eq.invars if not isinstance(v, jax.core.UnitVar)
|
||||
]
|
||||
self.assertEqual(len(eq_in_vars), len(tagged_in_vars))
|
||||
self.assertEqual(len(eq.outvars), len(tagged_eq.outvars))
|
||||
self.assertEqual(eq.primitive, tagged_eq.primitive)
|
||||
for variable, t_variable in zip(eq_in_vars + eq.outvars,
|
||||
tagged_in_vars + tagged_eq.outvars):
|
||||
if isinstance(variable, jax.core.Literal):
|
||||
self.assertEqual(variable.aval, t_variable.aval)
|
||||
else:
|
||||
if variable.count != t_variable.count:
|
||||
print("0")
|
||||
self.assertEqual(variable.count, t_variable.count)
|
||||
|
||||
def test_autoencoder(self):
|
||||
self._test_jaxpr(common.init_autoencoder, common.autoencoder,
|
||||
tagged_autoencoder, [784])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
198
kfac_ferminet_alpha/tests/tracer_test.py
Normal file
198
kfac_ferminet_alpha/tests/tracer_test.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jnr
|
||||
import jax.test_util as jtu
|
||||
|
||||
from kfac_ferminet_alpha import loss_functions
|
||||
from kfac_ferminet_alpha import tag_graph_matcher as tgm
|
||||
from kfac_ferminet_alpha import tracer
|
||||
from kfac_ferminet_alpha import utils
|
||||
from kfac_ferminet_alpha.tests import common
|
||||
|
||||
|
||||
def autoencoder_aux(all_aux, all_params, x_in):
|
||||
h_in = x_in
|
||||
layers_values = []
|
||||
for i, (params, aux) in enumerate(zip(all_params, all_aux)):
|
||||
h_out = common.fully_connected_layer(params, h_in + aux[1]) + aux[0]
|
||||
layers_values.append((h_out, h_in))
|
||||
# Last layer does not have a nonlinearity
|
||||
if i % 4 != 3:
|
||||
h_in = jnp.tanh(h_out)
|
||||
else:
|
||||
h_in = h_out
|
||||
h1, _ = loss_functions.register_normal_predictive_distribution(h_in, x_in)
|
||||
h2, _ = loss_functions.register_normal_predictive_distribution(
|
||||
h_in, targets=x_in, weight=0.1)
|
||||
l1 = (h1 - x_in)**2 + jnp.log(jnp.pi) / 2
|
||||
l2 = (h2 - x_in)**2 + jnp.log(jnp.pi) / 2
|
||||
return [l1, l2 * 0.1], layers_values
|
||||
|
||||
|
||||
class TestTracer(jtu.JaxTestCase):
|
||||
"""Class for running all of the tests for integrating the systems."""
|
||||
|
||||
@staticmethod
|
||||
def generate_data(init_func, func, data_shape, rng_key):
|
||||
n = 3
|
||||
|
||||
rng_key, key = jnr.split(rng_key)
|
||||
params = init_func(key, data_shape)
|
||||
rng_key, key = jnr.split(rng_key)
|
||||
p_tangents = init_func(key, data_shape)
|
||||
rng_key, key = jnr.split(rng_key)
|
||||
data = jnr.normal(key, [n] + data_shape)
|
||||
|
||||
loss_vals, layer_vals = func(params, data)
|
||||
h = layer_vals[-1][0]
|
||||
keys = jnr.split(key, len(loss_vals))
|
||||
h_tangents = tuple(jnr.normal(key, shape=h.shape) for key in keys)
|
||||
|
||||
return params, data, p_tangents, h_tangents
|
||||
|
||||
def assertStructureAllClose(self, x, y, **kwargs):
|
||||
x_v, x_tree = jax.tree_flatten(x)
|
||||
y_v, y_tree = jax.tree_flatten(y)
|
||||
self.assertEqual(x_tree, y_tree)
|
||||
for xi, yi in zip(x_v, y_v):
|
||||
self.assertEqual(xi.shape, yi.shape)
|
||||
self.assertAllClose(xi, yi, check_dtypes=True, **kwargs)
|
||||
|
||||
def test_tacer_jvp(self):
|
||||
init_func = common.init_autoencoder
|
||||
func = common.autoencoder
|
||||
data_shape = [784]
|
||||
rng_key = jnr.PRNGKey(12345)
|
||||
params, data, p_tangents, _ = self.generate_data(init_func, func,
|
||||
data_shape, rng_key)
|
||||
|
||||
def no_data_func(args):
|
||||
outputs = func(args, data)
|
||||
return outputs[0], outputs[1][-1][0]
|
||||
|
||||
# True computation
|
||||
(primals_out, tangents_out) = jax.jvp(no_data_func, [params], [p_tangents])
|
||||
loss_vals, _ = primals_out
|
||||
_, h_tangents = tangents_out
|
||||
loss_tangents = ((h_tangents,),) * len(loss_vals)
|
||||
# Tracer computation
|
||||
tracer_jvp = tracer.trace_losses_matrix_vector_jvp(func)
|
||||
tracer_losses, tracer_loss_tangents = tracer_jvp((params, data), p_tangents)
|
||||
tracer_losses = [loss.evaluate(None) for loss in tracer_losses]
|
||||
|
||||
self.assertStructureAllClose(loss_vals, tracer_losses)
|
||||
self.assertStructureAllClose(loss_tangents, tracer_loss_tangents)
|
||||
|
||||
def test_tracer_vjp(self):
|
||||
init_func = common.init_autoencoder
|
||||
func = common.autoencoder
|
||||
data_shape = [784]
|
||||
rng_key = jnr.PRNGKey(12345)
|
||||
params, data, _, h_tangents = self.generate_data(init_func, func,
|
||||
data_shape, rng_key)
|
||||
|
||||
def no_data_func(args):
|
||||
outputs = func(args, data)
|
||||
return outputs[0], outputs[1][-1][0]
|
||||
|
||||
# True computation
|
||||
(loss_vals, _), vjp_func = jax.vjp(no_data_func, params)
|
||||
loss_tangents = jax.tree_map(jnp.zeros_like, loss_vals)
|
||||
summed_h_tangents = sum(jax.tree_flatten(h_tangents)[0])
|
||||
p_tangents = vjp_func((loss_tangents, summed_h_tangents))
|
||||
# Tracer computation
|
||||
trace_vjp = tracer.trace_losses_matrix_vector_vjp(func)
|
||||
tracer_losses, tracer_vjp_func = trace_vjp(params, data)
|
||||
tracer_losses = [loss.evaluate(None) for loss in tracer_losses]
|
||||
tracer_p_tangents = tracer_vjp_func(h_tangents)
|
||||
|
||||
self.assertStructureAllClose(loss_vals, tracer_losses)
|
||||
self.assertStructureAllClose(p_tangents, tracer_p_tangents)
|
||||
|
||||
def test_tracer_hvp(self):
|
||||
init_func = common.init_autoencoder
|
||||
func = common.autoencoder
|
||||
data_shape = [784]
|
||||
rng_key = jnr.PRNGKey(12345)
|
||||
params, data, p_tangents, _ = self.generate_data(init_func, func,
|
||||
data_shape, rng_key)
|
||||
|
||||
def no_data_func(args):
|
||||
outputs = func(args, data)
|
||||
return sum(jax.tree_map(jnp.sum, outputs[0]))
|
||||
|
||||
# True computation
|
||||
grad_func = jax.grad(no_data_func)
|
||||
|
||||
def grad_time_tangents(args):
|
||||
return utils.inner_product(grad_func(args), p_tangents)
|
||||
|
||||
hvp = jax.grad(grad_time_tangents)
|
||||
hvp_vectors = hvp(params)
|
||||
# Tracer computation
|
||||
tracer_hvp = tracer.trace_losses_matrix_vector_hvp(func)
|
||||
tracer_hvp_vectors = tracer_hvp((params, data), p_tangents)
|
||||
|
||||
self.assertStructureAllClose(hvp_vectors, tracer_hvp_vectors, atol=1e-4)
|
||||
|
||||
def test_trace_estimator(self):
|
||||
init_func = common.init_autoencoder
|
||||
func = common.autoencoder
|
||||
aux_func = autoencoder_aux
|
||||
data_shape = [784]
|
||||
rng_key = jnr.PRNGKey(12345)
|
||||
params, data, _, h_tangents = self.generate_data(init_func, func,
|
||||
data_shape, rng_key)
|
||||
|
||||
def aux_last_layer(aux, args):
|
||||
outs = aux_func(aux, args, data)
|
||||
return outs[1][-1][0]
|
||||
|
||||
# True computation
|
||||
loss_vals, layer_vals = func(params, data)
|
||||
aux_vals = jax.tree_map(jnp.zeros_like, layer_vals)
|
||||
_, vjp = jax.vjp(aux_last_layer, aux_vals, params)
|
||||
summed_h_tangents = sum(jax.tree_flatten(h_tangents)[0])
|
||||
aux_tangents, p_tangents = vjp(summed_h_tangents)
|
||||
layers_info = []
|
||||
for aux_p, p_p in zip(layer_vals, params):
|
||||
info = dict()
|
||||
info["outputs"] = (aux_p[0],)
|
||||
info["inputs"] = (aux_p[1],)
|
||||
info["params"] = (p_p[0], p_p[1])
|
||||
layers_info.append(info)
|
||||
for i, (aux_t, p_t) in enumerate(zip(aux_tangents, p_tangents)):
|
||||
info = dict()
|
||||
info["outputs_tangent"] = (aux_t[0],)
|
||||
info["inputs_tangent"] = (aux_t[1],)
|
||||
info["params_tangent"] = (p_t[0], p_t[1])
|
||||
layers_info[i].update(info)
|
||||
layers_info = tuple(layers_info)
|
||||
|
||||
func = tgm.auto_register_tags(func, (params, data))
|
||||
tracer_vjp = tracer.trace_estimator_vjp(func)
|
||||
tracer_losses, tracer_vjp_func = tracer_vjp((params, data))
|
||||
tracer_losses = [loss.evaluate(None) for loss in tracer_losses]
|
||||
tracer_outputs = tracer_vjp_func((h_tangents[:1], h_tangents[1:]))
|
||||
|
||||
self.assertStructureAllClose(loss_vals, tracer_losses)
|
||||
self.assertStructureAllClose(tracer_outputs, layers_info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
327
kfac_ferminet_alpha/tracer.py
Normal file
327
kfac_ferminet_alpha/tracer.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for the Jax tracer functionality for tags."""
|
||||
import functools
|
||||
from typing import Any, Callable, Sequence, Tuple
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import util as jax_util
|
||||
import jax.numpy as jnp
|
||||
|
||||
from kfac_ferminet_alpha import layers_and_loss_tags as tags
|
||||
from kfac_ferminet_alpha import tag_graph_matcher as tgm
|
||||
from kfac_ferminet_alpha import utils
|
||||
|
||||
_Function = Callable[[Any], Any]
|
||||
_Loss = tags.LossTag
|
||||
|
||||
|
||||
def extract_tags(
|
||||
jaxpr: core.Jaxpr
|
||||
) -> Tuple[Sequence[core.JaxprEqn], Sequence[core.JaxprEqn]]:
|
||||
"""Extracts all of the tag equations."""
|
||||
# Loop through equations and evaluate primitives using `bind`
|
||||
layer_tags = []
|
||||
loss_tags = []
|
||||
for eqn in jaxpr.eqns:
|
||||
if isinstance(eqn.primitive, tags.LossTag):
|
||||
loss_tags.append(eqn)
|
||||
elif isinstance(eqn.primitive, tags.LayerTag):
|
||||
layer_tags.append(eqn)
|
||||
return tuple(layer_tags), tuple(loss_tags)
|
||||
|
||||
|
||||
def construct_compute_losses_inputs(
|
||||
jaxpr: core.Jaxpr,
|
||||
consts: Tuple[Any],
|
||||
num_losses: int,
|
||||
primals: Any,
|
||||
params_index: int) -> Callable[[Any], Sequence[Sequence[jnp.ndarray]]]:
|
||||
"""Constructs a function that computes all of the inputs to all losses."""
|
||||
primals_ = list(primals)
|
||||
|
||||
def forward_compute_losses(
|
||||
params_primals: Any,
|
||||
) -> Sequence[Sequence[jnp.ndarray]]:
|
||||
primals_[params_index] = params_primals
|
||||
flat_args = jax.tree_flatten(primals_)[0]
|
||||
# Mapping from variable -> value
|
||||
env = dict()
|
||||
read = functools.partial(tgm.read_env, env)
|
||||
write = functools.partial(tgm.write_env, env)
|
||||
|
||||
# Bind args and consts to environment
|
||||
write(jax.core.unitvar, jax.core.unit)
|
||||
jax_util.safe_map(write, jaxpr.invars, flat_args)
|
||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||
|
||||
# Loop through equations and evaluate primitives using `bind`
|
||||
losses_so_far = 0
|
||||
loss_tags = []
|
||||
for eqn in jaxpr.eqns:
|
||||
tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
|
||||
if isinstance(eqn.primitive, tags.LossTag):
|
||||
loss_tags.append(eqn)
|
||||
losses_so_far += 1
|
||||
if num_losses is not None and losses_so_far == num_losses:
|
||||
break
|
||||
return tuple(tuple(read(v) for v in tag.invars) for tag in loss_tags)
|
||||
# return tuple(jax_util.safe_map(read, tag.invars) for tag in loss_tags)
|
||||
return forward_compute_losses
|
||||
|
||||
|
||||
# We know when `.primitive` will be either a `LossTag` or a `LayerTag`, however
|
||||
# pytype cannot infer its subclass, so we need to unbox it.
|
||||
|
||||
|
||||
def _unbox_loss_tag(jaxpr_eqn: core.JaxprEqn) -> tags.LossTag:
|
||||
assert isinstance(jaxpr_eqn.primitive, tags.LossTag)
|
||||
return jaxpr_eqn.primitive
|
||||
|
||||
|
||||
def _unbox_layer_tag(jaxpr_eqn: core.JaxprEqn) -> tags.LayerTag:
|
||||
assert isinstance(jaxpr_eqn.primitive, tags.LayerTag)
|
||||
return jaxpr_eqn.primitive
|
||||
|
||||
|
||||
def trace_losses_matrix_vector_vjp(tagged_func: _Function,
|
||||
params_index: int = 0):
|
||||
"""Returns the Jacobian-transposed vector product (backward mode) function in equivalent form to jax.vjp."""
|
||||
def vjp(*primals):
|
||||
typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals)
|
||||
jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals
|
||||
_, loss_jaxpr_eqns = extract_tags(jaxpr)
|
||||
n = len(loss_jaxpr_eqns)
|
||||
losses_func = construct_compute_losses_inputs(
|
||||
jaxpr, consts, n, primals, params_index)
|
||||
losses_inputs, full_vjp_func = jax.vjp(losses_func, primals[params_index])
|
||||
losses = []
|
||||
for jaxpr_eqn, inputs in zip(loss_jaxpr_eqns, losses_inputs):
|
||||
loss_tag = _unbox_loss_tag(jaxpr_eqn)
|
||||
losses.append(loss_tag.loss(*inputs, weight=jaxpr_eqn.params["weight"]))
|
||||
losses = tuple(losses)
|
||||
|
||||
def vjp_func(tangents):
|
||||
flat_tangents = jax.tree_flatten(tangents)[0]
|
||||
loss_invars = []
|
||||
loss_targets = []
|
||||
for jaxpr_eqn, inputs in zip(loss_jaxpr_eqns, losses_inputs):
|
||||
num_inputs = _unbox_loss_tag(jaxpr_eqn).num_inputs
|
||||
loss_invars.append(tuple(jaxpr_eqn.invars[:num_inputs]))
|
||||
loss_targets.append(inputs[num_inputs:])
|
||||
treedef = jax.tree_structure(loss_invars)
|
||||
tangents = jax.tree_unflatten(treedef, flat_tangents)
|
||||
# Since the losses could also take and targets as inputs and we don't want
|
||||
# this function to computes vjp w.r.t to those (e.g. the user should not
|
||||
# be providing tangent vectors for the targets, only for inputs) we have
|
||||
# to manually fill in these "extra" tangents with zeros.
|
||||
targets_tangents = jax.tree_map(jnp.zeros_like, loss_targets)
|
||||
tangents = tuple(ti + tti for ti, tti in zip(tangents, targets_tangents))
|
||||
input_tangents = full_vjp_func(tangents)[0]
|
||||
return input_tangents,
|
||||
return losses, vjp_func
|
||||
return vjp
|
||||
|
||||
|
||||
def trace_losses_matrix_vector_jvp(
|
||||
tagged_func: _Function,
|
||||
params_index: int = 0):
|
||||
"""Returns the Jacobian vector product (forward mode) function in equivalent form to jax.jvp."""
|
||||
def jvp(primals, params_tangents):
|
||||
typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals)
|
||||
jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals
|
||||
_, loss_tags = extract_tags(jaxpr)
|
||||
n = len(loss_tags)
|
||||
losses_func = construct_compute_losses_inputs(jaxpr, consts, n,
|
||||
primals, params_index)
|
||||
primals = (primals[params_index],)
|
||||
tangents = (params_tangents,)
|
||||
(primals_out, tangents_out) = jax.jvp(losses_func, primals, tangents)
|
||||
tangents_out = tuple(tuple(t[:tag.primitive.num_inputs])
|
||||
for t, tag in zip(tangents_out, loss_tags))
|
||||
losses = tuple(tag.primitive.loss(*inputs, weight=tag.params["weight"])
|
||||
for tag, inputs in zip(loss_tags, primals_out))
|
||||
return losses, tangents_out
|
||||
return jvp
|
||||
|
||||
|
||||
def trace_losses_matrix_vector_hvp(tagged_func, params_index=0):
|
||||
"""Returns the Hessian vector product function of **the tagged losses**, rather than the output value of `tagged_func`."""
|
||||
# The function uses backward-over-forward mode.
|
||||
|
||||
def hvp(primals, params_tangents):
|
||||
typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals)
|
||||
jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals
|
||||
_, loss_tags = extract_tags(jaxpr)
|
||||
n = len(loss_tags)
|
||||
losses_func = construct_compute_losses_inputs(
|
||||
jaxpr, consts, n, primals, params_index)
|
||||
|
||||
def losses_sum(param_primals):
|
||||
loss_inputs = losses_func(param_primals)
|
||||
losses = [
|
||||
_unbox_loss_tag(jaxpr_eqn).loss(
|
||||
*inputs, weight=jaxpr_eqn.params["weight"])
|
||||
for jaxpr_eqn, inputs in zip(loss_tags, loss_inputs)
|
||||
]
|
||||
# This computes the sum of losses evaluated. Makes it easier as we can
|
||||
# now use jax.grad rather than jax.vjp for taking derivatives.
|
||||
return sum(jnp.sum(loss.evaluate(None)) for loss in losses)
|
||||
|
||||
def grads_times_tangents(params_primals):
|
||||
grads = jax.grad(losses_sum)(params_primals)
|
||||
return utils.inner_product(grads, params_tangents)
|
||||
|
||||
return jax.grad(grads_times_tangents)(primals[params_index])
|
||||
return hvp
|
||||
|
||||
|
||||
def trace_estimator_vjp(tagged_func: _Function) -> _Function:
|
||||
"""Creates the function needed for an estimator of curvature matrices.
|
||||
|
||||
Args:
|
||||
tagged_func: An function that has been annotated with tags both for layers
|
||||
and losses.
|
||||
|
||||
Returns:
|
||||
A function with the same signatures as `tagged_func`, which when provided
|
||||
with inputs returns two things:
|
||||
1. The instances of all losses objected that are tagged.
|
||||
2. A second function, which when provide with tangent vectors for each
|
||||
of the loss instances' parameters, returns for every tagged layer a
|
||||
dictionary containing the following elements:
|
||||
inputs - The primal values of the inputs to the layer.
|
||||
outputs - The primal values of the outputs to the layer.
|
||||
params - The primal values of the layer.
|
||||
inputs_tangent - The tangent value of layer, given the provided
|
||||
tangents of the losses.
|
||||
inputs_tangent - The tangent value of layer, given the provided
|
||||
tangents of the losses.
|
||||
inputs_tangent - The tangent value of layer, given the provided
|
||||
tangents of the losses.
|
||||
"""
|
||||
def full_vjp_func(func_args):
|
||||
# Trace the tagged function
|
||||
typed_jaxpr = jax.make_jaxpr(tagged_func)(*func_args)
|
||||
jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals
|
||||
layer_tags, loss_tags = extract_tags(jaxpr)
|
||||
|
||||
layer_vars_flat = jax.tree_flatten([tag.invars for tag in layer_tags])[0]
|
||||
layer_input_vars = tuple(set(layer_vars_flat))
|
||||
|
||||
def forward():
|
||||
own_func_args = func_args
|
||||
# Mapping from variable -> value
|
||||
env = dict()
|
||||
read = functools.partial(tgm.read_env, env)
|
||||
write = functools.partial(tgm.write_env, env)
|
||||
|
||||
# Bind args and consts to environment
|
||||
write(jax.core.unitvar, jax.core.unit)
|
||||
jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0])
|
||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||
|
||||
# Loop through equations and evaluate primitives using `bind`
|
||||
num_losses_passed = 0
|
||||
for eqn in jaxpr.eqns:
|
||||
tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
|
||||
if isinstance(eqn.primitive, tags.LossTag):
|
||||
num_losses_passed += 1
|
||||
if num_losses_passed == len(loss_tags):
|
||||
break
|
||||
if num_losses_passed != len(loss_tags):
|
||||
raise ValueError("This should be unreachable.")
|
||||
|
||||
return jax_util.safe_map(read, layer_input_vars)
|
||||
|
||||
def forward_aux(aux):
|
||||
own_func_args = func_args
|
||||
# Mapping from variable -> value
|
||||
env = dict()
|
||||
read = functools.partial(tgm.read_env, env)
|
||||
def write(var, val):
|
||||
if not isinstance(var, (jax.core.Literal, jax.core.UnitVar)):
|
||||
val = val + aux[var] if var in aux else val
|
||||
env[var] = val
|
||||
|
||||
# Bind args and consts to environment
|
||||
write(jax.core.unitvar, jax.core.unit)
|
||||
jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0])
|
||||
jax_util.safe_map(write, jaxpr.constvars, consts)
|
||||
|
||||
# Loop through equations and evaluate primitives using `bind`
|
||||
num_losses_passed = 0
|
||||
losses_inputs_values = []
|
||||
losses_kwargs_values = []
|
||||
for eqn in jaxpr.eqns:
|
||||
input_values = jax_util.safe_map(read, eqn.invars)
|
||||
tgm.evaluate_eqn(eqn, input_values, write)
|
||||
if isinstance(eqn.primitive, tags.LossTag):
|
||||
loss = eqn.primitive.loss(*input_values, weight=eqn.params["weight"])
|
||||
losses_inputs_values.append(loss.inputs)
|
||||
losses_kwargs_values.append(dict(
|
||||
targets=loss.targets,
|
||||
weight=eqn.params["weight"]
|
||||
))
|
||||
num_losses_passed += 1
|
||||
if num_losses_passed == len(loss_tags):
|
||||
break
|
||||
if num_losses_passed != len(loss_tags):
|
||||
raise ValueError("This should be unreachable.")
|
||||
# Read the inputs to the loss functions, but also return the target values
|
||||
return tuple(losses_inputs_values), tuple(losses_kwargs_values)
|
||||
|
||||
layer_input_values = forward()
|
||||
primals_dict = dict(zip(layer_input_vars, layer_input_values))
|
||||
primals_dict.update(zip(jaxpr.invars, jax.tree_flatten(func_args)[0]))
|
||||
aux_values = jax.tree_map(jnp.zeros_like, layer_input_values)
|
||||
aux_dict = dict(zip(layer_input_vars, aux_values))
|
||||
|
||||
losses_args, aux_vjp, losses_kwargs = jax.vjp(forward_aux, aux_dict,
|
||||
has_aux=True)
|
||||
losses = tuple(tag.primitive.loss(*inputs, **kwargs)
|
||||
for tag, inputs, kwargs in
|
||||
zip(loss_tags, losses_args, losses_kwargs))
|
||||
|
||||
def vjp_func(tangents):
|
||||
all_tangents = aux_vjp(tangents)
|
||||
tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:]
|
||||
inputs_tangents = jax.tree_flatten(inputs_tangents)[0]
|
||||
tangents_dict.update(zip(jaxpr.invars, inputs_tangents))
|
||||
|
||||
read_primals = functools.partial(tgm.read_env, primals_dict)
|
||||
read_tangents = functools.partial(tgm.read_env, tangents_dict)
|
||||
layers_info = []
|
||||
for jaxpr_eqn in layer_tags:
|
||||
layer_tag = _unbox_layer_tag(jaxpr_eqn)
|
||||
info = dict()
|
||||
primals = jax_util.safe_map(read_primals, tuple(jaxpr_eqn.invars))
|
||||
(
|
||||
info["outputs"],
|
||||
info["inputs"],
|
||||
info["params"],
|
||||
) = layer_tag.split_all_inputs(primals)
|
||||
tangents = jax_util.safe_map(read_tangents, tuple(jaxpr_eqn.invars))
|
||||
(
|
||||
info["outputs_tangent"],
|
||||
info["inputs_tangent"],
|
||||
info["params_tangent"],
|
||||
) = layer_tag.split_all_inputs(tangents)
|
||||
layers_info.append(info)
|
||||
return tuple(layers_info)
|
||||
|
||||
return losses, vjp_func
|
||||
return full_vjp_func
|
||||
455
kfac_ferminet_alpha/utils.py
Normal file
455
kfac_ferminet_alpha/utils.py
Normal file
@@ -0,0 +1,455 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities related to multi-device operations."""
|
||||
import collections
|
||||
from typing import Any, Mapping, Optional, Sequence, Tuple, TypeVar, Union
|
||||
import dataclasses
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax.scipy import linalg
|
||||
import jax.tree_util as tree_util
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def wrap_if_pmap(p_func):
|
||||
|
||||
def p_func_if_pmap(obj, axis_name):
|
||||
try:
|
||||
core.axis_frame(axis_name)
|
||||
return p_func(obj, axis_name)
|
||||
except NameError:
|
||||
return obj
|
||||
|
||||
return p_func_if_pmap
|
||||
|
||||
|
||||
pmean_if_pmap = wrap_if_pmap(lax.pmean)
|
||||
psum_if_pmap = wrap_if_pmap(lax.psum)
|
||||
compute_mean = jax.pmap(lambda x: lax.pmean(x, "i"), axis_name="i")
|
||||
compute_sum = jax.pmap(lambda x: lax.psum(x, "i"), axis_name="i")
|
||||
|
||||
|
||||
def get_first(obj: T) -> T:
|
||||
return jax.tree_map(lambda x: x[0], obj)
|
||||
|
||||
|
||||
def get_mean(obj: T) -> T:
|
||||
return get_first(compute_mean(obj))
|
||||
|
||||
|
||||
def get_sum(obj: T) -> T:
|
||||
return get_first(compute_sum(obj))
|
||||
|
||||
|
||||
broadcast_all_local_devices = jax.pmap(lambda x: x)
|
||||
|
||||
|
||||
def replicate_all_local_devices(obj: T) -> T:
|
||||
n = jax.local_device_count()
|
||||
obj_stacked = jax.tree_map(lambda x: jnp.stack([x] * n, axis=0), obj)
|
||||
return broadcast_all_local_devices(obj_stacked)
|
||||
|
||||
|
||||
def make_different_rng_key_on_all_devices(rng: jnp.ndarray) -> jnp.ndarray:
|
||||
rng = jax.random.fold_in(rng, jax.host_id())
|
||||
rng = jax.random.split(rng, jax.local_device_count())
|
||||
return broadcast_all_local_devices(rng)
|
||||
|
||||
|
||||
p_split = jax.pmap(lambda key: tuple(jax.random.split(key)))
|
||||
|
||||
|
||||
def scalar_mul(obj: T, scalar: Union[float, jnp.ndarray]) -> T:
|
||||
return jax.tree_map(lambda x: x * scalar, obj)
|
||||
|
||||
|
||||
def scalar_div(obj: T, scalar: Union[float, jnp.ndarray]) -> T:
|
||||
return jax.tree_map(lambda x: x / scalar, obj)
|
||||
|
||||
|
||||
def make_func_args(params, func_state, rng, batch, has_state: bool,
|
||||
has_rng: bool):
|
||||
"""Correctly puts all arguments to the function together."""
|
||||
func_args = (params,)
|
||||
if has_state:
|
||||
if func_state is None:
|
||||
raise ValueError("The `func_state` is None, but the argument `has_state` "
|
||||
"is True.")
|
||||
func_args += (func_state,)
|
||||
if has_rng:
|
||||
if rng is None:
|
||||
raise ValueError("The `rng` is None, but the argument `has_rng` is True.")
|
||||
func_args += (rng,)
|
||||
func_args += (batch,)
|
||||
return func_args
|
||||
|
||||
|
||||
def extract_func_outputs(
|
||||
raw_outputs: Any,
|
||||
has_aux: bool,
|
||||
has_state: bool,
|
||||
) -> Tuple[jnp.ndarray, Any, Any]:
|
||||
"""Given the function output returns separately the loss, func_state, aux."""
|
||||
if not has_aux and not has_state:
|
||||
return raw_outputs, None, None
|
||||
loss, other = raw_outputs
|
||||
if has_aux and has_state:
|
||||
func_state, aux = other
|
||||
elif has_aux:
|
||||
func_state, aux = None, other
|
||||
else:
|
||||
func_state, aux = other, None
|
||||
return loss, func_state, aux
|
||||
|
||||
|
||||
def inner_product(obj1: T, obj2: T) -> jnp.ndarray:
|
||||
if jax.tree_structure(obj1) != jax.tree_structure(obj2):
|
||||
raise ValueError("The two structures are not identical.")
|
||||
elements_product = jax.tree_multimap(lambda x, y: jnp.sum(x * y), obj1, obj2)
|
||||
return sum(jax.tree_flatten(elements_product)[0])
|
||||
|
||||
|
||||
def psd_inv_cholesky(matrix: jnp.ndarray, damping: jnp.ndarray) -> jnp.ndarray:
|
||||
assert matrix.ndim == 2
|
||||
identity = jnp.eye(matrix.shape[0])
|
||||
matrix = matrix + damping * identity
|
||||
return linalg.solve(matrix, identity, sym_pos=True)
|
||||
|
||||
|
||||
def solve_maybe_small(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Computes a^-1 b more efficiently for small matrices."""
|
||||
assert a.shape[-1] == a.shape[-2] == b.shape[-1]
|
||||
d = a.shape[-1]
|
||||
if d == 0:
|
||||
return a
|
||||
elif d == 1:
|
||||
return b / a[..., 0]
|
||||
elif d == 2:
|
||||
det = a[..., 0, 0] * a[..., 1, 1] - a[..., 0, 1] * a[..., 1, 0]
|
||||
b_0 = a[..., 1, 1] * b[..., 0] - a[..., 0, 1] * b[..., 1]
|
||||
b_1 = a[..., 0, 0] * b[..., 1] - a[..., 1, 0] * b[..., 0]
|
||||
return jnp.stack([b_0, b_1], axis=-1) / det
|
||||
elif d == 3:
|
||||
raise NotImplementedError()
|
||||
return jnp.linalg.solve(a, b)
|
||||
|
||||
|
||||
def pi_adjusted_inverse(
|
||||
factor_0: jnp.ndarray,
|
||||
factor_1: jnp.ndarray,
|
||||
damping: jnp.ndarray,
|
||||
pmap_axis_name: str,
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
"""Performs inversion with pi-adjusted damping."""
|
||||
# Compute the norms of each factor
|
||||
norm_0 = jnp.trace(factor_0)
|
||||
norm_1 = jnp.trace(factor_1)
|
||||
|
||||
# We need to sync the norms here, because reduction can be non-deterministic.
|
||||
# They specifically are on GPUs by default for better performance.
|
||||
# Hence although factor_0 and factor_1 are synced, the trace operation above
|
||||
# can still produce different answers on different devices.
|
||||
norm_0, norm_1 = pmean_if_pmap((norm_0, norm_1), axis_name=pmap_axis_name)
|
||||
|
||||
# Compute the overall scale
|
||||
scale = norm_0 * norm_1
|
||||
|
||||
def regular_inverse(
|
||||
operand: Sequence[jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
factor0, factor1, norm0, norm1, s, d = operand
|
||||
# Special cases with one or two scalar factors
|
||||
if factor0.size == 1 and factor1.size == 1:
|
||||
value = jnp.ones_like(factor0) / jnp.sqrt(s)
|
||||
return value, value
|
||||
if factor0.size == 1:
|
||||
factor1_normed = factor1 / norm1
|
||||
damping1 = d / norm1
|
||||
factor1_inv = psd_inv_cholesky(factor1_normed, damping1)
|
||||
return jnp.full((1, 1), s), factor1_inv
|
||||
if factor1.size == 1:
|
||||
factor0_normed = factor0 / norm0
|
||||
damping0 = d / norm0
|
||||
factor0_inv = psd_inv_cholesky(factor0_normed, damping0)
|
||||
return factor0_inv, jnp.full((1, 1), s)
|
||||
|
||||
# Invert first factor
|
||||
factor0_normed = factor0 / norm0
|
||||
damping0 = jnp.sqrt(d * factor1.shape[0] / (s * factor0.shape[0]))
|
||||
factor0_inv = psd_inv_cholesky(factor0_normed, damping0) / jnp.sqrt(s)
|
||||
|
||||
# Invert second factor
|
||||
factor1_normed = factor1 / norm1
|
||||
damping1 = jnp.sqrt(d * factor0.shape[0] / (s * factor1.shape[0]))
|
||||
factor1_inv = psd_inv_cholesky(factor1_normed, damping1) / jnp.sqrt(s)
|
||||
return factor0_inv, factor1_inv
|
||||
|
||||
def zero_inverse(
|
||||
operand: Sequence[jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
return (jnp.eye(factor_0.shape[0]) / jnp.sqrt(operand[-1]),
|
||||
jnp.eye(factor_1.shape[0]) / jnp.sqrt(operand[-1]))
|
||||
|
||||
# In the special case where for some reason one of the factors is zero, then
|
||||
# the correct inverse of `(0 kron A + lambda I)` is
|
||||
# `(I/sqrt(lambda) kron (I/sqrt(lambda)`. However, because one of the norms is
|
||||
# zero, then `pi` and `1/pi` would be 0 and infinity leading to NaN values.
|
||||
# Hence, we need to make this check explicitly.
|
||||
return lax.cond(
|
||||
jnp.greater(scale, 0.0),
|
||||
regular_inverse,
|
||||
zero_inverse,
|
||||
operand=(factor_0, factor_1, norm_0, norm_1, scale, damping))
|
||||
|
||||
|
||||
def convert_value_and_grad_to_value_func(
|
||||
value_and_grad_func,
|
||||
has_aux: bool = False,
|
||||
):
|
||||
"""Converts a value_and_grad function to value_func only."""
|
||||
|
||||
def value_func(*args, **kwargs):
|
||||
out, _ = value_and_grad_func(*args, **kwargs)
|
||||
if has_aux:
|
||||
return out[0]
|
||||
else:
|
||||
return out
|
||||
|
||||
return value_func
|
||||
|
||||
|
||||
def check_structure_shapes_and_dtype(obj1: T, obj2: T) -> None:
|
||||
"""Verifies that the two objects have the same pytree structure."""
|
||||
assert jax.tree_structure(obj1) == jax.tree_structure(obj2)
|
||||
for v1, v2 in zip(jax.tree_flatten(obj1)[0], jax.tree_flatten(obj2)[0]):
|
||||
assert v1.shape == v2.shape
|
||||
assert v1.dtype == v2.dtype
|
||||
|
||||
|
||||
def check_first_dim_is_batch_size(batch_size: int, *args: jnp.ndarray) -> None:
|
||||
for i, arg in enumerate(args):
|
||||
if arg.shape[0] != batch_size:
|
||||
raise ValueError(f"Expecting first dimension of arg[{i}] with shape "
|
||||
f"{arg.shape} to be equal to the batch size "
|
||||
f"{batch_size}.")
|
||||
|
||||
|
||||
def py_tree_registered_dataclass(cls, *args, **kwargs):
|
||||
"""Creates a new dataclass type and registers it as a pytree node."""
|
||||
dcls = dataclasses.dataclass(cls, *args, **kwargs)
|
||||
tree_util.register_pytree_node(
|
||||
dcls,
|
||||
lambda instance: ( # pylint: disable=g-long-lambda
|
||||
[getattr(instance, f.name)
|
||||
for f in dataclasses.fields(instance)], None),
|
||||
lambda _, instance_args: dcls(*instance_args))
|
||||
return dcls
|
||||
|
||||
|
||||
class WeightedMovingAverage:
|
||||
"""A wrapped class for a variable for which we keep exponential moving average."""
|
||||
|
||||
def __init__(self, weight: jnp.ndarray, array: jnp.ndarray):
|
||||
self._weight = weight
|
||||
self._array = array
|
||||
|
||||
@staticmethod
|
||||
def zero(shape: Sequence[int]) -> "WeightedMovingAverage":
|
||||
return WeightedMovingAverage(weight=jnp.zeros([]), array=jnp.zeros(shape))
|
||||
|
||||
@property
|
||||
def weight(self) -> jnp.ndarray:
|
||||
return self._weight
|
||||
|
||||
@property
|
||||
def value(self) -> jnp.ndarray:
|
||||
return self._array / self._weight
|
||||
|
||||
@property
|
||||
def raw_value(self) -> jnp.ndarray:
|
||||
return self._array
|
||||
|
||||
def update(self, value: jnp.ndarray, old_weight_multiplier: float,
|
||||
new_weight: float) -> None:
|
||||
self._weight = old_weight_multiplier * self._weight + new_weight
|
||||
self._array = old_weight_multiplier * self._array + new_weight * value
|
||||
|
||||
def sync(self, pmap_axis_name: str) -> None:
|
||||
self._array = pmean_if_pmap(self._array, pmap_axis_name)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (f"ExponentialMovingAverage(weight={self._weight}, "
|
||||
f"array={self._array})")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
tree_util.register_pytree_node(
|
||||
WeightedMovingAverage,
|
||||
lambda instance: ((instance.weight, instance.raw_value), None),
|
||||
lambda _, instance_args: WeightedMovingAverage(*instance_args),
|
||||
)
|
||||
|
||||
|
||||
class Stateful:
|
||||
"""A class for stateful objects."""
|
||||
|
||||
def __init__(self, stateful_fields_names: Optional[Sequence[str]] = ()):
|
||||
self.__stateful_fields_names = stateful_fields_names
|
||||
|
||||
def _add_stateful_fields_names(self, value: Sequence[str]) -> None:
|
||||
self.__stateful_fields_names += tuple(value)
|
||||
|
||||
def get_state(self) -> Mapping[str, Any]:
|
||||
"""Returns the state of the object."""
|
||||
state = dict()
|
||||
for name in self.__stateful_fields_names:
|
||||
state[name] = Stateful._get_state_from_instance(getattr(self, name))
|
||||
return state
|
||||
|
||||
def set_state(self, value):
|
||||
"""Sets the state of the object with the provided value and returns the object."""
|
||||
assert isinstance(value, dict)
|
||||
for name in self.__stateful_fields_names:
|
||||
setattr(self, name,
|
||||
Stateful._set_state_to_instance(getattr(self, name), value[name]))
|
||||
return self
|
||||
|
||||
def clear_state(self) -> None:
|
||||
"""Clears the state of the object."""
|
||||
for name in self.__stateful_fields_names:
|
||||
setattr(self, name,
|
||||
Stateful._clear_state_from_instance(getattr(self, name)))
|
||||
|
||||
def pop_state(self) -> Mapping[str, Any]:
|
||||
"""Returns the current state of the object, while simultaneously clearing it."""
|
||||
state = self.get_state()
|
||||
self.clear_state()
|
||||
return state
|
||||
|
||||
@staticmethod
|
||||
def _get_state_from_instance(obj):
|
||||
"""Recursively gets the state of the object and returns it."""
|
||||
if isinstance(obj, Stateful):
|
||||
return obj.get_state()
|
||||
if isinstance(obj, list):
|
||||
return [Stateful._get_state_from_instance(i) for i in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(Stateful._get_state_from_instance(i) for i in obj)
|
||||
if isinstance(obj, collections.OrderedDict):
|
||||
return collections.OrderedDict(
|
||||
(k, Stateful._get_state_from_instance(v)) for k, v in obj.items())
|
||||
if isinstance(obj, dict):
|
||||
return dict(
|
||||
(k, Stateful._get_state_from_instance(v)) for k, v in obj.items())
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def _set_state_to_instance(obj, value):
|
||||
"""Recursively sets the state of the object and returns it."""
|
||||
if isinstance(obj, Stateful):
|
||||
obj.set_state(value)
|
||||
return obj
|
||||
if isinstance(value, list):
|
||||
if obj is None:
|
||||
obj = [None] * len(value)
|
||||
return [
|
||||
Stateful._set_state_to_instance(obj_i, value_i)
|
||||
for obj_i, value_i in zip(obj, value)
|
||||
]
|
||||
if isinstance(value, tuple):
|
||||
if obj is None:
|
||||
obj = [None] * len(value)
|
||||
return tuple(
|
||||
Stateful._set_state_to_instance(obj_i, value_i)
|
||||
for obj_i, value_i in zip(obj, value))
|
||||
if isinstance(value, collections.OrderedDict):
|
||||
if obj is None:
|
||||
obj = dict((k, None) for k in value)
|
||||
return collections.OrderedDict(
|
||||
(k, Stateful._set_state_to_instance(obj[k], value[k])) for k in obj)
|
||||
if isinstance(value, dict):
|
||||
obj = dict((k, None) for k in value)
|
||||
return dict(
|
||||
(k, Stateful._set_state_to_instance(obj[k], value[k])) for k in obj)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _clear_state_from_instance(obj):
|
||||
"""Recursively clears the state of the object and returns it."""
|
||||
if isinstance(obj, Stateful):
|
||||
obj.clear_state()
|
||||
return obj
|
||||
if isinstance(obj, list):
|
||||
return [Stateful._clear_state_from_instance(obj_i) for obj_i in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(Stateful._clear_state_from_instance(obj_i) for obj_i in obj)
|
||||
if isinstance(obj, collections.OrderedDict):
|
||||
return collections.OrderedDict(
|
||||
(k, Stateful._clear_state_from_instance(obj[k])) for k in obj)
|
||||
if isinstance(obj, dict):
|
||||
return dict((k, Stateful._clear_state_from_instance(obj[k])) for k in obj)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def infer_class_state(class_type):
|
||||
"""Infers a stateful class state attributes from class annotations."""
|
||||
if not issubclass(class_type, Stateful):
|
||||
raise ValueError(
|
||||
f"In order to annotate a class as stateful it must inherit "
|
||||
f"{Stateful!r}")
|
||||
|
||||
class_type = dataclasses.dataclass(
|
||||
class_type, init=False, repr=False, eq=False) # pytype: disable=wrong-keyword-args
|
||||
fields_names = tuple(field.name for field in dataclasses.fields(class_type))
|
||||
original_init = getattr(class_type, "__init__", None)
|
||||
if original_init is None:
|
||||
|
||||
def injected_init(self, *args, **kwargs):
|
||||
super(self.__class__, self).__init__(*args, **kwargs) # pylint: disable=bad-super-call
|
||||
Stateful._add_stateful_fields_names(self, fields_names)
|
||||
for field_name in fields_names:
|
||||
if getattr(self, field_name, None) is None:
|
||||
setattr(self, field_name, None)
|
||||
|
||||
setattr(class_type, "__init__", injected_init)
|
||||
else:
|
||||
|
||||
def injected_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
Stateful._add_stateful_fields_names(self, fields_names)
|
||||
for field_name in fields_names:
|
||||
if getattr(self, field_name, None) is None:
|
||||
setattr(self, field_name, None)
|
||||
|
||||
setattr(class_type, "__init__", injected_init)
|
||||
return class_type
|
||||
|
||||
|
||||
def compute_sq_norm_relative_abs_diff(obj, pmap_axis_name):
|
||||
sq_norm = inner_product(obj, obj)
|
||||
synced_sq_norm = psum_if_pmap(sq_norm, pmap_axis_name)
|
||||
synced_sq_norm = (synced_sq_norm - sq_norm) / (jax.device_count() - 1.0)
|
||||
sq_norm_abs_diff = jnp.abs(sq_norm - synced_sq_norm)
|
||||
return sq_norm_abs_diff / sq_norm
|
||||
|
||||
|
||||
def product(iterable_object):
|
||||
x = 1
|
||||
for element in iterable_object:
|
||||
x *= element
|
||||
return x
|
||||
53
object_attention_for_reasoning/README.md
Normal file
53
object_attention_for_reasoning/README.md
Normal file
@@ -0,0 +1,53 @@
|
||||
Implementation of the object-based transformer model from
|
||||
["Object-based attention for spatio-temporal reasoning"](https://arxiv.org/abs/2012.08508)
|
||||
[1].
|
||||
|
||||
This package includes source code for the transformer model,
|
||||
pre-trained model parameters for the CLEVRER task,
|
||||
and MONet [2] latent variables for all videos in the training
|
||||
and validation sets. It does not include the model training code.
|
||||
See Section 2 of [1] for details.
|
||||
|
||||
[1] David Ding, Felix Hill, Adam Santoro, Matt Botvinick. *Object-based
|
||||
attention for spatio-temporal reasoning: Outperforming neuro-symbolic models
|
||||
with flexible distributed architectures*.
|
||||
arXiv preprint arXiv:2012.08508, 2020.
|
||||
|
||||
[2] Chris P. Burgess, Loic Matthey, Nick Watters, Rishabh Kabra, Irina Higgins,
|
||||
Matt Botvinick, and Alexander Lerchner
|
||||
*MONet: Unsupervised scene decomposition and representation*.
|
||||
arXiv preprint arXiv:1901.11390, 2019.
|
||||
|
||||
|
||||
# Instructions
|
||||
|
||||
Note: This code depends on Tensorflow 1 and Sonnet 1. Tensorflow 1 is only
|
||||
available on PYPI for Python 3.7 and earlier.
|
||||
|
||||
To run this code, execute the following commands from the `deepmind_research/`
|
||||
directory:
|
||||
|
||||
```shell
|
||||
# Download checkpoints and MONet latents
|
||||
wget https://storage.googleapis.com/object-attention-for-reasoning/checkpoints_and_latents.zip
|
||||
unzip checkpoints_and_latents.zip
|
||||
python3.7 -m venv object_based_attention_venv
|
||||
source object_based_attention_venv/bin/activate
|
||||
pip install --upgrade setuptools wheel
|
||||
pip install -r requirements.txt
|
||||
python -m object_attention_for_reasoning.run_model
|
||||
```
|
||||
If the code runs correctly, you should see the model's predicted answer to two
|
||||
CLEVRER questions (a descriptive one and a multiple choice one), and both
|
||||
answers should be correct.
|
||||
|
||||
If you find the provided code useful, please cite this paper:
|
||||
```
|
||||
@article{objectattention2020,
|
||||
title={Object-based attention for spatio-temporal reasoning: Outperforming
|
||||
neuro-symbolic models with flexible distributed architectures},
|
||||
author={David Ding and Felix Hill and Adam Santoro and Matt Botvinick},
|
||||
journal={arXiv preprint arXiv:2012.08508},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
185
object_attention_for_reasoning/model.py
Normal file
185
object_attention_for_reasoning/model.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Model code. Provided settings are identical to what was used in the paper."""
|
||||
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from object_attention_for_reasoning import transformer
|
||||
|
||||
|
||||
QUESTION_VOCAB_SIZE = 82
|
||||
ANSWER_VOCAB_SIZE = 22
|
||||
|
||||
MAX_QUESTION_LENGTH = 20
|
||||
MAX_CHOICE_LENGTH = 12
|
||||
|
||||
NUM_CHOICES = 4
|
||||
EMBED_DIM = 16
|
||||
|
||||
PRETRAINED_MODEL_CONFIG = dict(
|
||||
use_relative_positions=True,
|
||||
shuffle_objects=True,
|
||||
transformer_layers=28,
|
||||
head_size=128,
|
||||
num_heads=10,
|
||||
embed_dim=EMBED_DIM,
|
||||
)
|
||||
|
||||
|
||||
def append_ids(tensor, id_vector, axis):
|
||||
id_vector = tf.constant(id_vector, tf.float32)
|
||||
for a in range(len(tensor.shape)):
|
||||
if a != axis:
|
||||
id_vector = tf.expand_dims(id_vector, axis=a)
|
||||
tiling_vector = [s if i != axis else 1 for i, s in enumerate(tensor.shape)]
|
||||
id_tensor = tf.tile(id_vector, tiling_vector)
|
||||
return tf.concat([tensor, id_tensor], axis=axis)
|
||||
|
||||
|
||||
class ClevrerTransformerModel(object):
|
||||
"""Model from Ding et al. 2020 (https://arxiv.org/abs/2012.08508)."""
|
||||
|
||||
def __init__(self, use_relative_positions, shuffle_objects,
|
||||
transformer_layers, num_heads, head_size, embed_dim):
|
||||
"""Instantiate Sonnet modules."""
|
||||
self._embed_dim = embed_dim
|
||||
self._embed = snt.Embed(QUESTION_VOCAB_SIZE, embed_dim - 2)
|
||||
self._shuffle_objects = shuffle_objects
|
||||
self._memory_transformer = transformer.TransformerTower(
|
||||
value_size=embed_dim + 2,
|
||||
num_heads=num_heads,
|
||||
num_layers=transformer_layers,
|
||||
use_relative_positions=use_relative_positions,
|
||||
causal=False)
|
||||
|
||||
self._final_layer_mc = snt.Sequential(
|
||||
[snt.Linear(head_size), tf.nn.relu, snt.Linear(1)])
|
||||
self._final_layer_descriptive = snt.Sequential(
|
||||
[snt.Linear(head_size), tf.nn.relu,
|
||||
snt.Linear(ANSWER_VOCAB_SIZE)])
|
||||
|
||||
self._dummy = tf.get_variable("dummy", [embed_dim + 2], tf.float32,
|
||||
initializer=tf.zeros_initializer)
|
||||
self._infill_linear = snt.Linear(embed_dim + 2)
|
||||
self._mask_embedding = tf.get_variable(
|
||||
"mask", [embed_dim + 2], tf.float32, initializer=tf.zeros_initializer)
|
||||
|
||||
def _apply_transformers(self, lang_embedding, vision_embedding):
|
||||
"""Applies transformer to language and vision input.
|
||||
|
||||
Args:
|
||||
lang_embedding: tensor,
|
||||
vision_embedding: tensor, "validation", or "test".
|
||||
|
||||
Returns:
|
||||
tuple, output at dummy token, all output embeddings, infill loss
|
||||
"""
|
||||
def _unroll(tensor):
|
||||
"""Unroll the time dimension into the object dimension."""
|
||||
return tf.reshape(
|
||||
tensor, [tensor.shape[0], -1, tensor.shape[3]])
|
||||
|
||||
words = append_ids(lang_embedding, [1, 0], axis=2)
|
||||
dummy_word = tf.tile(self._dummy[None, None, :], [tf.shape(words)[0], 1, 1])
|
||||
vision_embedding = append_ids(vision_embedding, [0, 1], axis=3)
|
||||
vision_over_time = _unroll(vision_embedding)
|
||||
transformer_input = tf.concat([dummy_word, words, vision_over_time], axis=1)
|
||||
|
||||
output, _ = self._memory_transformer(transformer_input,
|
||||
is_training=False)
|
||||
return output[:, 0, :]
|
||||
|
||||
def apply_model_descriptive(self, inputs):
|
||||
"""Applies model to CLEVRER descriptive questions.
|
||||
|
||||
Args:
|
||||
inputs: dict of form: {
|
||||
"question": tf.int32 tensor of shape [batch, MAX_QUESTION_LENGTH],
|
||||
"monet_latents": tf.float32 tensor of shape [batch, frames, 8, 16],
|
||||
}
|
||||
Returns:
|
||||
Tensor of shape [batch, ANSWER_VOCAB_SIZE], representing logits for each
|
||||
possible answer word.
|
||||
"""
|
||||
question = inputs["question"]
|
||||
|
||||
# Shape: [batch, question_len, embed_dim-2]
|
||||
question_embedding = self._embed(question)
|
||||
# Shape: [batch, question_len, embed_dim]
|
||||
question_embedding = append_ids(question_embedding, [0, 1], 2)
|
||||
choices_embedding = self._embed(
|
||||
tf.zeros([question.shape[0], MAX_CHOICE_LENGTH], tf.int64))
|
||||
choices_embedding = append_ids(choices_embedding, [0, 1], 2)
|
||||
# Shape: [batch, choices, question_len + choice_len, embed_dim]
|
||||
lang_embedding = tf.concat([question_embedding, choices_embedding], axis=1)
|
||||
|
||||
# Shape: [batch, frames, num_objects, embed_dim]
|
||||
vision_embedding = inputs["monet_latents"]
|
||||
|
||||
if self._shuffle_objects:
|
||||
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
|
||||
vision_embedding = tf.random.shuffle(vision_embedding)
|
||||
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
|
||||
|
||||
output = self._apply_transformers(lang_embedding, vision_embedding)
|
||||
output = self._final_layer_descriptive(output)
|
||||
return output
|
||||
|
||||
def apply_model_mc(self, inputs):
|
||||
"""Applies model to CLEVRER multiple-choice questions.
|
||||
|
||||
Args:
|
||||
inputs: dict of form: {
|
||||
"question": tf.int32 tensor of shape [batch, MAX_QUESTION_LENGTH],
|
||||
"choices": tf.int32 tensor of shape [batch, 4, MAX_CHOICE_LENGTH],
|
||||
"monet_latents": tf.float32 tensor of shape [batch, frames, 8, 16],
|
||||
}
|
||||
Returns:
|
||||
Tensor of shape [batch, 4], representing logits for each choice
|
||||
"""
|
||||
question = inputs["question"]
|
||||
choices = inputs["choices"]
|
||||
|
||||
# Shape: [batch, question_len, embed_dim-2]
|
||||
question_embedding = self._embed(question)
|
||||
# Shape: [batch, question_len, embed_dim]
|
||||
question_embedding = append_ids(question_embedding, [1, 0], 2)
|
||||
# Shape: [batch, choices, choice_len, embed_dim-2]
|
||||
choices_embedding = snt.BatchApply(self._embed)(choices)
|
||||
# Shape: [batch, choices, choice_len, embed_dim]
|
||||
choices_embedding = append_ids(choices_embedding, [0, 1], 3)
|
||||
# Shape: [batch, choices, question_len + choice_len, embed_dim]
|
||||
lang_embedding = tf.concat([
|
||||
tf.tile(question_embedding[:, None],
|
||||
[1, choices_embedding.shape[1], 1, 1]),
|
||||
choices_embedding], axis=2)
|
||||
|
||||
# Shape: [batch, frames, num_objects, embed_dim]
|
||||
vision_embedding = inputs["monet_latents"]
|
||||
|
||||
if self._shuffle_objects:
|
||||
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
|
||||
vision_embedding = tf.random.shuffle(vision_embedding)
|
||||
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
|
||||
|
||||
output_per_choice = []
|
||||
for c in range(NUM_CHOICES):
|
||||
output = self._apply_transformers(
|
||||
lang_embedding[:, c, :, :], vision_embedding)
|
||||
output_per_choice.append(output)
|
||||
|
||||
output = tf.stack(output_per_choice, axis=1)
|
||||
output = tf.squeeze(snt.BatchApply(self._final_layer_mc)(output), axis=2)
|
||||
return output
|
||||
4
object_attention_for_reasoning/requirements.txt
Normal file
4
object_attention_for_reasoning/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
absl-py==0.11.0
|
||||
dm-sonnet==1.36
|
||||
numpy==1.20.1
|
||||
tensorflow==1.15.0
|
||||
163
object_attention_for_reasoning/run_model.py
Normal file
163
object_attention_for_reasoning/run_model.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Example code for running model on CLEVRER."""
|
||||
import json
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from object_attention_for_reasoning import model as modellib
|
||||
|
||||
|
||||
BATCH_SIZE = 1
|
||||
NUM_FRAMES = 25
|
||||
NUM_OBJECTS = 8
|
||||
|
||||
_BASE_DIR = flags.DEFINE_string(
|
||||
"base_dir", "./clevrer_monet_latents",
|
||||
"Directory containing checkpoints and MONet latents.")
|
||||
_SCENE_IDX = flags.DEFINE_string(
|
||||
"scene_idx", 1000, "Scene index of CLEVRER video.")
|
||||
|
||||
|
||||
def load_monet_latents(base_dir, scene_index):
|
||||
filename = f"{base_dir}/train/{scene_index}.npz"
|
||||
with open(filename, "rb") as f:
|
||||
return np.load(f)
|
||||
|
||||
|
||||
def _split_string(s):
|
||||
"""Splits string to words and standardize alphabet."""
|
||||
return s.lower().replace("?", "").split()
|
||||
|
||||
|
||||
def _pad(array, length):
|
||||
"""Pad an array to desired length."""
|
||||
return np.pad(array, [(0, length - array.shape[0])], mode="constant")
|
||||
|
||||
|
||||
def encode_sentence(token_map, sentence, pad_length):
|
||||
"""Encode CLEVRER question/choice sentences as sequence of token ids."""
|
||||
ret = np.array(
|
||||
[token_map["question_vocab"][w] for w in _split_string(sentence)],
|
||||
np.int32)
|
||||
return _pad(ret, pad_length)
|
||||
|
||||
|
||||
def encode_choices(token_map, choices):
|
||||
"""Encode CLEVRER choices."""
|
||||
arrays = [encode_sentence(token_map, choice["choice"],
|
||||
modellib.MAX_CHOICE_LENGTH)
|
||||
for choice in choices]
|
||||
return _pad(np.stack(arrays, axis=0), modellib.NUM_CHOICES)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
base_dir = _BASE_DIR.value
|
||||
with open(f"{base_dir}/vocab.json", "rb") as f:
|
||||
token_map = json.load(f)
|
||||
|
||||
reverse_answer_lookup = {v: k for k, v in token_map["answer_vocab"].items()}
|
||||
|
||||
with open(f"{base_dir}/train.json", "rb") as f:
|
||||
questions_data = json.load(f)
|
||||
|
||||
tf.reset_default_graph()
|
||||
model = modellib.ClevrerTransformerModel(**modellib.PRETRAINED_MODEL_CONFIG)
|
||||
|
||||
inputs_descriptive = {
|
||||
"monet_latents": tf.placeholder(
|
||||
tf.float32,
|
||||
[BATCH_SIZE, NUM_FRAMES, NUM_OBJECTS, modellib.EMBED_DIM]),
|
||||
"question": tf.placeholder(
|
||||
tf.int32, [BATCH_SIZE, modellib.MAX_QUESTION_LENGTH]),
|
||||
}
|
||||
|
||||
inputs_mc = {
|
||||
"monet_latents": tf.placeholder(
|
||||
tf.float32,
|
||||
[BATCH_SIZE, NUM_FRAMES, NUM_OBJECTS, modellib.EMBED_DIM]),
|
||||
"question": tf.placeholder(tf.int32,
|
||||
[BATCH_SIZE, modellib.MAX_QUESTION_LENGTH]),
|
||||
"choices": tf.placeholder(
|
||||
tf.int32, [BATCH_SIZE, modellib.NUM_CHOICES,
|
||||
modellib.MAX_CHOICE_LENGTH]),
|
||||
}
|
||||
|
||||
output_descriptive = model.apply_model_descriptive(inputs_descriptive)
|
||||
output_mc = model.apply_model_mc(inputs_mc)
|
||||
|
||||
# Restore from checkpoint
|
||||
saver = tf.train.Saver()
|
||||
checkpoint_dir = f"{base_dir}/checkpoints/"
|
||||
sess = tf.train.SingularMonitoredSession(checkpoint_dir=checkpoint_dir)
|
||||
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
|
||||
saver.restore(sess, ckpt.model_checkpoint_path)
|
||||
|
||||
def eval_descriptive(monet_latents, question_json):
|
||||
# CLEVRER provides videos with 128 frames. In our model, we subsample 25
|
||||
# frames (as was done in Yi et al (2020)).
|
||||
# For training, we randomize the choice of 25 frames, and for evaluation, we
|
||||
# sample the 25 frames as evenly as possible.
|
||||
# We do that by doing strided sampling of the frames.
|
||||
stride, rem = divmod(monet_latents.shape[0], NUM_FRAMES)
|
||||
monet_latents = monet_latents[None, :-rem:stride]
|
||||
assert monet_latents.shape[1] == NUM_FRAMES
|
||||
question = encode_sentence(token_map, question_json["question"],
|
||||
modellib.MAX_QUESTION_LENGTH)
|
||||
batched_question = np.expand_dims(question, axis=0)
|
||||
logits = sess.run(output_descriptive, feed_dict={
|
||||
inputs_descriptive["monet_latents"]: monet_latents,
|
||||
inputs_descriptive["question"]: batched_question,
|
||||
})
|
||||
descriptive_answer = np.argmax(logits)
|
||||
return reverse_answer_lookup[descriptive_answer]
|
||||
|
||||
def eval_mc(monet_latents, question_json):
|
||||
stride, rem = divmod(monet_latents.shape[0], NUM_FRAMES)
|
||||
monet_latents = monet_latents[None, :-rem:stride]
|
||||
assert monet_latents.shape[1] == NUM_FRAMES
|
||||
question = encode_sentence(
|
||||
token_map, question_json["question"], modellib.MAX_QUESTION_LENGTH)
|
||||
choices = encode_choices(
|
||||
token_map, question_json["choices"])
|
||||
mc_answer = sess.run(output_mc, feed_dict={
|
||||
inputs_mc["monet_latents"]: monet_latents,
|
||||
inputs_mc["question"]: np.expand_dims(question, axis=0),
|
||||
inputs_mc["choices"]: np.expand_dims(choices, axis=0),
|
||||
})
|
||||
return mc_answer >= 0
|
||||
|
||||
sample_scene_idx = _SCENE_IDX.value
|
||||
question_json = questions_data[sample_scene_idx]["questions"][0]
|
||||
print("Descriptive Question: ", question_json["question"])
|
||||
print("Model Answer: ",
|
||||
eval_descriptive(load_monet_latents(base_dir, sample_scene_idx),
|
||||
question_json))
|
||||
print("True Answer: ", question_json["answer"])
|
||||
|
||||
question_json = questions_data[sample_scene_idx]["questions"][-1]
|
||||
print("Multiple-Choice Question: ", question_json["question"])
|
||||
for i, choice_json in enumerate(question_json["choices"]):
|
||||
print(f"{i+1}) {choice_json['choice']}")
|
||||
print("Model Answer: ",
|
||||
eval_mc(load_monet_latents(base_dir, sample_scene_idx), question_json))
|
||||
print("True Answer: ",
|
||||
[choice_json["answer"] for choice_json in question_json["choices"]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
667
object_attention_for_reasoning/transformer.py
Normal file
667
object_attention_for_reasoning/transformer.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user