mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Adds the code for "gated_linear_networks" to the files release.bara.sky and README.md for public release on github.
PiperOrigin-RevId: 338219746
This commit is contained in:
@@ -0,0 +1,44 @@
|
||||
# Gated Linear Networks
|
||||
|
||||
Gated Linear Networks (GLNs) are a family of backpropation-free neural networks.
|
||||
Each neuron in a GLN predicts the target density (or probability mass) based on
|
||||
the outputs of the previous layer and is trained under a logarthmic loss.
|
||||
|
||||
## GLN variants
|
||||
|
||||
Neurons have probabilistic "activation functions". Implementations are provided
|
||||
for the following distributions:
|
||||
|
||||
- Gaussian, for regression.
|
||||
|
||||
- Bernoulli, for binary classification and multi-class classification using a
|
||||
one-vs-all scheme.
|
||||
|
||||
## Examples
|
||||
|
||||
Usage examples are provided in [`examples`](examples).
|
||||
|
||||
## Implementation details
|
||||
|
||||
### Constraint satisfaction
|
||||
|
||||
Because each neuron implements a probability density/mass function we need to
|
||||
ensure that they are well defined. For example, the scale parameter for a
|
||||
Gaussian density needs to be positive. We implement these constraints using
|
||||
linear projections and clipping.
|
||||
|
||||
### Aggregation
|
||||
|
||||
Because each neuron predicts the target, we can use any neuron output as the
|
||||
"network output", and are not bound to the last layer. Typically last layer
|
||||
neuron(s) are the best predictors, but they might take longer to converge in
|
||||
theory. In this implementation, we use a single neuron at the last layer, which
|
||||
then forms the network output.
|
||||
|
||||
There are alternative ways of aggregating, e.g. see Switching Aggregation in
|
||||
Appendix D of *Gaussian Gated Linear Networks* (link:
|
||||
https://arxiv.org/pdf/2006.05964.pdf).
|
||||
|
||||
## References
|
||||
|
||||
Coming soon.
|
||||
@@ -0,0 +1,347 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Base classes for Gated Linear Networks."""
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple
|
||||
|
||||
import chex
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
Array = chex.Array
|
||||
DType = Any
|
||||
Initializer = hk.initializers.Initializer
|
||||
Shape = Sequence[int]
|
||||
|
||||
EPS = 1e-12
|
||||
MIN_ALPHA = 1e-5
|
||||
|
||||
|
||||
def _l2_normalize(x: Array, axis: int) -> Array:
|
||||
return x / jnp.sqrt(jnp.maximum(jnp.sum(x**2, axis, keepdims=True), EPS))
|
||||
|
||||
|
||||
def _wrapped_fn_argnames(fun):
|
||||
"""Returns list of argnames of a (possibly wrapped) function."""
|
||||
return tuple(inspect.signature(fun).parameters)
|
||||
|
||||
|
||||
def _vmap(fun, in_axes=0, out_axes=0, parameters=None):
|
||||
"""JAX vmap with human-friendly axes."""
|
||||
|
||||
def _axes(fun, d):
|
||||
"""Maps dict {kwarg_i, : val_i} to [None, ..., val_i, ..., None]."""
|
||||
argnames = _wrapped_fn_argnames(fun) if not parameters else parameters
|
||||
for key in d:
|
||||
if key not in argnames:
|
||||
raise ValueError(f"{key} is not a valid axis.")
|
||||
return tuple(d.get(key, None) for key in argnames)
|
||||
|
||||
in_axes = _axes(fun, in_axes) if isinstance(in_axes, dict) else in_axes
|
||||
return jax.vmap(fun, in_axes, out_axes)
|
||||
|
||||
# Map a neuron-level function across a layer.
|
||||
_layer_vmap = functools.partial(
|
||||
_vmap,
|
||||
in_axes=({
|
||||
"weights": 0,
|
||||
"hyperplanes": 0,
|
||||
"hyperplane_bias": 0,
|
||||
}))
|
||||
|
||||
|
||||
class NormalizedRandomNormal(hk.initializers.RandomNormal):
|
||||
"""Random normal initializer with l2-normalization."""
|
||||
|
||||
def __init__(self,
|
||||
stddev: float = 1.,
|
||||
mean: float = 0.,
|
||||
normalize_axis: int = 0):
|
||||
super(NormalizedRandomNormal, self).__init__(stddev, mean)
|
||||
self._normalize_axis = normalize_axis
|
||||
|
||||
def __call__(self, shape: Shape, dtype: DType) -> Array:
|
||||
if self._normalize_axis >= len(shape):
|
||||
raise ValueError("Cannot normalize axis {} for ndim = {}.".format(
|
||||
self._normalize_axis, len(shape)))
|
||||
weights = super(NormalizedRandomNormal, self).__call__(shape, dtype)
|
||||
return _l2_normalize(weights, axis=self._normalize_axis)
|
||||
|
||||
|
||||
class ShapeScaledConstant(hk.initializers.Initializer):
|
||||
"""Initializes with a constant dependent on last dimension of input shape."""
|
||||
|
||||
def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:
|
||||
constant = 1. / shape[-1]
|
||||
return jnp.broadcast_to(constant, shape).astype(dtype)
|
||||
|
||||
|
||||
class LocalUpdateModule(hk.Module):
|
||||
"""Abstract base class for GLN variants and utils."""
|
||||
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
if hasattr(self, "__call__"):
|
||||
raise ValueError("Do not implement `__call__` for a LocalUpdateModule." +
|
||||
" Implement `inference` and `update` instead.")
|
||||
super(LocalUpdateModule, self).__init__(name)
|
||||
|
||||
@abc.abstractmethod
|
||||
def inference(self, *args, **kwargs):
|
||||
"""Module inference step."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, *args, **kwargs):
|
||||
"""Module update step."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def output_sizes(self) -> Shape:
|
||||
"""Returns network output sizes."""
|
||||
|
||||
|
||||
class GatedLinearNetwork(LocalUpdateModule):
|
||||
"""Abstract base class for a multi-layer Gated Linear Network."""
|
||||
|
||||
def __init__(self,
|
||||
output_sizes: Shape,
|
||||
context_dim: int,
|
||||
inference_fn: Callable[..., Array],
|
||||
update_fn: Callable[..., Array],
|
||||
init: Initializer,
|
||||
hyp_w_init: Optional[Initializer] = None,
|
||||
hyp_b_init: Optional[Initializer] = None,
|
||||
dtype: DType = jnp.float32,
|
||||
name: str = "gated_linear_network"):
|
||||
"""Initialize a GatedLinearNetwork as a sequence of GatedLinearLayers."""
|
||||
super(GatedLinearNetwork, self).__init__(name=name)
|
||||
|
||||
self._layers = []
|
||||
self._output_sizes = output_sizes
|
||||
for i, output_size in enumerate(self._output_sizes):
|
||||
layer = _GatedLinearLayer(
|
||||
output_size=output_size,
|
||||
context_dim=context_dim,
|
||||
update_fn=update_fn,
|
||||
inference_fn=inference_fn,
|
||||
init=init,
|
||||
hyp_w_init=hyp_w_init,
|
||||
hyp_b_init=hyp_b_init,
|
||||
dtype=dtype,
|
||||
name=name + "_layer_{}".format(i))
|
||||
self._layers.append(layer)
|
||||
self._name = name
|
||||
|
||||
@abc.abstractmethod
|
||||
def _add_bias(self, inputs):
|
||||
pass
|
||||
|
||||
def inference(self, inputs: Array, side_info: Array, *args,
|
||||
**kwargs) -> Array:
|
||||
"""GatedLinearNetwork inference."""
|
||||
predictions_per_layer = []
|
||||
predictions = inputs
|
||||
for layer in self._layers:
|
||||
predictions = self._add_bias(predictions)
|
||||
predictions = layer.inference(predictions, side_info, *args, **kwargs)
|
||||
predictions_per_layer.append(predictions)
|
||||
|
||||
return jnp.concatenate(predictions_per_layer, axis=0)
|
||||
|
||||
def update(self, inputs, side_info, target, learning_rate, *args, **kwargs):
|
||||
"""GatedLinearNetwork update."""
|
||||
all_params = []
|
||||
all_predictions = []
|
||||
all_losses = []
|
||||
predictions = inputs
|
||||
for layer in self._layers:
|
||||
predictions = self._add_bias(predictions)
|
||||
|
||||
# Note: This is correct because returned predictions are pre-update.
|
||||
params, predictions, log_loss = layer.update(predictions, side_info,
|
||||
target, learning_rate, *args,
|
||||
**kwargs)
|
||||
all_params.append(params)
|
||||
all_predictions.append(predictions)
|
||||
all_losses.append(log_loss)
|
||||
|
||||
new_params = dict(collections.ChainMap(*all_params))
|
||||
predictions = jnp.concatenate(all_predictions, axis=0)
|
||||
log_loss = jnp.concatenate(all_losses, axis=0)
|
||||
|
||||
return new_params, predictions, log_loss
|
||||
|
||||
@property
|
||||
def output_sizes(self):
|
||||
return self._output_sizes
|
||||
|
||||
@staticmethod
|
||||
def _compute_context(
|
||||
side_info: Array, # [side_info_size]
|
||||
hyperplanes: Array, # [context_dim, side_info_size]
|
||||
hyperplane_bias: Array, # [context_dim]
|
||||
) -> Array:
|
||||
# Index weights by side information.
|
||||
context_dim = hyperplane_bias.shape[0]
|
||||
proj = jnp.dot(hyperplanes, side_info)
|
||||
bits = (proj > hyperplane_bias).astype(jnp.int32)
|
||||
weight_index = jnp.sum(
|
||||
bits *
|
||||
jnp.array([2**i for i in range(context_dim)])) if context_dim else 0
|
||||
return weight_index
|
||||
|
||||
|
||||
class _GatedLinearLayer(LocalUpdateModule):
|
||||
"""A single layer of a Gated Linear Network."""
|
||||
|
||||
def __init__(self,
|
||||
output_size: int,
|
||||
context_dim: int,
|
||||
inference_fn: Callable[..., Array],
|
||||
update_fn: Callable[..., Array],
|
||||
init: Initializer,
|
||||
hyp_w_init: Optional[Initializer] = None,
|
||||
hyp_b_init: Optional[Initializer] = None,
|
||||
dtype: DType = jnp.float32,
|
||||
name: str = "gated_linear_layer"):
|
||||
"""Initialize a GatedLinearLayer."""
|
||||
super(_GatedLinearLayer, self).__init__(name=name)
|
||||
self._output_size = output_size
|
||||
self._context_dim = context_dim
|
||||
self._inference_fn = inference_fn
|
||||
self._update_fn = update_fn
|
||||
self._init = init
|
||||
self._hyp_w_init = hyp_w_init
|
||||
self._hyp_b_init = hyp_b_init
|
||||
self._dtype = dtype
|
||||
self._name = name
|
||||
|
||||
def _get_weights(self, input_size):
|
||||
"""Get (or initialize) weight parameters."""
|
||||
weights = hk.get_parameter(
|
||||
"weights",
|
||||
shape=(self._output_size, 2**self._context_dim, input_size),
|
||||
dtype=self._dtype,
|
||||
init=self._init,
|
||||
)
|
||||
|
||||
return weights
|
||||
|
||||
def _get_hyperplanes(self, side_info_size):
|
||||
"""Get (or initialize) hyperplane weights and bias."""
|
||||
|
||||
hyp_w_init = self._hyp_w_init or NormalizedRandomNormal(
|
||||
stddev=1., normalize_axis=1)
|
||||
hyperplanes = hk.get_state(
|
||||
"hyperplanes",
|
||||
shape=(self._output_size, self._context_dim, side_info_size),
|
||||
init=hyp_w_init)
|
||||
|
||||
hyp_b_init = self._hyp_b_init or hk.initializers.RandomNormal(stddev=0.05)
|
||||
hyperplane_bias = hk.get_state(
|
||||
"hyperplane_bias",
|
||||
shape=(self._output_size, self._context_dim),
|
||||
init=hyp_b_init)
|
||||
|
||||
return hyperplanes, hyperplane_bias
|
||||
|
||||
def inference(self, inputs: Array, side_info: Array, *args,
|
||||
**kwargs) -> Array:
|
||||
"""GatedLinearLayer inference."""
|
||||
# Initialize layer weights.
|
||||
weights = self._get_weights(inputs.shape[0])
|
||||
|
||||
# Initialize fixed random hyperplanes.
|
||||
side_info_size = side_info.shape[0]
|
||||
hyperplanes, hyperplane_bias = self._get_hyperplanes(side_info_size)
|
||||
|
||||
# Perform layer-wise inference by mapping along output_size (num_neurons).
|
||||
layer_inference = _layer_vmap(self._inference_fn)
|
||||
predictions = layer_inference(inputs, side_info, weights, hyperplanes,
|
||||
hyperplane_bias, *args, **kwargs)
|
||||
|
||||
return predictions
|
||||
|
||||
def update(self, inputs: Array, side_info: Array, target: Array,
|
||||
learning_rate: float, *args,
|
||||
**kwargs) -> Tuple[Array, Array, Array]:
|
||||
"""GatedLinearLayer update."""
|
||||
# Fetch layer weights.
|
||||
weights = self._get_weights(inputs.shape[0])
|
||||
|
||||
# Fetch fixed random hyperplanes.
|
||||
side_info_size = side_info.shape[0]
|
||||
hyperplanes, hyperplane_bias = self._get_hyperplanes(side_info_size)
|
||||
|
||||
# Perform layer-wise update by mapping along output_size (num_neurons).
|
||||
layer_update = _layer_vmap(self._update_fn)
|
||||
new_weights, predictions, log_loss = layer_update(inputs, side_info,
|
||||
weights, hyperplanes,
|
||||
hyperplane_bias, target,
|
||||
learning_rate, *args,
|
||||
**kwargs)
|
||||
|
||||
assert new_weights.shape == weights.shape
|
||||
params = {self.module_name: {"weights": new_weights}}
|
||||
return params, predictions, log_loss
|
||||
|
||||
@property
|
||||
def output_sizes(self):
|
||||
return self._output_size
|
||||
|
||||
|
||||
class Mutator(LocalUpdateModule):
|
||||
"""Abstract base class for GLN Mutators."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
network_factory: Callable[..., LocalUpdateModule],
|
||||
name: str,
|
||||
):
|
||||
super(Mutator, self).__init__(name=name)
|
||||
self._network = network_factory()
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def output_sizes(self):
|
||||
return self._network.output_sizes
|
||||
|
||||
|
||||
class LastNeuronAggregator(Mutator):
|
||||
"""Last neuron aggregator: network output is read from the last neuron."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
network_factory: Callable[..., LocalUpdateModule],
|
||||
name: str = "last_neuron",
|
||||
):
|
||||
super(LastNeuronAggregator, self).__init__(network_factory, name)
|
||||
if self._network.output_sizes[-1] != 1:
|
||||
raise ValueError(
|
||||
"LastNeuronAggregator requires the last GLN layer to have"
|
||||
" output_size = 1.")
|
||||
|
||||
def inference(self, *args, **kwargs) -> Array:
|
||||
predictions = self._network.inference(*args, **kwargs)
|
||||
return predictions[-1]
|
||||
|
||||
def update(self, *args, **kwargs) -> Tuple[Array, Array, Array]:
|
||||
params_t, predictions_tm1, loss_tm1 = self._network.update(*args, **kwargs)
|
||||
return params_t, predictions_tm1[-1], loss_tm1[-1]
|
||||
@@ -0,0 +1,107 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Bernoulli Gated Linear Network."""
|
||||
|
||||
from typing import List, Text, Tuple
|
||||
|
||||
import chex
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import rlax
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
from gated_linear_networks import base
|
||||
|
||||
tfp = tfp.experimental.substrates.jax
|
||||
tfd = tfp.distributions
|
||||
|
||||
Array = chex.Array
|
||||
|
||||
GLN_EPS = 0.01
|
||||
MAX_WEIGHT = 200.
|
||||
|
||||
|
||||
class GatedLinearNetwork(base.GatedLinearNetwork):
|
||||
"""Bernoulli Gated Linear Network."""
|
||||
|
||||
def __init__(self,
|
||||
output_sizes: List[int],
|
||||
context_dim: int,
|
||||
name: Text = "bernoulli_gln"):
|
||||
"""Initialize a Bernoulli GLN."""
|
||||
super(GatedLinearNetwork, self).__init__(
|
||||
output_sizes,
|
||||
context_dim,
|
||||
inference_fn=GatedLinearNetwork._inference_fn,
|
||||
update_fn=GatedLinearNetwork._update_fn,
|
||||
init=jnp.zeros,
|
||||
dtype=jnp.float32,
|
||||
name=name)
|
||||
|
||||
def _add_bias(self, inputs):
|
||||
return jnp.append(inputs, rlax.sigmoid(1.))
|
||||
|
||||
@staticmethod
|
||||
def _inference_fn(
|
||||
inputs: Array, # [input_size]
|
||||
side_info: Array, # [side_info_size]
|
||||
weights: Array, # [2**context_dim, input_size]
|
||||
hyperplanes: Array, # [context_dim, side_info_size]
|
||||
hyperplane_bias: Array, # [context_dim]
|
||||
) -> Array:
|
||||
"""Inference step for a single Beurnolli neuron."""
|
||||
|
||||
weight_index = GatedLinearNetwork._compute_context(side_info, hyperplanes,
|
||||
hyperplane_bias)
|
||||
used_weights = weights[weight_index]
|
||||
inputs = rlax.logit(jnp.clip(inputs, GLN_EPS, 1. - GLN_EPS))
|
||||
prediction = rlax.sigmoid(jnp.dot(used_weights, inputs))
|
||||
|
||||
return prediction
|
||||
|
||||
@staticmethod
|
||||
def _update_fn(
|
||||
inputs: Array, # [input_size]
|
||||
side_info: Array, # [side_info_size]
|
||||
weights: Array, # [2**context_dim, num_features]
|
||||
hyperplanes: Array, # [context_dim, side_info_size]
|
||||
hyperplane_bias: Array, # [context_dim]
|
||||
target: Array, # []
|
||||
learning_rate: float,
|
||||
) -> Tuple[Array, Array, Array]:
|
||||
"""Update step for a single Bernoulli neuron."""
|
||||
|
||||
def log_loss_fn(inputs, side_info, weights, hyperplanes, hyperplane_bias,
|
||||
target):
|
||||
"""Log loss for a single Bernoulli neuron."""
|
||||
prediction = GatedLinearNetwork._inference_fn(inputs, side_info, weights,
|
||||
hyperplanes,
|
||||
hyperplane_bias)
|
||||
prediction = jnp.clip(prediction, GLN_EPS, 1. - GLN_EPS)
|
||||
return rlax.log_loss(prediction, target), prediction
|
||||
|
||||
grad_log_loss = jax.value_and_grad(log_loss_fn, argnums=2, has_aux=True)
|
||||
((log_loss, prediction),
|
||||
dloss_dweights) = grad_log_loss(inputs, side_info, weights, hyperplanes,
|
||||
hyperplane_bias, target)
|
||||
|
||||
delta_weights = learning_rate * dloss_dweights
|
||||
new_weights = jnp.clip(weights - delta_weights, -MAX_WEIGHT, MAX_WEIGHT)
|
||||
return new_weights, prediction, log_loss
|
||||
|
||||
|
||||
class LastNeuronAggregator(base.LastNeuronAggregator):
|
||||
"""Bernoulli last neuron aggregator, implemented by the super class."""
|
||||
pass
|
||||
@@ -0,0 +1,215 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for `bernoulli.py`."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tree
|
||||
|
||||
from gated_linear_networks import bernoulli
|
||||
|
||||
|
||||
def _get_dataset(input_size, batch_size=None):
|
||||
"""Get mock dataset."""
|
||||
if batch_size:
|
||||
inputs = jnp.ones([batch_size, input_size])
|
||||
side_info = jnp.ones([batch_size, input_size])
|
||||
targets = jnp.ones([batch_size])
|
||||
else:
|
||||
inputs = jnp.ones([input_size])
|
||||
side_info = jnp.ones([input_size])
|
||||
targets = jnp.ones([])
|
||||
|
||||
return inputs, side_info, targets
|
||||
|
||||
|
||||
class GatedLinearNetworkTest(parameterized.TestCase):
|
||||
|
||||
# TODO(b/170843789): Factor out common test utilities.
|
||||
def setUp(self):
|
||||
super(GatedLinearNetworkTest, self).setUp()
|
||||
self._name = "test_network"
|
||||
self._rng = hk.PRNGSequence(jax.random.PRNGKey(42))
|
||||
|
||||
self._output_sizes = (4, 5, 6)
|
||||
self._context_dim = 2
|
||||
|
||||
def gln_factory():
|
||||
return bernoulli.GatedLinearNetwork(
|
||||
output_sizes=self._output_sizes,
|
||||
context_dim=self._context_dim,
|
||||
name=self._name)
|
||||
|
||||
def inference_fn(inputs, side_info):
|
||||
return gln_factory().inference(inputs, side_info)
|
||||
|
||||
def batch_inference_fn(inputs, side_info):
|
||||
return jax.vmap(inference_fn, in_axes=(0, 0))(inputs, side_info)
|
||||
|
||||
def update_fn(inputs, side_info, label, learning_rate):
|
||||
params, predictions, unused_loss = gln_factory().update(
|
||||
inputs, side_info, label, learning_rate)
|
||||
return predictions, params
|
||||
|
||||
def batch_update_fn(inputs, side_info, label, learning_rate):
|
||||
predictions, params = jax.vmap(
|
||||
update_fn, in_axes=(0, 0, 0, None))(inputs, side_info, label,
|
||||
learning_rate)
|
||||
avg_params = tree.map_structure(lambda x: jnp.mean(x, axis=0), params)
|
||||
return predictions, avg_params
|
||||
|
||||
# Haiku transform functions.
|
||||
self._init_fn, inference_fn_ = hk.without_apply_rng(
|
||||
hk.transform_with_state(inference_fn))
|
||||
self._batch_init_fn, batch_inference_fn_ = hk.without_apply_rng(
|
||||
hk.transform_with_state(batch_inference_fn))
|
||||
_, update_fn_ = hk.without_apply_rng(hk.transform_with_state(update_fn))
|
||||
_, batch_update_fn_ = hk.without_apply_rng(
|
||||
hk.transform_with_state(batch_update_fn))
|
||||
|
||||
self._inference_fn = jax.jit(inference_fn_)
|
||||
self._batch_inference_fn = jax.jit(batch_inference_fn_)
|
||||
self._update_fn = jax.jit(update_fn_)
|
||||
self._batch_update_fn = jax.jit(batch_update_fn_)
|
||||
|
||||
@parameterized.named_parameters(("Online mode", None), ("Batch mode", 3))
|
||||
def test_shapes(self, batch_size):
|
||||
"""Test shapes in online and batch regimes."""
|
||||
if batch_size is None:
|
||||
init_fn = self._init_fn
|
||||
inference_fn = self._inference_fn
|
||||
else:
|
||||
init_fn = self._batch_init_fn
|
||||
inference_fn = self._batch_inference_fn
|
||||
|
||||
input_size = 10
|
||||
inputs, side_info, _ = _get_dataset(input_size, batch_size)
|
||||
input_size = inputs.shape[-1]
|
||||
|
||||
# Initialize network.
|
||||
gln_params, gln_state = init_fn(next(self._rng), inputs, side_info)
|
||||
|
||||
# Test shapes of parameters layer-wise.
|
||||
layer_input_size = input_size
|
||||
for layer_idx, output_size in enumerate(self._output_sizes):
|
||||
name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx)
|
||||
weights = gln_params[name]["weights"]
|
||||
expected_shape = (output_size, 2**self._context_dim, layer_input_size + 1)
|
||||
self.assertEqual(weights.shape, expected_shape)
|
||||
|
||||
layer_input_size = output_size
|
||||
|
||||
# Test shape of output.
|
||||
output_size = sum(self._output_sizes)
|
||||
predictions, _ = inference_fn(gln_params, gln_state, inputs, side_info)
|
||||
expected_shape = (batch_size, output_size) if batch_size else (output_size,)
|
||||
self.assertEqual(predictions.shape, expected_shape)
|
||||
|
||||
@parameterized.named_parameters(("Online mode", None), ("Batch mode", 3))
|
||||
def test_update(self, batch_size):
|
||||
"""Test network updates in online and batch regimes."""
|
||||
if batch_size is None:
|
||||
init_fn = self._init_fn
|
||||
inference_fn = self._inference_fn
|
||||
update_fn = self._update_fn
|
||||
else:
|
||||
init_fn = self._batch_init_fn
|
||||
inference_fn = self._batch_inference_fn
|
||||
update_fn = self._batch_update_fn
|
||||
|
||||
input_size = 10
|
||||
inputs, side_info, targets = _get_dataset(input_size, batch_size)
|
||||
|
||||
# Initialize network.
|
||||
initial_params, gln_state = init_fn(next(self._rng), inputs, side_info)
|
||||
|
||||
# Initial predictions.
|
||||
initial_predictions, _ = inference_fn(initial_params, gln_state, inputs,
|
||||
side_info)
|
||||
|
||||
# Test that params remain valid after consecutive updates.
|
||||
gln_params = initial_params
|
||||
|
||||
for _ in range(3):
|
||||
(_, gln_params), gln_state = update_fn(
|
||||
gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-4)
|
||||
|
||||
# Check updated weights layer-wise.
|
||||
for layer_idx in range(len(self._output_sizes)):
|
||||
name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx)
|
||||
|
||||
initial_weights = initial_params[name]["weights"]
|
||||
new_weights = gln_params[name]["weights"]
|
||||
|
||||
# Shape consistency.
|
||||
self.assertEqual(new_weights.shape, initial_weights.shape)
|
||||
|
||||
# Check that different weights yield different predictions.
|
||||
new_predictions, _ = inference_fn(gln_params, gln_state, inputs,
|
||||
side_info)
|
||||
self.assertFalse(np.array_equal(new_predictions, initial_predictions))
|
||||
|
||||
def test_batch_consistency(self):
|
||||
"""Test consistency between online and batch updates."""
|
||||
|
||||
input_size = 10
|
||||
batch_size = 3
|
||||
inputs, side_info, targets = _get_dataset(input_size, batch_size)
|
||||
|
||||
# Initialize network.
|
||||
gln_params, gln_state = self._batch_init_fn(
|
||||
next(self._rng), inputs, side_info)
|
||||
test_layer = "{}/~/{}_layer_0".format(self._name, self._name)
|
||||
|
||||
for _ in range(10):
|
||||
|
||||
# Update on full batch.
|
||||
(expected_predictions, expected_params), _ = self._batch_update_fn(
|
||||
gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-3)
|
||||
|
||||
# Average updates across batch and check equivalence.
|
||||
accum_predictions = []
|
||||
accum_weights = []
|
||||
for inputs_, side_info_, targets_ in zip(inputs, side_info, targets):
|
||||
(predictions, params), _ = self._update_fn(
|
||||
gln_params,
|
||||
gln_state,
|
||||
inputs_,
|
||||
side_info_,
|
||||
targets_,
|
||||
learning_rate=1e-3)
|
||||
accum_predictions.append(predictions)
|
||||
accum_weights.append(params[test_layer]["weights"])
|
||||
|
||||
# Check prediction equivalence.
|
||||
actual_predictions = np.stack(accum_predictions, axis=0)
|
||||
np.testing.assert_array_almost_equal(actual_predictions,
|
||||
expected_predictions)
|
||||
|
||||
# Check weight equivalence.
|
||||
actual_weights = np.mean(np.stack(accum_weights, axis=0), axis=0)
|
||||
expected_weights = expected_params[test_layer]["weights"]
|
||||
np.testing.assert_array_almost_equal(actual_weights, expected_weights)
|
||||
|
||||
gln_params = expected_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
@@ -0,0 +1,144 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Online MNIST classification example with Bernoulli GLN."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import rlax
|
||||
|
||||
from gated_linear_networks import bernoulli
|
||||
from gated_linear_networks.examples import utils
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# Small example network, achieves ~95% test set accuracy =======================
|
||||
# Network parameters.
|
||||
flags.DEFINE_integer('num_layers', 2, '')
|
||||
flags.DEFINE_integer('neurons_per_layer', 100, '')
|
||||
flags.DEFINE_integer('context_dim', 1, '')
|
||||
|
||||
# Learning rate schedule.
|
||||
flags.DEFINE_float('max_lr', 0.003, '')
|
||||
flags.DEFINE_float('lr_constant', 1.0, '')
|
||||
flags.DEFINE_float('lr_decay', 0.1, '')
|
||||
|
||||
# Logging parameters.
|
||||
flags.DEFINE_integer('evaluate_every', 1000, '')
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
# Load MNIST dataset =========================================================
|
||||
mnist_data, info = utils.load_deskewed_mnist(
|
||||
name='mnist', batch_size=-1, with_info=True)
|
||||
num_classes = info.features['label'].num_classes
|
||||
|
||||
(train_images, train_labels) = (mnist_data['train']['image'],
|
||||
mnist_data['train']['label'])
|
||||
|
||||
(test_images, test_labels) = (mnist_data['test']['image'],
|
||||
mnist_data['test']['label'])
|
||||
|
||||
# Build a (binary) GLN classifier ============================================
|
||||
def network_factory():
|
||||
|
||||
def gln_factory():
|
||||
output_sizes = [FLAGS.neurons_per_layer] * FLAGS.num_layers + [1]
|
||||
return bernoulli.GatedLinearNetwork(
|
||||
output_sizes=output_sizes, context_dim=FLAGS.context_dim)
|
||||
|
||||
return bernoulli.LastNeuronAggregator(gln_factory)
|
||||
|
||||
def extract_features(image):
|
||||
mean, stddev = utils.MeanStdEstimator()(image)
|
||||
standardized_img = (image - mean) / (stddev + 1.)
|
||||
inputs = rlax.sigmoid(standardized_img)
|
||||
side_info = standardized_img
|
||||
return inputs, side_info
|
||||
|
||||
def inference_fn(image, *args, **kwargs):
|
||||
inputs, side_info = extract_features(image)
|
||||
return network_factory().inference(inputs, side_info, *args, **kwargs)
|
||||
|
||||
def update_fn(image, *args, **kwargs):
|
||||
inputs, side_info = extract_features(image)
|
||||
return network_factory().update(inputs, side_info, *args, **kwargs)
|
||||
|
||||
init_, inference_ = hk.without_apply_rng(
|
||||
hk.transform_with_state(inference_fn))
|
||||
_, update_ = hk.without_apply_rng(hk.transform_with_state(update_fn))
|
||||
|
||||
# Map along class dimension to create a one-vs-all classifier ================
|
||||
@jax.jit
|
||||
def init(dummy_image, key):
|
||||
"""One-vs-all classifier init fn."""
|
||||
dummy_images = jnp.stack([dummy_image] * num_classes, axis=0)
|
||||
keys = jax.random.split(key, num_classes)
|
||||
return jax.vmap(init_, in_axes=(0, 0))(keys, dummy_images)
|
||||
|
||||
@jax.jit
|
||||
def accuracy(params, state, image, label):
|
||||
"""One-vs-all classifier inference fn."""
|
||||
fn = jax.vmap(inference_, in_axes=(0, 0, None))
|
||||
predictions, unused_state = fn(params, state, image)
|
||||
return (jnp.argmax(predictions) == label).astype(jnp.float32)
|
||||
|
||||
@jax.jit
|
||||
def update(params, state, step, image, label):
|
||||
"""One-vs-all classifier update fn."""
|
||||
|
||||
# Learning rate schedules.
|
||||
learning_rate = jnp.minimum(
|
||||
FLAGS.max_lr, FLAGS.lr_constant / (1. + FLAGS.lr_decay * step))
|
||||
|
||||
# Update weights and report log-loss.
|
||||
targets = hk.one_hot(jnp.asarray(label), num_classes)
|
||||
|
||||
fn = jax.vmap(update_, in_axes=(0, 0, None, 0, None))
|
||||
out = fn(params, state, image, targets, learning_rate)
|
||||
(params, unused_predictions, log_loss), state = out
|
||||
return (jnp.mean(log_loss), params), state
|
||||
|
||||
# Train on train split =======================================================
|
||||
dummy_image = train_images[0]
|
||||
params, state = init(dummy_image, jax.random.PRNGKey(42))
|
||||
|
||||
for step, (image, label) in enumerate(zip(train_images, train_labels), 1):
|
||||
(unused_loss, params), state = update(
|
||||
params,
|
||||
state,
|
||||
step,
|
||||
image,
|
||||
label,
|
||||
)
|
||||
|
||||
# Evaluate on test split ===================================================
|
||||
if not step % FLAGS.evaluate_every:
|
||||
batch_accuracy = jax.vmap(accuracy, in_axes=(None, None, 0, 0))
|
||||
accuracies = batch_accuracy(params, state, test_images, test_labels)
|
||||
total_accuracy = float(jnp.mean(accuracies))
|
||||
|
||||
# Report statistics.
|
||||
print({
|
||||
'step': step,
|
||||
'accuracy': float(total_accuracy),
|
||||
})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,97 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Haiku modules for feature processing."""
|
||||
|
||||
import copy
|
||||
from typing import Tuple
|
||||
|
||||
import chex
|
||||
import haiku as hk
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from scipy.ndimage import interpolation
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
Array = chex.Array
|
||||
|
||||
|
||||
def _moments(image):
|
||||
"""Compute the first and second moments of a given image."""
|
||||
c0, c1 = np.mgrid[:image.shape[0], :image.shape[1]]
|
||||
total_image = np.sum(image)
|
||||
m0 = np.sum(c0 * image) / total_image
|
||||
m1 = np.sum(c1 * image) / total_image
|
||||
m00 = np.sum((c0 - m0)**2 * image) / total_image
|
||||
m11 = np.sum((c1 - m1)**2 * image) / total_image
|
||||
m01 = np.sum((c0 - m0) * (c1 - m1) * image) / total_image
|
||||
mu_vector = np.array([m0, m1])
|
||||
covariance_matrix = np.array([[m00, m01], [m01, m11]])
|
||||
return mu_vector, covariance_matrix
|
||||
|
||||
|
||||
def _deskew(image):
|
||||
"""Image deskew."""
|
||||
c, v = _moments(image)
|
||||
alpha = v[0, 1] / v[0, 0]
|
||||
affine = np.array([[1, 0], [alpha, 1]])
|
||||
ocenter = np.array(image.shape) / 2.0
|
||||
offset = c - np.dot(affine, ocenter)
|
||||
return interpolation.affine_transform(image, affine, offset=offset)
|
||||
|
||||
|
||||
def _deskew_dataset(dataset):
|
||||
"""Dataset deskew."""
|
||||
deskewed = copy.deepcopy(dataset)
|
||||
for k, before in dataset.items():
|
||||
images = before["image"]
|
||||
num_images = images.shape[0]
|
||||
after = np.stack([_deskew(i) for i in np.squeeze(images, axis=-1)], axis=0)
|
||||
deskewed[k]["image"] = np.reshape(after, (num_images, -1))
|
||||
return deskewed
|
||||
|
||||
|
||||
def load_deskewed_mnist(*a, **k):
|
||||
"""Returns deskewed MNIST numpy dataset."""
|
||||
mnist_data, info = tfds.load(*a, **k)
|
||||
mnist_data = tfds.as_numpy(mnist_data)
|
||||
deskewed_data = _deskew_dataset(mnist_data)
|
||||
return deskewed_data, info
|
||||
|
||||
|
||||
class MeanStdEstimator(hk.Module):
|
||||
"""Online mean and standard deviation estimator using Welford's algorithm."""
|
||||
|
||||
def __call__(self, sample: jnp.DeviceArray) -> Tuple[Array, Array]:
|
||||
if len(sample.shape) > 1:
|
||||
raise ValueError("sample must be a rank 0 or 1 DeviceArray.")
|
||||
|
||||
count = hk.get_state("count", shape=(), dtype=jnp.int32, init=jnp.zeros)
|
||||
mean = hk.get_state(
|
||||
"mean", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros)
|
||||
m2 = hk.get_state(
|
||||
"m2", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros)
|
||||
|
||||
count += 1
|
||||
delta = sample - mean
|
||||
mean += delta / count
|
||||
delta_2 = sample - mean
|
||||
m2 += delta * delta_2
|
||||
|
||||
hk.set_state("count", count)
|
||||
hk.set_state("mean", mean)
|
||||
hk.set_state("m2", m2)
|
||||
|
||||
stddev = jnp.sqrt(m2 / count)
|
||||
return mean, stddev
|
||||
@@ -0,0 +1,52 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for `utils.py`."""
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from gated_linear_networks.examples import utils
|
||||
|
||||
|
||||
class MeanStdEstimator(absltest.TestCase):
|
||||
|
||||
def test_statistics(self):
|
||||
num_features = 100
|
||||
feature_size = 3
|
||||
samples = np.random.normal(
|
||||
loc=5., scale=2., size=(num_features, feature_size))
|
||||
true_mean = np.mean(samples, axis=0)
|
||||
true_std = np.std(samples, axis=0)
|
||||
|
||||
def tick_(sample):
|
||||
return utils.MeanStdEstimator()(sample)
|
||||
|
||||
init_fn, apply_fn = hk.without_apply_rng(hk.transform_with_state(tick_))
|
||||
tick = jax.jit(apply_fn)
|
||||
|
||||
params, state = init_fn(rng=None, sample=samples[0])
|
||||
|
||||
for sample in samples:
|
||||
(mean, std), state = tick(params, state, sample)
|
||||
|
||||
np.testing.assert_array_almost_equal(mean, true_mean, decimal=5)
|
||||
np.testing.assert_array_almost_equal(std, true_std, decimal=5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
@@ -0,0 +1,197 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Gaussian Gated Linear Network."""
|
||||
|
||||
from typing import Callable, List, Text, Tuple
|
||||
|
||||
import chex
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
from gated_linear_networks import base
|
||||
|
||||
tfp = tfp.experimental.substrates.jax
|
||||
tfd = tfp.distributions
|
||||
|
||||
Array = chex.Array
|
||||
|
||||
MIN_SIGMA_SQ_AGGREGATOR = 0.5
|
||||
MAX_SIGMA_SQ = 1e5
|
||||
MAX_WEIGHT = 1e3
|
||||
MIN_WEIGHT = -1e3
|
||||
|
||||
|
||||
def _unpack_inputs(inputs: Array) -> Tuple[Array, Array]:
|
||||
inputs = jnp.atleast_2d(inputs)
|
||||
chex.assert_rank(inputs, 2)
|
||||
(mu, sigma_sq) = [jnp.squeeze(x, 1) for x in jnp.hsplit(inputs, 2)]
|
||||
return mu, sigma_sq
|
||||
|
||||
|
||||
def _pack_inputs(mu: Array, sigma_sq: Array) -> Array:
|
||||
mu = jnp.atleast_1d(mu)
|
||||
sigma_sq = jnp.atleast_1d(sigma_sq)
|
||||
chex.assert_rank([mu, sigma_sq], 1)
|
||||
return jnp.vstack([mu, sigma_sq]).T
|
||||
|
||||
|
||||
class GatedLinearNetwork(base.GatedLinearNetwork):
|
||||
"""Gaussian Gated Linear Network."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_sizes: List[int],
|
||||
context_dim: int,
|
||||
bias_len: int = 3,
|
||||
bias_max_mu: float = 1.,
|
||||
bias_sigma_sq: float = 1.,
|
||||
name: Text = "gaussian_gln"):
|
||||
"""Initialize a Gaussian GLN."""
|
||||
super(GatedLinearNetwork, self).__init__(
|
||||
output_sizes,
|
||||
context_dim,
|
||||
inference_fn=GatedLinearNetwork._inference_fn,
|
||||
update_fn=GatedLinearNetwork._update_fn,
|
||||
init=base.ShapeScaledConstant(),
|
||||
dtype=jnp.float64,
|
||||
name=name)
|
||||
|
||||
self._bias_len = bias_len
|
||||
self._bias_max_mu = bias_max_mu
|
||||
self._bias_sigma_sq = bias_sigma_sq
|
||||
|
||||
def _add_bias(self, inputs):
|
||||
mu = jnp.linspace(-1. * self._bias_max_mu, self._bias_max_mu,
|
||||
self._bias_len)
|
||||
sigma_sq = self._bias_sigma_sq * jnp.ones_like(mu)
|
||||
bias = _pack_inputs(mu, sigma_sq)
|
||||
return jnp.concatenate([inputs, bias], axis=0)
|
||||
|
||||
@staticmethod
|
||||
def _inference_fn(
|
||||
inputs: Array, # [input_size, 2]
|
||||
side_info: Array, # [side_info_size]
|
||||
weights: Array, # [2**context_dim, input_size]
|
||||
hyperplanes: Array, # [context_dim, side_info_size]
|
||||
hyperplane_bias: Array, # [context_dim]
|
||||
min_sigma_sq: float,
|
||||
) -> Array:
|
||||
"""Inference step for a single Gaussian neuron."""
|
||||
|
||||
mu_in, sigma_sq_in = _unpack_inputs(inputs)
|
||||
weight_index = GatedLinearNetwork._compute_context(side_info, hyperplanes,
|
||||
hyperplane_bias)
|
||||
used_weights = weights[weight_index]
|
||||
|
||||
# This projection operation is differentiable and affects the gradients.
|
||||
used_weights = GatedLinearNetwork._project_weights(inputs, used_weights,
|
||||
min_sigma_sq)
|
||||
|
||||
sigma_sq_out = 1. / jnp.sum(used_weights / sigma_sq_in)
|
||||
mu_out = sigma_sq_out * jnp.sum((used_weights * mu_in) / sigma_sq_in)
|
||||
prediction = jnp.hstack((mu_out, sigma_sq_out))
|
||||
return prediction
|
||||
|
||||
@staticmethod
|
||||
def _project_weights(inputs: Array, # [input_size]
|
||||
weights: Array, # [2**context_dim, num_features]
|
||||
min_sigma_sq: float) -> Array:
|
||||
"""Implements hard projection."""
|
||||
|
||||
# This projection should be performed before the sigma related ones.
|
||||
weights = jnp.minimum(jnp.maximum(MIN_WEIGHT, weights), MAX_WEIGHT)
|
||||
_, sigma_sq_in = _unpack_inputs(inputs)
|
||||
|
||||
lambda_in = 1. / sigma_sq_in
|
||||
sigma_sq_out = 1. / weights.dot(lambda_in)
|
||||
|
||||
# If w.dot(x) < U, linearly project w such that w.dot(x) = U.
|
||||
weights = jnp.where(
|
||||
sigma_sq_out < min_sigma_sq, weights - lambda_in *
|
||||
(1. / sigma_sq_out - 1. / min_sigma_sq) / jnp.sum(lambda_in**2),
|
||||
weights)
|
||||
|
||||
# If w.dot(x) > U, linearly project w such that w.dot(x) = U.
|
||||
weights = jnp.where(
|
||||
sigma_sq_out > MAX_SIGMA_SQ, weights - lambda_in *
|
||||
(1. / sigma_sq_out - 1. / MAX_SIGMA_SQ) / jnp.sum(lambda_in**2),
|
||||
weights)
|
||||
|
||||
return weights
|
||||
|
||||
@staticmethod
|
||||
def _update_fn(
|
||||
inputs: Array, # [input_size]
|
||||
side_info: Array, # [side_info_size]
|
||||
weights: Array, # [2**context_dim, num_features]
|
||||
hyperplanes: Array, # [context_dim, side_info_size]
|
||||
hyperplane_bias: Array, # [context_dim]
|
||||
target: Array, # []
|
||||
learning_rate: float,
|
||||
min_sigma_sq: float, # needed for inference (weight projection)
|
||||
) -> Tuple[Array, Array, Array]:
|
||||
"""Update step for a single Gaussian neuron."""
|
||||
|
||||
def log_loss_fn(inputs, side_info, weights, hyperplanes, hyperplane_bias,
|
||||
target):
|
||||
"""Log loss for a single Gaussian neuron."""
|
||||
prediction = GatedLinearNetwork._inference_fn(inputs, side_info, weights,
|
||||
hyperplanes,
|
||||
hyperplane_bias,
|
||||
min_sigma_sq)
|
||||
mu, sigma_sq = prediction.T
|
||||
loss = -tfd.Normal(mu, jnp.sqrt(sigma_sq)).log_prob(target)
|
||||
return loss, prediction
|
||||
|
||||
grad_log_loss = jax.value_and_grad(log_loss_fn, argnums=2, has_aux=True)
|
||||
(log_loss,
|
||||
prediction), dloss_dweights = grad_log_loss(inputs, side_info, weights,
|
||||
hyperplanes, hyperplane_bias,
|
||||
target)
|
||||
|
||||
delta_weights = learning_rate * dloss_dweights
|
||||
return weights - delta_weights, prediction, log_loss
|
||||
|
||||
|
||||
class ConstantInputSigma(base.Mutator):
|
||||
"""Input pre-processing by concatenating a constant sigma^2."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
network_factory: Callable[..., GatedLinearNetwork],
|
||||
input_sigma_sq: float,
|
||||
name: Text = "constant_input_sigma",
|
||||
):
|
||||
super(ConstantInputSigma, self).__init__(network_factory, name)
|
||||
self._input_sigma_sq = input_sigma_sq
|
||||
|
||||
def inference(self, inputs, *args, **kwargs):
|
||||
"""ConstantInputSigma inference."""
|
||||
chex.assert_rank(inputs, 1)
|
||||
sigma_sq = self._input_sigma_sq * jnp.ones_like(inputs)
|
||||
return self._network.inference(_pack_inputs(inputs, sigma_sq), *args,
|
||||
**kwargs)
|
||||
|
||||
def update(self, inputs, *args, **kwargs):
|
||||
"""ConstantInputSigma update."""
|
||||
chex.assert_rank(inputs, 1)
|
||||
sigma_sq = self._input_sigma_sq * jnp.ones_like(inputs)
|
||||
return self._network.update(_pack_inputs(inputs, sigma_sq), *args, **kwargs)
|
||||
|
||||
|
||||
class LastNeuronAggregator(base.LastNeuronAggregator):
|
||||
"""Gaussian last neuron aggregator, implemented by the super class."""
|
||||
pass
|
||||
@@ -0,0 +1,233 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for `gaussian.py`."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tree
|
||||
|
||||
from gated_linear_networks import gaussian
|
||||
|
||||
|
||||
def _get_dataset(input_size, batch_size=None):
|
||||
"""Get mock dataset."""
|
||||
if batch_size:
|
||||
inputs = jnp.ones([batch_size, input_size, 2])
|
||||
side_info = jnp.ones([batch_size, input_size])
|
||||
targets = 0.8 * jnp.ones([batch_size])
|
||||
else:
|
||||
inputs = jnp.ones([input_size, 2])
|
||||
side_info = jnp.ones([input_size])
|
||||
targets = jnp.ones([])
|
||||
|
||||
return inputs, side_info, targets
|
||||
|
||||
|
||||
class UtilsTest(absltest.TestCase):
|
||||
|
||||
def test_packing_identity(self):
|
||||
mu = jnp.array([1., 2., 3., 4., 5.])
|
||||
sigma_sq = jnp.array([6., 7., 8., 9., 10.])
|
||||
|
||||
mu_2, sigma_sq_2 = gaussian._unpack_inputs(
|
||||
gaussian._pack_inputs(mu, sigma_sq))
|
||||
|
||||
np.testing.assert_array_equal(mu, mu_2)
|
||||
np.testing.assert_array_equal(sigma_sq, sigma_sq_2)
|
||||
|
||||
|
||||
class GatedLinearNetworkTest(parameterized.TestCase):
|
||||
|
||||
# TODO(b/170843789): Factor out common test utilities.
|
||||
def setUp(self):
|
||||
super(GatedLinearNetworkTest, self).setUp()
|
||||
self._name = "test_network"
|
||||
self._rng = hk.PRNGSequence(jax.random.PRNGKey(42))
|
||||
|
||||
self._output_sizes = (4, 5, 6)
|
||||
self._context_dim = 2
|
||||
self._bias_len = 3
|
||||
|
||||
def gln_factory():
|
||||
return gaussian.GatedLinearNetwork(
|
||||
output_sizes=self._output_sizes,
|
||||
context_dim=self._context_dim,
|
||||
bias_len=self._bias_len,
|
||||
name=self._name,
|
||||
)
|
||||
|
||||
def inference_fn(inputs, side_info):
|
||||
return gln_factory().inference(inputs, side_info, 0.5)
|
||||
|
||||
def batch_inference_fn(inputs, side_info):
|
||||
return jax.vmap(inference_fn, in_axes=(0, 0))(inputs, side_info)
|
||||
|
||||
def update_fn(inputs, side_info, label, learning_rate):
|
||||
params, predictions, unused_loss = gln_factory().update(
|
||||
inputs, side_info, label, learning_rate, 0.5)
|
||||
return predictions, params
|
||||
|
||||
def batch_update_fn(inputs, side_info, label, learning_rate):
|
||||
predictions, params = jax.vmap(
|
||||
update_fn, in_axes=(0, 0, 0, None))(
|
||||
inputs,
|
||||
side_info,
|
||||
label,
|
||||
learning_rate)
|
||||
avg_params = tree.map_structure(lambda x: jnp.mean(x, axis=0), params)
|
||||
return predictions, avg_params
|
||||
|
||||
# Haiku transform functions.
|
||||
self._init_fn, inference_fn_ = hk.without_apply_rng(
|
||||
hk.transform_with_state(inference_fn))
|
||||
self._batch_init_fn, batch_inference_fn_ = hk.without_apply_rng(
|
||||
hk.transform_with_state(batch_inference_fn))
|
||||
_, update_fn_ = hk.without_apply_rng(hk.transform_with_state(update_fn))
|
||||
_, batch_update_fn_ = hk.without_apply_rng(
|
||||
hk.transform_with_state(batch_update_fn))
|
||||
|
||||
self._inference_fn = jax.jit(inference_fn_)
|
||||
self._batch_inference_fn = jax.jit(batch_inference_fn_)
|
||||
self._update_fn = jax.jit(update_fn_)
|
||||
self._batch_update_fn = jax.jit(batch_update_fn_)
|
||||
|
||||
@parameterized.named_parameters(("Online mode", None), ("Batch mode", 3))
|
||||
def test_shapes(self, batch_size):
|
||||
"""Test shapes in online and batch regimes."""
|
||||
if batch_size is None:
|
||||
init_fn = self._init_fn
|
||||
inference_fn = self._inference_fn
|
||||
else:
|
||||
init_fn = self._batch_init_fn
|
||||
inference_fn = self._batch_inference_fn
|
||||
|
||||
input_size = 10
|
||||
inputs, side_info, _ = _get_dataset(input_size, batch_size)
|
||||
|
||||
# Initialize network.
|
||||
gln_params, gln_state = init_fn(next(self._rng), inputs, side_info)
|
||||
|
||||
# Test shapes of parameters layer-wise.
|
||||
layer_input_size = input_size
|
||||
for layer_idx, output_size in enumerate(self._output_sizes):
|
||||
name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx)
|
||||
weights = gln_params[name]["weights"]
|
||||
expected_shape = (output_size, 2**self._context_dim,
|
||||
layer_input_size + self._bias_len)
|
||||
self.assertEqual(weights.shape, expected_shape)
|
||||
|
||||
layer_input_size = output_size
|
||||
|
||||
# Test shape of output.
|
||||
output_size = sum(self._output_sizes)
|
||||
predictions, _ = inference_fn(gln_params, gln_state, inputs, side_info)
|
||||
expected_shape = (batch_size, output_size,
|
||||
2) if batch_size else (output_size, 2)
|
||||
self.assertEqual(predictions.shape, expected_shape)
|
||||
|
||||
@parameterized.named_parameters(("Online mode", None), ("Batch mode", 3))
|
||||
def test_update(self, batch_size):
|
||||
"""Test network updates in online and batch regimes."""
|
||||
if batch_size is None:
|
||||
init_fn = self._init_fn
|
||||
inference_fn = self._inference_fn
|
||||
update_fn = self._update_fn
|
||||
else:
|
||||
init_fn = self._batch_init_fn
|
||||
inference_fn = self._batch_inference_fn
|
||||
update_fn = self._batch_update_fn
|
||||
|
||||
inputs, side_info, targets = _get_dataset(10, batch_size)
|
||||
|
||||
# Initialize network.
|
||||
initial_params, gln_state = init_fn(next(self._rng), inputs, side_info)
|
||||
|
||||
# Initial predictions.
|
||||
initial_predictions, _ = inference_fn(initial_params, gln_state, inputs,
|
||||
side_info)
|
||||
|
||||
# Test that params remain valid after consecutive updates.
|
||||
gln_params = initial_params
|
||||
|
||||
for _ in range(3):
|
||||
(_, gln_params), _ = update_fn(
|
||||
gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-4)
|
||||
|
||||
# Check updated weights layer-wise.
|
||||
for layer_idx in range(len(self._output_sizes)):
|
||||
name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx)
|
||||
|
||||
initial_weights = initial_params[name]["weights"]
|
||||
new_weights = gln_params[name]["weights"]
|
||||
|
||||
# Shape consistency.
|
||||
self.assertEqual(new_weights.shape, initial_weights.shape)
|
||||
|
||||
# Check that different weights yield different predictions.
|
||||
new_predictions, _ = inference_fn(gln_params, gln_state, inputs,
|
||||
side_info)
|
||||
self.assertFalse(np.array_equal(new_predictions, initial_predictions))
|
||||
|
||||
def test_batch_consistency(self):
|
||||
"""Test consistency between online and batch updates."""
|
||||
|
||||
batch_size = 3
|
||||
inputs, side_info, targets = _get_dataset(10, batch_size)
|
||||
|
||||
# Initialize network.
|
||||
gln_params, gln_state = self._batch_init_fn(
|
||||
next(self._rng), inputs, side_info)
|
||||
test_layer = "{}/~/{}_layer_0".format(self._name, self._name)
|
||||
|
||||
for _ in range(10):
|
||||
|
||||
# Update on full batch.
|
||||
(expected_predictions, expected_params), _ = self._batch_update_fn(
|
||||
gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-3)
|
||||
|
||||
# Average updates across batch and check equivalence.
|
||||
accum_predictions = []
|
||||
accum_weights = []
|
||||
for inputs_, side_info_, targets_ in zip(inputs, side_info, targets):
|
||||
(predictions, params), _ = self._update_fn(
|
||||
gln_params,
|
||||
gln_state,
|
||||
inputs_,
|
||||
side_info_,
|
||||
targets_,
|
||||
learning_rate=1e-3)
|
||||
accum_predictions.append(predictions)
|
||||
accum_weights.append(params[test_layer]["weights"])
|
||||
|
||||
# Check prediction equivalence.
|
||||
actual_predictions = np.stack(accum_predictions, axis=0)
|
||||
np.testing.assert_array_almost_equal(actual_predictions,
|
||||
expected_predictions)
|
||||
|
||||
# Check weight equivalence.
|
||||
actual_weights = np.mean(np.stack(accum_weights, axis=0), axis=0)
|
||||
expected_weights = expected_params[test_layer]["weights"]
|
||||
np.testing.assert_array_almost_equal(actual_weights, expected_weights)
|
||||
|
||||
gln_params = expected_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
@@ -0,0 +1,52 @@
|
||||
absl-py==0.10.0
|
||||
aiohttp==3.6.2
|
||||
astunparse==1.6.3
|
||||
async-timeout==3.0.1
|
||||
attrs==20.2.0
|
||||
cachetools==4.1.1
|
||||
certifi==2020.6.20
|
||||
chardet==3.0.4
|
||||
chex==0.0.2
|
||||
cloudpickle==1.6.0
|
||||
decorator==4.4.2
|
||||
dill==0.3.2
|
||||
dm-env==1.2
|
||||
dm-haiku==0.0.2
|
||||
dm-tree==0.1.5
|
||||
future==0.18.2
|
||||
gast==0.3.3
|
||||
google-auth==1.22.0
|
||||
google-auth-oauthlib==0.4.1
|
||||
google-pasta==0.2.0
|
||||
googleapis-common-protos==1.52.0
|
||||
grpcio==1.32.0
|
||||
h5py==2.10.0
|
||||
idna==2.10
|
||||
jax==0.2.0
|
||||
jaxlib==0.1.55
|
||||
Keras-Preprocessing==1.1.2
|
||||
Markdown==3.2.2
|
||||
multidict==4.7.6
|
||||
numpy==1.18.5
|
||||
oauthlib==3.1.0
|
||||
opt-einsum==3.3.0
|
||||
promise==2.3
|
||||
protobuf==3.13.0
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.8
|
||||
requests==2.24.0
|
||||
requests-oauthlib==1.3.0
|
||||
rlax==0.0.2
|
||||
rsa==4.6
|
||||
scipy==1.5.2
|
||||
six==1.15.0
|
||||
tensorboard==2.3.0
|
||||
tensorboard-plugin-wit==1.7.0
|
||||
tensorflow==2.3.1
|
||||
tensorflow-datasets==3.2.1
|
||||
tensorflow-estimator==2.3.0
|
||||
tensorflow-metadata==0.24.0
|
||||
tensorflow-probability==0.11.1
|
||||
termcolor==1.1.0
|
||||
toolz==0.11.1
|
||||
tqdm==4.50.0
|
||||
@@ -0,0 +1,26 @@
|
||||
#!/bin/sh
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
set -e
|
||||
|
||||
python3 -m venv gln_venv
|
||||
source gln_venv/bin/activate
|
||||
pip3 install --upgrade setuptools wheel
|
||||
pip3 install -r gated_linear_networks/requirements.txt
|
||||
|
||||
# Run MNIST example with Bernoulli GLN
|
||||
python3 -m gated_linear_networks.examples.bernoulli_mnist \
|
||||
--num_layers=2 \
|
||||
--neurons_per_layer=100 \
|
||||
--context_dim=1
|
||||
Reference in New Issue
Block a user