mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-31 04:45:31 +08:00
Adds the code for "gated_linear_networks" to the files release.bara.sky and README.md for public release on github.
PiperOrigin-RevId: 338219746
This commit is contained in:
@@ -0,0 +1,233 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for `gaussian.py`."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tree
|
||||
|
||||
from gated_linear_networks import gaussian
|
||||
|
||||
|
||||
def _get_dataset(input_size, batch_size=None):
|
||||
"""Get mock dataset."""
|
||||
if batch_size:
|
||||
inputs = jnp.ones([batch_size, input_size, 2])
|
||||
side_info = jnp.ones([batch_size, input_size])
|
||||
targets = 0.8 * jnp.ones([batch_size])
|
||||
else:
|
||||
inputs = jnp.ones([input_size, 2])
|
||||
side_info = jnp.ones([input_size])
|
||||
targets = jnp.ones([])
|
||||
|
||||
return inputs, side_info, targets
|
||||
|
||||
|
||||
class UtilsTest(absltest.TestCase):
|
||||
|
||||
def test_packing_identity(self):
|
||||
mu = jnp.array([1., 2., 3., 4., 5.])
|
||||
sigma_sq = jnp.array([6., 7., 8., 9., 10.])
|
||||
|
||||
mu_2, sigma_sq_2 = gaussian._unpack_inputs(
|
||||
gaussian._pack_inputs(mu, sigma_sq))
|
||||
|
||||
np.testing.assert_array_equal(mu, mu_2)
|
||||
np.testing.assert_array_equal(sigma_sq, sigma_sq_2)
|
||||
|
||||
|
||||
class GatedLinearNetworkTest(parameterized.TestCase):
|
||||
|
||||
# TODO(b/170843789): Factor out common test utilities.
|
||||
def setUp(self):
|
||||
super(GatedLinearNetworkTest, self).setUp()
|
||||
self._name = "test_network"
|
||||
self._rng = hk.PRNGSequence(jax.random.PRNGKey(42))
|
||||
|
||||
self._output_sizes = (4, 5, 6)
|
||||
self._context_dim = 2
|
||||
self._bias_len = 3
|
||||
|
||||
def gln_factory():
|
||||
return gaussian.GatedLinearNetwork(
|
||||
output_sizes=self._output_sizes,
|
||||
context_dim=self._context_dim,
|
||||
bias_len=self._bias_len,
|
||||
name=self._name,
|
||||
)
|
||||
|
||||
def inference_fn(inputs, side_info):
|
||||
return gln_factory().inference(inputs, side_info, 0.5)
|
||||
|
||||
def batch_inference_fn(inputs, side_info):
|
||||
return jax.vmap(inference_fn, in_axes=(0, 0))(inputs, side_info)
|
||||
|
||||
def update_fn(inputs, side_info, label, learning_rate):
|
||||
params, predictions, unused_loss = gln_factory().update(
|
||||
inputs, side_info, label, learning_rate, 0.5)
|
||||
return predictions, params
|
||||
|
||||
def batch_update_fn(inputs, side_info, label, learning_rate):
|
||||
predictions, params = jax.vmap(
|
||||
update_fn, in_axes=(0, 0, 0, None))(
|
||||
inputs,
|
||||
side_info,
|
||||
label,
|
||||
learning_rate)
|
||||
avg_params = tree.map_structure(lambda x: jnp.mean(x, axis=0), params)
|
||||
return predictions, avg_params
|
||||
|
||||
# Haiku transform functions.
|
||||
self._init_fn, inference_fn_ = hk.without_apply_rng(
|
||||
hk.transform_with_state(inference_fn))
|
||||
self._batch_init_fn, batch_inference_fn_ = hk.without_apply_rng(
|
||||
hk.transform_with_state(batch_inference_fn))
|
||||
_, update_fn_ = hk.without_apply_rng(hk.transform_with_state(update_fn))
|
||||
_, batch_update_fn_ = hk.without_apply_rng(
|
||||
hk.transform_with_state(batch_update_fn))
|
||||
|
||||
self._inference_fn = jax.jit(inference_fn_)
|
||||
self._batch_inference_fn = jax.jit(batch_inference_fn_)
|
||||
self._update_fn = jax.jit(update_fn_)
|
||||
self._batch_update_fn = jax.jit(batch_update_fn_)
|
||||
|
||||
@parameterized.named_parameters(("Online mode", None), ("Batch mode", 3))
|
||||
def test_shapes(self, batch_size):
|
||||
"""Test shapes in online and batch regimes."""
|
||||
if batch_size is None:
|
||||
init_fn = self._init_fn
|
||||
inference_fn = self._inference_fn
|
||||
else:
|
||||
init_fn = self._batch_init_fn
|
||||
inference_fn = self._batch_inference_fn
|
||||
|
||||
input_size = 10
|
||||
inputs, side_info, _ = _get_dataset(input_size, batch_size)
|
||||
|
||||
# Initialize network.
|
||||
gln_params, gln_state = init_fn(next(self._rng), inputs, side_info)
|
||||
|
||||
# Test shapes of parameters layer-wise.
|
||||
layer_input_size = input_size
|
||||
for layer_idx, output_size in enumerate(self._output_sizes):
|
||||
name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx)
|
||||
weights = gln_params[name]["weights"]
|
||||
expected_shape = (output_size, 2**self._context_dim,
|
||||
layer_input_size + self._bias_len)
|
||||
self.assertEqual(weights.shape, expected_shape)
|
||||
|
||||
layer_input_size = output_size
|
||||
|
||||
# Test shape of output.
|
||||
output_size = sum(self._output_sizes)
|
||||
predictions, _ = inference_fn(gln_params, gln_state, inputs, side_info)
|
||||
expected_shape = (batch_size, output_size,
|
||||
2) if batch_size else (output_size, 2)
|
||||
self.assertEqual(predictions.shape, expected_shape)
|
||||
|
||||
@parameterized.named_parameters(("Online mode", None), ("Batch mode", 3))
|
||||
def test_update(self, batch_size):
|
||||
"""Test network updates in online and batch regimes."""
|
||||
if batch_size is None:
|
||||
init_fn = self._init_fn
|
||||
inference_fn = self._inference_fn
|
||||
update_fn = self._update_fn
|
||||
else:
|
||||
init_fn = self._batch_init_fn
|
||||
inference_fn = self._batch_inference_fn
|
||||
update_fn = self._batch_update_fn
|
||||
|
||||
inputs, side_info, targets = _get_dataset(10, batch_size)
|
||||
|
||||
# Initialize network.
|
||||
initial_params, gln_state = init_fn(next(self._rng), inputs, side_info)
|
||||
|
||||
# Initial predictions.
|
||||
initial_predictions, _ = inference_fn(initial_params, gln_state, inputs,
|
||||
side_info)
|
||||
|
||||
# Test that params remain valid after consecutive updates.
|
||||
gln_params = initial_params
|
||||
|
||||
for _ in range(3):
|
||||
(_, gln_params), _ = update_fn(
|
||||
gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-4)
|
||||
|
||||
# Check updated weights layer-wise.
|
||||
for layer_idx in range(len(self._output_sizes)):
|
||||
name = "{}/~/{}_layer_{}".format(self._name, self._name, layer_idx)
|
||||
|
||||
initial_weights = initial_params[name]["weights"]
|
||||
new_weights = gln_params[name]["weights"]
|
||||
|
||||
# Shape consistency.
|
||||
self.assertEqual(new_weights.shape, initial_weights.shape)
|
||||
|
||||
# Check that different weights yield different predictions.
|
||||
new_predictions, _ = inference_fn(gln_params, gln_state, inputs,
|
||||
side_info)
|
||||
self.assertFalse(np.array_equal(new_predictions, initial_predictions))
|
||||
|
||||
def test_batch_consistency(self):
|
||||
"""Test consistency between online and batch updates."""
|
||||
|
||||
batch_size = 3
|
||||
inputs, side_info, targets = _get_dataset(10, batch_size)
|
||||
|
||||
# Initialize network.
|
||||
gln_params, gln_state = self._batch_init_fn(
|
||||
next(self._rng), inputs, side_info)
|
||||
test_layer = "{}/~/{}_layer_0".format(self._name, self._name)
|
||||
|
||||
for _ in range(10):
|
||||
|
||||
# Update on full batch.
|
||||
(expected_predictions, expected_params), _ = self._batch_update_fn(
|
||||
gln_params, gln_state, inputs, side_info, targets, learning_rate=1e-3)
|
||||
|
||||
# Average updates across batch and check equivalence.
|
||||
accum_predictions = []
|
||||
accum_weights = []
|
||||
for inputs_, side_info_, targets_ in zip(inputs, side_info, targets):
|
||||
(predictions, params), _ = self._update_fn(
|
||||
gln_params,
|
||||
gln_state,
|
||||
inputs_,
|
||||
side_info_,
|
||||
targets_,
|
||||
learning_rate=1e-3)
|
||||
accum_predictions.append(predictions)
|
||||
accum_weights.append(params[test_layer]["weights"])
|
||||
|
||||
# Check prediction equivalence.
|
||||
actual_predictions = np.stack(accum_predictions, axis=0)
|
||||
np.testing.assert_array_almost_equal(actual_predictions,
|
||||
expected_predictions)
|
||||
|
||||
# Check weight equivalence.
|
||||
actual_weights = np.mean(np.stack(accum_weights, axis=0), axis=0)
|
||||
expected_weights = expected_params[test_layer]["weights"]
|
||||
np.testing.assert_array_almost_equal(actual_weights, expected_weights)
|
||||
|
||||
gln_params = expected_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
Reference in New Issue
Block a user