diff --git a/README.md b/README.md index 6a00fe5..57351b4 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ https://deepmind.com/research/publications/ ## Projects +* [Gated Linear Networks](gated_linear_networks), NeurIPS 2020 * [Value-driven Hindsight Modelling](himo), NeurIPS 2020 * [Targeted free energy estimation via learned mappings](learned_free_energy_estimation), Journal of Chemical Physics 2020 * [Learning to Simulate Complex Physics with Graph Networks](learning_to_simulate), ICML 2020 diff --git a/gated_linear_networks/README.md b/gated_linear_networks/README.md new file mode 100644 index 0000000..30a8223 --- /dev/null +++ b/gated_linear_networks/README.md @@ -0,0 +1,44 @@ +# Gated Linear Networks + +Gated Linear Networks (GLNs) are a family of backpropation-free neural networks. +Each neuron in a GLN predicts the target density (or probability mass) based on +the outputs of the previous layer and is trained under a logarthmic loss. + +## GLN variants + +Neurons have probabilistic "activation functions". Implementations are provided +for the following distributions: + +- Gaussian, for regression. + +- Bernoulli, for binary classification and multi-class classification using a + one-vs-all scheme. + +## Examples + +Usage examples are provided in [`examples`](examples). + +## Implementation details + +### Constraint satisfaction + +Because each neuron implements a probability density/mass function we need to +ensure that they are well defined. For example, the scale parameter for a +Gaussian density needs to be positive. We implement these constraints using +linear projections and clipping. + +### Aggregation + +Because each neuron predicts the target, we can use any neuron output as the +"network output", and are not bound to the last layer. Typically last layer +neuron(s) are the best predictors, but they might take longer to converge in +theory. In this implementation, we use a single neuron at the last layer, which +then forms the network output. + +There are alternative ways of aggregating, e.g. see Switching Aggregation in +Appendix D of *Gaussian Gated Linear Networks* (link: +https://arxiv.org/pdf/2006.05964.pdf). + +## References + +Coming soon. diff --git a/gated_linear_networks/base.py b/gated_linear_networks/base.py new file mode 100644 index 0000000..9342e84 --- /dev/null +++ b/gated_linear_networks/base.py @@ -0,0 +1,347 @@ +# Lint as: python3 +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base classes for Gated Linear Networks.""" + +import abc +import collections +import functools +import inspect +from typing import Any, Callable, Optional, Sequence, Tuple + +import chex +import haiku as hk +import jax +import jax.numpy as jnp + + +Array = chex.Array +DType = Any +Initializer = hk.initializers.Initializer +Shape = Sequence[int] + +EPS = 1e-12 +MIN_ALPHA = 1e-5 + + +def _l2_normalize(x: Array, axis: int) -> Array: + return x / jnp.sqrt(jnp.maximum(jnp.sum(x**2, axis, keepdims=True), EPS)) + + +def _wrapped_fn_argnames(fun): + """Returns list of argnames of a (possibly wrapped) function.""" + return tuple(inspect.signature(fun).parameters) + + +def _vmap(fun, in_axes=0, out_axes=0, parameters=None): + """JAX vmap with human-friendly axes.""" + + def _axes(fun, d): + """Maps dict {kwarg_i, : val_i} to [None, ..., val_i, ..., None].""" + argnames = _wrapped_fn_argnames(fun) if not parameters else parameters + for key in d: + if key not in argnames: + raise ValueError(f"{key} is not a valid axis.") + return tuple(d.get(key, None) for key in argnames) + + in_axes = _axes(fun, in_axes) if isinstance(in_axes, dict) else in_axes + return jax.vmap(fun, in_axes, out_axes) + +# Map a neuron-level function across a layer. +_layer_vmap = functools.partial( + _vmap, + in_axes=({ + "weights": 0, + "hyperplanes": 0, + "hyperplane_bias": 0, + })) + + +class NormalizedRandomNormal(hk.initializers.RandomNormal): + """Random normal initializer with l2-normalization.""" + + def __init__(self, + stddev: float = 1., + mean: float = 0., + normalize_axis: int = 0): + super(NormalizedRandomNormal, self).__init__(stddev, mean) + self._normalize_axis = normalize_axis + + def __call__(self, shape: Shape, dtype: DType) -> Array: + if self._normalize_axis >= len(shape): + raise ValueError("Cannot normalize axis {} for ndim = {}.".format( + self._normalize_axis, len(shape))) + weights = super(NormalizedRandomNormal, self).__call__(shape, dtype) + return _l2_normalize(weights, axis=self._normalize_axis) + + +class ShapeScaledConstant(hk.initializers.Initializer): + """Initializes with a constant dependent on last dimension of input shape.""" + + def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray: + constant = 1. / shape[-1] + return jnp.broadcast_to(constant, shape).astype(dtype) + + +class LocalUpdateModule(hk.Module): + """Abstract base class for GLN variants and utils.""" + + def __init__(self, name: Optional[str] = None): + if hasattr(self, "__call__"): + raise ValueError("Do not implement `__call__` for a LocalUpdateModule." + + " Implement `inference` and `update` instead.") + super(LocalUpdateModule, self).__init__(name) + + @abc.abstractmethod + def inference(self, *args, **kwargs): + """Module inference step.""" + + @abc.abstractmethod + def update(self, *args, **kwargs): + """Module update step.""" + + @property + @abc.abstractmethod + def output_sizes(self) -> Shape: + """Returns network output sizes.""" + + +class GatedLinearNetwork(LocalUpdateModule): + """Abstract base class for a multi-layer Gated Linear Network.""" + + def __init__(self, + output_sizes: Shape, + context_dim: int, + inference_fn: Callable[..., Array], + update_fn: Callable[..., Array], + init: Initializer, + hyp_w_init: Optional[Initializer] = None, + hyp_b_init: Optional[Initializer] = None, + dtype: DType = jnp.float32, + name: str = "gated_linear_network"): + """Initialize a GatedLinearNetwork as a sequence of GatedLinearLayers.""" + super(GatedLinearNetwork, self).__init__(name=name) + + self._layers = [] + self._output_sizes = output_sizes + for i, output_size in enumerate(self._output_sizes): + layer = _GatedLinearLayer( + output_size=output_size, + context_dim=context_dim, + update_fn=update_fn, + inference_fn=inference_fn, + init=init, + hyp_w_init=hyp_w_init, + hyp_b_init=hyp_b_init, + dtype=dtype, + name=name + "_layer_{}".format(i)) + self._layers.append(layer) + self._name = name + + @abc.abstractmethod + def _add_bias(self, inputs): + pass + + def inference(self, inputs: Array, side_info: Array, *args, + **kwargs) -> Array: + """GatedLinearNetwork inference.""" + predictions_per_layer = [] + predictions = inputs + for layer in self._layers: + predictions = self._add_bias(predictions) + predictions = layer.inference(predictions, side_info, *args, **kwargs) + predictions_per_layer.append(predictions) + + return jnp.concatenate(predictions_per_layer, axis=0) + + def update(self, inputs, side_info, target, learning_rate, *args, **kwargs): + """GatedLinearNetwork update.""" + all_params = [] + all_predictions = [] + all_losses = [] + predictions = inputs + for layer in self._layers: + predictions = self._add_bias(predictions) + + # Note: This is correct because returned predictions are pre-update. + params, predictions, log_loss = layer.update(predictions, side_info, + target, learning_rate, *args, + **kwargs) + all_params.append(params) + all_predictions.append(predictions) + all_losses.append(log_loss) + + new_params = dict(collections.ChainMap(*all_params)) + predictions = jnp.concatenate(all_predictions, axis=0) + log_loss = jnp.concatenate(all_losses, axis=0) + + return new_params, predictions, log_loss + + @property + def output_sizes(self): + return self._output_sizes + + @staticmethod + def _compute_context( + side_info: Array, # [side_info_size] + hyperplanes: Array, # [context_dim, side_info_size] + hyperplane_bias: Array, # [context_dim] + ) -> Array: + # Index weights by side information. + context_dim = hyperplane_bias.shape[0] + proj = jnp.dot(hyperplanes, side_info) + bits = (proj > hyperplane_bias).astype(jnp.int32) + weight_index = jnp.sum( + bits * + jnp.array([2**i for i in range(context_dim)])) if context_dim else 0 + return weight_index + + +class _GatedLinearLayer(LocalUpdateModule): + """A single layer of a Gated Linear Network.""" + + def __init__(self, + output_size: int, + context_dim: int, + inference_fn: Callable[..., Array], + update_fn: Callable[..., Array], + init: Initializer, + hyp_w_init: Optional[Initializer] = None, + hyp_b_init: Optional[Initializer] = None, + dtype: DType = jnp.float32, + name: str = "gated_linear_layer"): + """Initialize a GatedLinearLayer.""" + super(_GatedLinearLayer, self).__init__(name=name) + self._output_size = output_size + self._context_dim = context_dim + self._inference_fn = inference_fn + self._update_fn = update_fn + self._init = init + self._hyp_w_init = hyp_w_init + self._hyp_b_init = hyp_b_init + self._dtype = dtype + self._name = name + + def _get_weights(self, input_size): + """Get (or initialize) weight parameters.""" + weights = hk.get_parameter( + "weights", + shape=(self._output_size, 2**self._context_dim, input_size), + dtype=self._dtype, + init=self._init, + ) + + return weights + + def _get_hyperplanes(self, side_info_size): + """Get (or initialize) hyperplane weights and bias.""" + + hyp_w_init = self._hyp_w_init or NormalizedRandomNormal( + stddev=1., normalize_axis=1) + hyperplanes = hk.get_state( + "hyperplanes", + shape=(self._output_size, self._context_dim, side_info_size), + init=hyp_w_init) + + hyp_b_init = self._hyp_b_init or hk.initializers.RandomNormal(stddev=0.05) + hyperplane_bias = hk.get_state( + "hyperplane_bias", + shape=(self._output_size, self._context_dim), + init=hyp_b_init) + + return hyperplanes, hyperplane_bias + + def inference(self, inputs: Array, side_info: Array, *args, + **kwargs) -> Array: + """GatedLinearLayer inference.""" + # Initialize layer weights. + weights = self._get_weights(inputs.shape[0]) + + # Initialize fixed random hyperplanes. + side_info_size = side_info.shape[0] + hyperplanes, hyperplane_bias = self._get_hyperplanes(side_info_size) + + # Perform layer-wise inference by mapping along output_size (num_neurons). + layer_inference = _layer_vmap(self._inference_fn) + predictions = layer_inference(inputs, side_info, weights, hyperplanes, + hyperplane_bias, *args, **kwargs) + + return predictions + + def update(self, inputs: Array, side_info: Array, target: Array, + learning_rate: float, *args, + **kwargs) -> Tuple[Array, Array, Array]: + """GatedLinearLayer update.""" + # Fetch layer weights. + weights = self._get_weights(inputs.shape[0]) + + # Fetch fixed random hyperplanes. + side_info_size = side_info.shape[0] + hyperplanes, hyperplane_bias = self._get_hyperplanes(side_info_size) + + # Perform layer-wise update by mapping along output_size (num_neurons). + layer_update = _layer_vmap(self._update_fn) + new_weights, predictions, log_loss = layer_update(inputs, side_info, + weights, hyperplanes, + hyperplane_bias, target, + learning_rate, *args, + **kwargs) + + assert new_weights.shape == weights.shape + params = {self.module_name: {"weights": new_weights}} + return params, predictions, log_loss + + @property + def output_sizes(self): + return self._output_size + + +class Mutator(LocalUpdateModule): + """Abstract base class for GLN Mutators.""" + + def __init__( + self, + network_factory: Callable[..., LocalUpdateModule], + name: str, + ): + super(Mutator, self).__init__(name=name) + self._network = network_factory() + self._name = name + + @property + def output_sizes(self): + return self._network.output_sizes + + +class LastNeuronAggregator(Mutator): + """Last neuron aggregator: network output is read from the last neuron.""" + + def __init__( + self, + network_factory: Callable[..., LocalUpdateModule], + name: str = "last_neuron", + ): + super(LastNeuronAggregator, self).__init__(network_factory, name) + if self._network.output_sizes[-1] != 1: + raise ValueError( + "LastNeuronAggregator requires the last GLN layer to have" + " output_size = 1.") + + def inference(self, *args, **kwargs) -> Array: + predictions = self._network.inference(*args, **kwargs) + return predictions[-1] + + def update(self, *args, **kwargs) -> Tuple[Array, Array, Array]: + params_t, predictions_tm1, loss_tm1 = self._network.update(*args, **kwargs) + return params_t, predictions_tm1[-1], loss_tm1[-1] diff --git a/gated_linear_networks/bernoulli.py b/gated_linear_networks/bernoulli.py new file mode 100644 index 0000000..74ed799 --- /dev/null +++ b/gated_linear_networks/bernoulli.py @@ -0,0 +1,107 @@ +# Lint as: python3 +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bernoulli Gated Linear Network.""" + +from typing import List, Text, Tuple + +import chex +import jax +import jax.numpy as jnp +import rlax +import tensorflow_probability as tfp + +from gated_linear_networks import base + +tfp = tfp.experimental.substrates.jax +tfd = tfp.distributions + +Array = chex.Array + +GLN_EPS = 0.01 +MAX_WEIGHT = 200. + + +class GatedLinearNetwork(base.GatedLinearNetwork): + """Bernoulli Gated Linear Network.""" + + def __init__(self, + output_sizes: List[int], + context_dim: int, + name: Text = "bernoulli_gln"): + """Initialize a Bernoulli GLN.""" + super(GatedLinearNetwork, self).__init__( + output_sizes, + context_dim, + inference_fn=GatedLinearNetwork._inference_fn, + update_fn=GatedLinearNetwork._update_fn, + init=jnp.zeros, + dtype=jnp.float32, + name=name) + + def _add_bias(self, inputs): + return jnp.append(inputs, rlax.sigmoid(1.)) + + @staticmethod + def _inference_fn( + inputs: Array, # [input_size] + side_info: Array, # [side_info_size] + weights: Array, # [2**context_dim, input_size] + hyperplanes: Array, # [context_dim, side_info_size] + hyperplane_bias: Array, # [context_dim] + ) -> Array: + """Inference step for a single Beurnolli neuron.""" + + weight_index = GatedLinearNetwork._compute_context(side_info, hyperplanes, + hyperplane_bias) + used_weights = weights[weight_index] + inputs = rlax.logit(jnp.clip(inputs, GLN_EPS, 1. - GLN_EPS)) + prediction = rlax.sigmoid(jnp.dot(used_weights, inputs)) + + return prediction + + @staticmethod + def _update_fn( + inputs: Array, # [input_size] + side_info: Array, # [side_info_size] + weights: Array, # [2**context_dim, num_features] + hyperplanes: Array, # [context_dim, side_info_size] + hyperplane_bias: Array, # [context_dim] + target: Array, # [] + learning_rate: float, + ) -> Tuple[Array, Array, Array]: + """Update step for a single Bernoulli neuron.""" + + def log_loss_fn(inputs, side_info, weights, hyperplanes, hyperplane_bias, + target): + """Log loss for a single Bernoulli neuron.""" + prediction = GatedLinearNetwork._inference_fn(inputs, side_info, weights, + hyperplanes, + hyperplane_bias) + prediction = jnp.clip(prediction, GLN_EPS, 1. - GLN_EPS) + return rlax.log_loss(prediction, target), prediction + + grad_log_loss = jax.value_and_grad(log_loss_fn, argnums=2, has_aux=True) + ((log_loss, prediction), + dloss_dweights) = grad_log_loss(inputs, side_info, weights, hyperplanes, + hyperplane_bias, target) + + delta_weights = learning_rate * dloss_dweights + new_weights = jnp.clip(weights - delta_weights, -MAX_WEIGHT, MAX_WEIGHT) + return new_weights, prediction, log_loss + + +class LastNeuronAggregator(base.LastNeuronAggregator): + """Bernoulli last neuron aggregator, implemented by the super class.""" + pass diff --git a/gated_linear_networks/bernoulli_test.py b/gated_linear_networks/bernoulli_test.py new file mode 100644 index 0000000..a53aae1 --- /dev/null +++ b/gated_linear_networks/bernoulli_test.py @@ -0,0 +1,215 @@ +# Lint as: python3 +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for `bernoulli.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import tree + +from gated_linear_networks import bernoulli + + +def _get_dataset(input_size, batch_size=None): + """Get mock dataset.""" + if batch_size: + inputs = jnp.ones([batch_size, input_size]) + side_info = jnp.ones([batch_size, input_size]) + targets = jnp.ones([batch_size]) + else: + inputs = jnp.ones([input_size]) + side_info = jnp.ones([input_size]) + targets = jnp.ones([]) + + return inputs, side_info, targets + + +class GatedLinearNetworkTest(parameterized.TestCase): + + # TODO(b/170843789): Factor out common test utilities. + def setUp(self): + super(GatedLinearNetworkTest, self).setUp() + self._name = "test_network" + self._rng = hk.PRNGSequence(jax.random.PRNGKey(42)) + + self._output_sizes = (4, 5, 6) + self._context_dim = 2 + + def gln_factory(): + return bernoulli.GatedLinearNetwork( + output_sizes=self._output_sizes, + context_dim=self._context_dim, + name=self._name) + + def inference_fn(inputs, side_info): + return gln_factory().inference(inputs, side_info) + + def batch_inference_fn(inputs, side_info): + return jax.vmap(inference_fn, in_axes=(0, 0))(inputs, side_info) + + def update_fn(inputs, side_info, label, learning_rate): + params, predictions, unused_loss = gln_factory().update( + inputs, side_info, label, learning_rate) + return predictions, params + + def batch_update_fn(inputs, side_info, label, learning_rate): + predictions, params = jax.vmap( + update_fn, in_axes=(0, 0, 0, None))(inputs, side_info, label, + learning_rate) + avg_params = tree.map_structure(lambda x: jnp.mean(x, axis=0), params) + return predictions, avg_params + + # Haiku transform functions. + self._init_fn, inference_fn_ = hk.without_apply_rng( + hk.transform_with_state(inference_fn)) + self._batch_init_fn, batch_inference_fn_ = hk.without_apply_rng( + hk.transform_with_state(batch_inference_fn)) + _, update_fn_ = hk.without_apply_rng(hk.transform_with_state(update_fn)) + _, batch_update_fn_ = hk.without_apply_rng( + hk.transform_with_state(batch_update_fn)) + + self._inference_fn = jax.jit(inference_fn_) + self._batch_inference_fn = jax.jit(batch_inference_fn_) + self._update_fn = jax.jit(update_fn_) + self._batch_update_fn = jax.jit(batch_update_fn_) + + @parameterized.named_parameters(("Online mode", None), ("Batch mode", 3)) + def test_shapes(self, batch_size): + """Test shapes in online and batch regimes.""" + if batch_size is None: + init_fn = self._init_fn + inference_fn = self._inference_fn + else: + init_fn = self._batch_init_fn + inference_fn = self._batch_inference_fn + + input_size = 10 + inputs, side_info, _ = _get_dataset(input_size, batch_size) + input_size = inputs.shape[-1] + + # Initialize network. + gln_params, gln_state = init_fn(next(self._rng), inputs, side_info) + + # Test shapes of parameters layer-wise. + layer_input_size = input_size + for layer_idx, output_size in enumerate(self._output_sizes): + name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx) + weights = gln_params[name]["weights"] + expected_shape = (output_size, 2**self._context_dim, layer_input_size + 1) + self.assertEqual(weights.shape, expected_shape) + + layer_input_size = output_size + + # Test shape of output. + output_size = sum(self._output_sizes) + predictions, _ = inference_fn(gln_params, gln_state, inputs, side_info) + expected_shape = (batch_size, output_size) if batch_size else (output_size,) + self.assertEqual(predictions.shape, expected_shape) + + @parameterized.named_parameters(("Online mode", None), ("Batch mode", 3)) + def test_update(self, batch_size): + """Test network updates in online and batch regimes.""" + if batch_size is None: + init_fn = self._init_fn + inference_fn = self._inference_fn + update_fn = self._update_fn + else: + init_fn = self._batch_init_fn + inference_fn = self._batch_inference_fn + update_fn = self._batch_update_fn + + input_size = 10 + inputs, side_info, targets = _get_dataset(input_size, batch_size) + + # Initialize network. + initial_params, gln_state = init_fn(next(self._rng), inputs, side_info) + + # Initial predictions. + initial_predictions, _ = inference_fn(initial_params, gln_state, inputs, + side_info) + + # Test that params remain valid after consecutive updates. + gln_params = initial_params + + for _ in range(3): + (_, gln_params), gln_state = update_fn( + gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-4) + + # Check updated weights layer-wise. + for layer_idx in range(len(self._output_sizes)): + name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx) + + initial_weights = initial_params[name]["weights"] + new_weights = gln_params[name]["weights"] + + # Shape consistency. + self.assertEqual(new_weights.shape, initial_weights.shape) + + # Check that different weights yield different predictions. + new_predictions, _ = inference_fn(gln_params, gln_state, inputs, + side_info) + self.assertFalse(np.array_equal(new_predictions, initial_predictions)) + + def test_batch_consistency(self): + """Test consistency between online and batch updates.""" + + input_size = 10 + batch_size = 3 + inputs, side_info, targets = _get_dataset(input_size, batch_size) + + # Initialize network. + gln_params, gln_state = self._batch_init_fn( + next(self._rng), inputs, side_info) + test_layer = "{}/~/{}_layer_0".format(self._name, self._name) + + for _ in range(10): + + # Update on full batch. + (expected_predictions, expected_params), _ = self._batch_update_fn( + gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-3) + + # Average updates across batch and check equivalence. + accum_predictions = [] + accum_weights = [] + for inputs_, side_info_, targets_ in zip(inputs, side_info, targets): + (predictions, params), _ = self._update_fn( + gln_params, + gln_state, + inputs_, + side_info_, + targets_, + learning_rate=1e-3) + accum_predictions.append(predictions) + accum_weights.append(params[test_layer]["weights"]) + + # Check prediction equivalence. + actual_predictions = np.stack(accum_predictions, axis=0) + np.testing.assert_array_almost_equal(actual_predictions, + expected_predictions) + + # Check weight equivalence. + actual_weights = np.mean(np.stack(accum_weights, axis=0), axis=0) + expected_weights = expected_params[test_layer]["weights"] + np.testing.assert_array_almost_equal(actual_weights, expected_weights) + + gln_params = expected_params + + +if __name__ == "__main__": + absltest.main() diff --git a/gated_linear_networks/examples/bernoulli_mnist.py b/gated_linear_networks/examples/bernoulli_mnist.py new file mode 100644 index 0000000..41fc304 --- /dev/null +++ b/gated_linear_networks/examples/bernoulli_mnist.py @@ -0,0 +1,144 @@ +# Lint as: python3 +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Online MNIST classification example with Bernoulli GLN.""" + +from absl import app +from absl import flags + +import haiku as hk +import jax +import jax.numpy as jnp +import rlax + +from gated_linear_networks import bernoulli +from gated_linear_networks.examples import utils + +FLAGS = flags.FLAGS + +# Small example network, achieves ~95% test set accuracy ======================= +# Network parameters. +flags.DEFINE_integer('num_layers', 2, '') +flags.DEFINE_integer('neurons_per_layer', 100, '') +flags.DEFINE_integer('context_dim', 1, '') + +# Learning rate schedule. +flags.DEFINE_float('max_lr', 0.003, '') +flags.DEFINE_float('lr_constant', 1.0, '') +flags.DEFINE_float('lr_decay', 0.1, '') + +# Logging parameters. +flags.DEFINE_integer('evaluate_every', 1000, '') + + +def main(unused_argv): + # Load MNIST dataset ========================================================= + mnist_data, info = utils.load_deskewed_mnist( + name='mnist', batch_size=-1, with_info=True) + num_classes = info.features['label'].num_classes + + (train_images, train_labels) = (mnist_data['train']['image'], + mnist_data['train']['label']) + + (test_images, test_labels) = (mnist_data['test']['image'], + mnist_data['test']['label']) + + # Build a (binary) GLN classifier ============================================ + def network_factory(): + + def gln_factory(): + output_sizes = [FLAGS.neurons_per_layer] * FLAGS.num_layers + [1] + return bernoulli.GatedLinearNetwork( + output_sizes=output_sizes, context_dim=FLAGS.context_dim) + + return bernoulli.LastNeuronAggregator(gln_factory) + + def extract_features(image): + mean, stddev = utils.MeanStdEstimator()(image) + standardized_img = (image - mean) / (stddev + 1.) + inputs = rlax.sigmoid(standardized_img) + side_info = standardized_img + return inputs, side_info + + def inference_fn(image, *args, **kwargs): + inputs, side_info = extract_features(image) + return network_factory().inference(inputs, side_info, *args, **kwargs) + + def update_fn(image, *args, **kwargs): + inputs, side_info = extract_features(image) + return network_factory().update(inputs, side_info, *args, **kwargs) + + init_, inference_ = hk.without_apply_rng( + hk.transform_with_state(inference_fn)) + _, update_ = hk.without_apply_rng(hk.transform_with_state(update_fn)) + + # Map along class dimension to create a one-vs-all classifier ================ + @jax.jit + def init(dummy_image, key): + """One-vs-all classifier init fn.""" + dummy_images = jnp.stack([dummy_image] * num_classes, axis=0) + keys = jax.random.split(key, num_classes) + return jax.vmap(init_, in_axes=(0, 0))(keys, dummy_images) + + @jax.jit + def accuracy(params, state, image, label): + """One-vs-all classifier inference fn.""" + fn = jax.vmap(inference_, in_axes=(0, 0, None)) + predictions, unused_state = fn(params, state, image) + return (jnp.argmax(predictions) == label).astype(jnp.float32) + + @jax.jit + def update(params, state, step, image, label): + """One-vs-all classifier update fn.""" + + # Learning rate schedules. + learning_rate = jnp.minimum( + FLAGS.max_lr, FLAGS.lr_constant / (1. + FLAGS.lr_decay * step)) + + # Update weights and report log-loss. + targets = hk.one_hot(jnp.asarray(label), num_classes) + + fn = jax.vmap(update_, in_axes=(0, 0, None, 0, None)) + out = fn(params, state, image, targets, learning_rate) + (params, unused_predictions, log_loss), state = out + return (jnp.mean(log_loss), params), state + + # Train on train split ======================================================= + dummy_image = train_images[0] + params, state = init(dummy_image, jax.random.PRNGKey(42)) + + for step, (image, label) in enumerate(zip(train_images, train_labels), 1): + (unused_loss, params), state = update( + params, + state, + step, + image, + label, + ) + + # Evaluate on test split =================================================== + if not step % FLAGS.evaluate_every: + batch_accuracy = jax.vmap(accuracy, in_axes=(None, None, 0, 0)) + accuracies = batch_accuracy(params, state, test_images, test_labels) + total_accuracy = float(jnp.mean(accuracies)) + + # Report statistics. + print({ + 'step': step, + 'accuracy': float(total_accuracy), + }) + + +if __name__ == '__main__': + app.run(main) diff --git a/gated_linear_networks/examples/utils.py b/gated_linear_networks/examples/utils.py new file mode 100644 index 0000000..5c531d5 --- /dev/null +++ b/gated_linear_networks/examples/utils.py @@ -0,0 +1,97 @@ +# Lint as: python3 +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Haiku modules for feature processing.""" + +import copy +from typing import Tuple + +import chex +import haiku as hk +import jax.numpy as jnp +import numpy as np +from scipy.ndimage import interpolation +import tensorflow_datasets as tfds + +Array = chex.Array + + +def _moments(image): + """Compute the first and second moments of a given image.""" + c0, c1 = np.mgrid[:image.shape[0], :image.shape[1]] + total_image = np.sum(image) + m0 = np.sum(c0 * image) / total_image + m1 = np.sum(c1 * image) / total_image + m00 = np.sum((c0 - m0)**2 * image) / total_image + m11 = np.sum((c1 - m1)**2 * image) / total_image + m01 = np.sum((c0 - m0) * (c1 - m1) * image) / total_image + mu_vector = np.array([m0, m1]) + covariance_matrix = np.array([[m00, m01], [m01, m11]]) + return mu_vector, covariance_matrix + + +def _deskew(image): + """Image deskew.""" + c, v = _moments(image) + alpha = v[0, 1] / v[0, 0] + affine = np.array([[1, 0], [alpha, 1]]) + ocenter = np.array(image.shape) / 2.0 + offset = c - np.dot(affine, ocenter) + return interpolation.affine_transform(image, affine, offset=offset) + + +def _deskew_dataset(dataset): + """Dataset deskew.""" + deskewed = copy.deepcopy(dataset) + for k, before in dataset.items(): + images = before["image"] + num_images = images.shape[0] + after = np.stack([_deskew(i) for i in np.squeeze(images, axis=-1)], axis=0) + deskewed[k]["image"] = np.reshape(after, (num_images, -1)) + return deskewed + + +def load_deskewed_mnist(*a, **k): + """Returns deskewed MNIST numpy dataset.""" + mnist_data, info = tfds.load(*a, **k) + mnist_data = tfds.as_numpy(mnist_data) + deskewed_data = _deskew_dataset(mnist_data) + return deskewed_data, info + + +class MeanStdEstimator(hk.Module): + """Online mean and standard deviation estimator using Welford's algorithm.""" + + def __call__(self, sample: jnp.DeviceArray) -> Tuple[Array, Array]: + if len(sample.shape) > 1: + raise ValueError("sample must be a rank 0 or 1 DeviceArray.") + + count = hk.get_state("count", shape=(), dtype=jnp.int32, init=jnp.zeros) + mean = hk.get_state( + "mean", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros) + m2 = hk.get_state( + "m2", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros) + + count += 1 + delta = sample - mean + mean += delta / count + delta_2 = sample - mean + m2 += delta * delta_2 + + hk.set_state("count", count) + hk.set_state("mean", mean) + hk.set_state("m2", m2) + + stddev = jnp.sqrt(m2 / count) + return mean, stddev diff --git a/gated_linear_networks/examples/utils_test.py b/gated_linear_networks/examples/utils_test.py new file mode 100644 index 0000000..2674699 --- /dev/null +++ b/gated_linear_networks/examples/utils_test.py @@ -0,0 +1,52 @@ +# Lint as: python3 +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for `utils.py`.""" + +from absl.testing import absltest + +import haiku as hk +import jax +import numpy as np + +from gated_linear_networks.examples import utils + + +class MeanStdEstimator(absltest.TestCase): + + def test_statistics(self): + num_features = 100 + feature_size = 3 + samples = np.random.normal( + loc=5., scale=2., size=(num_features, feature_size)) + true_mean = np.mean(samples, axis=0) + true_std = np.std(samples, axis=0) + + def tick_(sample): + return utils.MeanStdEstimator()(sample) + + init_fn, apply_fn = hk.without_apply_rng(hk.transform_with_state(tick_)) + tick = jax.jit(apply_fn) + + params, state = init_fn(rng=None, sample=samples[0]) + + for sample in samples: + (mean, std), state = tick(params, state, sample) + + np.testing.assert_array_almost_equal(mean, true_mean, decimal=5) + np.testing.assert_array_almost_equal(std, true_std, decimal=5) + + +if __name__ == '__main__': + absltest.main() diff --git a/gated_linear_networks/gaussian.py b/gated_linear_networks/gaussian.py new file mode 100644 index 0000000..6924370 --- /dev/null +++ b/gated_linear_networks/gaussian.py @@ -0,0 +1,197 @@ +# Lint as: python3 +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gaussian Gated Linear Network.""" + +from typing import Callable, List, Text, Tuple + +import chex +import jax +import jax.numpy as jnp +import tensorflow_probability as tfp + +from gated_linear_networks import base + +tfp = tfp.experimental.substrates.jax +tfd = tfp.distributions + +Array = chex.Array + +MIN_SIGMA_SQ_AGGREGATOR = 0.5 +MAX_SIGMA_SQ = 1e5 +MAX_WEIGHT = 1e3 +MIN_WEIGHT = -1e3 + + +def _unpack_inputs(inputs: Array) -> Tuple[Array, Array]: + inputs = jnp.atleast_2d(inputs) + chex.assert_rank(inputs, 2) + (mu, sigma_sq) = [jnp.squeeze(x, 1) for x in jnp.hsplit(inputs, 2)] + return mu, sigma_sq + + +def _pack_inputs(mu: Array, sigma_sq: Array) -> Array: + mu = jnp.atleast_1d(mu) + sigma_sq = jnp.atleast_1d(sigma_sq) + chex.assert_rank([mu, sigma_sq], 1) + return jnp.vstack([mu, sigma_sq]).T + + +class GatedLinearNetwork(base.GatedLinearNetwork): + """Gaussian Gated Linear Network.""" + + def __init__( + self, + output_sizes: List[int], + context_dim: int, + bias_len: int = 3, + bias_max_mu: float = 1., + bias_sigma_sq: float = 1., + name: Text = "gaussian_gln"): + """Initialize a Gaussian GLN.""" + super(GatedLinearNetwork, self).__init__( + output_sizes, + context_dim, + inference_fn=GatedLinearNetwork._inference_fn, + update_fn=GatedLinearNetwork._update_fn, + init=base.ShapeScaledConstant(), + dtype=jnp.float64, + name=name) + + self._bias_len = bias_len + self._bias_max_mu = bias_max_mu + self._bias_sigma_sq = bias_sigma_sq + + def _add_bias(self, inputs): + mu = jnp.linspace(-1. * self._bias_max_mu, self._bias_max_mu, + self._bias_len) + sigma_sq = self._bias_sigma_sq * jnp.ones_like(mu) + bias = _pack_inputs(mu, sigma_sq) + return jnp.concatenate([inputs, bias], axis=0) + + @staticmethod + def _inference_fn( + inputs: Array, # [input_size, 2] + side_info: Array, # [side_info_size] + weights: Array, # [2**context_dim, input_size] + hyperplanes: Array, # [context_dim, side_info_size] + hyperplane_bias: Array, # [context_dim] + min_sigma_sq: float, + ) -> Array: + """Inference step for a single Gaussian neuron.""" + + mu_in, sigma_sq_in = _unpack_inputs(inputs) + weight_index = GatedLinearNetwork._compute_context(side_info, hyperplanes, + hyperplane_bias) + used_weights = weights[weight_index] + + # This projection operation is differentiable and affects the gradients. + used_weights = GatedLinearNetwork._project_weights(inputs, used_weights, + min_sigma_sq) + + sigma_sq_out = 1. / jnp.sum(used_weights / sigma_sq_in) + mu_out = sigma_sq_out * jnp.sum((used_weights * mu_in) / sigma_sq_in) + prediction = jnp.hstack((mu_out, sigma_sq_out)) + return prediction + + @staticmethod + def _project_weights(inputs: Array, # [input_size] + weights: Array, # [2**context_dim, num_features] + min_sigma_sq: float) -> Array: + """Implements hard projection.""" + + # This projection should be performed before the sigma related ones. + weights = jnp.minimum(jnp.maximum(MIN_WEIGHT, weights), MAX_WEIGHT) + _, sigma_sq_in = _unpack_inputs(inputs) + + lambda_in = 1. / sigma_sq_in + sigma_sq_out = 1. / weights.dot(lambda_in) + + # If w.dot(x) < U, linearly project w such that w.dot(x) = U. + weights = jnp.where( + sigma_sq_out < min_sigma_sq, weights - lambda_in * + (1. / sigma_sq_out - 1. / min_sigma_sq) / jnp.sum(lambda_in**2), + weights) + + # If w.dot(x) > U, linearly project w such that w.dot(x) = U. + weights = jnp.where( + sigma_sq_out > MAX_SIGMA_SQ, weights - lambda_in * + (1. / sigma_sq_out - 1. / MAX_SIGMA_SQ) / jnp.sum(lambda_in**2), + weights) + + return weights + + @staticmethod + def _update_fn( + inputs: Array, # [input_size] + side_info: Array, # [side_info_size] + weights: Array, # [2**context_dim, num_features] + hyperplanes: Array, # [context_dim, side_info_size] + hyperplane_bias: Array, # [context_dim] + target: Array, # [] + learning_rate: float, + min_sigma_sq: float, # needed for inference (weight projection) + ) -> Tuple[Array, Array, Array]: + """Update step for a single Gaussian neuron.""" + + def log_loss_fn(inputs, side_info, weights, hyperplanes, hyperplane_bias, + target): + """Log loss for a single Gaussian neuron.""" + prediction = GatedLinearNetwork._inference_fn(inputs, side_info, weights, + hyperplanes, + hyperplane_bias, + min_sigma_sq) + mu, sigma_sq = prediction.T + loss = -tfd.Normal(mu, jnp.sqrt(sigma_sq)).log_prob(target) + return loss, prediction + + grad_log_loss = jax.value_and_grad(log_loss_fn, argnums=2, has_aux=True) + (log_loss, + prediction), dloss_dweights = grad_log_loss(inputs, side_info, weights, + hyperplanes, hyperplane_bias, + target) + + delta_weights = learning_rate * dloss_dweights + return weights - delta_weights, prediction, log_loss + + +class ConstantInputSigma(base.Mutator): + """Input pre-processing by concatenating a constant sigma^2.""" + + def __init__( + self, + network_factory: Callable[..., GatedLinearNetwork], + input_sigma_sq: float, + name: Text = "constant_input_sigma", + ): + super(ConstantInputSigma, self).__init__(network_factory, name) + self._input_sigma_sq = input_sigma_sq + + def inference(self, inputs, *args, **kwargs): + """ConstantInputSigma inference.""" + chex.assert_rank(inputs, 1) + sigma_sq = self._input_sigma_sq * jnp.ones_like(inputs) + return self._network.inference(_pack_inputs(inputs, sigma_sq), *args, + **kwargs) + + def update(self, inputs, *args, **kwargs): + """ConstantInputSigma update.""" + chex.assert_rank(inputs, 1) + sigma_sq = self._input_sigma_sq * jnp.ones_like(inputs) + return self._network.update(_pack_inputs(inputs, sigma_sq), *args, **kwargs) + + +class LastNeuronAggregator(base.LastNeuronAggregator): + """Gaussian last neuron aggregator, implemented by the super class.""" + pass diff --git a/gated_linear_networks/gaussian_test.py b/gated_linear_networks/gaussian_test.py new file mode 100644 index 0000000..6cbf575 --- /dev/null +++ b/gated_linear_networks/gaussian_test.py @@ -0,0 +1,233 @@ +# Lint as: python3 +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for `gaussian.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import tree + +from gated_linear_networks import gaussian + + +def _get_dataset(input_size, batch_size=None): + """Get mock dataset.""" + if batch_size: + inputs = jnp.ones([batch_size, input_size, 2]) + side_info = jnp.ones([batch_size, input_size]) + targets = 0.8 * jnp.ones([batch_size]) + else: + inputs = jnp.ones([input_size, 2]) + side_info = jnp.ones([input_size]) + targets = jnp.ones([]) + + return inputs, side_info, targets + + +class UtilsTest(absltest.TestCase): + + def test_packing_identity(self): + mu = jnp.array([1., 2., 3., 4., 5.]) + sigma_sq = jnp.array([6., 7., 8., 9., 10.]) + + mu_2, sigma_sq_2 = gaussian._unpack_inputs( + gaussian._pack_inputs(mu, sigma_sq)) + + np.testing.assert_array_equal(mu, mu_2) + np.testing.assert_array_equal(sigma_sq, sigma_sq_2) + + +class GatedLinearNetworkTest(parameterized.TestCase): + + # TODO(b/170843789): Factor out common test utilities. + def setUp(self): + super(GatedLinearNetworkTest, self).setUp() + self._name = "test_network" + self._rng = hk.PRNGSequence(jax.random.PRNGKey(42)) + + self._output_sizes = (4, 5, 6) + self._context_dim = 2 + self._bias_len = 3 + + def gln_factory(): + return gaussian.GatedLinearNetwork( + output_sizes=self._output_sizes, + context_dim=self._context_dim, + bias_len=self._bias_len, + name=self._name, + ) + + def inference_fn(inputs, side_info): + return gln_factory().inference(inputs, side_info, 0.5) + + def batch_inference_fn(inputs, side_info): + return jax.vmap(inference_fn, in_axes=(0, 0))(inputs, side_info) + + def update_fn(inputs, side_info, label, learning_rate): + params, predictions, unused_loss = gln_factory().update( + inputs, side_info, label, learning_rate, 0.5) + return predictions, params + + def batch_update_fn(inputs, side_info, label, learning_rate): + predictions, params = jax.vmap( + update_fn, in_axes=(0, 0, 0, None))( + inputs, + side_info, + label, + learning_rate) + avg_params = tree.map_structure(lambda x: jnp.mean(x, axis=0), params) + return predictions, avg_params + + # Haiku transform functions. + self._init_fn, inference_fn_ = hk.without_apply_rng( + hk.transform_with_state(inference_fn)) + self._batch_init_fn, batch_inference_fn_ = hk.without_apply_rng( + hk.transform_with_state(batch_inference_fn)) + _, update_fn_ = hk.without_apply_rng(hk.transform_with_state(update_fn)) + _, batch_update_fn_ = hk.without_apply_rng( + hk.transform_with_state(batch_update_fn)) + + self._inference_fn = jax.jit(inference_fn_) + self._batch_inference_fn = jax.jit(batch_inference_fn_) + self._update_fn = jax.jit(update_fn_) + self._batch_update_fn = jax.jit(batch_update_fn_) + + @parameterized.named_parameters(("Online mode", None), ("Batch mode", 3)) + def test_shapes(self, batch_size): + """Test shapes in online and batch regimes.""" + if batch_size is None: + init_fn = self._init_fn + inference_fn = self._inference_fn + else: + init_fn = self._batch_init_fn + inference_fn = self._batch_inference_fn + + input_size = 10 + inputs, side_info, _ = _get_dataset(input_size, batch_size) + + # Initialize network. + gln_params, gln_state = init_fn(next(self._rng), inputs, side_info) + + # Test shapes of parameters layer-wise. + layer_input_size = input_size + for layer_idx, output_size in enumerate(self._output_sizes): + name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx) + weights = gln_params[name]["weights"] + expected_shape = (output_size, 2**self._context_dim, + layer_input_size + self._bias_len) + self.assertEqual(weights.shape, expected_shape) + + layer_input_size = output_size + + # Test shape of output. + output_size = sum(self._output_sizes) + predictions, _ = inference_fn(gln_params, gln_state, inputs, side_info) + expected_shape = (batch_size, output_size, + 2) if batch_size else (output_size, 2) + self.assertEqual(predictions.shape, expected_shape) + + @parameterized.named_parameters(("Online mode", None), ("Batch mode", 3)) + def test_update(self, batch_size): + """Test network updates in online and batch regimes.""" + if batch_size is None: + init_fn = self._init_fn + inference_fn = self._inference_fn + update_fn = self._update_fn + else: + init_fn = self._batch_init_fn + inference_fn = self._batch_inference_fn + update_fn = self._batch_update_fn + + inputs, side_info, targets = _get_dataset(10, batch_size) + + # Initialize network. + initial_params, gln_state = init_fn(next(self._rng), inputs, side_info) + + # Initial predictions. + initial_predictions, _ = inference_fn(initial_params, gln_state, inputs, + side_info) + + # Test that params remain valid after consecutive updates. + gln_params = initial_params + + for _ in range(3): + (_, gln_params), _ = update_fn( + gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-4) + + # Check updated weights layer-wise. + for layer_idx in range(len(self._output_sizes)): + name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx) + + initial_weights = initial_params[name]["weights"] + new_weights = gln_params[name]["weights"] + + # Shape consistency. + self.assertEqual(new_weights.shape, initial_weights.shape) + + # Check that different weights yield different predictions. + new_predictions, _ = inference_fn(gln_params, gln_state, inputs, + side_info) + self.assertFalse(np.array_equal(new_predictions, initial_predictions)) + + def test_batch_consistency(self): + """Test consistency between online and batch updates.""" + + batch_size = 3 + inputs, side_info, targets = _get_dataset(10, batch_size) + + # Initialize network. + gln_params, gln_state = self._batch_init_fn( + next(self._rng), inputs, side_info) + test_layer = "{}/~/{}_layer_0".format(self._name, self._name) + + for _ in range(10): + + # Update on full batch. + (expected_predictions, expected_params), _ = self._batch_update_fn( + gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-3) + + # Average updates across batch and check equivalence. + accum_predictions = [] + accum_weights = [] + for inputs_, side_info_, targets_ in zip(inputs, side_info, targets): + (predictions, params), _ = self._update_fn( + gln_params, + gln_state, + inputs_, + side_info_, + targets_, + learning_rate=1e-3) + accum_predictions.append(predictions) + accum_weights.append(params[test_layer]["weights"]) + + # Check prediction equivalence. + actual_predictions = np.stack(accum_predictions, axis=0) + np.testing.assert_array_almost_equal(actual_predictions, + expected_predictions) + + # Check weight equivalence. + actual_weights = np.mean(np.stack(accum_weights, axis=0), axis=0) + expected_weights = expected_params[test_layer]["weights"] + np.testing.assert_array_almost_equal(actual_weights, expected_weights) + + gln_params = expected_params + + +if __name__ == "__main__": + absltest.main() diff --git a/gated_linear_networks/requirements.txt b/gated_linear_networks/requirements.txt new file mode 100644 index 0000000..e9781de --- /dev/null +++ b/gated_linear_networks/requirements.txt @@ -0,0 +1,52 @@ +absl-py==0.10.0 +aiohttp==3.6.2 +astunparse==1.6.3 +async-timeout==3.0.1 +attrs==20.2.0 +cachetools==4.1.1 +certifi==2020.6.20 +chardet==3.0.4 +chex==0.0.2 +cloudpickle==1.6.0 +decorator==4.4.2 +dill==0.3.2 +dm-env==1.2 +dm-haiku==0.0.2 +dm-tree==0.1.5 +future==0.18.2 +gast==0.3.3 +google-auth==1.22.0 +google-auth-oauthlib==0.4.1 +google-pasta==0.2.0 +googleapis-common-protos==1.52.0 +grpcio==1.32.0 +h5py==2.10.0 +idna==2.10 +jax==0.2.0 +jaxlib==0.1.55 +Keras-Preprocessing==1.1.2 +Markdown==3.2.2 +multidict==4.7.6 +numpy==1.18.5 +oauthlib==3.1.0 +opt-einsum==3.3.0 +promise==2.3 +protobuf==3.13.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +requests==2.24.0 +requests-oauthlib==1.3.0 +rlax==0.0.2 +rsa==4.6 +scipy==1.5.2 +six==1.15.0 +tensorboard==2.3.0 +tensorboard-plugin-wit==1.7.0 +tensorflow==2.3.1 +tensorflow-datasets==3.2.1 +tensorflow-estimator==2.3.0 +tensorflow-metadata==0.24.0 +tensorflow-probability==0.11.1 +termcolor==1.1.0 +toolz==0.11.1 +tqdm==4.50.0 diff --git a/gated_linear_networks/run.sh b/gated_linear_networks/run.sh new file mode 100644 index 0000000..cc4f394 --- /dev/null +++ b/gated_linear_networks/run.sh @@ -0,0 +1,26 @@ +#!/bin/sh +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +python3 -m venv gln_venv +source gln_venv/bin/activate +pip3 install --upgrade setuptools wheel +pip3 install -r gated_linear_networks/requirements.txt + +# Run MNIST example with Bernoulli GLN +python3 -m gated_linear_networks.examples.bernoulli_mnist \ + --num_layers=2 \ + --neurons_per_layer=100 \ + --context_dim=1