# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== r"""Basic Jaxline ImageNet experiment.""" import importlib import sys from absl import flags from absl import logging import haiku as hk import jax import jax.numpy as jnp from jaxline import base_config from jaxline import experiment from jaxline import platform from jaxline import utils as jl_utils from ml_collections import config_dict import numpy as np from nfnets import dataset from nfnets import optim from nfnets import utils # pylint: disable=logging-format-interpolation FLAGS = flags.FLAGS # We define the experiment launch config in the same file as the experiment to # keep things self-contained in a single file, but one might consider moving the # config and/or sweep functions to a separate file, if necessary. def get_config(): """Return config object for training.""" config = base_config.get_base_config() # Experiment config. train_batch_size = 1024 # Global batch size. images_per_epoch = 1281167 num_epochs = 90 steps_per_epoch = images_per_epoch / train_batch_size config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size) config.random_seed = 0 config.experiment_kwargs = config_dict.ConfigDict( dict( config=dict( lr=0.1, num_epochs=num_epochs, label_smoothing=0.1, model='ResNet', image_size=224, use_ema=False, ema_decay=0.9999, # Quatros nuevos amigos ema_start=0, which_ema='tf1_ema', augment_name=None, # 'mixup_cutmix', augment_before_mix=True, eval_preproc='crop_resize', train_batch_size=train_batch_size, eval_batch_size=50, eval_subset='test', num_classes=1000, which_dataset='imagenet', fake_data=False, which_loss='softmax_cross_entropy', # For now, must be softmax transpose=True, # Use the double-transpose trick? bfloat16=False, lr_schedule=dict( name='WarmupCosineDecay', kwargs=dict(num_steps=config.training_steps, start_val=0, min_val=0, warmup_steps=5*steps_per_epoch), ), lr_scale_by_bs=True, optimizer=dict( name='SGD', kwargs={'momentum': 0.9, 'nesterov': True, 'weight_decay': 1e-4,}, ), model_kwargs=dict( width=4, which_norm='BatchNorm', norm_kwargs=dict(create_scale=True, create_offset=True, decay_rate=0.9, ), # cross_replica_axis='i'), variant='ResNet50', activation='relu', drop_rate=0.0, ), ),)) # Training loop config: log and checkpoint every minute config.log_train_data_interval = 60 config.log_tensors_interval = 60 config.save_checkpoint_interval = 60 config.eval_specific_checkpoint_dir = '' return config class Experiment(experiment.AbstractExperiment): """Imagenet experiment.""" CHECKPOINT_ATTRS = { '_params': 'params', '_state': 'state', '_ema_params': 'ema_params', '_ema_state': 'ema_state', '_opt_state': 'opt_state', } def __init__(self, mode, config, init_rng): super().__init__(mode=mode) self.mode = mode self.config = config self.init_rng = init_rng # Checkpointed experiment state. self._params = None self._state = None self._ema_params = None self._ema_state = None self._opt_state = None # Input pipelines. self._train_input = None self._eval_input = None # Get model, loaded in from the zoo self.model_module = importlib.import_module( ('nfnets.'+ self.config.model.lower())) self.net = hk.transform_with_state(self._forward_fn) # Assign image sizes if self.config.get('override_imsize', False): self.train_imsize = self.config.image_size self.test_imsize = self.config.get('eval_image_size', self.train_imsize) else: variant_dict = getattr(self.model_module, self.config.model).variant_dict variant_dict = variant_dict[self.config.model_kwargs.variant] self.train_imsize = variant_dict.get('train_imsize', self.config.image_size) # Test imsize defaults to model-specific value, then to config imsize test_imsize = self.config.get('eval_image_size', self.config.image_size) self.test_imsize = variant_dict.get('test_imsize', test_imsize) donate_argnums = (0, 1, 2, 6, 7) if self.config.use_ema else (0, 1, 2) self.train_fn = jax.pmap(self._train_fn, axis_name='i', donate_argnums=donate_argnums) self.eval_fn = jax.pmap(self._eval_fn, axis_name='i') def _initialize_train(self): self._train_input = self._build_train_input() # Initialize net and EMA copy of net if no params available. if self._params is None: inputs = next(self._train_input) init_net = jax.pmap(lambda *a: self.net.init(*a, is_training=True), axis_name='i') init_rng = jl_utils.bcast_local_devices(self.init_rng) self._params, self._state = init_net(init_rng, inputs) if self.config.use_ema: self._ema_params, self._ema_state = init_net(init_rng, inputs) num_params = hk.data_structures.tree_size(self._params) logging.info(f'Net parameters: {num_params / jax.local_device_count()}') self._make_opt() def _make_opt(self): # Separate conv params and gains/biases def pred(mod, name, val): # pylint:disable=unused-argument return (name in ['scale', 'offset', 'b'] or 'gain' in name or 'bias' in name) gains_biases, weights = hk.data_structures.partition(pred, self._params) # Lr schedule with batch-based LR scaling if self.config.lr_scale_by_bs: max_lr = (self.config.lr * self.config.train_batch_size) / 256 else: max_lr = self.config.lr lr_sched_fn = getattr(optim, self.config.lr_schedule.name) lr_schedule = lr_sched_fn(max_val=max_lr, **self.config.lr_schedule.kwargs) # Optimizer; no need to broadcast! opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()} opt_kwargs['lr'] = lr_schedule opt_module = getattr(optim, self.config.optimizer.name) self.opt = opt_module([{'params': gains_biases, 'weight_decay': None}, {'params': weights}], **opt_kwargs) if self._opt_state is None: self._opt_state = self.opt.states() else: self.opt.plugin(self._opt_state) def _forward_fn(self, inputs, is_training): net_kwargs = {'num_classes': self.config.num_classes, **self.config.model_kwargs} net = getattr(self.model_module, self.config.model)(**net_kwargs) if self.config.get('transpose', False): images = jnp.transpose(inputs['images'], (3, 0, 1, 2)) # HWCN -> NHWC else: images = inputs['images'] if self.config.bfloat16 and self.mode == 'train': images = utils.to_bf16(images) return net(images, is_training=is_training)['logits'] def _one_hot(self, value): """One-hot encoding potentially over a sequence of labels.""" y = jax.nn.one_hot(value, self.config.num_classes) return y def _loss_fn(self, params, state, inputs, rng): logits, state = self.net.apply(params, state, rng, inputs, is_training=True) y = self._one_hot(inputs['labels']) if 'mix_labels' in inputs: # Handle cutmix/mixup label mixing logging.info('Using mixup or cutmix!') y1 = self._one_hot(inputs['mix_labels']) y = inputs['ratio'][:, None] * y + (1. - inputs['ratio'][:, None]) * y1 if self.config.label_smoothing > 0: # get smoothy spositives = 1. - self.config.label_smoothing snegatives = self.config.label_smoothing / self.config.num_classes y = spositives * y + snegatives if self.config.bfloat16: # Cast logits to float32 logits = logits.astype(jnp.float32) which_loss = getattr(utils, self.config.which_loss) loss = which_loss(logits, y, reduction='mean') metrics = utils.topk_correct(logits, inputs['labels'], prefix='train_') # Average top-1 and top-5 correct labels metrics = jax.tree_map(jnp.mean, metrics) metrics['train_loss'] = loss # Metrics will be pmeaned so don't divide here scaled_loss = loss / jax.device_count() # Grads get psummed so do divide return scaled_loss, (metrics, state) def _train_fn(self, params, states, opt_states, inputs, rng, global_step, ema_params, ema_states): """Runs one batch forward + backward and run a single opt step.""" grad_fn = jax.grad(self._loss_fn, argnums=0, has_aux=True) if self.config.bfloat16: in_params, states = jax.tree_map(utils.to_bf16, (params, states)) else: in_params = params grads, (metrics, states) = grad_fn(in_params, states, inputs, rng) if self.config.bfloat16: states, metrics, grads = jax.tree_map(utils.from_bf16, (states, metrics, grads)) # Sum gradients and average losses for pmap grads = jax.lax.psum(grads, 'i') metrics = jax.lax.pmean(metrics, 'i') # Compute updates and update parameters metrics['learning_rate'] = self.opt._hyperparameters['lr'](global_step) # pylint: disable=protected-access params, opt_states = self.opt.step(params, grads, opt_states, global_step) if ema_params is not None: ema_fn = getattr(utils, self.config.get('which_ema', 'tf1_ema')) ema = lambda x, y: ema_fn(x, y, self.config.ema_decay, global_step) ema_params = jax.tree_multimap(ema, ema_params, params) ema_states = jax.tree_multimap(ema, ema_states, states) return {'params': params, 'states': states, 'opt_states': opt_states, 'ema_params': ema_params, 'ema_states': ema_states, 'metrics': metrics} # _ _ # | |_ _ __ __ _(_)_ __ # | __| '__/ _` | | '_ \ # | |_| | | (_| | | | | | # \__|_| \__,_|_|_| |_| # def step(self, global_step, rng, *unused_args, **unused_kwargs): if self._train_input is None: self._initialize_train() inputs = next(self._train_input) out = self.train_fn(params=self._params, states=self._state, opt_states=self._opt_state, inputs=inputs, rng=rng, global_step=global_step, ema_params=self._ema_params, ema_states=self._ema_state) self._params, self._state = out['params'], out['states'] self._opt_state = out['opt_states'] self._ema_params, self._ema_state = out['ema_params'], out['ema_states'] self.opt.plugin(self._opt_state) return jl_utils.get_first(out['metrics']) def _build_train_input(self): num_devices = jax.device_count() global_batch_size = self.config.train_batch_size bs_per_device, ragged = divmod(global_batch_size, num_devices) if ragged: raise ValueError( f'Global batch size {global_batch_size} must be divisible by ' f'num devices {num_devices}') return dataset.load( dataset.Split.TRAIN_AND_VALID, is_training=True, batch_dims=[jax.local_device_count(), bs_per_device], transpose=self.config.get('transpose', False), image_size=(self.train_imsize,) * 2, augment_name=self.config.augment_name, augment_before_mix=self.config.get('augment_before_mix', True), name=self.config.which_dataset, fake_data=self.config.get('fake_data', False)) # _ # _____ ____ _| | # / _ \ \ / / _` | | # | __/\ V / (_| | | # \___| \_/ \__,_|_| # def evaluate(self, global_step, **unused_args): metrics = self._eval_epoch(self._params, self._state) if self.config.use_ema: ema_metrics = self._eval_epoch(self._ema_params, self._ema_state) metrics.update({f'ema_{key}': val for key, val in ema_metrics.items()}) logging.info(f'[Step {global_step}] Eval scalars: {metrics}') return metrics def _eval_epoch(self, params, state): """Evaluates an epoch.""" num_samples = 0. summed_metrics = None for inputs in self._build_eval_input(): num_samples += np.prod(inputs['labels'].shape[:2]) # Account for pmaps metrics = self.eval_fn(params, state, inputs) # Accumulate the sum of metrics for each step. metrics = jax.tree_map(lambda x: jnp.sum(x[0], axis=0), metrics) if summed_metrics is None: summed_metrics = metrics else: summed_metrics = jax.tree_multimap(jnp.add, summed_metrics, metrics) mean_metrics = jax.tree_map(lambda x: x / num_samples, summed_metrics) return jax.device_get(mean_metrics) def _eval_fn(self, params, state, inputs): """Evaluate a single batch and return loss and top-k acc.""" logits, _ = self.net.apply(params, state, None, inputs, is_training=False) y = self._one_hot(inputs['labels']) which_loss = getattr(utils, self.config.which_loss) loss = which_loss(logits, y, reduction=None) metrics = utils.topk_correct(logits, inputs['labels'], prefix='eval_') metrics['eval_loss'] = loss return jax.lax.psum(metrics, 'i') def _build_eval_input(self): """Builds the evaluation input pipeline.""" bs_per_device = (self.config.eval_batch_size // jax.local_device_count()) split = dataset.Split.from_string(self.config.eval_subset) eval_preproc = self.config.get('eval_preproc', 'crop_resize') return dataset.load(split, is_training=False, batch_dims=[jax.local_device_count(), bs_per_device], transpose=self.config.get('transpose', False), image_size=(self.test_imsize,) * 2, name=self.config.which_dataset, eval_preproc=eval_preproc, fake_data=self.config.get('fake_data', False)) if __name__ == '__main__': flags.mark_flag_as_required('config') platform.main(Experiment, sys.argv[1:])