diff --git a/.travis.yml b/.travis.yml index 1aa0850..d161e2d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,6 +11,7 @@ env: - PROJECT="adversarial_robustness" - PROJECT="avae" # - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable + - PROJECT="gated_linear_networks" - PROJECT="iodine" - PROJECT="kfac_ferminet_alpha" - PROJECT="learning_to_simulate" diff --git a/gated_linear_networks/examples/bernoulli_mnist.py b/gated_linear_networks/examples/bernoulli_mnist.py index 41fc304..fc49228 100644 --- a/gated_linear_networks/examples/bernoulli_mnist.py +++ b/gated_linear_networks/examples/bernoulli_mnist.py @@ -25,21 +25,53 @@ import rlax from gated_linear_networks import bernoulli from gated_linear_networks.examples import utils -FLAGS = flags.FLAGS +MAX_TRAIN_STEPS = flags.DEFINE_integer( + name='max_train_steps', + default=None, + help='Maximum number of training steps to perform (None=no limit)', +) # Small example network, achieves ~95% test set accuracy ======================= # Network parameters. -flags.DEFINE_integer('num_layers', 2, '') -flags.DEFINE_integer('neurons_per_layer', 100, '') -flags.DEFINE_integer('context_dim', 1, '') +NUM_LAYERS = flags.DEFINE_integer( + name='num_layers', + default=2, + help='Number of network layers', +) +NEURONS_PER_LAYER = flags.DEFINE_integer( + name='neurons_per_layer', + default=100, + help='Number of neurons per layer', +) +CONTEXT_DIM = flags.DEFINE_integer( + name='context_dim', + default=1, + help='Context vector size', +) # Learning rate schedule. -flags.DEFINE_float('max_lr', 0.003, '') -flags.DEFINE_float('lr_constant', 1.0, '') -flags.DEFINE_float('lr_decay', 0.1, '') +MAX_LR = flags.DEFINE_float( + name='max_lr', + default=0.003, + help='Maximum learning rate', +) +LR_CONSTANT = flags.DEFINE_float( + name='lr_constant', + default=1.0, + help='Learning rate constant parameter', +) +LR_DECAY = flags.DEFINE_float( + name='lr_decay', + default=0.1, + help='Learning rate decay parameter', +) # Logging parameters. -flags.DEFINE_integer('evaluate_every', 1000, '') +EVALUATE_EVERY = flags.DEFINE_integer( + name='evaluate_every', + default=1000, + help='Number of training steps per evaluation epoch', +) def main(unused_argv): @@ -58,9 +90,9 @@ def main(unused_argv): def network_factory(): def gln_factory(): - output_sizes = [FLAGS.neurons_per_layer] * FLAGS.num_layers + [1] + output_sizes = [NEURONS_PER_LAYER.value] * NUM_LAYERS.value + [1] return bernoulli.GatedLinearNetwork( - output_sizes=output_sizes, context_dim=FLAGS.context_dim) + output_sizes=output_sizes, context_dim=CONTEXT_DIM.value) return bernoulli.LastNeuronAggregator(gln_factory) @@ -104,7 +136,7 @@ def main(unused_argv): # Learning rate schedules. learning_rate = jnp.minimum( - FLAGS.max_lr, FLAGS.lr_constant / (1. + FLAGS.lr_decay * step)) + MAX_LR.value, LR_CONSTANT.value / (1. + LR_DECAY.value * step)) # Update weights and report log-loss. targets = hk.one_hot(jnp.asarray(label), num_classes) @@ -128,7 +160,7 @@ def main(unused_argv): ) # Evaluate on test split =================================================== - if not step % FLAGS.evaluate_every: + if not step % EVALUATE_EVERY.value: batch_accuracy = jax.vmap(accuracy, in_axes=(None, None, 0, 0)) accuracies = batch_accuracy(params, state, test_images, test_labels) total_accuracy = float(jnp.mean(accuracies)) @@ -139,6 +171,9 @@ def main(unused_argv): 'accuracy': float(total_accuracy), }) + if MAX_TRAIN_STEPS.value is not None and step >= MAX_TRAIN_STEPS.value: + return + if __name__ == '__main__': app.run(main) diff --git a/gated_linear_networks/run.sh b/gated_linear_networks/run.sh old mode 100644 new mode 100755 index cc4f394..d992a9d --- a/gated_linear_networks/run.sh +++ b/gated_linear_networks/run.sh @@ -14,13 +14,14 @@ # limitations under the License. set -e -python3 -m venv gln_venv -source gln_venv/bin/activate -pip3 install --upgrade setuptools wheel +python3 -m venv /tmp/gln_venv +source /tmp/gln_venv/bin/activate +pip3 install --upgrade pip setuptools wheel pip3 install -r gated_linear_networks/requirements.txt # Run MNIST example with Bernoulli GLN python3 -m gated_linear_networks.examples.bernoulli_mnist \ --num_layers=2 \ --neurons_per_layer=100 \ - --context_dim=1 + --context_dim=1 \ + --max_train_steps=2000