mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-30 04:05:27 +08:00
Enable Travis tests for gated_linear_networks
* Restrict the number of training steps to 2k so that the test finishes faster * General cleanup of flags PiperOrigin-RevId: 368228930
This commit is contained in:
committed by
Diego de Las Casas
parent
582b26587f
commit
d4a9a684cb
@@ -11,6 +11,7 @@ env:
|
|||||||
- PROJECT="adversarial_robustness"
|
- PROJECT="adversarial_robustness"
|
||||||
- PROJECT="avae"
|
- PROJECT="avae"
|
||||||
# - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable
|
# - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable
|
||||||
|
- PROJECT="gated_linear_networks"
|
||||||
- PROJECT="iodine"
|
- PROJECT="iodine"
|
||||||
- PROJECT="kfac_ferminet_alpha"
|
- PROJECT="kfac_ferminet_alpha"
|
||||||
- PROJECT="learning_to_simulate"
|
- PROJECT="learning_to_simulate"
|
||||||
|
|||||||
@@ -25,21 +25,53 @@ import rlax
|
|||||||
from gated_linear_networks import bernoulli
|
from gated_linear_networks import bernoulli
|
||||||
from gated_linear_networks.examples import utils
|
from gated_linear_networks.examples import utils
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
MAX_TRAIN_STEPS = flags.DEFINE_integer(
|
||||||
|
name='max_train_steps',
|
||||||
|
default=None,
|
||||||
|
help='Maximum number of training steps to perform (None=no limit)',
|
||||||
|
)
|
||||||
|
|
||||||
# Small example network, achieves ~95% test set accuracy =======================
|
# Small example network, achieves ~95% test set accuracy =======================
|
||||||
# Network parameters.
|
# Network parameters.
|
||||||
flags.DEFINE_integer('num_layers', 2, '')
|
NUM_LAYERS = flags.DEFINE_integer(
|
||||||
flags.DEFINE_integer('neurons_per_layer', 100, '')
|
name='num_layers',
|
||||||
flags.DEFINE_integer('context_dim', 1, '')
|
default=2,
|
||||||
|
help='Number of network layers',
|
||||||
|
)
|
||||||
|
NEURONS_PER_LAYER = flags.DEFINE_integer(
|
||||||
|
name='neurons_per_layer',
|
||||||
|
default=100,
|
||||||
|
help='Number of neurons per layer',
|
||||||
|
)
|
||||||
|
CONTEXT_DIM = flags.DEFINE_integer(
|
||||||
|
name='context_dim',
|
||||||
|
default=1,
|
||||||
|
help='Context vector size',
|
||||||
|
)
|
||||||
|
|
||||||
# Learning rate schedule.
|
# Learning rate schedule.
|
||||||
flags.DEFINE_float('max_lr', 0.003, '')
|
MAX_LR = flags.DEFINE_float(
|
||||||
flags.DEFINE_float('lr_constant', 1.0, '')
|
name='max_lr',
|
||||||
flags.DEFINE_float('lr_decay', 0.1, '')
|
default=0.003,
|
||||||
|
help='Maximum learning rate',
|
||||||
|
)
|
||||||
|
LR_CONSTANT = flags.DEFINE_float(
|
||||||
|
name='lr_constant',
|
||||||
|
default=1.0,
|
||||||
|
help='Learning rate constant parameter',
|
||||||
|
)
|
||||||
|
LR_DECAY = flags.DEFINE_float(
|
||||||
|
name='lr_decay',
|
||||||
|
default=0.1,
|
||||||
|
help='Learning rate decay parameter',
|
||||||
|
)
|
||||||
|
|
||||||
# Logging parameters.
|
# Logging parameters.
|
||||||
flags.DEFINE_integer('evaluate_every', 1000, '')
|
EVALUATE_EVERY = flags.DEFINE_integer(
|
||||||
|
name='evaluate_every',
|
||||||
|
default=1000,
|
||||||
|
help='Number of training steps per evaluation epoch',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(unused_argv):
|
def main(unused_argv):
|
||||||
@@ -58,9 +90,9 @@ def main(unused_argv):
|
|||||||
def network_factory():
|
def network_factory():
|
||||||
|
|
||||||
def gln_factory():
|
def gln_factory():
|
||||||
output_sizes = [FLAGS.neurons_per_layer] * FLAGS.num_layers + [1]
|
output_sizes = [NEURONS_PER_LAYER.value] * NUM_LAYERS.value + [1]
|
||||||
return bernoulli.GatedLinearNetwork(
|
return bernoulli.GatedLinearNetwork(
|
||||||
output_sizes=output_sizes, context_dim=FLAGS.context_dim)
|
output_sizes=output_sizes, context_dim=CONTEXT_DIM.value)
|
||||||
|
|
||||||
return bernoulli.LastNeuronAggregator(gln_factory)
|
return bernoulli.LastNeuronAggregator(gln_factory)
|
||||||
|
|
||||||
@@ -104,7 +136,7 @@ def main(unused_argv):
|
|||||||
|
|
||||||
# Learning rate schedules.
|
# Learning rate schedules.
|
||||||
learning_rate = jnp.minimum(
|
learning_rate = jnp.minimum(
|
||||||
FLAGS.max_lr, FLAGS.lr_constant / (1. + FLAGS.lr_decay * step))
|
MAX_LR.value, LR_CONSTANT.value / (1. + LR_DECAY.value * step))
|
||||||
|
|
||||||
# Update weights and report log-loss.
|
# Update weights and report log-loss.
|
||||||
targets = hk.one_hot(jnp.asarray(label), num_classes)
|
targets = hk.one_hot(jnp.asarray(label), num_classes)
|
||||||
@@ -128,7 +160,7 @@ def main(unused_argv):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate on test split ===================================================
|
# Evaluate on test split ===================================================
|
||||||
if not step % FLAGS.evaluate_every:
|
if not step % EVALUATE_EVERY.value:
|
||||||
batch_accuracy = jax.vmap(accuracy, in_axes=(None, None, 0, 0))
|
batch_accuracy = jax.vmap(accuracy, in_axes=(None, None, 0, 0))
|
||||||
accuracies = batch_accuracy(params, state, test_images, test_labels)
|
accuracies = batch_accuracy(params, state, test_images, test_labels)
|
||||||
total_accuracy = float(jnp.mean(accuracies))
|
total_accuracy = float(jnp.mean(accuracies))
|
||||||
@@ -139,6 +171,9 @@ def main(unused_argv):
|
|||||||
'accuracy': float(total_accuracy),
|
'accuracy': float(total_accuracy),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if MAX_TRAIN_STEPS.value is not None and step >= MAX_TRAIN_STEPS.value:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(main)
|
app.run(main)
|
||||||
|
|||||||
Regular → Executable
+5
-4
@@ -14,13 +14,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
python3 -m venv gln_venv
|
python3 -m venv /tmp/gln_venv
|
||||||
source gln_venv/bin/activate
|
source /tmp/gln_venv/bin/activate
|
||||||
pip3 install --upgrade setuptools wheel
|
pip3 install --upgrade pip setuptools wheel
|
||||||
pip3 install -r gated_linear_networks/requirements.txt
|
pip3 install -r gated_linear_networks/requirements.txt
|
||||||
|
|
||||||
# Run MNIST example with Bernoulli GLN
|
# Run MNIST example with Bernoulli GLN
|
||||||
python3 -m gated_linear_networks.examples.bernoulli_mnist \
|
python3 -m gated_linear_networks.examples.bernoulli_mnist \
|
||||||
--num_layers=2 \
|
--num_layers=2 \
|
||||||
--neurons_per_layer=100 \
|
--neurons_per_layer=100 \
|
||||||
--context_dim=1
|
--context_dim=1 \
|
||||||
|
--max_train_steps=2000
|
||||||
|
|||||||
Reference in New Issue
Block a user