mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-10 05:17:46 +08:00
923ad3cff0
PiperOrigin-RevId: 326141025
158 lines
5.5 KiB
Python
158 lines
5.5 KiB
Python
# Copyright 2020 DeepMind Technologies Limited.
|
|
#
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Training and evaluation loops for an experiment."""
|
|
|
|
import time
|
|
from typing import Any, Mapping, Text, Type, Union
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
from absl import logging
|
|
import jax
|
|
import numpy as np
|
|
|
|
from byol import byol_experiment
|
|
from byol import eval_experiment
|
|
from byol.configs import byol as byol_config
|
|
from byol.configs import eval as eval_config
|
|
|
|
flags.DEFINE_string('experiment_mode',
|
|
'pretrain', 'The experiment, pretrain or linear-eval')
|
|
flags.DEFINE_string('worker_mode', 'train', 'The mode, train or eval')
|
|
flags.DEFINE_string('worker_tpu_driver', '', 'The tpu driver to use')
|
|
flags.DEFINE_integer('pretrain_epochs', 1000, 'Number of pre-training epochs')
|
|
flags.DEFINE_integer('batch_size', 4096, 'Total batch size')
|
|
flags.DEFINE_string('checkpoint_root', '/tmp/byol',
|
|
'The directory to save checkpoints to.')
|
|
flags.DEFINE_integer('log_tensors_interval', 60, 'Log tensors every n seconds.')
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
Experiment = Union[
|
|
Type[byol_experiment.ByolExperiment],
|
|
Type[eval_experiment.EvalExperiment]]
|
|
|
|
|
|
def train_loop(experiment_class, config):
|
|
"""The main training loop.
|
|
|
|
This loop periodically saves a checkpoint to be evaluated in the eval_loop.
|
|
|
|
Args:
|
|
experiment_class: the constructor for the experiment (either byol_experiment
|
|
or eval_experiment).
|
|
config: the experiment config.
|
|
"""
|
|
experiment = experiment_class(**config)
|
|
rng = jax.random.PRNGKey(0)
|
|
step = 0
|
|
|
|
host_id = jax.host_id()
|
|
last_logging = time.time()
|
|
if config['checkpointing_config']['use_checkpointing']:
|
|
checkpoint_data = experiment.load_checkpoint()
|
|
if checkpoint_data is None:
|
|
step = 0
|
|
else:
|
|
step, rng = checkpoint_data
|
|
|
|
local_device_count = jax.local_device_count()
|
|
while step < config['max_steps']:
|
|
step_rng, rng = tuple(jax.random.split(rng))
|
|
# Broadcast the random seeds across the devices
|
|
step_rng_device = jax.random.split(step_rng, num=jax.device_count())
|
|
step_rng_device = step_rng_device[
|
|
host_id * local_device_count:(host_id + 1) * local_device_count]
|
|
step_device = np.broadcast_to(step, [local_device_count])
|
|
|
|
# Perform a training step and get scalars to log.
|
|
scalars = experiment.step(global_step=step_device, rng=step_rng_device)
|
|
|
|
# Checkpointing and logging.
|
|
if config['checkpointing_config']['use_checkpointing']:
|
|
experiment.save_checkpoint(step, rng)
|
|
current_time = time.time()
|
|
if current_time - last_logging > FLAGS.log_tensors_interval:
|
|
logging.info('Step %d: %s', step, scalars)
|
|
last_logging = current_time
|
|
step += 1
|
|
logging.info('Saving final checkpoint')
|
|
logging.info('Step %d: %s', step, scalars)
|
|
experiment.save_checkpoint(step, rng)
|
|
|
|
|
|
def eval_loop(experiment_class, config):
|
|
"""The main evaluation loop.
|
|
|
|
This loop periodically loads a checkpoint and evaluates its performance on the
|
|
test set, by calling experiment.evaluate.
|
|
|
|
Args:
|
|
experiment_class: the constructor for the experiment (either byol_experiment
|
|
or eval_experiment).
|
|
config: the experiment config.
|
|
"""
|
|
experiment = experiment_class(**config)
|
|
last_evaluated_step = -1
|
|
while True:
|
|
checkpoint_data = experiment.load_checkpoint()
|
|
if checkpoint_data is None:
|
|
logging.info('No checkpoint found. Waiting for 10s.')
|
|
time.sleep(10)
|
|
continue
|
|
step, _ = checkpoint_data
|
|
if step <= last_evaluated_step:
|
|
logging.info('Checkpoint at step %d already evaluated, waiting.', step)
|
|
time.sleep(10)
|
|
continue
|
|
host_id = jax.host_id()
|
|
local_device_count = jax.local_device_count()
|
|
step_device = np.broadcast_to(step, [local_device_count])
|
|
scalars = experiment.evaluate(global_step=step_device)
|
|
if host_id == 0: # Only perform logging in one host.
|
|
logging.info('Evaluation at step %d: %s', step, scalars)
|
|
last_evaluated_step = step
|
|
if last_evaluated_step >= config['max_steps']:
|
|
return
|
|
|
|
|
|
def main(_):
|
|
if FLAGS.worker_tpu_driver:
|
|
jax.config.update('jax_xla_backend', 'tpu_driver')
|
|
jax.config.update('jax_backend_target', FLAGS.worker_tpu_driver)
|
|
logging.info('Backend: %s %r', FLAGS.worker_tpu_driver, jax.devices())
|
|
|
|
if FLAGS.experiment_mode == 'pretrain':
|
|
experiment_class = byol_experiment.ByolExperiment
|
|
config = byol_config.get_config(FLAGS.pretrain_epochs, FLAGS.batch_size)
|
|
elif FLAGS.experiment_mode == 'linear-eval':
|
|
experiment_class = eval_experiment.EvalExperiment
|
|
config = eval_config.get_config(f'{FLAGS.checkpoint_root}/pretrain.pkl',
|
|
FLAGS.batch_size)
|
|
else:
|
|
raise ValueError(f'Unknown experiment mode: {FLAGS.experiment_mode}')
|
|
config['checkpointing_config']['checkpoint_dir'] = FLAGS.checkpoint_root
|
|
|
|
if FLAGS.worker_mode == 'train':
|
|
train_loop(experiment_class, config)
|
|
elif FLAGS.worker_mode == 'eval':
|
|
eval_loop(experiment_class, config)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
app.run(main)
|