mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-30 20:35:25 +08:00
Adds the code for "gated_linear_networks" to the files release.bara.sky and README.md for public release on github.
PiperOrigin-RevId: 338219746
This commit is contained in:
@@ -0,0 +1,197 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Gaussian Gated Linear Network."""
|
||||
|
||||
from typing import Callable, List, Text, Tuple
|
||||
|
||||
import chex
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
from gated_linear_networks import base
|
||||
|
||||
tfp = tfp.experimental.substrates.jax
|
||||
tfd = tfp.distributions
|
||||
|
||||
Array = chex.Array
|
||||
|
||||
MIN_SIGMA_SQ_AGGREGATOR = 0.5
|
||||
MAX_SIGMA_SQ = 1e5
|
||||
MAX_WEIGHT = 1e3
|
||||
MIN_WEIGHT = -1e3
|
||||
|
||||
|
||||
def _unpack_inputs(inputs: Array) -> Tuple[Array, Array]:
|
||||
inputs = jnp.atleast_2d(inputs)
|
||||
chex.assert_rank(inputs, 2)
|
||||
(mu, sigma_sq) = [jnp.squeeze(x, 1) for x in jnp.hsplit(inputs, 2)]
|
||||
return mu, sigma_sq
|
||||
|
||||
|
||||
def _pack_inputs(mu: Array, sigma_sq: Array) -> Array:
|
||||
mu = jnp.atleast_1d(mu)
|
||||
sigma_sq = jnp.atleast_1d(sigma_sq)
|
||||
chex.assert_rank([mu, sigma_sq], 1)
|
||||
return jnp.vstack([mu, sigma_sq]).T
|
||||
|
||||
|
||||
class GatedLinearNetwork(base.GatedLinearNetwork):
|
||||
"""Gaussian Gated Linear Network."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_sizes: List[int],
|
||||
context_dim: int,
|
||||
bias_len: int = 3,
|
||||
bias_max_mu: float = 1.,
|
||||
bias_sigma_sq: float = 1.,
|
||||
name: Text = "gaussian_gln"):
|
||||
"""Initialize a Gaussian GLN."""
|
||||
super(GatedLinearNetwork, self).__init__(
|
||||
output_sizes,
|
||||
context_dim,
|
||||
inference_fn=GatedLinearNetwork._inference_fn,
|
||||
update_fn=GatedLinearNetwork._update_fn,
|
||||
init=base.ShapeScaledConstant(),
|
||||
dtype=jnp.float64,
|
||||
name=name)
|
||||
|
||||
self._bias_len = bias_len
|
||||
self._bias_max_mu = bias_max_mu
|
||||
self._bias_sigma_sq = bias_sigma_sq
|
||||
|
||||
def _add_bias(self, inputs):
|
||||
mu = jnp.linspace(-1. * self._bias_max_mu, self._bias_max_mu,
|
||||
self._bias_len)
|
||||
sigma_sq = self._bias_sigma_sq * jnp.ones_like(mu)
|
||||
bias = _pack_inputs(mu, sigma_sq)
|
||||
return jnp.concatenate([inputs, bias], axis=0)
|
||||
|
||||
@staticmethod
|
||||
def _inference_fn(
|
||||
inputs: Array, # [input_size, 2]
|
||||
side_info: Array, # [side_info_size]
|
||||
weights: Array, # [2**context_dim, input_size]
|
||||
hyperplanes: Array, # [context_dim, side_info_size]
|
||||
hyperplane_bias: Array, # [context_dim]
|
||||
min_sigma_sq: float,
|
||||
) -> Array:
|
||||
"""Inference step for a single Gaussian neuron."""
|
||||
|
||||
mu_in, sigma_sq_in = _unpack_inputs(inputs)
|
||||
weight_index = GatedLinearNetwork._compute_context(side_info, hyperplanes,
|
||||
hyperplane_bias)
|
||||
used_weights = weights[weight_index]
|
||||
|
||||
# This projection operation is differentiable and affects the gradients.
|
||||
used_weights = GatedLinearNetwork._project_weights(inputs, used_weights,
|
||||
min_sigma_sq)
|
||||
|
||||
sigma_sq_out = 1. / jnp.sum(used_weights / sigma_sq_in)
|
||||
mu_out = sigma_sq_out * jnp.sum((used_weights * mu_in) / sigma_sq_in)
|
||||
prediction = jnp.hstack((mu_out, sigma_sq_out))
|
||||
return prediction
|
||||
|
||||
@staticmethod
|
||||
def _project_weights(inputs: Array, # [input_size]
|
||||
weights: Array, # [2**context_dim, num_features]
|
||||
min_sigma_sq: float) -> Array:
|
||||
"""Implements hard projection."""
|
||||
|
||||
# This projection should be performed before the sigma related ones.
|
||||
weights = jnp.minimum(jnp.maximum(MIN_WEIGHT, weights), MAX_WEIGHT)
|
||||
_, sigma_sq_in = _unpack_inputs(inputs)
|
||||
|
||||
lambda_in = 1. / sigma_sq_in
|
||||
sigma_sq_out = 1. / weights.dot(lambda_in)
|
||||
|
||||
# If w.dot(x) < U, linearly project w such that w.dot(x) = U.
|
||||
weights = jnp.where(
|
||||
sigma_sq_out < min_sigma_sq, weights - lambda_in *
|
||||
(1. / sigma_sq_out - 1. / min_sigma_sq) / jnp.sum(lambda_in**2),
|
||||
weights)
|
||||
|
||||
# If w.dot(x) > U, linearly project w such that w.dot(x) = U.
|
||||
weights = jnp.where(
|
||||
sigma_sq_out > MAX_SIGMA_SQ, weights - lambda_in *
|
||||
(1. / sigma_sq_out - 1. / MAX_SIGMA_SQ) / jnp.sum(lambda_in**2),
|
||||
weights)
|
||||
|
||||
return weights
|
||||
|
||||
@staticmethod
|
||||
def _update_fn(
|
||||
inputs: Array, # [input_size]
|
||||
side_info: Array, # [side_info_size]
|
||||
weights: Array, # [2**context_dim, num_features]
|
||||
hyperplanes: Array, # [context_dim, side_info_size]
|
||||
hyperplane_bias: Array, # [context_dim]
|
||||
target: Array, # []
|
||||
learning_rate: float,
|
||||
min_sigma_sq: float, # needed for inference (weight projection)
|
||||
) -> Tuple[Array, Array, Array]:
|
||||
"""Update step for a single Gaussian neuron."""
|
||||
|
||||
def log_loss_fn(inputs, side_info, weights, hyperplanes, hyperplane_bias,
|
||||
target):
|
||||
"""Log loss for a single Gaussian neuron."""
|
||||
prediction = GatedLinearNetwork._inference_fn(inputs, side_info, weights,
|
||||
hyperplanes,
|
||||
hyperplane_bias,
|
||||
min_sigma_sq)
|
||||
mu, sigma_sq = prediction.T
|
||||
loss = -tfd.Normal(mu, jnp.sqrt(sigma_sq)).log_prob(target)
|
||||
return loss, prediction
|
||||
|
||||
grad_log_loss = jax.value_and_grad(log_loss_fn, argnums=2, has_aux=True)
|
||||
(log_loss,
|
||||
prediction), dloss_dweights = grad_log_loss(inputs, side_info, weights,
|
||||
hyperplanes, hyperplane_bias,
|
||||
target)
|
||||
|
||||
delta_weights = learning_rate * dloss_dweights
|
||||
return weights - delta_weights, prediction, log_loss
|
||||
|
||||
|
||||
class ConstantInputSigma(base.Mutator):
|
||||
"""Input pre-processing by concatenating a constant sigma^2."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
network_factory: Callable[..., GatedLinearNetwork],
|
||||
input_sigma_sq: float,
|
||||
name: Text = "constant_input_sigma",
|
||||
):
|
||||
super(ConstantInputSigma, self).__init__(network_factory, name)
|
||||
self._input_sigma_sq = input_sigma_sq
|
||||
|
||||
def inference(self, inputs, *args, **kwargs):
|
||||
"""ConstantInputSigma inference."""
|
||||
chex.assert_rank(inputs, 1)
|
||||
sigma_sq = self._input_sigma_sq * jnp.ones_like(inputs)
|
||||
return self._network.inference(_pack_inputs(inputs, sigma_sq), *args,
|
||||
**kwargs)
|
||||
|
||||
def update(self, inputs, *args, **kwargs):
|
||||
"""ConstantInputSigma update."""
|
||||
chex.assert_rank(inputs, 1)
|
||||
sigma_sq = self._input_sigma_sq * jnp.ones_like(inputs)
|
||||
return self._network.update(_pack_inputs(inputs, sigma_sq), *args, **kwargs)
|
||||
|
||||
|
||||
class LastNeuronAggregator(base.LastNeuronAggregator):
|
||||
"""Gaussian last neuron aggregator, implemented by the super class."""
|
||||
pass
|
||||
Reference in New Issue
Block a user