Files
deepmind-research/perceiver/train/experiment.py
T
2023-06-02 18:03:40 +01:00

539 lines
18 KiB
Python

# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A reference training pipeline for Perceiver/Perceiver IO on ImageNet.
We use the Jaxline (https://github.com/deepmind/jaxline) training framework.
Two sets of hyperparameters are provided, the hyperparameters we used for the
Perceiver IO paper, and scaled-down hyperparameters for local testing.
This script should run out-of-the-box with the local hyper parameters.
The scaled-up hyperparameters requires a distributed learning setup to run,
and this script will need to be adapted to your specific setup.
"""
import functools
from typing import Generator, Mapping, Text, Tuple
from absl import app
from absl import flags
from absl import logging
import haiku as hk
import jax
import jax.numpy as jnp
from jaxline import base_config
from jaxline import experiment
from jaxline import platform
from jaxline import utils as jl_utils
from ml_collections import config_dict
import numpy as np
import optax
from perceiver import io_processors
from perceiver import perceiver
from perceiver.train import dataset
from perceiver.train import utils
FLAGS = flags.FLAGS
OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState]
Scalars = Mapping[Text, jnp.ndarray]
N_TRAIN_EXAMPLES = dataset.Split.TRAIN_AND_VALID.num_examples
N_CLASSES = 1000
# Only local/debug parameters are supported out of the box.
# To use the scaled-up hyperparameters, please adapt this script to your
# training setup and set this flag to False
IS_LOCAL = True
def get_training_steps(batch_size, n_epochs):
return (N_TRAIN_EXAMPLES * n_epochs) // batch_size
def get_config():
"""Return config object for training."""
use_debug_settings = IS_LOCAL
config = base_config.get_base_config()
# Experiment config.
local_batch_size = 2
# Modify this to adapt to your custom distributed learning setup
num_devices = 1
config.train_batch_size = local_batch_size * num_devices
config.n_epochs = 110
def _default_or_debug(default_value, debug_value):
return debug_value if use_debug_settings else default_value
n_train_examples = N_TRAIN_EXAMPLES
num_classes = N_CLASSES
config.experiment_kwargs = config_dict.ConfigDict(
dict(
config=dict(
optimizer=dict(
base_lr=5e-4,
max_norm=10.0, # < 0 to turn off.
schedule_type='constant_cosine',
weight_decay=1e-1,
decay_pos_embs=True,
scale_by_batch=True,
cosine_decay_kwargs=dict(
init_value=0.0,
warmup_epochs=0,
end_value=0.0,
),
step_decay_kwargs=dict(
decay_boundaries=[0.5, 0.8, 0.95],
decay_rate=0.1,
),
constant_cosine_decay_kwargs=dict(
constant_fraction=0.5,
end_value=0.0,
),
optimizer='lamb',
# Optimizer-specific kwargs:
adam_kwargs=dict(
b1=0.9,
b2=0.999,
eps=1e-8,
),
lamb_kwargs=dict(
b1=0.9,
b2=0.999,
eps=1e-6,
),
),
# Don't specify output_channels - it's not used for
# classifiers.
model=dict(
perceiver_kwargs=dict(
input_preprocessor=dict(
prep_type='pixels',
# Channels for conv/conv1x1 preprocessing:
num_channels=64,
# -------------------------
# Position encoding arguments:
# -------------------------
position_encoding_type='fourier',
concat_or_add_pos='concat',
spatial_downsample=1,
# If >0, project position to this size:
project_pos_dim=-1,
trainable_position_encoding_kwargs=dict(
num_channels=258, # Match default # for Fourier.
init_scale=0.02,
),
fourier_position_encoding_kwargs=dict(
num_bands=64,
max_resolution=(224, 224),
sine_only=False,
concat_pos=True,
),
),
encoder=dict(
num_self_attends_per_block=_default_or_debug(6, 2),
# Weights won't be shared if num_blocks is set to 1.
num_blocks=_default_or_debug(8, 2),
z_index_dim=512,
num_z_channels=1024,
num_cross_attend_heads=1,
num_self_attend_heads=8,
cross_attend_widening_factor=1,
self_attend_widening_factor=1,
dropout_prob=0.0,
# Position encoding for the latent array.
z_pos_enc_init_scale=0.02,
cross_attention_shape_for_attn='kv',
use_query_residual=True,
),
decoder=dict(
num_z_channels=1024,
use_query_residual=True,
# Position encoding for the output logits.
position_encoding_type='trainable',
trainable_position_encoding_kwargs=dict(
num_channels=1024,
init_scale=0.02,
),
),
),
),
training=dict(
images_per_epoch=n_train_examples,
label_smoothing=0.1,
n_epochs=config.get_oneway_ref('n_epochs'),
batch_size=config.get_oneway_ref('train_batch_size')
),
data=dict(
num_classes=num_classes,
# Run on smaller images to debug.
im_dim=_default_or_debug(224, 32),
augmentation=dict(
# Typical randaug params:
# num_layers in [1, 3]
# magnitude in [5, 30]
# Set randaugment to None to disable.
randaugment=dict(
num_layers=4,
magnitude=5),
cutmix=True,
# Mixup alpha should be in [0, 1].
# Set to None to disable.
mixup_alpha=0.2,
),
),
evaluation=dict(
subset='test',
batch_size=2,
),
)
)
)
# Training loop config.
config.training_steps = get_training_steps(
config.get_oneway_ref('train_batch_size'),
config.get_oneway_ref('n_epochs'))
config.log_train_data_interval = 60
config.log_tensors_interval = 60
config.save_checkpoint_interval = 300
config.eval_specific_checkpoint_dir = ''
config.best_model_eval_metric = 'eval_top_1_acc'
config.checkpoint_dir = '/tmp/perceiver_imagnet_checkpoints'
config.train_checkpoint_all_hosts = False
# Prevents accidentally setting keys that aren't recognized (e.g. in tests).
config.lock()
return config
class Experiment(experiment.AbstractExperiment):
"""ImageNet experiment."""
# A map from object properties that will be checkpointed to their name
# in a checkpoint. Currently we assume that these are all sharded
# device arrays.
CHECKPOINT_ATTRS = {
'_params': 'params',
'_state': 'state',
'_opt_state': 'opt_state',
}
def __init__(self, mode, init_rng, config):
"""Initializes experiment."""
super(Experiment, self).__init__(mode=mode, init_rng=init_rng)
self.mode = mode
self.init_rng = init_rng
self.config = config
# Checkpointed experiment state.
self._params = None
self._state = None
self._opt_state = None
# Input pipelines.
self._train_input = None
self._eval_input = None
self.forward = hk.transform_with_state(self._forward_fn)
# NOTE: We "donate" the `params, state, opt_state` arguments which allows
# JAX (on some backends) to reuse the device memory associated with these
# inputs to store the outputs of our function (which also start with
# `params, state, opt_state`).
self._update_func = jax.pmap(self._update_func, axis_name='i',
donate_argnums=(0, 1, 2))
self._eval_batch = jax.jit(self._eval_batch)
def _forward_fn(
self,
inputs: dataset.Batch,
is_training: bool,
) -> jnp.ndarray:
images = inputs['images']
perceiver_kwargs = self.config.model.perceiver_kwargs
input_preprocessor = io_processors.ImagePreprocessor(
**perceiver_kwargs['input_preprocessor'])
encoder = perceiver.PerceiverEncoder(**perceiver_kwargs['encoder'])
decoder = perceiver.ClassificationDecoder(
self.config.data.num_classes,
**perceiver_kwargs['decoder'])
model = perceiver.Perceiver(
encoder=encoder,
decoder=decoder,
input_preprocessor=input_preprocessor)
return model(images, is_training=is_training)
# _ _
# | |_ _ __ __ _(_)_ __
# | __| '__/ _` | | '_ \
# | |_| | | (_| | | | | |
# \__|_| \__,_|_|_| |_|
#
def step(self, global_step: int, rng: jnp.ndarray, # pytype: disable=signature-mismatch # jax-ndarray
*unused_args, **unused_kwargs):
"""See base class."""
if self._train_input is None:
self._initialize_train()
inputs = next(self._train_input)
self._params, self._state, self._opt_state, scalars = (
self._update_func(
self._params, self._state, self._opt_state, inputs, rng, global_step
))
scalars = jl_utils.get_first(scalars)
return scalars
def _initialize_train(self):
self._train_input = jl_utils.py_prefetch(self._build_train_input)
total_batch_size = self.config.training.batch_size
steps_per_epoch = (
self.config.training.images_per_epoch / self.config.training.batch_size)
total_steps = self.config.training.n_epochs * steps_per_epoch
# Scale by the (negative) learning rate.
self._lr_schedule = utils.get_learning_rate_schedule(
total_batch_size, steps_per_epoch, total_steps, self.config.optimizer)
self._optimizer = utils.make_optimizer(
self.config.optimizer,
self._lr_schedule)
# Check we haven't already restored params
if self._params is None:
logging.info('Initializing parameters.')
inputs = next(self._train_input)
init_net = jax.pmap(lambda *a: self.forward.init(*a, is_training=True))
init_opt = jax.pmap(self._optimizer.init)
# Init uses the same RNG key on all hosts+devices to ensure everyone
# computes the same initial state.
init_rng = jl_utils.bcast_local_devices(self.init_rng)
self._params, self._state = init_net(init_rng, inputs)
self._opt_state = init_opt(self._params)
def _load_data(self, split, is_training, batch_dims):
"""Wrapper for dataset loading."""
return dataset.load(
split=split,
is_training=is_training,
batch_dims=batch_dims,
im_dim=self.config.data.im_dim,
augmentation_settings=self.config.data.augmentation,
)
def _build_train_input(self) -> Generator[dataset.Batch, None, None]:
"""See base class."""
num_devices = jax.device_count()
global_batch_size = self.config.training.batch_size
per_device_batch_size, ragged = divmod(global_batch_size, num_devices)
if ragged:
raise ValueError(
f'Global batch size {global_batch_size} must be divisible by '
f'num devices {num_devices}')
split = dataset.Split.TRAIN_AND_VALID
return self._load_data(
split=split,
is_training=True,
batch_dims=[jax.local_device_count(), per_device_batch_size])
def _one_hot(self, value):
"""One-hot encoding potentially over a sequence of labels."""
y = jax.nn.one_hot(value, self.config.data.num_classes)
return y
def _loss_fn(
self,
params: hk.Params,
state: hk.State,
inputs: dataset.Batch,
rng: jnp.ndarray,
) -> Tuple[jnp.ndarray, Tuple[Scalars, hk.State]]:
logits, state = self.forward.apply(
params, state, rng, inputs, is_training=True)
label = self._one_hot(inputs['labels'])
# Handle cutmix/mixup label mixing:
if 'mix_labels' in inputs:
logging.info('Using mixup or cutmix!')
mix_label = self._one_hot(inputs['mix_labels'])
mix_ratio = inputs['ratio'][:, None]
label = mix_ratio * label + (1. - mix_ratio) * mix_label
# Apply label-smoothing to one-hot labels.
label_smoothing = self.config.training.label_smoothing
if not (label_smoothing >= 0. and label_smoothing < 1.):
raise ValueError(
f"'label_smoothing is {label_smoothing} and should be in [0, 1)")
if label_smoothing > 0:
smooth_positives = 1. - label_smoothing
smooth_negatives = label_smoothing / self.config.data.num_classes
label = smooth_positives * label + smooth_negatives
loss_w_batch = utils.softmax_cross_entropy(logits, label)
loss = jnp.mean(loss_w_batch, dtype=loss_w_batch.dtype)
scaled_loss = loss / jax.device_count()
metrics = utils.topk_correct(logits, inputs['labels'], prefix='')
metrics = jax.tree_map(jnp.mean, metrics)
top_1_acc = metrics['top_1_acc']
top_5_acc = metrics['top_5_acc']
loss_scalars = dict(
loss=loss,
top_1_acc=top_1_acc,
top_5_acc=top_5_acc,
)
return scaled_loss, (loss_scalars, state)
def _update_func(
self,
params: hk.Params,
state: hk.State,
opt_state: OptState,
inputs: dataset.Batch,
rng: jnp.ndarray,
global_step: int,
) -> Tuple[hk.Params, hk.State, OptState, Scalars]:
"""Applies an update to parameters and returns new state."""
# This function computes the gradient of the first output of loss_fn and
# passes through the other arguments unchanged.
grad_loss_fn = jax.grad(self._loss_fn, has_aux=True)
scaled_grads, (loss_scalars, state) = grad_loss_fn(
params, state, inputs, rng)
grads = jax.lax.psum(scaled_grads, axis_name='i')
# Grab the learning rate to log before performing the step.
learning_rate = self._lr_schedule(global_step)
# Compute and apply updates via our optimizer.
updates, opt_state = self._optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
n_params = 0
for k in params.keys(): # pytype: disable=attribute-error # numpy-scalars
for l in params[k]:
n_params = n_params + np.prod(params[k][l].shape) # pytype: disable=attribute-error # numpy-scalars
# Scalars to log (note: we log the mean across all hosts/devices).
scalars = {'learning_rate': learning_rate,
'n_params (M)': float(n_params/1e6),
'global_gradient_norm': optax.global_norm(grads)}
loss_scalars = {f'train_{k}': v for k, v in loss_scalars.items()}
scalars.update(loss_scalars)
scalars = jax.lax.pmean(scalars, axis_name='i')
return params, state, opt_state, scalars
# _
# _____ ____ _| |
# / _ \ \ / / _` | |
# | __/\ V / (_| | |
# \___| \_/ \__,_|_|
#
def evaluate(self, global_step, rng, **unused_args):
"""See base class."""
global_step = np.array(jl_utils.get_first(global_step))
scalars = jax.device_get(self._eval_epoch(jl_utils.get_first(rng)))
logging.info('[Step %d] Eval scalars: %s', global_step, scalars)
return scalars
def _eval_batch(
self,
params: hk.Params,
state: hk.State,
inputs: dataset.Batch,
rng: jnp.ndarray,
) -> Scalars:
"""Evaluates a batch."""
logits, _ = self.forward.apply(
params, state, rng, inputs, is_training=False)
labels = self._one_hot(inputs['labels'])
loss = utils.softmax_cross_entropy(logits, labels)
metrics = utils.topk_correct(logits, inputs['labels'], prefix='')
metrics = jax.tree_map(jnp.mean, metrics)
top_1_acc = metrics['top_1_acc']
top_5_acc = metrics['top_5_acc']
bs = logits.shape[0]
top_1_acc = jnp.expand_dims(top_1_acc, axis=0) * bs
top_5_acc = jnp.expand_dims(top_5_acc, axis=0) * bs
# NOTE: Returned values will be summed and finally divided by num_samples.
return {
'eval_loss': loss,
'eval_top_1_acc': top_1_acc, 'eval_top_5_acc': top_5_acc}
def _build_eval_input(self) -> Generator[dataset.Batch, None, None]:
split = dataset.Split.from_string(self.config.evaluation.subset)
return self._load_data(
split=split,
is_training=False,
batch_dims=[self.config.evaluation.batch_size])
def _eval_epoch(self, rng):
"""Evaluates an epoch."""
num_samples = 0.
summed_scalars = None
params = jl_utils.get_first(self._params)
state = jl_utils.get_first(self._state)
for inputs in self._build_eval_input():
num_samples += inputs['labels'].shape[0]
scalars = self._eval_batch(params, state, inputs, rng)
# Accumulate the sum of scalars for each step.
scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars)
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars
if __name__ == '__main__':
flags.mark_flag_as_required('config')
app.run(functools.partial(platform.main, Experiment))