Adds the code for "gated_linear_networks" to the files release.bara.sky and README.md for public release on github.

PiperOrigin-RevId: 338219746
This commit is contained in:
Louise Deason
2020-10-21 08:27:28 +00:00
parent c45af649a7
commit 49def83d1d
12 changed files with 1515 additions and 0 deletions
+44
View File
@@ -0,0 +1,44 @@
# Gated Linear Networks
Gated Linear Networks (GLNs) are a family of backpropation-free neural networks.
Each neuron in a GLN predicts the target density (or probability mass) based on
the outputs of the previous layer and is trained under a logarthmic loss.
## GLN variants
Neurons have probabilistic "activation functions". Implementations are provided
for the following distributions:
- Gaussian, for regression.
- Bernoulli, for binary classification and multi-class classification using a
one-vs-all scheme.
## Examples
Usage examples are provided in [`examples`](examples).
## Implementation details
### Constraint satisfaction
Because each neuron implements a probability density/mass function we need to
ensure that they are well defined. For example, the scale parameter for a
Gaussian density needs to be positive. We implement these constraints using
linear projections and clipping.
### Aggregation
Because each neuron predicts the target, we can use any neuron output as the
"network output", and are not bound to the last layer. Typically last layer
neuron(s) are the best predictors, but they might take longer to converge in
theory. In this implementation, we use a single neuron at the last layer, which
then forms the network output.
There are alternative ways of aggregating, e.g. see Switching Aggregation in
Appendix D of *Gaussian Gated Linear Networks* (link:
https://arxiv.org/pdf/2006.05964.pdf).
## References
Coming soon.
+347
View File
@@ -0,0 +1,347 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base classes for Gated Linear Networks."""
import abc
import collections
import functools
import inspect
from typing import Any, Callable, Optional, Sequence, Tuple
import chex
import haiku as hk
import jax
import jax.numpy as jnp
Array = chex.Array
DType = Any
Initializer = hk.initializers.Initializer
Shape = Sequence[int]
EPS = 1e-12
MIN_ALPHA = 1e-5
def _l2_normalize(x: Array, axis: int) -> Array:
return x / jnp.sqrt(jnp.maximum(jnp.sum(x**2, axis, keepdims=True), EPS))
def _wrapped_fn_argnames(fun):
"""Returns list of argnames of a (possibly wrapped) function."""
return tuple(inspect.signature(fun).parameters)
def _vmap(fun, in_axes=0, out_axes=0, parameters=None):
"""JAX vmap with human-friendly axes."""
def _axes(fun, d):
"""Maps dict {kwarg_i, : val_i} to [None, ..., val_i, ..., None]."""
argnames = _wrapped_fn_argnames(fun) if not parameters else parameters
for key in d:
if key not in argnames:
raise ValueError(f"{key} is not a valid axis.")
return tuple(d.get(key, None) for key in argnames)
in_axes = _axes(fun, in_axes) if isinstance(in_axes, dict) else in_axes
return jax.vmap(fun, in_axes, out_axes)
# Map a neuron-level function across a layer.
_layer_vmap = functools.partial(
_vmap,
in_axes=({
"weights": 0,
"hyperplanes": 0,
"hyperplane_bias": 0,
}))
class NormalizedRandomNormal(hk.initializers.RandomNormal):
"""Random normal initializer with l2-normalization."""
def __init__(self,
stddev: float = 1.,
mean: float = 0.,
normalize_axis: int = 0):
super(NormalizedRandomNormal, self).__init__(stddev, mean)
self._normalize_axis = normalize_axis
def __call__(self, shape: Shape, dtype: DType) -> Array:
if self._normalize_axis >= len(shape):
raise ValueError("Cannot normalize axis {} for ndim = {}.".format(
self._normalize_axis, len(shape)))
weights = super(NormalizedRandomNormal, self).__call__(shape, dtype)
return _l2_normalize(weights, axis=self._normalize_axis)
class ShapeScaledConstant(hk.initializers.Initializer):
"""Initializes with a constant dependent on last dimension of input shape."""
def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:
constant = 1. / shape[-1]
return jnp.broadcast_to(constant, shape).astype(dtype)
class LocalUpdateModule(hk.Module):
"""Abstract base class for GLN variants and utils."""
def __init__(self, name: Optional[str] = None):
if hasattr(self, "__call__"):
raise ValueError("Do not implement `__call__` for a LocalUpdateModule." +
" Implement `inference` and `update` instead.")
super(LocalUpdateModule, self).__init__(name)
@abc.abstractmethod
def inference(self, *args, **kwargs):
"""Module inference step."""
@abc.abstractmethod
def update(self, *args, **kwargs):
"""Module update step."""
@property
@abc.abstractmethod
def output_sizes(self) -> Shape:
"""Returns network output sizes."""
class GatedLinearNetwork(LocalUpdateModule):
"""Abstract base class for a multi-layer Gated Linear Network."""
def __init__(self,
output_sizes: Shape,
context_dim: int,
inference_fn: Callable[..., Array],
update_fn: Callable[..., Array],
init: Initializer,
hyp_w_init: Optional[Initializer] = None,
hyp_b_init: Optional[Initializer] = None,
dtype: DType = jnp.float32,
name: str = "gated_linear_network"):
"""Initialize a GatedLinearNetwork as a sequence of GatedLinearLayers."""
super(GatedLinearNetwork, self).__init__(name=name)
self._layers = []
self._output_sizes = output_sizes
for i, output_size in enumerate(self._output_sizes):
layer = _GatedLinearLayer(
output_size=output_size,
context_dim=context_dim,
update_fn=update_fn,
inference_fn=inference_fn,
init=init,
hyp_w_init=hyp_w_init,
hyp_b_init=hyp_b_init,
dtype=dtype,
name=name + "_layer_{}".format(i))
self._layers.append(layer)
self._name = name
@abc.abstractmethod
def _add_bias(self, inputs):
pass
def inference(self, inputs: Array, side_info: Array, *args,
**kwargs) -> Array:
"""GatedLinearNetwork inference."""
predictions_per_layer = []
predictions = inputs
for layer in self._layers:
predictions = self._add_bias(predictions)
predictions = layer.inference(predictions, side_info, *args, **kwargs)
predictions_per_layer.append(predictions)
return jnp.concatenate(predictions_per_layer, axis=0)
def update(self, inputs, side_info, target, learning_rate, *args, **kwargs):
"""GatedLinearNetwork update."""
all_params = []
all_predictions = []
all_losses = []
predictions = inputs
for layer in self._layers:
predictions = self._add_bias(predictions)
# Note: This is correct because returned predictions are pre-update.
params, predictions, log_loss = layer.update(predictions, side_info,
target, learning_rate, *args,
**kwargs)
all_params.append(params)
all_predictions.append(predictions)
all_losses.append(log_loss)
new_params = dict(collections.ChainMap(*all_params))
predictions = jnp.concatenate(all_predictions, axis=0)
log_loss = jnp.concatenate(all_losses, axis=0)
return new_params, predictions, log_loss
@property
def output_sizes(self):
return self._output_sizes
@staticmethod
def _compute_context(
side_info: Array, # [side_info_size]
hyperplanes: Array, # [context_dim, side_info_size]
hyperplane_bias: Array, # [context_dim]
) -> Array:
# Index weights by side information.
context_dim = hyperplane_bias.shape[0]
proj = jnp.dot(hyperplanes, side_info)
bits = (proj > hyperplane_bias).astype(jnp.int32)
weight_index = jnp.sum(
bits *
jnp.array([2**i for i in range(context_dim)])) if context_dim else 0
return weight_index
class _GatedLinearLayer(LocalUpdateModule):
"""A single layer of a Gated Linear Network."""
def __init__(self,
output_size: int,
context_dim: int,
inference_fn: Callable[..., Array],
update_fn: Callable[..., Array],
init: Initializer,
hyp_w_init: Optional[Initializer] = None,
hyp_b_init: Optional[Initializer] = None,
dtype: DType = jnp.float32,
name: str = "gated_linear_layer"):
"""Initialize a GatedLinearLayer."""
super(_GatedLinearLayer, self).__init__(name=name)
self._output_size = output_size
self._context_dim = context_dim
self._inference_fn = inference_fn
self._update_fn = update_fn
self._init = init
self._hyp_w_init = hyp_w_init
self._hyp_b_init = hyp_b_init
self._dtype = dtype
self._name = name
def _get_weights(self, input_size):
"""Get (or initialize) weight parameters."""
weights = hk.get_parameter(
"weights",
shape=(self._output_size, 2**self._context_dim, input_size),
dtype=self._dtype,
init=self._init,
)
return weights
def _get_hyperplanes(self, side_info_size):
"""Get (or initialize) hyperplane weights and bias."""
hyp_w_init = self._hyp_w_init or NormalizedRandomNormal(
stddev=1., normalize_axis=1)
hyperplanes = hk.get_state(
"hyperplanes",
shape=(self._output_size, self._context_dim, side_info_size),
init=hyp_w_init)
hyp_b_init = self._hyp_b_init or hk.initializers.RandomNormal(stddev=0.05)
hyperplane_bias = hk.get_state(
"hyperplane_bias",
shape=(self._output_size, self._context_dim),
init=hyp_b_init)
return hyperplanes, hyperplane_bias
def inference(self, inputs: Array, side_info: Array, *args,
**kwargs) -> Array:
"""GatedLinearLayer inference."""
# Initialize layer weights.
weights = self._get_weights(inputs.shape[0])
# Initialize fixed random hyperplanes.
side_info_size = side_info.shape[0]
hyperplanes, hyperplane_bias = self._get_hyperplanes(side_info_size)
# Perform layer-wise inference by mapping along output_size (num_neurons).
layer_inference = _layer_vmap(self._inference_fn)
predictions = layer_inference(inputs, side_info, weights, hyperplanes,
hyperplane_bias, *args, **kwargs)
return predictions
def update(self, inputs: Array, side_info: Array, target: Array,
learning_rate: float, *args,
**kwargs) -> Tuple[Array, Array, Array]:
"""GatedLinearLayer update."""
# Fetch layer weights.
weights = self._get_weights(inputs.shape[0])
# Fetch fixed random hyperplanes.
side_info_size = side_info.shape[0]
hyperplanes, hyperplane_bias = self._get_hyperplanes(side_info_size)
# Perform layer-wise update by mapping along output_size (num_neurons).
layer_update = _layer_vmap(self._update_fn)
new_weights, predictions, log_loss = layer_update(inputs, side_info,
weights, hyperplanes,
hyperplane_bias, target,
learning_rate, *args,
**kwargs)
assert new_weights.shape == weights.shape
params = {self.module_name: {"weights": new_weights}}
return params, predictions, log_loss
@property
def output_sizes(self):
return self._output_size
class Mutator(LocalUpdateModule):
"""Abstract base class for GLN Mutators."""
def __init__(
self,
network_factory: Callable[..., LocalUpdateModule],
name: str,
):
super(Mutator, self).__init__(name=name)
self._network = network_factory()
self._name = name
@property
def output_sizes(self):
return self._network.output_sizes
class LastNeuronAggregator(Mutator):
"""Last neuron aggregator: network output is read from the last neuron."""
def __init__(
self,
network_factory: Callable[..., LocalUpdateModule],
name: str = "last_neuron",
):
super(LastNeuronAggregator, self).__init__(network_factory, name)
if self._network.output_sizes[-1] != 1:
raise ValueError(
"LastNeuronAggregator requires the last GLN layer to have"
" output_size = 1.")
def inference(self, *args, **kwargs) -> Array:
predictions = self._network.inference(*args, **kwargs)
return predictions[-1]
def update(self, *args, **kwargs) -> Tuple[Array, Array, Array]:
params_t, predictions_tm1, loss_tm1 = self._network.update(*args, **kwargs)
return params_t, predictions_tm1[-1], loss_tm1[-1]
+107
View File
@@ -0,0 +1,107 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bernoulli Gated Linear Network."""
from typing import List, Text, Tuple
import chex
import jax
import jax.numpy as jnp
import rlax
import tensorflow_probability as tfp
from gated_linear_networks import base
tfp = tfp.experimental.substrates.jax
tfd = tfp.distributions
Array = chex.Array
GLN_EPS = 0.01
MAX_WEIGHT = 200.
class GatedLinearNetwork(base.GatedLinearNetwork):
"""Bernoulli Gated Linear Network."""
def __init__(self,
output_sizes: List[int],
context_dim: int,
name: Text = "bernoulli_gln"):
"""Initialize a Bernoulli GLN."""
super(GatedLinearNetwork, self).__init__(
output_sizes,
context_dim,
inference_fn=GatedLinearNetwork._inference_fn,
update_fn=GatedLinearNetwork._update_fn,
init=jnp.zeros,
dtype=jnp.float32,
name=name)
def _add_bias(self, inputs):
return jnp.append(inputs, rlax.sigmoid(1.))
@staticmethod
def _inference_fn(
inputs: Array, # [input_size]
side_info: Array, # [side_info_size]
weights: Array, # [2**context_dim, input_size]
hyperplanes: Array, # [context_dim, side_info_size]
hyperplane_bias: Array, # [context_dim]
) -> Array:
"""Inference step for a single Beurnolli neuron."""
weight_index = GatedLinearNetwork._compute_context(side_info, hyperplanes,
hyperplane_bias)
used_weights = weights[weight_index]
inputs = rlax.logit(jnp.clip(inputs, GLN_EPS, 1. - GLN_EPS))
prediction = rlax.sigmoid(jnp.dot(used_weights, inputs))
return prediction
@staticmethod
def _update_fn(
inputs: Array, # [input_size]
side_info: Array, # [side_info_size]
weights: Array, # [2**context_dim, num_features]
hyperplanes: Array, # [context_dim, side_info_size]
hyperplane_bias: Array, # [context_dim]
target: Array, # []
learning_rate: float,
) -> Tuple[Array, Array, Array]:
"""Update step for a single Bernoulli neuron."""
def log_loss_fn(inputs, side_info, weights, hyperplanes, hyperplane_bias,
target):
"""Log loss for a single Bernoulli neuron."""
prediction = GatedLinearNetwork._inference_fn(inputs, side_info, weights,
hyperplanes,
hyperplane_bias)
prediction = jnp.clip(prediction, GLN_EPS, 1. - GLN_EPS)
return rlax.log_loss(prediction, target), prediction
grad_log_loss = jax.value_and_grad(log_loss_fn, argnums=2, has_aux=True)
((log_loss, prediction),
dloss_dweights) = grad_log_loss(inputs, side_info, weights, hyperplanes,
hyperplane_bias, target)
delta_weights = learning_rate * dloss_dweights
new_weights = jnp.clip(weights - delta_weights, -MAX_WEIGHT, MAX_WEIGHT)
return new_weights, prediction, log_loss
class LastNeuronAggregator(base.LastNeuronAggregator):
"""Bernoulli last neuron aggregator, implemented by the super class."""
pass
+215
View File
@@ -0,0 +1,215 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for `bernoulli.py`."""
from absl.testing import absltest
from absl.testing import parameterized
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import tree
from gated_linear_networks import bernoulli
def _get_dataset(input_size, batch_size=None):
"""Get mock dataset."""
if batch_size:
inputs = jnp.ones([batch_size, input_size])
side_info = jnp.ones([batch_size, input_size])
targets = jnp.ones([batch_size])
else:
inputs = jnp.ones([input_size])
side_info = jnp.ones([input_size])
targets = jnp.ones([])
return inputs, side_info, targets
class GatedLinearNetworkTest(parameterized.TestCase):
# TODO(b/170843789): Factor out common test utilities.
def setUp(self):
super(GatedLinearNetworkTest, self).setUp()
self._name = "test_network"
self._rng = hk.PRNGSequence(jax.random.PRNGKey(42))
self._output_sizes = (4, 5, 6)
self._context_dim = 2
def gln_factory():
return bernoulli.GatedLinearNetwork(
output_sizes=self._output_sizes,
context_dim=self._context_dim,
name=self._name)
def inference_fn(inputs, side_info):
return gln_factory().inference(inputs, side_info)
def batch_inference_fn(inputs, side_info):
return jax.vmap(inference_fn, in_axes=(0, 0))(inputs, side_info)
def update_fn(inputs, side_info, label, learning_rate):
params, predictions, unused_loss = gln_factory().update(
inputs, side_info, label, learning_rate)
return predictions, params
def batch_update_fn(inputs, side_info, label, learning_rate):
predictions, params = jax.vmap(
update_fn, in_axes=(0, 0, 0, None))(inputs, side_info, label,
learning_rate)
avg_params = tree.map_structure(lambda x: jnp.mean(x, axis=0), params)
return predictions, avg_params
# Haiku transform functions.
self._init_fn, inference_fn_ = hk.without_apply_rng(
hk.transform_with_state(inference_fn))
self._batch_init_fn, batch_inference_fn_ = hk.without_apply_rng(
hk.transform_with_state(batch_inference_fn))
_, update_fn_ = hk.without_apply_rng(hk.transform_with_state(update_fn))
_, batch_update_fn_ = hk.without_apply_rng(
hk.transform_with_state(batch_update_fn))
self._inference_fn = jax.jit(inference_fn_)
self._batch_inference_fn = jax.jit(batch_inference_fn_)
self._update_fn = jax.jit(update_fn_)
self._batch_update_fn = jax.jit(batch_update_fn_)
@parameterized.named_parameters(("Online mode", None), ("Batch mode", 3))
def test_shapes(self, batch_size):
"""Test shapes in online and batch regimes."""
if batch_size is None:
init_fn = self._init_fn
inference_fn = self._inference_fn
else:
init_fn = self._batch_init_fn
inference_fn = self._batch_inference_fn
input_size = 10
inputs, side_info, _ = _get_dataset(input_size, batch_size)
input_size = inputs.shape[-1]
# Initialize network.
gln_params, gln_state = init_fn(next(self._rng), inputs, side_info)
# Test shapes of parameters layer-wise.
layer_input_size = input_size
for layer_idx, output_size in enumerate(self._output_sizes):
name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx)
weights = gln_params[name]["weights"]
expected_shape = (output_size, 2**self._context_dim, layer_input_size + 1)
self.assertEqual(weights.shape, expected_shape)
layer_input_size = output_size
# Test shape of output.
output_size = sum(self._output_sizes)
predictions, _ = inference_fn(gln_params, gln_state, inputs, side_info)
expected_shape = (batch_size, output_size) if batch_size else (output_size,)
self.assertEqual(predictions.shape, expected_shape)
@parameterized.named_parameters(("Online mode", None), ("Batch mode", 3))
def test_update(self, batch_size):
"""Test network updates in online and batch regimes."""
if batch_size is None:
init_fn = self._init_fn
inference_fn = self._inference_fn
update_fn = self._update_fn
else:
init_fn = self._batch_init_fn
inference_fn = self._batch_inference_fn
update_fn = self._batch_update_fn
input_size = 10
inputs, side_info, targets = _get_dataset(input_size, batch_size)
# Initialize network.
initial_params, gln_state = init_fn(next(self._rng), inputs, side_info)
# Initial predictions.
initial_predictions, _ = inference_fn(initial_params, gln_state, inputs,
side_info)
# Test that params remain valid after consecutive updates.
gln_params = initial_params
for _ in range(3):
(_, gln_params), gln_state = update_fn(
gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-4)
# Check updated weights layer-wise.
for layer_idx in range(len(self._output_sizes)):
name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx)
initial_weights = initial_params[name]["weights"]
new_weights = gln_params[name]["weights"]
# Shape consistency.
self.assertEqual(new_weights.shape, initial_weights.shape)
# Check that different weights yield different predictions.
new_predictions, _ = inference_fn(gln_params, gln_state, inputs,
side_info)
self.assertFalse(np.array_equal(new_predictions, initial_predictions))
def test_batch_consistency(self):
"""Test consistency between online and batch updates."""
input_size = 10
batch_size = 3
inputs, side_info, targets = _get_dataset(input_size, batch_size)
# Initialize network.
gln_params, gln_state = self._batch_init_fn(
next(self._rng), inputs, side_info)
test_layer = "{}/~/{}_layer_0".format(self._name, self._name)
for _ in range(10):
# Update on full batch.
(expected_predictions, expected_params), _ = self._batch_update_fn(
gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-3)
# Average updates across batch and check equivalence.
accum_predictions = []
accum_weights = []
for inputs_, side_info_, targets_ in zip(inputs, side_info, targets):
(predictions, params), _ = self._update_fn(
gln_params,
gln_state,
inputs_,
side_info_,
targets_,
learning_rate=1e-3)
accum_predictions.append(predictions)
accum_weights.append(params[test_layer]["weights"])
# Check prediction equivalence.
actual_predictions = np.stack(accum_predictions, axis=0)
np.testing.assert_array_almost_equal(actual_predictions,
expected_predictions)
# Check weight equivalence.
actual_weights = np.mean(np.stack(accum_weights, axis=0), axis=0)
expected_weights = expected_params[test_layer]["weights"]
np.testing.assert_array_almost_equal(actual_weights, expected_weights)
gln_params = expected_params
if __name__ == "__main__":
absltest.main()
@@ -0,0 +1,144 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Online MNIST classification example with Bernoulli GLN."""
from absl import app
from absl import flags
import haiku as hk
import jax
import jax.numpy as jnp
import rlax
from gated_linear_networks import bernoulli
from gated_linear_networks.examples import utils
FLAGS = flags.FLAGS
# Small example network, achieves ~95% test set accuracy =======================
# Network parameters.
flags.DEFINE_integer('num_layers', 2, '')
flags.DEFINE_integer('neurons_per_layer', 100, '')
flags.DEFINE_integer('context_dim', 1, '')
# Learning rate schedule.
flags.DEFINE_float('max_lr', 0.003, '')
flags.DEFINE_float('lr_constant', 1.0, '')
flags.DEFINE_float('lr_decay', 0.1, '')
# Logging parameters.
flags.DEFINE_integer('evaluate_every', 1000, '')
def main(unused_argv):
# Load MNIST dataset =========================================================
mnist_data, info = utils.load_deskewed_mnist(
name='mnist', batch_size=-1, with_info=True)
num_classes = info.features['label'].num_classes
(train_images, train_labels) = (mnist_data['train']['image'],
mnist_data['train']['label'])
(test_images, test_labels) = (mnist_data['test']['image'],
mnist_data['test']['label'])
# Build a (binary) GLN classifier ============================================
def network_factory():
def gln_factory():
output_sizes = [FLAGS.neurons_per_layer] * FLAGS.num_layers + [1]
return bernoulli.GatedLinearNetwork(
output_sizes=output_sizes, context_dim=FLAGS.context_dim)
return bernoulli.LastNeuronAggregator(gln_factory)
def extract_features(image):
mean, stddev = utils.MeanStdEstimator()(image)
standardized_img = (image - mean) / (stddev + 1.)
inputs = rlax.sigmoid(standardized_img)
side_info = standardized_img
return inputs, side_info
def inference_fn(image, *args, **kwargs):
inputs, side_info = extract_features(image)
return network_factory().inference(inputs, side_info, *args, **kwargs)
def update_fn(image, *args, **kwargs):
inputs, side_info = extract_features(image)
return network_factory().update(inputs, side_info, *args, **kwargs)
init_, inference_ = hk.without_apply_rng(
hk.transform_with_state(inference_fn))
_, update_ = hk.without_apply_rng(hk.transform_with_state(update_fn))
# Map along class dimension to create a one-vs-all classifier ================
@jax.jit
def init(dummy_image, key):
"""One-vs-all classifier init fn."""
dummy_images = jnp.stack([dummy_image] * num_classes, axis=0)
keys = jax.random.split(key, num_classes)
return jax.vmap(init_, in_axes=(0, 0))(keys, dummy_images)
@jax.jit
def accuracy(params, state, image, label):
"""One-vs-all classifier inference fn."""
fn = jax.vmap(inference_, in_axes=(0, 0, None))
predictions, unused_state = fn(params, state, image)
return (jnp.argmax(predictions) == label).astype(jnp.float32)
@jax.jit
def update(params, state, step, image, label):
"""One-vs-all classifier update fn."""
# Learning rate schedules.
learning_rate = jnp.minimum(
FLAGS.max_lr, FLAGS.lr_constant / (1. + FLAGS.lr_decay * step))
# Update weights and report log-loss.
targets = hk.one_hot(jnp.asarray(label), num_classes)
fn = jax.vmap(update_, in_axes=(0, 0, None, 0, None))
out = fn(params, state, image, targets, learning_rate)
(params, unused_predictions, log_loss), state = out
return (jnp.mean(log_loss), params), state
# Train on train split =======================================================
dummy_image = train_images[0]
params, state = init(dummy_image, jax.random.PRNGKey(42))
for step, (image, label) in enumerate(zip(train_images, train_labels), 1):
(unused_loss, params), state = update(
params,
state,
step,
image,
label,
)
# Evaluate on test split ===================================================
if not step % FLAGS.evaluate_every:
batch_accuracy = jax.vmap(accuracy, in_axes=(None, None, 0, 0))
accuracies = batch_accuracy(params, state, test_images, test_labels)
total_accuracy = float(jnp.mean(accuracies))
# Report statistics.
print({
'step': step,
'accuracy': float(total_accuracy),
})
if __name__ == '__main__':
app.run(main)
+97
View File
@@ -0,0 +1,97 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Haiku modules for feature processing."""
import copy
from typing import Tuple
import chex
import haiku as hk
import jax.numpy as jnp
import numpy as np
from scipy.ndimage import interpolation
import tensorflow_datasets as tfds
Array = chex.Array
def _moments(image):
"""Compute the first and second moments of a given image."""
c0, c1 = np.mgrid[:image.shape[0], :image.shape[1]]
total_image = np.sum(image)
m0 = np.sum(c0 * image) / total_image
m1 = np.sum(c1 * image) / total_image
m00 = np.sum((c0 - m0)**2 * image) / total_image
m11 = np.sum((c1 - m1)**2 * image) / total_image
m01 = np.sum((c0 - m0) * (c1 - m1) * image) / total_image
mu_vector = np.array([m0, m1])
covariance_matrix = np.array([[m00, m01], [m01, m11]])
return mu_vector, covariance_matrix
def _deskew(image):
"""Image deskew."""
c, v = _moments(image)
alpha = v[0, 1] / v[0, 0]
affine = np.array([[1, 0], [alpha, 1]])
ocenter = np.array(image.shape) / 2.0
offset = c - np.dot(affine, ocenter)
return interpolation.affine_transform(image, affine, offset=offset)
def _deskew_dataset(dataset):
"""Dataset deskew."""
deskewed = copy.deepcopy(dataset)
for k, before in dataset.items():
images = before["image"]
num_images = images.shape[0]
after = np.stack([_deskew(i) for i in np.squeeze(images, axis=-1)], axis=0)
deskewed[k]["image"] = np.reshape(after, (num_images, -1))
return deskewed
def load_deskewed_mnist(*a, **k):
"""Returns deskewed MNIST numpy dataset."""
mnist_data, info = tfds.load(*a, **k)
mnist_data = tfds.as_numpy(mnist_data)
deskewed_data = _deskew_dataset(mnist_data)
return deskewed_data, info
class MeanStdEstimator(hk.Module):
"""Online mean and standard deviation estimator using Welford's algorithm."""
def __call__(self, sample: jnp.DeviceArray) -> Tuple[Array, Array]:
if len(sample.shape) > 1:
raise ValueError("sample must be a rank 0 or 1 DeviceArray.")
count = hk.get_state("count", shape=(), dtype=jnp.int32, init=jnp.zeros)
mean = hk.get_state(
"mean", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros)
m2 = hk.get_state(
"m2", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros)
count += 1
delta = sample - mean
mean += delta / count
delta_2 = sample - mean
m2 += delta * delta_2
hk.set_state("count", count)
hk.set_state("mean", mean)
hk.set_state("m2", m2)
stddev = jnp.sqrt(m2 / count)
return mean, stddev
@@ -0,0 +1,52 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for `utils.py`."""
from absl.testing import absltest
import haiku as hk
import jax
import numpy as np
from gated_linear_networks.examples import utils
class MeanStdEstimator(absltest.TestCase):
def test_statistics(self):
num_features = 100
feature_size = 3
samples = np.random.normal(
loc=5., scale=2., size=(num_features, feature_size))
true_mean = np.mean(samples, axis=0)
true_std = np.std(samples, axis=0)
def tick_(sample):
return utils.MeanStdEstimator()(sample)
init_fn, apply_fn = hk.without_apply_rng(hk.transform_with_state(tick_))
tick = jax.jit(apply_fn)
params, state = init_fn(rng=None, sample=samples[0])
for sample in samples:
(mean, std), state = tick(params, state, sample)
np.testing.assert_array_almost_equal(mean, true_mean, decimal=5)
np.testing.assert_array_almost_equal(std, true_std, decimal=5)
if __name__ == '__main__':
absltest.main()
+197
View File
@@ -0,0 +1,197 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gaussian Gated Linear Network."""
from typing import Callable, List, Text, Tuple
import chex
import jax
import jax.numpy as jnp
import tensorflow_probability as tfp
from gated_linear_networks import base
tfp = tfp.experimental.substrates.jax
tfd = tfp.distributions
Array = chex.Array
MIN_SIGMA_SQ_AGGREGATOR = 0.5
MAX_SIGMA_SQ = 1e5
MAX_WEIGHT = 1e3
MIN_WEIGHT = -1e3
def _unpack_inputs(inputs: Array) -> Tuple[Array, Array]:
inputs = jnp.atleast_2d(inputs)
chex.assert_rank(inputs, 2)
(mu, sigma_sq) = [jnp.squeeze(x, 1) for x in jnp.hsplit(inputs, 2)]
return mu, sigma_sq
def _pack_inputs(mu: Array, sigma_sq: Array) -> Array:
mu = jnp.atleast_1d(mu)
sigma_sq = jnp.atleast_1d(sigma_sq)
chex.assert_rank([mu, sigma_sq], 1)
return jnp.vstack([mu, sigma_sq]).T
class GatedLinearNetwork(base.GatedLinearNetwork):
"""Gaussian Gated Linear Network."""
def __init__(
self,
output_sizes: List[int],
context_dim: int,
bias_len: int = 3,
bias_max_mu: float = 1.,
bias_sigma_sq: float = 1.,
name: Text = "gaussian_gln"):
"""Initialize a Gaussian GLN."""
super(GatedLinearNetwork, self).__init__(
output_sizes,
context_dim,
inference_fn=GatedLinearNetwork._inference_fn,
update_fn=GatedLinearNetwork._update_fn,
init=base.ShapeScaledConstant(),
dtype=jnp.float64,
name=name)
self._bias_len = bias_len
self._bias_max_mu = bias_max_mu
self._bias_sigma_sq = bias_sigma_sq
def _add_bias(self, inputs):
mu = jnp.linspace(-1. * self._bias_max_mu, self._bias_max_mu,
self._bias_len)
sigma_sq = self._bias_sigma_sq * jnp.ones_like(mu)
bias = _pack_inputs(mu, sigma_sq)
return jnp.concatenate([inputs, bias], axis=0)
@staticmethod
def _inference_fn(
inputs: Array, # [input_size, 2]
side_info: Array, # [side_info_size]
weights: Array, # [2**context_dim, input_size]
hyperplanes: Array, # [context_dim, side_info_size]
hyperplane_bias: Array, # [context_dim]
min_sigma_sq: float,
) -> Array:
"""Inference step for a single Gaussian neuron."""
mu_in, sigma_sq_in = _unpack_inputs(inputs)
weight_index = GatedLinearNetwork._compute_context(side_info, hyperplanes,
hyperplane_bias)
used_weights = weights[weight_index]
# This projection operation is differentiable and affects the gradients.
used_weights = GatedLinearNetwork._project_weights(inputs, used_weights,
min_sigma_sq)
sigma_sq_out = 1. / jnp.sum(used_weights / sigma_sq_in)
mu_out = sigma_sq_out * jnp.sum((used_weights * mu_in) / sigma_sq_in)
prediction = jnp.hstack((mu_out, sigma_sq_out))
return prediction
@staticmethod
def _project_weights(inputs: Array, # [input_size]
weights: Array, # [2**context_dim, num_features]
min_sigma_sq: float) -> Array:
"""Implements hard projection."""
# This projection should be performed before the sigma related ones.
weights = jnp.minimum(jnp.maximum(MIN_WEIGHT, weights), MAX_WEIGHT)
_, sigma_sq_in = _unpack_inputs(inputs)
lambda_in = 1. / sigma_sq_in
sigma_sq_out = 1. / weights.dot(lambda_in)
# If w.dot(x) < U, linearly project w such that w.dot(x) = U.
weights = jnp.where(
sigma_sq_out < min_sigma_sq, weights - lambda_in *
(1. / sigma_sq_out - 1. / min_sigma_sq) / jnp.sum(lambda_in**2),
weights)
# If w.dot(x) > U, linearly project w such that w.dot(x) = U.
weights = jnp.where(
sigma_sq_out > MAX_SIGMA_SQ, weights - lambda_in *
(1. / sigma_sq_out - 1. / MAX_SIGMA_SQ) / jnp.sum(lambda_in**2),
weights)
return weights
@staticmethod
def _update_fn(
inputs: Array, # [input_size]
side_info: Array, # [side_info_size]
weights: Array, # [2**context_dim, num_features]
hyperplanes: Array, # [context_dim, side_info_size]
hyperplane_bias: Array, # [context_dim]
target: Array, # []
learning_rate: float,
min_sigma_sq: float, # needed for inference (weight projection)
) -> Tuple[Array, Array, Array]:
"""Update step for a single Gaussian neuron."""
def log_loss_fn(inputs, side_info, weights, hyperplanes, hyperplane_bias,
target):
"""Log loss for a single Gaussian neuron."""
prediction = GatedLinearNetwork._inference_fn(inputs, side_info, weights,
hyperplanes,
hyperplane_bias,
min_sigma_sq)
mu, sigma_sq = prediction.T
loss = -tfd.Normal(mu, jnp.sqrt(sigma_sq)).log_prob(target)
return loss, prediction
grad_log_loss = jax.value_and_grad(log_loss_fn, argnums=2, has_aux=True)
(log_loss,
prediction), dloss_dweights = grad_log_loss(inputs, side_info, weights,
hyperplanes, hyperplane_bias,
target)
delta_weights = learning_rate * dloss_dweights
return weights - delta_weights, prediction, log_loss
class ConstantInputSigma(base.Mutator):
"""Input pre-processing by concatenating a constant sigma^2."""
def __init__(
self,
network_factory: Callable[..., GatedLinearNetwork],
input_sigma_sq: float,
name: Text = "constant_input_sigma",
):
super(ConstantInputSigma, self).__init__(network_factory, name)
self._input_sigma_sq = input_sigma_sq
def inference(self, inputs, *args, **kwargs):
"""ConstantInputSigma inference."""
chex.assert_rank(inputs, 1)
sigma_sq = self._input_sigma_sq * jnp.ones_like(inputs)
return self._network.inference(_pack_inputs(inputs, sigma_sq), *args,
**kwargs)
def update(self, inputs, *args, **kwargs):
"""ConstantInputSigma update."""
chex.assert_rank(inputs, 1)
sigma_sq = self._input_sigma_sq * jnp.ones_like(inputs)
return self._network.update(_pack_inputs(inputs, sigma_sq), *args, **kwargs)
class LastNeuronAggregator(base.LastNeuronAggregator):
"""Gaussian last neuron aggregator, implemented by the super class."""
pass
+233
View File
@@ -0,0 +1,233 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for `gaussian.py`."""
from absl.testing import absltest
from absl.testing import parameterized
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import tree
from gated_linear_networks import gaussian
def _get_dataset(input_size, batch_size=None):
"""Get mock dataset."""
if batch_size:
inputs = jnp.ones([batch_size, input_size, 2])
side_info = jnp.ones([batch_size, input_size])
targets = 0.8 * jnp.ones([batch_size])
else:
inputs = jnp.ones([input_size, 2])
side_info = jnp.ones([input_size])
targets = jnp.ones([])
return inputs, side_info, targets
class UtilsTest(absltest.TestCase):
def test_packing_identity(self):
mu = jnp.array([1., 2., 3., 4., 5.])
sigma_sq = jnp.array([6., 7., 8., 9., 10.])
mu_2, sigma_sq_2 = gaussian._unpack_inputs(
gaussian._pack_inputs(mu, sigma_sq))
np.testing.assert_array_equal(mu, mu_2)
np.testing.assert_array_equal(sigma_sq, sigma_sq_2)
class GatedLinearNetworkTest(parameterized.TestCase):
# TODO(b/170843789): Factor out common test utilities.
def setUp(self):
super(GatedLinearNetworkTest, self).setUp()
self._name = "test_network"
self._rng = hk.PRNGSequence(jax.random.PRNGKey(42))
self._output_sizes = (4, 5, 6)
self._context_dim = 2
self._bias_len = 3
def gln_factory():
return gaussian.GatedLinearNetwork(
output_sizes=self._output_sizes,
context_dim=self._context_dim,
bias_len=self._bias_len,
name=self._name,
)
def inference_fn(inputs, side_info):
return gln_factory().inference(inputs, side_info, 0.5)
def batch_inference_fn(inputs, side_info):
return jax.vmap(inference_fn, in_axes=(0, 0))(inputs, side_info)
def update_fn(inputs, side_info, label, learning_rate):
params, predictions, unused_loss = gln_factory().update(
inputs, side_info, label, learning_rate, 0.5)
return predictions, params
def batch_update_fn(inputs, side_info, label, learning_rate):
predictions, params = jax.vmap(
update_fn, in_axes=(0, 0, 0, None))(
inputs,
side_info,
label,
learning_rate)
avg_params = tree.map_structure(lambda x: jnp.mean(x, axis=0), params)
return predictions, avg_params
# Haiku transform functions.
self._init_fn, inference_fn_ = hk.without_apply_rng(
hk.transform_with_state(inference_fn))
self._batch_init_fn, batch_inference_fn_ = hk.without_apply_rng(
hk.transform_with_state(batch_inference_fn))
_, update_fn_ = hk.without_apply_rng(hk.transform_with_state(update_fn))
_, batch_update_fn_ = hk.without_apply_rng(
hk.transform_with_state(batch_update_fn))
self._inference_fn = jax.jit(inference_fn_)
self._batch_inference_fn = jax.jit(batch_inference_fn_)
self._update_fn = jax.jit(update_fn_)
self._batch_update_fn = jax.jit(batch_update_fn_)
@parameterized.named_parameters(("Online mode", None), ("Batch mode", 3))
def test_shapes(self, batch_size):
"""Test shapes in online and batch regimes."""
if batch_size is None:
init_fn = self._init_fn
inference_fn = self._inference_fn
else:
init_fn = self._batch_init_fn
inference_fn = self._batch_inference_fn
input_size = 10
inputs, side_info, _ = _get_dataset(input_size, batch_size)
# Initialize network.
gln_params, gln_state = init_fn(next(self._rng), inputs, side_info)
# Test shapes of parameters layer-wise.
layer_input_size = input_size
for layer_idx, output_size in enumerate(self._output_sizes):
name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx)
weights = gln_params[name]["weights"]
expected_shape = (output_size, 2**self._context_dim,
layer_input_size + self._bias_len)
self.assertEqual(weights.shape, expected_shape)
layer_input_size = output_size
# Test shape of output.
output_size = sum(self._output_sizes)
predictions, _ = inference_fn(gln_params, gln_state, inputs, side_info)
expected_shape = (batch_size, output_size,
2) if batch_size else (output_size, 2)
self.assertEqual(predictions.shape, expected_shape)
@parameterized.named_parameters(("Online mode", None), ("Batch mode", 3))
def test_update(self, batch_size):
"""Test network updates in online and batch regimes."""
if batch_size is None:
init_fn = self._init_fn
inference_fn = self._inference_fn
update_fn = self._update_fn
else:
init_fn = self._batch_init_fn
inference_fn = self._batch_inference_fn
update_fn = self._batch_update_fn
inputs, side_info, targets = _get_dataset(10, batch_size)
# Initialize network.
initial_params, gln_state = init_fn(next(self._rng), inputs, side_info)
# Initial predictions.
initial_predictions, _ = inference_fn(initial_params, gln_state, inputs,
side_info)
# Test that params remain valid after consecutive updates.
gln_params = initial_params
for _ in range(3):
(_, gln_params), _ = update_fn(
gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-4)
# Check updated weights layer-wise.
for layer_idx in range(len(self._output_sizes)):
name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx)
initial_weights = initial_params[name]["weights"]
new_weights = gln_params[name]["weights"]
# Shape consistency.
self.assertEqual(new_weights.shape, initial_weights.shape)
# Check that different weights yield different predictions.
new_predictions, _ = inference_fn(gln_params, gln_state, inputs,
side_info)
self.assertFalse(np.array_equal(new_predictions, initial_predictions))
def test_batch_consistency(self):
"""Test consistency between online and batch updates."""
batch_size = 3
inputs, side_info, targets = _get_dataset(10, batch_size)
# Initialize network.
gln_params, gln_state = self._batch_init_fn(
next(self._rng), inputs, side_info)
test_layer = "{}/~/{}_layer_0".format(self._name, self._name)
for _ in range(10):
# Update on full batch.
(expected_predictions, expected_params), _ = self._batch_update_fn(
gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-3)
# Average updates across batch and check equivalence.
accum_predictions = []
accum_weights = []
for inputs_, side_info_, targets_ in zip(inputs, side_info, targets):
(predictions, params), _ = self._update_fn(
gln_params,
gln_state,
inputs_,
side_info_,
targets_,
learning_rate=1e-3)
accum_predictions.append(predictions)
accum_weights.append(params[test_layer]["weights"])
# Check prediction equivalence.
actual_predictions = np.stack(accum_predictions, axis=0)
np.testing.assert_array_almost_equal(actual_predictions,
expected_predictions)
# Check weight equivalence.
actual_weights = np.mean(np.stack(accum_weights, axis=0), axis=0)
expected_weights = expected_params[test_layer]["weights"]
np.testing.assert_array_almost_equal(actual_weights, expected_weights)
gln_params = expected_params
if __name__ == "__main__":
absltest.main()
+52
View File
@@ -0,0 +1,52 @@
absl-py==0.10.0
aiohttp==3.6.2
astunparse==1.6.3
async-timeout==3.0.1
attrs==20.2.0
cachetools==4.1.1
certifi==2020.6.20
chardet==3.0.4
chex==0.0.2
cloudpickle==1.6.0
decorator==4.4.2
dill==0.3.2
dm-env==1.2
dm-haiku==0.0.2
dm-tree==0.1.5
future==0.18.2
gast==0.3.3
google-auth==1.22.0
google-auth-oauthlib==0.4.1
google-pasta==0.2.0
googleapis-common-protos==1.52.0
grpcio==1.32.0
h5py==2.10.0
idna==2.10
jax==0.2.0
jaxlib==0.1.55
Keras-Preprocessing==1.1.2
Markdown==3.2.2
multidict==4.7.6
numpy==1.18.5
oauthlib==3.1.0
opt-einsum==3.3.0
promise==2.3
protobuf==3.13.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
requests==2.24.0
requests-oauthlib==1.3.0
rlax==0.0.2
rsa==4.6
scipy==1.5.2
six==1.15.0
tensorboard==2.3.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.1
tensorflow-datasets==3.2.1
tensorflow-estimator==2.3.0
tensorflow-metadata==0.24.0
tensorflow-probability==0.11.1
termcolor==1.1.0
toolz==0.11.1
tqdm==4.50.0
+26
View File
@@ -0,0 +1,26 @@
#!/bin/sh
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
python3 -m venv gln_venv
source gln_venv/bin/activate
pip3 install --upgrade setuptools wheel
pip3 install -r gated_linear_networks/requirements.txt
# Run MNIST example with Bernoulli GLN
python3 -m gated_linear_networks.examples.bernoulli_mnist \
--num_layers=2 \
--neurons_per_layer=100 \
--context_dim=1