Files
deepmind-research/gated_linear_networks/gaussian.py
T

198 lines
6.9 KiB
Python

# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gaussian Gated Linear Network."""
from typing import Callable, List, Text, Tuple
import chex
import jax
import jax.numpy as jnp
import tensorflow_probability as tfp
from gated_linear_networks import base
tfp = tfp.experimental.substrates.jax
tfd = tfp.distributions
Array = chex.Array
MIN_SIGMA_SQ_AGGREGATOR = 0.5
MAX_SIGMA_SQ = 1e5
MAX_WEIGHT = 1e3
MIN_WEIGHT = -1e3
def _unpack_inputs(inputs: Array) -> Tuple[Array, Array]:
inputs = jnp.atleast_2d(inputs)
chex.assert_rank(inputs, 2)
(mu, sigma_sq) = [jnp.squeeze(x, 1) for x in jnp.hsplit(inputs, 2)]
return mu, sigma_sq
def _pack_inputs(mu: Array, sigma_sq: Array) -> Array:
mu = jnp.atleast_1d(mu)
sigma_sq = jnp.atleast_1d(sigma_sq)
chex.assert_rank([mu, sigma_sq], 1)
return jnp.vstack([mu, sigma_sq]).T
class GatedLinearNetwork(base.GatedLinearNetwork):
"""Gaussian Gated Linear Network."""
def __init__(
self,
output_sizes: List[int],
context_dim: int,
bias_len: int = 3,
bias_max_mu: float = 1.,
bias_sigma_sq: float = 1.,
name: Text = "gaussian_gln"):
"""Initialize a Gaussian GLN."""
super(GatedLinearNetwork, self).__init__(
output_sizes,
context_dim,
inference_fn=GatedLinearNetwork._inference_fn,
update_fn=GatedLinearNetwork._update_fn,
init=base.ShapeScaledConstant(),
dtype=jnp.float64,
name=name)
self._bias_len = bias_len
self._bias_max_mu = bias_max_mu
self._bias_sigma_sq = bias_sigma_sq
def _add_bias(self, inputs):
mu = jnp.linspace(-1. * self._bias_max_mu, self._bias_max_mu,
self._bias_len)
sigma_sq = self._bias_sigma_sq * jnp.ones_like(mu)
bias = _pack_inputs(mu, sigma_sq)
return jnp.concatenate([inputs, bias], axis=0)
@staticmethod
def _inference_fn(
inputs: Array, # [input_size, 2]
side_info: Array, # [side_info_size]
weights: Array, # [2**context_dim, input_size]
hyperplanes: Array, # [context_dim, side_info_size]
hyperplane_bias: Array, # [context_dim]
min_sigma_sq: float,
) -> Array:
"""Inference step for a single Gaussian neuron."""
mu_in, sigma_sq_in = _unpack_inputs(inputs)
weight_index = GatedLinearNetwork._compute_context(side_info, hyperplanes,
hyperplane_bias)
used_weights = weights[weight_index]
# This projection operation is differentiable and affects the gradients.
used_weights = GatedLinearNetwork._project_weights(inputs, used_weights,
min_sigma_sq)
sigma_sq_out = 1. / jnp.sum(used_weights / sigma_sq_in)
mu_out = sigma_sq_out * jnp.sum((used_weights * mu_in) / sigma_sq_in)
prediction = jnp.hstack((mu_out, sigma_sq_out))
return prediction
@staticmethod
def _project_weights(inputs: Array, # [input_size]
weights: Array, # [2**context_dim, num_features]
min_sigma_sq: float) -> Array:
"""Implements hard projection."""
# This projection should be performed before the sigma related ones.
weights = jnp.minimum(jnp.maximum(MIN_WEIGHT, weights), MAX_WEIGHT)
_, sigma_sq_in = _unpack_inputs(inputs)
lambda_in = 1. / sigma_sq_in
sigma_sq_out = 1. / weights.dot(lambda_in)
# If w.dot(x) < U, linearly project w such that w.dot(x) = U.
weights = jnp.where(
sigma_sq_out < min_sigma_sq, weights - lambda_in *
(1. / sigma_sq_out - 1. / min_sigma_sq) / jnp.sum(lambda_in**2),
weights)
# If w.dot(x) > U, linearly project w such that w.dot(x) = U.
weights = jnp.where(
sigma_sq_out > MAX_SIGMA_SQ, weights - lambda_in *
(1. / sigma_sq_out - 1. / MAX_SIGMA_SQ) / jnp.sum(lambda_in**2),
weights)
return weights
@staticmethod
def _update_fn(
inputs: Array, # [input_size]
side_info: Array, # [side_info_size]
weights: Array, # [2**context_dim, num_features]
hyperplanes: Array, # [context_dim, side_info_size]
hyperplane_bias: Array, # [context_dim]
target: Array, # []
learning_rate: float,
min_sigma_sq: float, # needed for inference (weight projection)
) -> Tuple[Array, Array, Array]:
"""Update step for a single Gaussian neuron."""
def log_loss_fn(inputs, side_info, weights, hyperplanes, hyperplane_bias,
target):
"""Log loss for a single Gaussian neuron."""
prediction = GatedLinearNetwork._inference_fn(inputs, side_info, weights,
hyperplanes,
hyperplane_bias,
min_sigma_sq)
mu, sigma_sq = prediction.T
loss = -tfd.Normal(mu, jnp.sqrt(sigma_sq)).log_prob(target)
return loss, prediction
grad_log_loss = jax.value_and_grad(log_loss_fn, argnums=2, has_aux=True)
(log_loss,
prediction), dloss_dweights = grad_log_loss(inputs, side_info, weights,
hyperplanes, hyperplane_bias,
target)
delta_weights = learning_rate * dloss_dweights
return weights - delta_weights, prediction, log_loss
class ConstantInputSigma(base.Mutator):
"""Input pre-processing by concatenating a constant sigma^2."""
def __init__(
self,
network_factory: Callable[..., GatedLinearNetwork],
input_sigma_sq: float,
name: Text = "constant_input_sigma",
):
super(ConstantInputSigma, self).__init__(network_factory, name)
self._input_sigma_sq = input_sigma_sq
def inference(self, inputs, *args, **kwargs):
"""ConstantInputSigma inference."""
chex.assert_rank(inputs, 1)
sigma_sq = self._input_sigma_sq * jnp.ones_like(inputs)
return self._network.inference(_pack_inputs(inputs, sigma_sq), *args,
**kwargs)
def update(self, inputs, *args, **kwargs):
"""ConstantInputSigma update."""
chex.assert_rank(inputs, 1)
sigma_sq = self._input_sigma_sq * jnp.ones_like(inputs)
return self._network.update(_pack_inputs(inputs, sigma_sq), *args, **kwargs)
class LastNeuronAggregator(base.LastNeuronAggregator):
"""Gaussian last neuron aggregator, implemented by the super class."""
pass