Enable Travis tests for gated_linear_networks

* Restrict the number of training steps to 2k so that the test finishes faster
* General cleanup of flags

PiperOrigin-RevId: 368228930
This commit is contained in:
Alistair Muldal
2021-04-13 17:18:48 +01:00
committed by Diego de Las Casas
parent 582b26587f
commit d4a9a684cb
3 changed files with 53 additions and 16 deletions
@@ -25,21 +25,53 @@ import rlax
from gated_linear_networks import bernoulli
from gated_linear_networks.examples import utils
FLAGS = flags.FLAGS
MAX_TRAIN_STEPS = flags.DEFINE_integer(
name='max_train_steps',
default=None,
help='Maximum number of training steps to perform (None=no limit)',
)
# Small example network, achieves ~95% test set accuracy =======================
# Network parameters.
flags.DEFINE_integer('num_layers', 2, '')
flags.DEFINE_integer('neurons_per_layer', 100, '')
flags.DEFINE_integer('context_dim', 1, '')
NUM_LAYERS = flags.DEFINE_integer(
name='num_layers',
default=2,
help='Number of network layers',
)
NEURONS_PER_LAYER = flags.DEFINE_integer(
name='neurons_per_layer',
default=100,
help='Number of neurons per layer',
)
CONTEXT_DIM = flags.DEFINE_integer(
name='context_dim',
default=1,
help='Context vector size',
)
# Learning rate schedule.
flags.DEFINE_float('max_lr', 0.003, '')
flags.DEFINE_float('lr_constant', 1.0, '')
flags.DEFINE_float('lr_decay', 0.1, '')
MAX_LR = flags.DEFINE_float(
name='max_lr',
default=0.003,
help='Maximum learning rate',
)
LR_CONSTANT = flags.DEFINE_float(
name='lr_constant',
default=1.0,
help='Learning rate constant parameter',
)
LR_DECAY = flags.DEFINE_float(
name='lr_decay',
default=0.1,
help='Learning rate decay parameter',
)
# Logging parameters.
flags.DEFINE_integer('evaluate_every', 1000, '')
EVALUATE_EVERY = flags.DEFINE_integer(
name='evaluate_every',
default=1000,
help='Number of training steps per evaluation epoch',
)
def main(unused_argv):
@@ -58,9 +90,9 @@ def main(unused_argv):
def network_factory():
def gln_factory():
output_sizes = [FLAGS.neurons_per_layer] * FLAGS.num_layers + [1]
output_sizes = [NEURONS_PER_LAYER.value] * NUM_LAYERS.value + [1]
return bernoulli.GatedLinearNetwork(
output_sizes=output_sizes, context_dim=FLAGS.context_dim)
output_sizes=output_sizes, context_dim=CONTEXT_DIM.value)
return bernoulli.LastNeuronAggregator(gln_factory)
@@ -104,7 +136,7 @@ def main(unused_argv):
# Learning rate schedules.
learning_rate = jnp.minimum(
FLAGS.max_lr, FLAGS.lr_constant / (1. + FLAGS.lr_decay * step))
MAX_LR.value, LR_CONSTANT.value / (1. + LR_DECAY.value * step))
# Update weights and report log-loss.
targets = hk.one_hot(jnp.asarray(label), num_classes)
@@ -128,7 +160,7 @@ def main(unused_argv):
)
# Evaluate on test split ===================================================
if not step % FLAGS.evaluate_every:
if not step % EVALUATE_EVERY.value:
batch_accuracy = jax.vmap(accuracy, in_axes=(None, None, 0, 0))
accuracies = batch_accuracy(params, state, test_images, test_labels)
total_accuracy = float(jnp.mean(accuracies))
@@ -139,6 +171,9 @@ def main(unused_argv):
'accuracy': float(total_accuracy),
})
if MAX_TRAIN_STEPS.value is not None and step >= MAX_TRAIN_STEPS.value:
return
if __name__ == '__main__':
app.run(main)
Regular → Executable
+5 -4
View File
@@ -14,13 +14,14 @@
# limitations under the License.
set -e
python3 -m venv gln_venv
source gln_venv/bin/activate
pip3 install --upgrade setuptools wheel
python3 -m venv /tmp/gln_venv
source /tmp/gln_venv/bin/activate
pip3 install --upgrade pip setuptools wheel
pip3 install -r gated_linear_networks/requirements.txt
# Run MNIST example with Bernoulli GLN
python3 -m gated_linear_networks.examples.bernoulli_mnist \
--num_layers=2 \
--neurons_per_layer=100 \
--context_dim=1
--context_dim=1 \
--max_train_steps=2000