# Copyright 2020 DeepMind Technologies Limited. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Linear evaluation or fine-tuning pipeline. Use this experiment to evaluate a checkpoint from byol_experiment. """ import functools from typing import Any, Generator, Mapping, NamedTuple, Optional, Text, Tuple, Union from absl import logging from acme.jax import utils as acme_utils import haiku as hk import jax import jax.numpy as jnp import numpy as np import optax from byol.utils import checkpointing from byol.utils import dataset from byol.utils import helpers from byol.utils import networks from byol.utils import schedules # Type declarations. OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState] LogsDict = Mapping[Text, jnp.ndarray] class _EvalExperimentState(NamedTuple): backbone_params: hk.Params classif_params: hk.Params backbone_state: hk.State backbone_opt_state: Union[None, OptState] classif_opt_state: OptState class EvalExperiment: """Linear evaluation experiment.""" def __init__( self, random_seed: int, num_classes: int, batch_size: int, max_steps: int, enable_double_transpose: bool, checkpoint_to_evaluate: Optional[Text], allow_train_from_scratch: bool, freeze_backbone: bool, network_config: Mapping[Text, Any], optimizer_config: Mapping[Text, Any], lr_schedule_config: Mapping[Text, Any], evaluation_config: Mapping[Text, Any], checkpointing_config: Mapping[Text, Any]): """Constructs the experiment. Args: random_seed: the random seed to use when initializing network weights. num_classes: the number of classes; used for the online evaluation. batch_size: the total batch size; should be a multiple of the number of available accelerators. max_steps: the number of training steps; used for the lr/target network ema schedules. enable_double_transpose: see dataset.py; only has effect on TPU. checkpoint_to_evaluate: the path to the checkpoint to evaluate. allow_train_from_scratch: whether to allow training without specifying a checkpoint to evaluate (training from scratch). freeze_backbone: whether the backbone resnet should remain frozen (linear evaluation) or be trainable (fine-tuning). network_config: the configuration for the network. optimizer_config: the configuration for the optimizer. lr_schedule_config: the configuration for the learning rate schedule. evaluation_config: the evaluation configuration. checkpointing_config: the configuration for checkpointing. """ self._random_seed = random_seed self._enable_double_transpose = enable_double_transpose self._num_classes = num_classes self._lr_schedule_config = lr_schedule_config self._batch_size = batch_size self._max_steps = max_steps self._checkpoint_to_evaluate = checkpoint_to_evaluate self._allow_train_from_scratch = allow_train_from_scratch self._freeze_backbone = freeze_backbone self._optimizer_config = optimizer_config self._evaluation_config = evaluation_config # Checkpointed experiment state. self._experiment_state = None # Input pipelines. self._train_input = None self._eval_input = None backbone_fn = functools.partial(self._backbone_fn, **network_config) self.forward_backbone = hk.without_apply_rng( hk.transform_with_state(backbone_fn)) self.forward_classif = hk.without_apply_rng(hk.transform(self._classif_fn)) self.update_pmap = jax.pmap(self._update_func, axis_name='i') self.eval_batch_jit = jax.jit(self._eval_batch) self._is_backbone_training = not self._freeze_backbone self._checkpointer = checkpointing.Checkpointer(**checkpointing_config) def _should_transpose_images(self): """Should we transpose images (saves host-to-device time on TPUs).""" return (self._enable_double_transpose and jax.local_devices()[0].platform == 'tpu') def _backbone_fn( self, inputs: dataset.Batch, encoder_class: Text, encoder_config: Mapping[Text, Any], bn_decay_rate: float, is_training: bool, ) -> jnp.ndarray: """Forward of the encoder (backbone).""" bn_config = {'decay_rate': bn_decay_rate} encoder = getattr(networks, encoder_class) model = encoder( None, bn_config=bn_config, **encoder_config) if self._should_transpose_images(): inputs = dataset.transpose_images(inputs) images = dataset.normalize_images(inputs['images']) return model(images, is_training=is_training) def _classif_fn( self, embeddings: jnp.ndarray, ) -> jnp.ndarray: classifier = hk.Linear(output_size=self._num_classes) return classifier(embeddings) # _ _ # | |_ _ __ __ _(_)_ __ # | __| '__/ _` | | '_ \ # | |_| | | (_| | | | | | # \__|_| \__,_|_|_| |_| # def step(self, *, global_step: jnp.ndarray, rng: jnp.ndarray) -> Mapping[Text, np.ndarray]: """Performs a single training step.""" if self._train_input is None: self._initialize_train(rng) inputs = next(self._train_input) self._experiment_state, scalars = self.update_pmap( self._experiment_state, global_step, inputs) scalars = helpers.get_first(scalars) return scalars def save_checkpoint(self, step: int, rng: jnp.ndarray): self._checkpointer.maybe_save_checkpoint( self._experiment_state, step=step, rng=rng, is_final=step >= self._max_steps) def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]: checkpoint_data = self._checkpointer.maybe_load_checkpoint() if checkpoint_data is None: return None self._experiment_state, step, rng = checkpoint_data return step, rng def _initialize_train(self, rng): """BYOL's _ExperimentState initialization. Args: rng: random number generator used to initialize parameters. If working in a multi device setup, this need to be a ShardedArray. dummy_input: a dummy image, used to compute intermediate outputs shapes. Returns: Initial EvalExperiment state. Raises: RuntimeError: invalid or empty checkpoint. """ self._train_input = acme_utils.prefetch(self._build_train_input()) # Check we haven't already restored params if self._experiment_state is None: inputs = next(self._train_input) if self._checkpoint_to_evaluate is not None: # Load params from checkpoint checkpoint_data = checkpointing.load_checkpoint( self._checkpoint_to_evaluate) if checkpoint_data is None: raise RuntimeError('Invalid checkpoint.') backbone_params = checkpoint_data['experiment_state'].online_params backbone_state = checkpoint_data['experiment_state'].online_state backbone_params = helpers.bcast_local_devices(backbone_params) backbone_state = helpers.bcast_local_devices(backbone_state) else: if not self._allow_train_from_scratch: raise ValueError( 'No checkpoint specified, but `allow_train_from_scratch` ' 'set to False') # Initialize with random parameters logging.info( 'No checkpoint specified, initializing the networks from scratch ' '(dry run mode)') backbone_params, backbone_state = jax.pmap( functools.partial(self.forward_backbone.init, is_training=True), axis_name='i')(rng=rng, inputs=inputs) init_experiment = jax.pmap(self._make_initial_state, axis_name='i') # Init uses the same RNG key on all hosts+devices to ensure everyone # computes the same initial state and parameters. init_rng = jax.random.PRNGKey(self._random_seed) init_rng = helpers.bcast_local_devices(init_rng) self._experiment_state = init_experiment( rng=init_rng, dummy_input=inputs, backbone_params=backbone_params, backbone_state=backbone_state) # Clear the backbone optimizer's state when the backbone is frozen. if self._freeze_backbone: self._experiment_state = _EvalExperimentState( backbone_params=self._experiment_state.backbone_params, classif_params=self._experiment_state.classif_params, backbone_state=self._experiment_state.backbone_state, backbone_opt_state=None, classif_opt_state=self._experiment_state.classif_opt_state, ) def _make_initial_state( self, rng: jnp.ndarray, dummy_input: dataset.Batch, backbone_params: hk.Params, backbone_state: hk.Params, ) -> _EvalExperimentState: """_EvalExperimentState initialization.""" # Initialize the backbone params # Always create the batchnorm weights (is_training=True), they will be # overwritten when loading the checkpoint. embeddings, _ = self.forward_backbone.apply( backbone_params, backbone_state, dummy_input, is_training=True) backbone_opt_state = self._optimizer(0.).init(backbone_params) # Initialize the classifier params and optimizer_state classif_params = self.forward_classif.init(rng, embeddings) classif_opt_state = self._optimizer(0.).init(classif_params) return _EvalExperimentState( backbone_params=backbone_params, classif_params=classif_params, backbone_state=backbone_state, backbone_opt_state=backbone_opt_state, classif_opt_state=classif_opt_state, ) def _build_train_input(self) -> Generator[dataset.Batch, None, None]: """See base class.""" num_devices = jax.device_count() global_batch_size = self._batch_size per_device_batch_size, ragged = divmod(global_batch_size, num_devices) if ragged: raise ValueError( f'Global batch size {global_batch_size} must be divisible by ' f'num devices {num_devices}') return dataset.load( dataset.Split.TRAIN_AND_VALID, preprocess_mode=dataset.PreprocessMode.LINEAR_TRAIN, transpose=self._should_transpose_images(), batch_dims=[jax.local_device_count(), per_device_batch_size]) def _optimizer(self, learning_rate: float): """Build optimizer from config.""" return optax.sgd(learning_rate, **self._optimizer_config) def _loss_fn( self, backbone_params: hk.Params, classif_params: hk.Params, backbone_state: hk.State, inputs: dataset.Batch, ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, hk.State]]: """Compute the classification loss function. Args: backbone_params: parameters of the encoder network. classif_params: parameters of the linear classifier. backbone_state: internal state of encoder network. inputs: inputs, containing `images` and `labels`. Returns: The classification loss and various logs. """ embeddings, backbone_state = self.forward_backbone.apply( backbone_params, backbone_state, inputs, is_training=not self._freeze_backbone) logits = self.forward_classif.apply(classif_params, embeddings) labels = hk.one_hot(inputs['labels'], self._num_classes) loss = helpers.softmax_cross_entropy(logits, labels, reduction='mean') scaled_loss = loss / jax.device_count() return scaled_loss, (loss, backbone_state) def _update_func( self, experiment_state: _EvalExperimentState, global_step: jnp.ndarray, inputs: dataset.Batch, ) -> Tuple[_EvalExperimentState, LogsDict]: """Applies an update to parameters and returns new state.""" # This function computes the gradient of the first output of loss_fn and # passes through the other arguments unchanged. # Gradient of the first output of _loss_fn wrt the backbone (arg 0) and the # classifier parameters (arg 1). The auxiliary outputs are returned as-is. grad_loss_fn = jax.grad(self._loss_fn, has_aux=True, argnums=(0, 1)) grads, aux_outputs = grad_loss_fn( experiment_state.backbone_params, experiment_state.classif_params, experiment_state.backbone_state, inputs, ) backbone_grads, classifier_grads = grads train_loss, new_backbone_state = aux_outputs classifier_grads = jax.lax.psum(classifier_grads, axis_name='i') # Compute the decayed learning rate learning_rate = schedules.learning_schedule( global_step, batch_size=self._batch_size, total_steps=self._max_steps, **self._lr_schedule_config) # Compute and apply updates via our optimizer. classif_updates, new_classif_opt_state = \ self._optimizer(learning_rate).update( classifier_grads, experiment_state.classif_opt_state) new_classif_params = optax.apply_updates(experiment_state.classif_params, classif_updates) if self._freeze_backbone: del backbone_grads, new_backbone_state # Unused # The backbone is not updated. new_backbone_params = experiment_state.backbone_params new_backbone_opt_state = None new_backbone_state = experiment_state.backbone_state else: backbone_grads = jax.lax.psum(backbone_grads, axis_name='i') # Compute and apply updates via our optimizer. backbone_updates, new_backbone_opt_state = \ self._optimizer(learning_rate).update( backbone_grads, experiment_state.backbone_opt_state) new_backbone_params = optax.apply_updates( experiment_state.backbone_params, backbone_updates) experiment_state = _EvalExperimentState( new_backbone_params, new_classif_params, new_backbone_state, new_backbone_opt_state, new_classif_opt_state, ) # Scalars to log (note: we log the mean across all hosts/devices). scalars = {'train_loss': train_loss} scalars = jax.lax.pmean(scalars, axis_name='i') return experiment_state, scalars # _ # _____ ____ _| | # / _ \ \ / / _` | | # | __/\ V / (_| | | # \___| \_/ \__,_|_| # def evaluate(self, global_step, **unused_args): """See base class.""" global_step = np.array(helpers.get_first(global_step)) scalars = jax.device_get(self._eval_epoch(**self._evaluation_config)) logging.info('[Step %d] Eval scalars: %s', global_step, scalars) return scalars def _eval_batch( self, backbone_params: hk.Params, classif_params: hk.Params, backbone_state: hk.State, inputs: dataset.Batch, ) -> LogsDict: """Evaluates a batch.""" embeddings, backbone_state = self.forward_backbone.apply( backbone_params, backbone_state, inputs, is_training=False) logits = self.forward_classif.apply(classif_params, embeddings) labels = hk.one_hot(inputs['labels'], self._num_classes) loss = helpers.softmax_cross_entropy(logits, labels, reduction=None) top1_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=1) top5_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=5) # NOTE: Returned values will be summed and finally divided by num_samples. return { 'eval_loss': loss, 'top1_accuracy': top1_correct, 'top5_accuracy': top5_correct } def _eval_epoch(self, subset: Text, batch_size: int): """Evaluates an epoch.""" num_samples = 0. summed_scalars = None backbone_params = helpers.get_first(self._experiment_state.backbone_params) classif_params = helpers.get_first(self._experiment_state.classif_params) backbone_state = helpers.get_first(self._experiment_state.backbone_state) split = dataset.Split.from_string(subset) dataset_iterator = dataset.load( split, preprocess_mode=dataset.PreprocessMode.EVAL, transpose=self._should_transpose_images(), batch_dims=[batch_size]) for inputs in dataset_iterator: num_samples += inputs['labels'].shape[0] scalars = self.eval_batch_jit( backbone_params, classif_params, backbone_state, inputs, ) # Accumulate the sum of scalars for each step. scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars) if summed_scalars is None: summed_scalars = scalars else: summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars) mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars) return mean_scalars