mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 12:37:43 +08:00
Enable Travis tests for gated_linear_networks
* Restrict the number of training steps to 2k so that the test finishes faster * General cleanup of flags PiperOrigin-RevId: 368228930
This commit is contained in:
committed by
Diego de Las Casas
parent
582b26587f
commit
d4a9a684cb
@@ -11,6 +11,7 @@ env:
|
||||
- PROJECT="adversarial_robustness"
|
||||
- PROJECT="avae"
|
||||
# - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable
|
||||
- PROJECT="gated_linear_networks"
|
||||
- PROJECT="iodine"
|
||||
- PROJECT="kfac_ferminet_alpha"
|
||||
- PROJECT="learning_to_simulate"
|
||||
|
||||
@@ -25,21 +25,53 @@ import rlax
|
||||
from gated_linear_networks import bernoulli
|
||||
from gated_linear_networks.examples import utils
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
MAX_TRAIN_STEPS = flags.DEFINE_integer(
|
||||
name='max_train_steps',
|
||||
default=None,
|
||||
help='Maximum number of training steps to perform (None=no limit)',
|
||||
)
|
||||
|
||||
# Small example network, achieves ~95% test set accuracy =======================
|
||||
# Network parameters.
|
||||
flags.DEFINE_integer('num_layers', 2, '')
|
||||
flags.DEFINE_integer('neurons_per_layer', 100, '')
|
||||
flags.DEFINE_integer('context_dim', 1, '')
|
||||
NUM_LAYERS = flags.DEFINE_integer(
|
||||
name='num_layers',
|
||||
default=2,
|
||||
help='Number of network layers',
|
||||
)
|
||||
NEURONS_PER_LAYER = flags.DEFINE_integer(
|
||||
name='neurons_per_layer',
|
||||
default=100,
|
||||
help='Number of neurons per layer',
|
||||
)
|
||||
CONTEXT_DIM = flags.DEFINE_integer(
|
||||
name='context_dim',
|
||||
default=1,
|
||||
help='Context vector size',
|
||||
)
|
||||
|
||||
# Learning rate schedule.
|
||||
flags.DEFINE_float('max_lr', 0.003, '')
|
||||
flags.DEFINE_float('lr_constant', 1.0, '')
|
||||
flags.DEFINE_float('lr_decay', 0.1, '')
|
||||
MAX_LR = flags.DEFINE_float(
|
||||
name='max_lr',
|
||||
default=0.003,
|
||||
help='Maximum learning rate',
|
||||
)
|
||||
LR_CONSTANT = flags.DEFINE_float(
|
||||
name='lr_constant',
|
||||
default=1.0,
|
||||
help='Learning rate constant parameter',
|
||||
)
|
||||
LR_DECAY = flags.DEFINE_float(
|
||||
name='lr_decay',
|
||||
default=0.1,
|
||||
help='Learning rate decay parameter',
|
||||
)
|
||||
|
||||
# Logging parameters.
|
||||
flags.DEFINE_integer('evaluate_every', 1000, '')
|
||||
EVALUATE_EVERY = flags.DEFINE_integer(
|
||||
name='evaluate_every',
|
||||
default=1000,
|
||||
help='Number of training steps per evaluation epoch',
|
||||
)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
@@ -58,9 +90,9 @@ def main(unused_argv):
|
||||
def network_factory():
|
||||
|
||||
def gln_factory():
|
||||
output_sizes = [FLAGS.neurons_per_layer] * FLAGS.num_layers + [1]
|
||||
output_sizes = [NEURONS_PER_LAYER.value] * NUM_LAYERS.value + [1]
|
||||
return bernoulli.GatedLinearNetwork(
|
||||
output_sizes=output_sizes, context_dim=FLAGS.context_dim)
|
||||
output_sizes=output_sizes, context_dim=CONTEXT_DIM.value)
|
||||
|
||||
return bernoulli.LastNeuronAggregator(gln_factory)
|
||||
|
||||
@@ -104,7 +136,7 @@ def main(unused_argv):
|
||||
|
||||
# Learning rate schedules.
|
||||
learning_rate = jnp.minimum(
|
||||
FLAGS.max_lr, FLAGS.lr_constant / (1. + FLAGS.lr_decay * step))
|
||||
MAX_LR.value, LR_CONSTANT.value / (1. + LR_DECAY.value * step))
|
||||
|
||||
# Update weights and report log-loss.
|
||||
targets = hk.one_hot(jnp.asarray(label), num_classes)
|
||||
@@ -128,7 +160,7 @@ def main(unused_argv):
|
||||
)
|
||||
|
||||
# Evaluate on test split ===================================================
|
||||
if not step % FLAGS.evaluate_every:
|
||||
if not step % EVALUATE_EVERY.value:
|
||||
batch_accuracy = jax.vmap(accuracy, in_axes=(None, None, 0, 0))
|
||||
accuracies = batch_accuracy(params, state, test_images, test_labels)
|
||||
total_accuracy = float(jnp.mean(accuracies))
|
||||
@@ -139,6 +171,9 @@ def main(unused_argv):
|
||||
'accuracy': float(total_accuracy),
|
||||
})
|
||||
|
||||
if MAX_TRAIN_STEPS.value is not None and step >= MAX_TRAIN_STEPS.value:
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
|
||||
Regular → Executable
+5
-4
@@ -14,13 +14,14 @@
|
||||
# limitations under the License.
|
||||
set -e
|
||||
|
||||
python3 -m venv gln_venv
|
||||
source gln_venv/bin/activate
|
||||
pip3 install --upgrade setuptools wheel
|
||||
python3 -m venv /tmp/gln_venv
|
||||
source /tmp/gln_venv/bin/activate
|
||||
pip3 install --upgrade pip setuptools wheel
|
||||
pip3 install -r gated_linear_networks/requirements.txt
|
||||
|
||||
# Run MNIST example with Bernoulli GLN
|
||||
python3 -m gated_linear_networks.examples.bernoulli_mnist \
|
||||
--num_layers=2 \
|
||||
--neurons_per_layer=100 \
|
||||
--context_dim=1
|
||||
--context_dim=1 \
|
||||
--max_train_steps=2000
|
||||
|
||||
Reference in New Issue
Block a user