mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
49def83d1d
PiperOrigin-RevId: 338219746
198 lines
6.9 KiB
Python
198 lines
6.9 KiB
Python
# Lint as: python3
|
|
# Copyright 2020 DeepMind Technologies Limited.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Gaussian Gated Linear Network."""
|
|
|
|
from typing import Callable, List, Text, Tuple
|
|
|
|
import chex
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import tensorflow_probability as tfp
|
|
|
|
from gated_linear_networks import base
|
|
|
|
tfp = tfp.experimental.substrates.jax
|
|
tfd = tfp.distributions
|
|
|
|
Array = chex.Array
|
|
|
|
MIN_SIGMA_SQ_AGGREGATOR = 0.5
|
|
MAX_SIGMA_SQ = 1e5
|
|
MAX_WEIGHT = 1e3
|
|
MIN_WEIGHT = -1e3
|
|
|
|
|
|
def _unpack_inputs(inputs: Array) -> Tuple[Array, Array]:
|
|
inputs = jnp.atleast_2d(inputs)
|
|
chex.assert_rank(inputs, 2)
|
|
(mu, sigma_sq) = [jnp.squeeze(x, 1) for x in jnp.hsplit(inputs, 2)]
|
|
return mu, sigma_sq
|
|
|
|
|
|
def _pack_inputs(mu: Array, sigma_sq: Array) -> Array:
|
|
mu = jnp.atleast_1d(mu)
|
|
sigma_sq = jnp.atleast_1d(sigma_sq)
|
|
chex.assert_rank([mu, sigma_sq], 1)
|
|
return jnp.vstack([mu, sigma_sq]).T
|
|
|
|
|
|
class GatedLinearNetwork(base.GatedLinearNetwork):
|
|
"""Gaussian Gated Linear Network."""
|
|
|
|
def __init__(
|
|
self,
|
|
output_sizes: List[int],
|
|
context_dim: int,
|
|
bias_len: int = 3,
|
|
bias_max_mu: float = 1.,
|
|
bias_sigma_sq: float = 1.,
|
|
name: Text = "gaussian_gln"):
|
|
"""Initialize a Gaussian GLN."""
|
|
super(GatedLinearNetwork, self).__init__(
|
|
output_sizes,
|
|
context_dim,
|
|
inference_fn=GatedLinearNetwork._inference_fn,
|
|
update_fn=GatedLinearNetwork._update_fn,
|
|
init=base.ShapeScaledConstant(),
|
|
dtype=jnp.float64,
|
|
name=name)
|
|
|
|
self._bias_len = bias_len
|
|
self._bias_max_mu = bias_max_mu
|
|
self._bias_sigma_sq = bias_sigma_sq
|
|
|
|
def _add_bias(self, inputs):
|
|
mu = jnp.linspace(-1. * self._bias_max_mu, self._bias_max_mu,
|
|
self._bias_len)
|
|
sigma_sq = self._bias_sigma_sq * jnp.ones_like(mu)
|
|
bias = _pack_inputs(mu, sigma_sq)
|
|
return jnp.concatenate([inputs, bias], axis=0)
|
|
|
|
@staticmethod
|
|
def _inference_fn(
|
|
inputs: Array, # [input_size, 2]
|
|
side_info: Array, # [side_info_size]
|
|
weights: Array, # [2**context_dim, input_size]
|
|
hyperplanes: Array, # [context_dim, side_info_size]
|
|
hyperplane_bias: Array, # [context_dim]
|
|
min_sigma_sq: float,
|
|
) -> Array:
|
|
"""Inference step for a single Gaussian neuron."""
|
|
|
|
mu_in, sigma_sq_in = _unpack_inputs(inputs)
|
|
weight_index = GatedLinearNetwork._compute_context(side_info, hyperplanes,
|
|
hyperplane_bias)
|
|
used_weights = weights[weight_index]
|
|
|
|
# This projection operation is differentiable and affects the gradients.
|
|
used_weights = GatedLinearNetwork._project_weights(inputs, used_weights,
|
|
min_sigma_sq)
|
|
|
|
sigma_sq_out = 1. / jnp.sum(used_weights / sigma_sq_in)
|
|
mu_out = sigma_sq_out * jnp.sum((used_weights * mu_in) / sigma_sq_in)
|
|
prediction = jnp.hstack((mu_out, sigma_sq_out))
|
|
return prediction
|
|
|
|
@staticmethod
|
|
def _project_weights(inputs: Array, # [input_size]
|
|
weights: Array, # [2**context_dim, num_features]
|
|
min_sigma_sq: float) -> Array:
|
|
"""Implements hard projection."""
|
|
|
|
# This projection should be performed before the sigma related ones.
|
|
weights = jnp.minimum(jnp.maximum(MIN_WEIGHT, weights), MAX_WEIGHT)
|
|
_, sigma_sq_in = _unpack_inputs(inputs)
|
|
|
|
lambda_in = 1. / sigma_sq_in
|
|
sigma_sq_out = 1. / weights.dot(lambda_in)
|
|
|
|
# If w.dot(x) < U, linearly project w such that w.dot(x) = U.
|
|
weights = jnp.where(
|
|
sigma_sq_out < min_sigma_sq, weights - lambda_in *
|
|
(1. / sigma_sq_out - 1. / min_sigma_sq) / jnp.sum(lambda_in**2),
|
|
weights)
|
|
|
|
# If w.dot(x) > U, linearly project w such that w.dot(x) = U.
|
|
weights = jnp.where(
|
|
sigma_sq_out > MAX_SIGMA_SQ, weights - lambda_in *
|
|
(1. / sigma_sq_out - 1. / MAX_SIGMA_SQ) / jnp.sum(lambda_in**2),
|
|
weights)
|
|
|
|
return weights
|
|
|
|
@staticmethod
|
|
def _update_fn(
|
|
inputs: Array, # [input_size]
|
|
side_info: Array, # [side_info_size]
|
|
weights: Array, # [2**context_dim, num_features]
|
|
hyperplanes: Array, # [context_dim, side_info_size]
|
|
hyperplane_bias: Array, # [context_dim]
|
|
target: Array, # []
|
|
learning_rate: float,
|
|
min_sigma_sq: float, # needed for inference (weight projection)
|
|
) -> Tuple[Array, Array, Array]:
|
|
"""Update step for a single Gaussian neuron."""
|
|
|
|
def log_loss_fn(inputs, side_info, weights, hyperplanes, hyperplane_bias,
|
|
target):
|
|
"""Log loss for a single Gaussian neuron."""
|
|
prediction = GatedLinearNetwork._inference_fn(inputs, side_info, weights,
|
|
hyperplanes,
|
|
hyperplane_bias,
|
|
min_sigma_sq)
|
|
mu, sigma_sq = prediction.T
|
|
loss = -tfd.Normal(mu, jnp.sqrt(sigma_sq)).log_prob(target)
|
|
return loss, prediction
|
|
|
|
grad_log_loss = jax.value_and_grad(log_loss_fn, argnums=2, has_aux=True)
|
|
(log_loss,
|
|
prediction), dloss_dweights = grad_log_loss(inputs, side_info, weights,
|
|
hyperplanes, hyperplane_bias,
|
|
target)
|
|
|
|
delta_weights = learning_rate * dloss_dweights
|
|
return weights - delta_weights, prediction, log_loss
|
|
|
|
|
|
class ConstantInputSigma(base.Mutator):
|
|
"""Input pre-processing by concatenating a constant sigma^2."""
|
|
|
|
def __init__(
|
|
self,
|
|
network_factory: Callable[..., GatedLinearNetwork],
|
|
input_sigma_sq: float,
|
|
name: Text = "constant_input_sigma",
|
|
):
|
|
super(ConstantInputSigma, self).__init__(network_factory, name)
|
|
self._input_sigma_sq = input_sigma_sq
|
|
|
|
def inference(self, inputs, *args, **kwargs):
|
|
"""ConstantInputSigma inference."""
|
|
chex.assert_rank(inputs, 1)
|
|
sigma_sq = self._input_sigma_sq * jnp.ones_like(inputs)
|
|
return self._network.inference(_pack_inputs(inputs, sigma_sq), *args,
|
|
**kwargs)
|
|
|
|
def update(self, inputs, *args, **kwargs):
|
|
"""ConstantInputSigma update."""
|
|
chex.assert_rank(inputs, 1)
|
|
sigma_sq = self._input_sigma_sq * jnp.ones_like(inputs)
|
|
return self._network.update(_pack_inputs(inputs, sigma_sq), *args, **kwargs)
|
|
|
|
|
|
class LastNeuronAggregator(base.LastNeuronAggregator):
|
|
"""Gaussian last neuron aggregator, implemented by the super class."""
|
|
pass
|