diff --git a/byol/README.md b/byol/README.md new file mode 100644 index 0000000..6f44c51 --- /dev/null +++ b/byol/README.md @@ -0,0 +1,159 @@ +# Bootstrap Your Own Latent + +This is the implementation of the pre-training and linear evaluation pipeline of +BYOL - https://arxiv.org/abs/2006.07733. + +Using this implementation should achieve a top-1 accuracy on Imagenet between +74.0% and 74.5% after about 8h of training using 512 Cloud TPU v3. + +The main pretraining module is `byol_experiment.py`. By default it uses BYOL to +pretrain a Resnet-50 on Imagenet. In parallel, we train a classifier on top of +the representation to assess its performance during training. This classifier +does not back-propagate any gradient to the ResNet-50. + +The evaluation module is `eval_experiment.py`. It evaluates the performance of +the representation learnt by BYOL (using a given checkpoint). + +## Setup + +To set up a Python virtual environment with the required dependencies, run: + +```shell +python3 -m venv byol_env +source byol_env/bin/activate +pip install --upgrade pip +pip install -r byol/requirements.txt +``` + +The code uses `tensorflow_datasets` to load the Imagenet dataset. Manual +download may be required; see https://www.tensorflow.org/datasets/catalog/imagenet2012 +for details. + +## Running a short pre-training locally + +To run a short (40 epochs) pre-training experiment on a local machine, use: + +```shell +mkdir /tmp/byol_checkpoints +python -m byol.main_loop \ + --experiment_mode='pretrain' \ + --worker_mode='train' \ + --checkpoint_root='/tmp/byol_checkpoints' \ + --pretrain_epochs=40 +``` + +## Full pipeline and presets + +The various parts of the pipeline can be run using: + +```shell +python -m byol.main_loop \ + --experiment_mode=<'pretrain' or 'linear-eval'> \ + --worker_mode=<'train' or 'eval'> \ + --checkpoint_root= \ + --pretrain_epochs=<40, 100, 300 or 1000> +``` + +### Pretraining +Setting `--experiment_mode=pretrain` will configure the main loop for +pretraining; we provide presets for 40, 100, 300 and 1000 epochs. + +Use `--worker_mode=train` for a training job, which will regularly save +checkpoints under `/pretrain.pkl`. To monitor the progress of +the pretraining, you can run a second worker (using `--worker_mode=eval`) with +the same `checkpoint_root` setting. This worker will regularly load the +checkpoint and evaluate the performance of a linear classifier (trained by the +pretraining `train` worker) on the `TEST` set. + +Note that the default settings are set for large-scale training on Cloud TPUs, +with a total batch size of 4096. To avoid the need to re-run the full +experiment, we will make a pre-trained checkpoint available on GCP. + +### Linear evaluation +Setting `--experiment_mode=linear-eval` will configure the main loop for +linear evaluation; we provide presets for 80 epochs. + +Use `--worker_mode=train` for a training job, which will load the encoder +weights from an existing checkpoint (form a pretrain experiment) located at +`/pretrain.pkl`, and train a linear classifier on top of this +encoder. The weights from the linear classifier trained in the pretraining phase +will be discarded. + +The training job will regularly save checkpoints under +`/linear-eval.pkl`. You can run a second worker +(using `--worker_mode=eval`) with the same `checkpoint_root` setting to +regularly load the checkpoint and evaluate the performance of the classifier +(trained by the linear-eval `train` worker) on the test set. + +Note that the above will run a simplified version of the linear evaluation +pipeline described in the paper, with a single value of the base learning rate +and without using a validation set. To fully reproduce the results from the +paper, one should run 5 instances of the linear-eval `train` worker (using only +the `TRAIN` subset, and each using a different `checkpoint_root`), run the +`eval` worker (using only the `VALID` subset) for each checkpoint, then run a +final `eval` worker on the `TEST` set. + + +## Running on GCP + +Notice: we currently do not recommend running the full experiment on public +Cloud TPUs. We provide an alternative small-scale GPU setup in the next section. + +The experiments from the paper were run on Google's internal infrastructure. +Unfortunately, training JAX models on publicly available Cloud TPUs is currently +in its early stages; in particular, we found the learning to be heavily +bottlenecked by the data loading pipeline, resulting in a significantly slower +training. We will update the code and instructions below once the limitation is +addressed. + +To train the model on TPU using Google Compute Engine, please follow the +instructions at https://cloud.google.com/tpu/docs/imagenet-setup to first +download and preprocess the ImageNet dataset and upload it to a +Cloud Storage Bucket. From a GCE VM, you can check that `tensorflow_datasets` +can correctly load the dataset by running: + +```python +import tensorflow_datasets as tfds +tfds.load('imagenet2012', data_dir='gs://') +``` + +To learn how to use JAX with Cloud TPUs, please follow the instructions here: +https://github.com/google/jax/tree/master/cloud_tpu_colabs. + + +## Setup for fast iteration + +In order to make reproduction easier, it is possible to change the training +setup to use the smaller [imagenette](https://github.com/fastai/imagenette) +dataset (9469 training images with 10 classes). The following setup and +hyperparameters can be used on a machine with a single V100 GPU: + +- in `utils/dataset.py`: + - update `Split.num_examples` with the figures from [tfds](https://www.tensorflow.org/datasets/catalog/imagenette) (with `Split.VALID: 0`) + - use `imagenette/160px-v2` in the call to `tfds.load` + - use 128x128 px images (_i.e._, replace all instances of `224` by `128`) + - it doesn't seem necessary to change the color normalization (make sure to + not replace the value `0.224` by mistake in the previous step). +- in `configs/byol.py`, use: + - `num_classes`: `10` + - `network_config.encoder`: `ResNet18` + - `optimizer_config.weight_decay`: `1e-6` + - `lr_schedule_config.base_learning_rate`: `2.0` + - `evaluation_config.batch_size`: `25` + - other parameters unchanged. + +You can then run: + +```shell +mkdir /tmp/byol_checkpoints +python -m byol.main_loop \ + --experiment_mode='pretrain' \ + --worker_mode='train' \ + --checkpoint_root='/tmp/byol_checkpoints' \ + --batch_size=256 \ + --pretrain_epochs=1000 +``` + +With these settings, BYOL should achieve ~92.3% top-1 accuracy (for the +*online* classifier) in roughly 4 hours. Note that the above parameters were not +finely tuned and may not be optimal. diff --git a/byol/byol_experiment.py b/byol/byol_experiment.py new file mode 100644 index 0000000..ef6143f --- /dev/null +++ b/byol/byol_experiment.py @@ -0,0 +1,533 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BYOL pre-training implementation. + +Use this experiment to pre-train a self-supervised representation. +""" + +import functools +from typing import Any, Generator, Mapping, NamedTuple, Text, Tuple, Union + +from absl import logging +from acme.jax import utils as acme_utils +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import optax + +from byol.utils import augmentations +from byol.utils import checkpointing +from byol.utils import dataset +from byol.utils import helpers +from byol.utils import networks +from byol.utils import optimizers +from byol.utils import schedules + + +# Type declarations. +LogsDict = Mapping[Text, jnp.ndarray] + + +class _ByolExperimentState(NamedTuple): + """Byol's model and optimization parameters and state.""" + online_params: hk.Params + target_params: hk.Params + online_state: hk.State + target_state: hk.State + opt_state: optimizers.LarsState + + +class ByolExperiment: + """Byol's training and evaluation component definition.""" + + def __init__( + self, + random_seed, + num_classes, + batch_size, + max_steps, + enable_double_transpose, + base_target_ema, + network_config, + optimizer_config, + lr_schedule_config, + evaluation_config, + checkpointing_config): + """Constructs the experiment. + + Args: + random_seed: the random seed to use when initializing network weights. + num_classes: the number of classes; used for the online evaluation. + batch_size: the total batch size; should be a multiple of the number of + available accelerators. + max_steps: the number of training steps; used for the lr/target network + ema schedules. + enable_double_transpose: see dataset.py; only has effect on TPU. + base_target_ema: the initial value for the ema decay rate of the target + network. + network_config: the configuration for the network. + optimizer_config: the configuration for the optimizer. + lr_schedule_config: the configuration for the learning rate schedule. + evaluation_config: the evaluation configuration. + checkpointing_config: the configuration for checkpointing. + """ + + self._random_seed = random_seed + self._enable_double_transpose = enable_double_transpose + self._num_classes = num_classes + self._lr_schedule_config = lr_schedule_config + self._batch_size = batch_size + self._max_steps = max_steps + self._base_target_ema = base_target_ema + self._optimizer_config = optimizer_config + self._evaluation_config = evaluation_config + + # Checkpointed experiment state. + self._byol_state = None + + # Input pipelines. + self._train_input = None + self._eval_input = None + + # build the transformed ops + forward_fn = functools.partial(self._forward, **network_config) + self.forward = hk.without_apply_rng(hk.transform_with_state(forward_fn)) + # training can handle multiple devices, thus the pmap + self.update_pmap = jax.pmap(self._update_fn, axis_name='i') + # evaluation can only handle single device + self.eval_batch_jit = jax.jit(self._eval_batch) + + self._checkpointer = checkpointing.Checkpointer(**checkpointing_config) + + def _forward( + self, + inputs, + projector_hidden_size, + projector_output_size, + predictor_hidden_size, + encoder_class, + encoder_config, + bn_config, + is_training, + ): + """Forward application of byol's architecture. + + Args: + inputs: A batch of data, i.e. a dictionary, with either two keys, + (`images` and `labels`) or three keys (`view1`, `view2`, `labels`). + projector_hidden_size: hidden size of the projector MLP. + projector_output_size: output size of the projector and predictor MLPs. + predictor_hidden_size: hidden size of the predictor MLP. + encoder_class: type of the encoder (should match a class in + utils/networks). + encoder_config: passed to the encoder constructor. + bn_config: passed to the hk.BatchNorm constructors. + is_training: Training or evaluating the model? When True, inputs must + contain keys `view1` and `view2`. When False, inputs must contain key + `images`. + + Returns: + All outputs of the model, i.e. a dictionary with projection, prediction + and logits keys, for either the two views, or the image. + """ + encoder = getattr(networks, encoder_class) + net = encoder( + num_classes=None, # Don't build the final linear layer + bn_config=bn_config, + **encoder_config) + + projector = networks.MLP( + name='projector', + hidden_size=projector_hidden_size, + output_size=projector_output_size, + bn_config=bn_config) + predictor = networks.MLP( + name='predictor', + hidden_size=predictor_hidden_size, + output_size=projector_output_size, + bn_config=bn_config) + classifier = hk.Linear( + output_size=self._num_classes, name='classifier') + + def apply_once_fn(images, suffix = ''): + images = dataset.normalize_images(images) + + embedding = net(images, is_training=is_training) + proj_out = projector(embedding, is_training) + pred_out = predictor(proj_out, is_training) + + # Note the stop_gradient: label information is not leaked into the + # main network. + classif_out = classifier(jax.lax.stop_gradient(embedding)) + outputs = {} + outputs['projection' + suffix] = proj_out + outputs['prediction' + suffix] = pred_out + outputs['logits' + suffix] = classif_out + return outputs + + if is_training: + outputs_view1 = apply_once_fn(inputs['view1'], '_view1') + outputs_view2 = apply_once_fn(inputs['view2'], '_view2') + return {**outputs_view1, **outputs_view2} + else: + return apply_once_fn(inputs['images'], '') + + def _optimizer(self, learning_rate): + """Build optimizer from config.""" + return optimizers.lars( + learning_rate, + weight_decay_filter=optimizers.exclude_bias_and_norm, + lars_adaptation_filter=optimizers.exclude_bias_and_norm, + **self._optimizer_config) + + def loss_fn( + self, + online_params, + target_params, + online_state, + target_state, + rng, + inputs, + ): + """Compute BYOL's loss function. + + Args: + online_params: parameters of the online network (the loss is later + differentiated with respect to the online parameters). + target_params: parameters of the target network. + online_state: internal state of online network. + target_state: internal state of target network. + rng: random number generator state. + inputs: inputs, containing two batches of crops from the same images, + view1 and view2 and labels + + Returns: + BYOL's loss, a mapping containing the online and target networks updated + states after processing inputs, and various logs. + """ + if self._should_transpose_images(): + inputs = dataset.transpose_images(inputs) + inputs = augmentations.postprocess(inputs, rng) + labels = inputs['labels'] + + online_network_out, online_state = self.forward.apply( + params=online_params, + state=online_state, + inputs=inputs, + is_training=True) + target_network_out, target_state = self.forward.apply( + params=target_params, + state=target_state, + inputs=inputs, + is_training=True) + + # Representation loss + + # The stop_gradient is not necessary as we explicitly take the gradient with + # respect to online parameters only in `optax.apply_updates`. We leave it to + # indicate that gradients are not backpropagated through the target network. + repr_loss = helpers.regression_loss( + online_network_out['prediction_view1'], + jax.lax.stop_gradient(target_network_out['projection_view2'])) + repr_loss = repr_loss + helpers.regression_loss( + online_network_out['prediction_view2'], + jax.lax.stop_gradient(target_network_out['projection_view1'])) + + repr_loss = jnp.mean(repr_loss) + + # Classification loss (with gradient flows stopped from flowing into the + # ResNet). This is used to provide an evaluation of the representation + # quality during training. + + classif_loss = helpers.softmax_cross_entropy( + logits=online_network_out['logits_view1'], + labels=jax.nn.one_hot(labels, self._num_classes)) + + top1_correct = helpers.topk_accuracy( + online_network_out['logits_view1'], + inputs['labels'], + topk=1, + ) + + top5_correct = helpers.topk_accuracy( + online_network_out['logits_view1'], + inputs['labels'], + topk=5, + ) + + top1_acc = jnp.mean(top1_correct) + top5_acc = jnp.mean(top5_correct) + + classif_loss = jnp.mean(classif_loss) + loss = repr_loss + classif_loss + logs = dict( + loss=loss, + repr_loss=repr_loss, + classif_loss=classif_loss, + top1_accuracy=top1_acc, + top5_accuracy=top5_acc, + ) + + return loss, (dict(online_state=online_state, + target_state=target_state), logs) + + def _should_transpose_images(self): + """Should we transpose images (saves host-to-device time on TPUs).""" + return (self._enable_double_transpose and + jax.local_devices()[0].platform == 'tpu') + + def _update_fn( + self, + byol_state, + global_step, + rng, + inputs, + ): + """Update online and target parameters. + + Args: + byol_state: current BYOL state. + global_step: current training step. + rng: current random number generator + inputs: inputs, containing two batches of crops from the same images, + view1 and view2 and labels + + Returns: + Tuple containing the updated Byol state after processing the inputs, and + various logs. + """ + online_params = byol_state.online_params + target_params = byol_state.target_params + online_state = byol_state.online_state + target_state = byol_state.target_state + opt_state = byol_state.opt_state + + # update online network + grad_fn = jax.grad(self.loss_fn, argnums=0, has_aux=True) + grads, (net_states, logs) = grad_fn(online_params, target_params, + online_state, target_state, rng, inputs) + + # cross-device grad and logs reductions + grads = jax.tree_map(lambda v: jax.lax.pmean(v, axis_name='i'), grads) + logs = jax.tree_multimap(lambda x: jax.lax.pmean(x, axis_name='i'), logs) + + learning_rate = schedules.learning_schedule( + global_step, + batch_size=self._batch_size, + total_steps=self._max_steps, + **self._lr_schedule_config) + updates, opt_state = self._optimizer(learning_rate).update( + grads, opt_state, online_params) + online_params = optax.apply_updates(online_params, updates) + + # update target network + tau = schedules.target_ema( + global_step, + base_ema=self._base_target_ema, + max_steps=self._max_steps) + target_params = jax.tree_multimap(lambda x, y: x + (1 - tau) * (y - x), + target_params, online_params) + logs['tau'] = tau + logs['learning_rate'] = learning_rate + return _ByolExperimentState( + online_params=online_params, + target_params=target_params, + online_state=net_states['online_state'], + target_state=net_states['target_state'], + opt_state=opt_state), logs + + def _make_initial_state( + self, + rng, + dummy_input, + ): + """BYOL's _ByolExperimentState initialization. + + Args: + rng: random number generator used to initialize parameters. If working in + a multi device setup, this need to be a ShardedArray. + dummy_input: a dummy image, used to compute intermediate outputs shapes. + + Returns: + Initial Byol state. + """ + rng_online, rng_target = jax.random.split(rng) + + if self._should_transpose_images(): + dummy_input = dataset.transpose_images(dummy_input) + + # Online and target parameters are initialized using different rngs, + # in our experiments we did not notice a significant different with using + # the same rng for both. + online_params, online_state = self.forward.init( + rng_online, + dummy_input, + is_training=True, + ) + target_params, target_state = self.forward.init( + rng_target, + dummy_input, + is_training=True, + ) + opt_state = self._optimizer(0).init(online_params) + return _ByolExperimentState( + online_params=online_params, + target_params=target_params, + opt_state=opt_state, + online_state=online_state, + target_state=target_state, + ) + + def step(self, *, + global_step, + rng): + """Performs a single training step.""" + if self._train_input is None: + self._initialize_train() + + inputs = next(self._train_input) + + self._byol_state, scalars = self.update_pmap( + self._byol_state, + global_step=global_step, + rng=rng, + inputs=inputs, + ) + + return helpers.get_first(scalars) + + def save_checkpoint(self, step, rng): + self._checkpointer.maybe_save_checkpoint( + self._byol_state, step=step, rng=rng, is_final=step >= self._max_steps) + + def load_checkpoint(self): + checkpoint_data = self._checkpointer.maybe_load_checkpoint() + if checkpoint_data is None: + return None + self._byol_state, step, rng = checkpoint_data + return step, rng + + def _initialize_train(self): + """Initialize train. + + This includes initializing the input pipeline and Byol's state. + """ + self._train_input = acme_utils.prefetch(self._build_train_input()) + + # Check we haven't already restored params + if self._byol_state is None: + logging.info( + 'Initializing parameters rather than restoring from checkpoint.') + + # initialize Byol and setup optimizer state + inputs = next(self._train_input) + init_byol = jax.pmap(self._make_initial_state, axis_name='i') + + # Init uses the same RNG key on all hosts+devices to ensure everyone + # computes the same initial state and parameters. + init_rng = jax.random.PRNGKey(self._random_seed) + init_rng = helpers.bcast_local_devices(init_rng) + + self._byol_state = init_byol(rng=init_rng, dummy_input=inputs) + + def _build_train_input(self): + """Loads the (infinitely looping) dataset iterator.""" + num_devices = jax.device_count() + global_batch_size = self._batch_size + per_device_batch_size, ragged = divmod(global_batch_size, num_devices) + + if ragged: + raise ValueError( + f'Global batch size {global_batch_size} must be divisible by ' + f'num devices {num_devices}') + + return dataset.load( + dataset.Split.TRAIN_AND_VALID, + preprocess_mode=dataset.PreprocessMode.PRETRAIN, + transpose=self._should_transpose_images(), + batch_dims=[jax.local_device_count(), per_device_batch_size]) + + def _eval_batch( + self, + params, + state, + batch, + ): + """Evaluates a batch. + + Args: + params: Parameters of the model to evaluate. Typically Byol's online + parameters. + state: State of the model to evaluate. Typically Byol's online state. + batch: Batch of data to evaluate (must contain keys images and labels). + + Returns: + Unreduced evaluation loss and top1 accuracy on the batch. + """ + if self._should_transpose_images(): + batch = dataset.transpose_images(batch) + + outputs, _ = self.forward.apply(params, state, batch, is_training=False) + logits = outputs['logits'] + labels = hk.one_hot(batch['labels'], self._num_classes) + loss = helpers.softmax_cross_entropy(logits, labels, reduction=None) + top1_correct = helpers.topk_accuracy(logits, batch['labels'], topk=1) + top5_correct = helpers.topk_accuracy(logits, batch['labels'], topk=5) + # NOTE: Returned values will be summed and finally divided by num_samples. + return { + 'eval_loss': loss, + 'top1_accuracy': top1_correct, + 'top5_accuracy': top5_correct, + } + + def _eval_epoch(self, subset, batch_size): + """Evaluates an epoch.""" + num_samples = 0. + summed_scalars = None + + params = helpers.get_first(self._byol_state.online_params) + state = helpers.get_first(self._byol_state.online_state) + split = dataset.Split.from_string(subset) + + dataset_iterator = dataset.load( + split, + preprocess_mode=dataset.PreprocessMode.EVAL, + transpose=self._should_transpose_images(), + batch_dims=[batch_size]) + + for inputs in dataset_iterator: + num_samples += inputs['labels'].shape[0] + scalars = self.eval_batch_jit(params, state, inputs) + + # Accumulate the sum of scalars for each step. + scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars) + if summed_scalars is None: + summed_scalars = scalars + else: + summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars) + + mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars) + return mean_scalars + + def evaluate(self, global_step, **unused_args): + """Thin wrapper around _eval_epoch.""" + + global_step = np.array(helpers.get_first(global_step)) + scalars = jax.device_get(self._eval_epoch(**self._evaluation_config)) + + logging.info('[Step %d] Eval scalars: %s', global_step, scalars) + return scalars diff --git a/byol/configs/byol.py b/byol/configs/byol.py new file mode 100644 index 0000000..7088324 --- /dev/null +++ b/byol/configs/byol.py @@ -0,0 +1,78 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Config file for BYOL experiment.""" + +from byol.utils import dataset + + +# Preset values for certain number of training epochs. +_LR_PRESETS = {40: 0.45, 100: 0.45, 300: 0.3, 1000: 0.2} +_WD_PRESETS = {40: 1e-6, 100: 1e-6, 300: 1e-6, 1000: 1.5e-6} +_EMA_PRESETS = {40: 0.97, 100: 0.99, 300: 0.99, 1000: 0.996} + + +def get_config(num_epochs, batch_size): + """Return config object, containing all hyperparameters for training.""" + train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples + + assert num_epochs in [40, 100, 300, 1000] + + config = dict( + random_seed=0, + num_classes=1000, + batch_size=batch_size, + max_steps=num_epochs * train_images_per_epoch // batch_size, + enable_double_transpose=True, + base_target_ema=_EMA_PRESETS[num_epochs], + network_config=dict( + projector_hidden_size=4096, + projector_output_size=256, + predictor_hidden_size=4096, + encoder_class='ResNet50', # Should match a class in utils/networks. + encoder_config=dict( + resnet_v2=False, + width_multiplier=1), + bn_config={ + 'decay_rate': .9, + 'eps': 1e-5, + # Accumulate batchnorm statistics across devices. + # This should be equal to the `axis_name` argument passed + # to jax.pmap. + 'cross_replica_axis': 'i', + 'create_scale': True, + 'create_offset': True, + }), + optimizer_config=dict( + weight_decay=_WD_PRESETS[num_epochs], + eta=1e-3, + momentum=.9, + ), + lr_schedule_config=dict( + base_learning_rate=_LR_PRESETS[num_epochs], + warmup_steps=10 * train_images_per_epoch // batch_size, + ), + evaluation_config=dict( + subset='test', + batch_size=100, + ), + checkpointing_config=dict( + use_checkpointing=True, + checkpoint_dir='/tmp/byol', + save_checkpoint_interval=300, + filename='pretrain.pkl' + ), + ) + + return config diff --git a/byol/configs/eval.py b/byol/configs/eval.py new file mode 100644 index 0000000..d01b6dc --- /dev/null +++ b/byol/configs/eval.py @@ -0,0 +1,67 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Config file for evaluation experiment.""" + +from typing import Text + +from byol.utils import dataset + + +def get_config(checkpoint_to_evaluate, batch_size): + """Return config object for training.""" + train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples + + config = dict( + random_seed=0, + enable_double_transpose=True, + max_steps=80 * train_images_per_epoch // batch_size, + num_classes=1000, + batch_size=batch_size, + checkpoint_to_evaluate=checkpoint_to_evaluate, + # If True, allows training without loading a checkpoint. + allow_train_from_scratch=False, + # Whether the backbone should be frozen (linear evaluation) or + # trainable (fine-tuning). + freeze_backbone=True, + optimizer_config=dict( + momentum=0.9, + nesterov=True, + ), + lr_schedule_config=dict( + base_learning_rate=0.2, + warmup_steps=0, + ), + network_config=dict( # Should match the evaluated checkpoint + encoder_class='ResNet50', # Should match a class in utils/networks. + encoder_config=dict( + resnet_v2=False, + width_multiplier=1), + bn_decay_rate=0.9, + ), + evaluation_config=dict( + subset='test', + batch_size=100, + ), + checkpointing_config=dict( + use_checkpointing=True, + checkpoint_dir='/tmp/byol', + save_checkpoint_interval=300, + filename='linear-eval.pkl' + ), + ) + + return config + + diff --git a/byol/eval_experiment.py b/byol/eval_experiment.py new file mode 100644 index 0000000..2755848 --- /dev/null +++ b/byol/eval_experiment.py @@ -0,0 +1,477 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Linear evaluation or fine-tuning pipeline. + +Use this experiment to evaluate a checkpoint from byol_experiment. +""" + +import functools +from typing import Any, Generator, Mapping, NamedTuple, Optional, Text, Tuple, Union + +from absl import logging +from acme.jax import utils as acme_utils +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import optax + +from byol.utils import checkpointing +from byol.utils import dataset +from byol.utils import helpers +from byol.utils import networks +from byol.utils import schedules + +# Type declarations. +OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState] +LogsDict = Mapping[Text, jnp.ndarray] + + +class _EvalExperimentState(NamedTuple): + backbone_params: hk.Params + classif_params: hk.Params + backbone_state: hk.State + backbone_opt_state: Union[None, OptState] + classif_opt_state: OptState + + +class EvalExperiment: + """Linear evaluation experiment.""" + + def __init__( + self, + random_seed, + num_classes, + batch_size, + max_steps, + enable_double_transpose, + checkpoint_to_evaluate, + allow_train_from_scratch, + freeze_backbone, + network_config, + optimizer_config, + lr_schedule_config, + evaluation_config, + checkpointing_config): + """Constructs the experiment. + + Args: + random_seed: the random seed to use when initializing network weights. + num_classes: the number of classes; used for the online evaluation. + batch_size: the total batch size; should be a multiple of the number of + available accelerators. + max_steps: the number of training steps; used for the lr/target network + ema schedules. + enable_double_transpose: see dataset.py; only has effect on TPU. + checkpoint_to_evaluate: the path to the checkpoint to evaluate. + allow_train_from_scratch: whether to allow training without specifying a + checkpoint to evaluate (training from scratch). + freeze_backbone: whether the backbone resnet should remain frozen (linear + evaluation) or be trainable (fine-tuning). + network_config: the configuration for the network. + optimizer_config: the configuration for the optimizer. + lr_schedule_config: the configuration for the learning rate schedule. + evaluation_config: the evaluation configuration. + checkpointing_config: the configuration for checkpointing. + """ + + self._random_seed = random_seed + self._enable_double_transpose = enable_double_transpose + self._num_classes = num_classes + self._lr_schedule_config = lr_schedule_config + self._batch_size = batch_size + self._max_steps = max_steps + self._checkpoint_to_evaluate = checkpoint_to_evaluate + self._allow_train_from_scratch = allow_train_from_scratch + self._freeze_backbone = freeze_backbone + self._optimizer_config = optimizer_config + self._evaluation_config = evaluation_config + + # Checkpointed experiment state. + self._experiment_state = None + + # Input pipelines. + self._train_input = None + self._eval_input = None + + backbone_fn = functools.partial(self._backbone_fn, **network_config) + self.forward_backbone = hk.without_apply_rng( + hk.transform_with_state(backbone_fn)) + self.forward_classif = hk.without_apply_rng(hk.transform(self._classif_fn)) + self.update_pmap = jax.pmap(self._update_func, axis_name='i') + self.eval_batch_jit = jax.jit(self._eval_batch) + + self._is_backbone_training = not self._freeze_backbone + + self._checkpointer = checkpointing.Checkpointer(**checkpointing_config) + + def _should_transpose_images(self): + """Should we transpose images (saves host-to-device time on TPUs).""" + return (self._enable_double_transpose and + jax.local_devices()[0].platform == 'tpu') + + def _backbone_fn( + self, + inputs, + encoder_class, + encoder_config, + bn_decay_rate, + is_training, + ): + """Forward of the encoder (backbone).""" + bn_config = {'decay_rate': bn_decay_rate} + encoder = getattr(networks, encoder_class) + model = encoder( + None, + bn_config=bn_config, + **encoder_config) + + if self._should_transpose_images(): + inputs = dataset.transpose_images(inputs) + images = dataset.normalize_images(inputs['images']) + return model(images, is_training=is_training) + + def _classif_fn( + self, + embeddings, + ): + classifier = hk.Linear(output_size=self._num_classes) + return classifier(embeddings) + + # _ _ + # | |_ _ __ __ _(_)_ __ + # | __| '__/ _` | | '_ \ + # | |_| | | (_| | | | | | + # \__|_| \__,_|_|_| |_| + # + + def step(self, *, + global_step, + rng): + """Performs a single training step.""" + + if self._train_input is None: + self._initialize_train(rng) + + inputs = next(self._train_input) + self._experiment_state, scalars = self.update_pmap( + self._experiment_state, global_step, inputs) + + scalars = helpers.get_first(scalars) + return scalars + + def save_checkpoint(self, step, rng): + self._checkpointer.maybe_save_checkpoint( + self._experiment_state, step=step, rng=rng, + is_final=step >= self._max_steps) + + def load_checkpoint(self): + checkpoint_data = self._checkpointer.maybe_load_checkpoint() + if checkpoint_data is None: + return None + self._experiment_state, step, rng = checkpoint_data + return step, rng + + def _initialize_train(self, rng): + """BYOL's _ExperimentState initialization. + + Args: + rng: random number generator used to initialize parameters. If working in + a multi device setup, this need to be a ShardedArray. + dummy_input: a dummy image, used to compute intermediate outputs shapes. + + Returns: + Initial EvalExperiment state. + + Raises: + RuntimeError: invalid or empty checkpoint. + """ + self._train_input = acme_utils.prefetch(self._build_train_input()) + + # Check we haven't already restored params + if self._experiment_state is None: + + inputs = next(self._train_input) + + if self._checkpoint_to_evaluate is not None: + # Load params from checkpoint + checkpoint_data = checkpointing.load_checkpoint( + self._checkpoint_to_evaluate) + if checkpoint_data is None: + raise RuntimeError('Invalid checkpoint.') + backbone_params = checkpoint_data['experiment_state'].online_params + backbone_state = checkpoint_data['experiment_state'].online_state + backbone_params = helpers.bcast_local_devices(backbone_params) + backbone_state = helpers.bcast_local_devices(backbone_state) + else: + if not self._allow_train_from_scratch: + raise ValueError( + 'No checkpoint specified, but `allow_train_from_scratch` ' + 'set to False') + # Initialize with random parameters + logging.info( + 'No checkpoint specified, initializing the networks from scratch ' + '(dry run mode)') + backbone_params, backbone_state = jax.pmap( + functools.partial(self.forward_backbone.init, is_training=True), + axis_name='i')(rng=rng, inputs=inputs) + + init_experiment = jax.pmap(self._make_initial_state, axis_name='i') + + # Init uses the same RNG key on all hosts+devices to ensure everyone + # computes the same initial state and parameters. + init_rng = jax.random.PRNGKey(self._random_seed) + init_rng = helpers.bcast_local_devices(init_rng) + self._experiment_state = init_experiment( + rng=init_rng, + dummy_input=inputs, + backbone_params=backbone_params, + backbone_state=backbone_state) + + # Clear the backbone optimizer's state when the backbone is frozen. + if self._freeze_backbone: + self._experiment_state = _EvalExperimentState( + backbone_params=self._experiment_state.backbone_params, + classif_params=self._experiment_state.classif_params, + backbone_state=self._experiment_state.backbone_state, + backbone_opt_state=None, + classif_opt_state=self._experiment_state.classif_opt_state, + ) + + def _make_initial_state( + self, + rng, + dummy_input, + backbone_params, + backbone_state, + ): + """_EvalExperimentState initialization.""" + + # Initialize the backbone params + # Always create the batchnorm weights (is_training=True), they will be + # overwritten when loading the checkpoint. + embeddings, _ = self.forward_backbone.apply( + backbone_params, backbone_state, dummy_input, is_training=True) + backbone_opt_state = self._optimizer(0.).init(backbone_params) + + # Initialize the classifier params and optimizer_state + classif_params = self.forward_classif.init(rng, embeddings) + classif_opt_state = self._optimizer(0.).init(classif_params) + + return _EvalExperimentState( + backbone_params=backbone_params, + classif_params=classif_params, + backbone_state=backbone_state, + backbone_opt_state=backbone_opt_state, + classif_opt_state=classif_opt_state, + ) + + def _build_train_input(self): + """See base class.""" + num_devices = jax.device_count() + global_batch_size = self._batch_size + per_device_batch_size, ragged = divmod(global_batch_size, num_devices) + + if ragged: + raise ValueError( + f'Global batch size {global_batch_size} must be divisible by ' + f'num devices {num_devices}') + + return dataset.load( + dataset.Split.TRAIN_AND_VALID, + preprocess_mode=dataset.PreprocessMode.LINEAR_TRAIN, + transpose=self._should_transpose_images(), + batch_dims=[jax.local_device_count(), per_device_batch_size]) + + def _optimizer(self, learning_rate): + """Build optimizer from config.""" + return optax.sgd(learning_rate, **self._optimizer_config) + + def _loss_fn( + self, + backbone_params, + classif_params, + backbone_state, + inputs, + ): + """Compute the classification loss function. + + Args: + backbone_params: parameters of the encoder network. + classif_params: parameters of the linear classifier. + backbone_state: internal state of encoder network. + inputs: inputs, containing `images` and `labels`. + + Returns: + The classification loss and various logs. + """ + embeddings, backbone_state = self.forward_backbone.apply( + backbone_params, + backbone_state, + inputs, + is_training=not self._freeze_backbone) + + logits = self.forward_classif.apply(classif_params, embeddings) + labels = hk.one_hot(inputs['labels'], self._num_classes) + loss = helpers.softmax_cross_entropy(logits, labels, reduction='mean') + scaled_loss = loss / jax.device_count() + + return scaled_loss, (loss, backbone_state) + + def _update_func( + self, + experiment_state, + global_step, + inputs, + ): + """Applies an update to parameters and returns new state.""" + # This function computes the gradient of the first output of loss_fn and + # passes through the other arguments unchanged. + + # Gradient of the first output of _loss_fn wrt the backbone (arg 0) and the + # classifier parameters (arg 1). The auxiliary outputs are returned as-is. + grad_loss_fn = jax.grad(self._loss_fn, has_aux=True, argnums=(0, 1)) + + grads, aux_outputs = grad_loss_fn( + experiment_state.backbone_params, + experiment_state.classif_params, + experiment_state.backbone_state, + inputs, + ) + backbone_grads, classifier_grads = grads + train_loss, new_backbone_state = aux_outputs + classifier_grads = jax.lax.psum(classifier_grads, axis_name='i') + + # Compute the decayed learning rate + learning_rate = schedules.learning_schedule( + global_step, + batch_size=self._batch_size, + total_steps=self._max_steps, + **self._lr_schedule_config) + + # Compute and apply updates via our optimizer. + classif_updates, new_classif_opt_state = \ + self._optimizer(learning_rate).update( + classifier_grads, + experiment_state.classif_opt_state) + + new_classif_params = optax.apply_updates(experiment_state.classif_params, + classif_updates) + + if self._freeze_backbone: + del backbone_grads, new_backbone_state # Unused + # The backbone is not updated. + new_backbone_params = experiment_state.backbone_params + new_backbone_opt_state = None + new_backbone_state = experiment_state.backbone_state + else: + backbone_grads = jax.lax.psum(backbone_grads, axis_name='i') + + # Compute and apply updates via our optimizer. + backbone_updates, new_backbone_opt_state = \ + self._optimizer(learning_rate).update( + backbone_grads, + experiment_state.backbone_opt_state) + + new_backbone_params = optax.apply_updates( + experiment_state.backbone_params, backbone_updates) + + experiment_state = _EvalExperimentState( + new_backbone_params, + new_classif_params, + new_backbone_state, + new_backbone_opt_state, + new_classif_opt_state, + ) + + # Scalars to log (note: we log the mean across all hosts/devices). + scalars = {'train_loss': train_loss} + scalars = jax.lax.pmean(scalars, axis_name='i') + + return experiment_state, scalars + + # _ + # _____ ____ _| | + # / _ \ \ / / _` | | + # | __/\ V / (_| | | + # \___| \_/ \__,_|_| + # + + def evaluate(self, global_step, **unused_args): + """See base class.""" + + global_step = np.array(helpers.get_first(global_step)) + scalars = jax.device_get(self._eval_epoch(**self._evaluation_config)) + + logging.info('[Step %d] Eval scalars: %s', global_step, scalars) + return scalars + + def _eval_batch( + self, + backbone_params, + classif_params, + backbone_state, + inputs, + ): + """Evaluates a batch.""" + embeddings, backbone_state = self.forward_backbone.apply( + backbone_params, backbone_state, inputs, is_training=False) + logits = self.forward_classif.apply(classif_params, embeddings) + labels = hk.one_hot(inputs['labels'], self._num_classes) + loss = helpers.softmax_cross_entropy(logits, labels, reduction=None) + top1_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=1) + top5_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=5) + # NOTE: Returned values will be summed and finally divided by num_samples. + return { + 'eval_loss': loss, + 'top1_accuracy': top1_correct, + 'top5_accuracy': top5_correct + } + + def _eval_epoch(self, subset, batch_size): + """Evaluates an epoch.""" + num_samples = 0. + summed_scalars = None + + backbone_params = helpers.get_first(self._experiment_state.backbone_params) + classif_params = helpers.get_first(self._experiment_state.classif_params) + backbone_state = helpers.get_first(self._experiment_state.backbone_state) + split = dataset.Split.from_string(subset) + + dataset_iterator = dataset.load( + split, + preprocess_mode=dataset.PreprocessMode.EVAL, + transpose=self._should_transpose_images(), + batch_dims=[batch_size]) + + for inputs in dataset_iterator: + num_samples += inputs['labels'].shape[0] + scalars = self.eval_batch_jit( + backbone_params, + classif_params, + backbone_state, + inputs, + ) + + # Accumulate the sum of scalars for each step. + scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars) + if summed_scalars is None: + summed_scalars = scalars + else: + summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars) + + mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars) + return mean_scalars diff --git a/byol/main_loop.py b/byol/main_loop.py new file mode 100644 index 0000000..9dc5b4c --- /dev/null +++ b/byol/main_loop.py @@ -0,0 +1,157 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training and evaluation loops for an experiment.""" + +import time +from typing import Any, Mapping, Text, Type, Union + +from absl import app +from absl import flags +from absl import logging +import jax +import numpy as np + +from byol import byol_experiment +from byol import eval_experiment +from byol.configs import byol as byol_config +from byol.configs import eval as eval_config + +flags.DEFINE_string('experiment_mode', + 'pretrain', 'The experiment, pretrain or linear-eval') +flags.DEFINE_string('worker_mode', 'train', 'The mode, train or eval') +flags.DEFINE_string('worker_tpu_driver', '', 'The tpu driver to use') +flags.DEFINE_integer('pretrain_epochs', 1000, 'Number of pre-training epochs') +flags.DEFINE_integer('batch_size', 4096, 'Total batch size') +flags.DEFINE_string('checkpoint_root', '/tmp/byol', + 'The directory to save checkpoints to.') +flags.DEFINE_integer('log_tensors_interval', 60, 'Log tensors every n seconds.') + +FLAGS = flags.FLAGS + + +Experiment = Union[ + Type[byol_experiment.ByolExperiment], + Type[eval_experiment.EvalExperiment]] + + +def train_loop(experiment_class, config): + """The main training loop. + + This loop periodically saves a checkpoint to be evaluated in the eval_loop. + + Args: + experiment_class: the constructor for the experiment (either byol_experiment + or eval_experiment). + config: the experiment config. + """ + experiment = experiment_class(**config) + rng = jax.random.PRNGKey(0) + step = 0 + + host_id = jax.host_id() + last_logging = time.time() + if config['checkpointing_config']['use_checkpointing']: + checkpoint_data = experiment.load_checkpoint() + if checkpoint_data is None: + step = 0 + else: + step, rng = checkpoint_data + + local_device_count = jax.local_device_count() + while step < config['max_steps']: + step_rng, rng = tuple(jax.random.split(rng)) + # Broadcast the random seeds across the devices + step_rng_device = jax.random.split(step_rng, num=jax.device_count()) + step_rng_device = step_rng_device[ + host_id * local_device_count:(host_id + 1) * local_device_count] + step_device = np.broadcast_to(step, [local_device_count]) + + # Perform a training step and get scalars to log. + scalars = experiment.step(global_step=step_device, rng=step_rng_device) + + # Checkpointing and logging. + if config['checkpointing_config']['use_checkpointing']: + experiment.save_checkpoint(step, rng) + current_time = time.time() + if current_time - last_logging > FLAGS.log_tensors_interval: + logging.info('Step %d: %s', step, scalars) + last_logging = current_time + step += 1 + logging.info('Saving final checkpoint') + logging.info('Step %d: %s', step, scalars) + experiment.save_checkpoint(step, rng) + + +def eval_loop(experiment_class, config): + """The main evaluation loop. + + This loop periodically loads a checkpoint and evaluates its performance on the + test set, by calling experiment.evaluate. + + Args: + experiment_class: the constructor for the experiment (either byol_experiment + or eval_experiment). + config: the experiment config. + """ + experiment = experiment_class(**config) + last_evaluated_step = -1 + while True: + checkpoint_data = experiment.load_checkpoint() + if checkpoint_data is None: + logging.info('No checkpoint found. Waiting for 10s.') + time.sleep(10) + continue + step, _ = checkpoint_data + if step <= last_evaluated_step: + logging.info('Checkpoint at step %d already evaluated, waiting.', step) + time.sleep(10) + continue + host_id = jax.host_id() + local_device_count = jax.local_device_count() + step_device = np.broadcast_to(step, [local_device_count]) + scalars = experiment.evaluate(global_step=step_device) + if host_id == 0: # Only perform logging in one host. + logging.info('Evaluation at step %d: %s', step, scalars) + last_evaluated_step = step + if last_evaluated_step >= config['max_steps']: + return + + +def main(_): + if FLAGS.worker_tpu_driver: + jax.config.update('jax_xla_backend', 'tpu_driver') + jax.config.update('jax_backend_target', FLAGS.worker_tpu_driver) + logging.info('Backend: %s %r', FLAGS.worker_tpu_driver, jax.devices()) + + if FLAGS.experiment_mode == 'pretrain': + experiment_class = byol_experiment.ByolExperiment + config = byol_config.get_config(FLAGS.pretrain_epochs, FLAGS.batch_size) + elif FLAGS.experiment_mode == 'linear-eval': + experiment_class = eval_experiment.EvalExperiment + config = eval_config.get_config(f'{FLAGS.checkpoint_root}/pretrain.pkl', + FLAGS.batch_size) + else: + raise ValueError(f'Unknown experiment mode: {FLAGS.experiment_mode}') + config['checkpointing_config']['checkpoint_dir'] = FLAGS.checkpoint_root + + if FLAGS.worker_mode == 'train': + train_loop(experiment_class, config) + elif FLAGS.worker_mode == 'eval': + eval_loop(experiment_class, config) + + +if __name__ == '__main__': + app.run(main) diff --git a/byol/main_loop_test.py b/byol/main_loop_test.py new file mode 100644 index 0000000..8c39534 --- /dev/null +++ b/byol/main_loop_test.py @@ -0,0 +1,97 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for BYOL's main training loop.""" + +from absl import flags +from absl.testing import absltest +import tensorflow_datasets as tfds + +from byol import byol_experiment +from byol import eval_experiment +from byol import main_loop +from byol.configs import byol as byol_config +from byol.configs import eval as eval_config + + +FLAGS = flags.FLAGS + + +class MainLoopTest(absltest.TestCase): + + def test_pretrain(self): + config = byol_config.get_config(num_epochs=40, batch_size=4) + temp_dir = self.create_tempdir().full_path + + # Override some config fields to make test lighter. + config['network_config']['encoder_class'] = 'TinyResNet' + config['network_config']['projector_hidden_size'] = 256 + config['network_config']['predictor_hidden_size'] = 256 + config['checkpointing_config']['checkpoint_dir'] = temp_dir + config['evaluation_config']['batch_size'] = 16 + config['max_steps'] = 16 + + with tfds.testing.mock_data(num_examples=64): + experiment_class = byol_experiment.ByolExperiment + main_loop.train_loop(experiment_class, config) + main_loop.eval_loop(experiment_class, config) + + def test_linear_eval(self): + config = eval_config.get_config(checkpoint_to_evaluate=None, batch_size=4) + temp_dir = self.create_tempdir().full_path + + # Override some config fields to make test lighter. + config['network_config']['encoder_class'] = 'TinyResNet' + config['allow_train_from_scratch'] = True + config['checkpointing_config']['checkpoint_dir'] = temp_dir + config['evaluation_config']['batch_size'] = 16 + config['max_steps'] = 16 + + with tfds.testing.mock_data(num_examples=64): + experiment_class = eval_experiment.EvalExperiment + main_loop.train_loop(experiment_class, config) + main_loop.eval_loop(experiment_class, config) + + def test_pipeline(self): + b_config = byol_config.get_config(num_epochs=40, batch_size=4) + temp_dir = self.create_tempdir().full_path + + # Override some config fields to make test lighter. + b_config['network_config']['encoder_class'] = 'TinyResNet' + b_config['network_config']['projector_hidden_size'] = 256 + b_config['network_config']['predictor_hidden_size'] = 256 + b_config['checkpointing_config']['checkpoint_dir'] = temp_dir + b_config['evaluation_config']['batch_size'] = 16 + b_config['max_steps'] = 16 + + with tfds.testing.mock_data(num_examples=64): + main_loop.train_loop(byol_experiment.ByolExperiment, b_config) + + e_config = eval_config.get_config( + checkpoint_to_evaluate=f'{temp_dir}/pretrain.pkl', + batch_size=4) + + # Override some config fields to make test lighter. + e_config['network_config']['encoder_class'] = 'TinyResNet' + e_config['allow_train_from_scratch'] = True + e_config['checkpointing_config']['checkpoint_dir'] = temp_dir + e_config['evaluation_config']['batch_size'] = 16 + e_config['max_steps'] = 16 + + with tfds.testing.mock_data(num_examples=64): + main_loop.train_loop(eval_experiment.EvalExperiment, e_config) + + +if __name__ == '__main__': + absltest.main() diff --git a/byol/requirements.txt b/byol/requirements.txt new file mode 100644 index 0000000..5758ba6 --- /dev/null +++ b/byol/requirements.txt @@ -0,0 +1,8 @@ +dm-acme +dm-haiku +dm-tree +jax +jaxlib +numpy>=1.16 +tensorflow +tensorflow_datasets diff --git a/byol/utils/augmentations.py b/byol/utils/augmentations.py new file mode 100644 index 0000000..9be3b49 --- /dev/null +++ b/byol/utils/augmentations.py @@ -0,0 +1,457 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data preprocessing and augmentation.""" + +import functools +from typing import Any, Mapping, Text + +import jax +import jax.numpy as jnp + +# typing +JaxBatch = Mapping[Text, jnp.ndarray] +ConfigDict = Mapping[Text, Any] + +augment_config = dict( + view1=dict( + random_flip=True, # Random left/right flip + color_transform=dict( + apply_prob=1.0, + # Range of jittering + brightness=0.4, + contrast=0.4, + saturation=0.2, + hue=0.1, + # Probability of applying color jittering + color_jitter_prob=0.8, + # Probability of converting to grayscale + to_grayscale_prob=0.2, + # Shuffle the order of color transforms + shuffle=True), + gaussian_blur=dict( + apply_prob=1.0, + # Kernel size ~ image_size / blur_divider + blur_divider=10., + # Kernel distribution + sigma_min=0.1, + sigma_max=2.0), + solarize=dict(apply_prob=0.0, threshold=0.5), + ), + view2=dict( + random_flip=True, + color_transform=dict( + apply_prob=1.0, + brightness=0.4, + contrast=0.4, + saturation=0.2, + hue=0.1, + color_jitter_prob=0.8, + to_grayscale_prob=0.2, + shuffle=True), + gaussian_blur=dict( + apply_prob=0.1, blur_divider=10., sigma_min=0.1, sigma_max=2.0), + solarize=dict(apply_prob=0.2, threshold=0.5), + )) + + +def postprocess(inputs, rng): + """Apply the image augmentations to crops in inputs (view1 and view2).""" + + def _postprocess_image( + images, + rng, + presets, + ): + """Applies augmentations in post-processing. + + Args: + images: an NHWC tensor (with C=3), with float values in [0, 1]. + rng: a single PRNGKey. + presets: a dict of presets for the augmentations. + + Returns: + A batch of augmented images with shape NHWC, with keys view1, view2 + and labels. + """ + flip_rng, color_rng, blur_rng, solarize_rng = jax.random.split(rng, 4) + out = images + if presets['random_flip']: + out = random_flip(out, flip_rng) + if presets['color_transform']['apply_prob'] > 0: + out = color_transform(out, color_rng, **presets['color_transform']) + if presets['gaussian_blur']['apply_prob'] > 0: + out = gaussian_blur(out, blur_rng, **presets['gaussian_blur']) + if presets['solarize']['apply_prob'] > 0: + out = solarize(out, solarize_rng, **presets['solarize']) + out = jnp.clip(out, 0., 1.) + return jax.lax.stop_gradient(out) + + rng1, rng2 = jax.random.split(rng, num=2) + view1 = _postprocess_image(inputs['view1'], rng1, augment_config['view1']) + view2 = _postprocess_image(inputs['view2'], rng2, augment_config['view2']) + return dict(view1=view1, view2=view2, labels=inputs['labels']) + + +def _maybe_apply(apply_fn, inputs, rng, apply_prob): + should_apply = jax.random.uniform(rng, shape=()) <= apply_prob + return jax.lax.cond(should_apply, inputs, apply_fn, inputs, lambda x: x) + + +def _depthwise_conv2d(inputs, kernel, strides, padding): + """Computes a depthwise conv2d in Jax. + + Args: + inputs: an NHWC tensor with N=1. + kernel: a [H", W", 1, C] tensor. + strides: a 2d tensor. + padding: "SAME" or "VALID". + + Returns: + The depthwise convolution of inputs with kernel, as [H, W, C]. + """ + return jax.lax.conv_general_dilated( + inputs, + kernel, + strides, + padding, + feature_group_count=inputs.shape[-1], + dimension_numbers=('NHWC', 'HWIO', 'NHWC')) + + +def _gaussian_blur_single_image(image, kernel_size, padding, sigma): + """Applies gaussian blur to a single image, given as NHWC with N=1.""" + radius = int(kernel_size / 2) + kernel_size_ = 2 * radius + 1 + x = jnp.arange(-radius, radius + 1).astype(jnp.float32) + blur_filter = jnp.exp(-x**2 / (2. * sigma**2)) + blur_filter = blur_filter / jnp.sum(blur_filter) + blur_v = jnp.reshape(blur_filter, [kernel_size_, 1, 1, 1]) + blur_h = jnp.reshape(blur_filter, [1, kernel_size_, 1, 1]) + num_channels = image.shape[-1] + blur_h = jnp.tile(blur_h, [1, 1, 1, num_channels]) + blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels]) + expand_batch_dim = len(image.shape) == 3 + if expand_batch_dim: + image = image[jnp.newaxis, Ellipsis] + blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding) + blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding) + blurred = jnp.squeeze(blurred, axis=0) + return blurred + + +def _random_gaussian_blur(image, rng, kernel_size, padding, sigma_min, + sigma_max, apply_prob): + """Applies a random gaussian blur.""" + apply_rng, transform_rng = jax.random.split(rng) + + def _apply(image): + sigma_rng, = jax.random.split(transform_rng, 1) + sigma = jax.random.uniform( + sigma_rng, + shape=(), + minval=sigma_min, + maxval=sigma_max, + dtype=jnp.float32) + return _gaussian_blur_single_image(image, kernel_size, padding, sigma) + + return _maybe_apply(_apply, image, apply_rng, apply_prob) + + +def rgb_to_hsv(r, g, b): + """Converts R, G, B values to H, S, V values. + + Reference TF implementation: + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc + Only input values between 0 and 1 are guaranteed to work properly, but this + function complies with the TF implementation outside of this range. + + Args: + r: A tensor representing the red color component as floats. + g: A tensor representing the green color component as floats. + b: A tensor representing the blue color component as floats. + + Returns: + H, S, V values, each as tensors of shape [...] (same as the input without + the last dimension). + """ + vv = jnp.maximum(jnp.maximum(r, g), b) + range_ = vv - jnp.minimum(jnp.minimum(r, g), b) + sat = jnp.where(vv > 0, range_ / vv, 0.) + norm = jnp.where(range_ != 0, 1. / (6. * range_), 1e9) + + hr = norm * (g - b) + hg = norm * (b - r) + 2. / 6. + hb = norm * (r - g) + 4. / 6. + + hue = jnp.where(r == vv, hr, jnp.where(g == vv, hg, hb)) + hue = hue * (range_ > 0) + hue = hue + (hue < 0) + + return hue, sat, vv + + +def hsv_to_rgb(h, s, v): + """Converts H, S, V values to an R, G, B tuple. + + Reference TF implementation: + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc + Only input values between 0 and 1 are guaranteed to work properly, but this + function complies with the TF implementation outside of this range. + + Args: + h: A float tensor of arbitrary shape for the hue (0-1 values). + s: A float tensor of the same shape for the saturation (0-1 values). + v: A float tensor of the same shape for the value channel (0-1 values). + + Returns: + An (r, g, b) tuple, each with the same dimension as the inputs. + """ + c = s * v + m = v - c + dh = (h % 1.) * 6. + fmodu = dh % 2. + x = c * (1 - jnp.abs(fmodu - 1)) + hcat = jnp.floor(dh).astype(jnp.int32) + rr = jnp.where( + (hcat == 0) | (hcat == 5), c, jnp.where( + (hcat == 1) | (hcat == 4), x, 0)) + m + gg = jnp.where( + (hcat == 1) | (hcat == 2), c, jnp.where( + (hcat == 0) | (hcat == 3), x, 0)) + m + bb = jnp.where( + (hcat == 3) | (hcat == 4), c, jnp.where( + (hcat == 2) | (hcat == 5), x, 0)) + m + return rr, gg, bb + + +def adjust_brightness(rgb_tuple, delta): + return jax.tree_map(lambda x: x + delta, rgb_tuple) + + +def adjust_contrast(image, factor): + def _adjust_contrast_channel(channel): + mean = jnp.mean(channel, axis=(-2, -1), keepdims=True) + return factor * (channel - mean) + mean + return jax.tree_map(_adjust_contrast_channel, image) + + +def adjust_saturation(h, s, v, factor): + return h, jnp.clip(s * factor, 0., 1.), v + + +def adjust_hue(h, s, v, delta): + # Note: this method exactly matches TF"s adjust_hue (combined with the hsv/rgb + # conversions) when running on GPU. When running on CPU, the results will be + # different if all RGB values for a pixel are outside of the [0, 1] range. + return (h + delta) % 1.0, s, v + + +def _random_brightness(rgb_tuple, rng, max_delta): + delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta) + return adjust_brightness(rgb_tuple, delta) + + +def _random_contrast(rgb_tuple, rng, max_delta): + factor = jax.random.uniform( + rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta) + return adjust_contrast(rgb_tuple, factor) + + +def _random_saturation(rgb_tuple, rng, max_delta): + h, s, v = rgb_to_hsv(*rgb_tuple) + factor = jax.random.uniform( + rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta) + return hsv_to_rgb(*adjust_saturation(h, s, v, factor)) + + +def _random_hue(rgb_tuple, rng, max_delta): + h, s, v = rgb_to_hsv(*rgb_tuple) + delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta) + return hsv_to_rgb(*adjust_hue(h, s, v, delta)) + + +def _to_grayscale(image): + rgb_weights = jnp.array([0.2989, 0.5870, 0.1140]) + grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[Ellipsis, jnp.newaxis] + return jnp.tile(grayscale, (1, 1, 3)) # Back to 3 channels. + + +def _color_transform_single_image(image, rng, brightness, contrast, saturation, + hue, to_grayscale_prob, color_jitter_prob, + apply_prob, shuffle): + """Applies color jittering to a single image.""" + apply_rng, transform_rng = jax.random.split(rng) + perm_rng, b_rng, c_rng, s_rng, h_rng, cj_rng, gs_rng = jax.random.split( + transform_rng, 7) + + # Whether the transform should be applied at all. + should_apply = jax.random.uniform(apply_rng, shape=()) <= apply_prob + # Whether to apply grayscale transform. + should_apply_gs = jax.random.uniform(gs_rng, shape=()) <= to_grayscale_prob + # Whether to apply color jittering. + should_apply_color = jax.random.uniform(cj_rng, shape=()) <= color_jitter_prob + + # Decorator to conditionally apply fn based on an index. + def _make_cond(fn, idx): + + def identity_fn(x, unused_rng, unused_param): + return x + + def cond_fn(args, i): + def clip(args): + return jax.tree_map(lambda arg: jnp.clip(arg, 0., 1.), args) + out = jax.lax.cond(should_apply & should_apply_color & (i == idx), args, + lambda a: clip(fn(*a)), args, + lambda a: identity_fn(*a)) + return jax.lax.stop_gradient(out) + + return cond_fn + + random_brightness_cond = _make_cond(_random_brightness, idx=0) + random_contrast_cond = _make_cond(_random_contrast, idx=1) + random_saturation_cond = _make_cond(_random_saturation, idx=2) + random_hue_cond = _make_cond(_random_hue, idx=3) + + def _color_jitter(x): + rgb_tuple = tuple(jax.tree_map(jnp.squeeze, jnp.split(x, 3, axis=-1))) + if shuffle: + order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32)) + else: + order = range(4) + for idx in order: + if brightness > 0: + rgb_tuple = random_brightness_cond((rgb_tuple, b_rng, brightness), idx) + if contrast > 0: + rgb_tuple = random_contrast_cond((rgb_tuple, c_rng, contrast), idx) + if saturation > 0: + rgb_tuple = random_saturation_cond((rgb_tuple, s_rng, saturation), idx) + if hue > 0: + rgb_tuple = random_hue_cond((rgb_tuple, h_rng, hue), idx) + return jnp.stack(rgb_tuple, axis=-1) + + out_apply = _color_jitter(image) + out_apply = jax.lax.cond(should_apply & should_apply_gs, out_apply, + _to_grayscale, out_apply, lambda x: x) + return jnp.clip(out_apply, 0., 1.) + + +def _random_flip_single_image(image, rng): + _, flip_rng = jax.random.split(rng) + should_flip_lr = jax.random.uniform(flip_rng, shape=()) <= 0.5 + image = jax.lax.cond(should_flip_lr, image, jnp.fliplr, image, lambda x: x) + return image + + +def random_flip(images, rng): + rngs = jax.random.split(rng, images.shape[0]) + return jax.vmap(_random_flip_single_image)(images, rngs) + + +def color_transform(images, + rng, + brightness=0.8, + contrast=0.8, + saturation=0.8, + hue=0.2, + color_jitter_prob=0.8, + to_grayscale_prob=0.2, + apply_prob=1.0, + shuffle=True): + """Applies color jittering and/or grayscaling to a batch of images. + + Args: + images: an NHWC tensor, with C=3. + rng: a single PRNGKey. + brightness: the range of jitter on brightness. + contrast: the range of jitter on contrast. + saturation: the range of jitter on saturation. + hue: the range of jitter on hue. + color_jitter_prob: the probability of applying color jittering. + to_grayscale_prob: the probability of converting the image to grayscale. + apply_prob: the probability of applying the transform to a batch element. + shuffle: whether to apply the transforms in a random order. + + Returns: + A NHWC tensor of the transformed images. + """ + rngs = jax.random.split(rng, images.shape[0]) + jitter_fn = functools.partial( + _color_transform_single_image, + brightness=brightness, + contrast=contrast, + saturation=saturation, + hue=hue, + color_jitter_prob=color_jitter_prob, + to_grayscale_prob=to_grayscale_prob, + apply_prob=apply_prob, + shuffle=shuffle) + return jax.vmap(jitter_fn)(images, rngs) + + +def gaussian_blur(images, + rng, + blur_divider=10., + sigma_min=0.1, + sigma_max=2.0, + apply_prob=1.0): + """Applies gaussian blur to a batch of images. + + Args: + images: an NHWC tensor, with C=3. + rng: a single PRNGKey. + blur_divider: the blurring kernel will have size H / blur_divider. + sigma_min: the minimum value for sigma in the blurring kernel. + sigma_max: the maximum value for sigma in the blurring kernel. + apply_prob: the probability of applying the transform to a batch element. + + Returns: + A NHWC tensor of the blurred images. + """ + rngs = jax.random.split(rng, images.shape[0]) + kernel_size = images.shape[1] / blur_divider + blur_fn = functools.partial( + _random_gaussian_blur, + kernel_size=kernel_size, + padding='SAME', + sigma_min=sigma_min, + sigma_max=sigma_max, + apply_prob=apply_prob) + return jax.vmap(blur_fn)(images, rngs) + + +def _solarize_single_image(image, rng, threshold, apply_prob): + + def _apply(image): + return jnp.where(image < threshold, image, 1. - image) + + return _maybe_apply(_apply, image, rng, apply_prob) + + +def solarize(images, rng, threshold=0.5, apply_prob=1.0): + """Applies solarization. + + Args: + images: an NHWC tensor (with C=3). + rng: a single PRNGKey. + threshold: the solarization threshold. + apply_prob: the probability of applying the transform to a batch element. + + Returns: + A NHWC tensor of the transformed images. + """ + rngs = jax.random.split(rng, images.shape[0]) + solarize_fn = functools.partial( + _solarize_single_image, threshold=threshold, apply_prob=apply_prob) + return jax.vmap(solarize_fn)(images, rngs) diff --git a/byol/utils/checkpointing.py b/byol/utils/checkpointing.py new file mode 100644 index 0000000..94b1825 --- /dev/null +++ b/byol/utils/checkpointing.py @@ -0,0 +1,105 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Checkpoint saving and restoring utilities.""" + +import os +import time +from typing import Mapping, Text, Tuple, Union + +from absl import logging +import dill +import jax +import jax.numpy as jnp + +from byol.utils import helpers + + +class Checkpointer: + """A checkpoint saving and loading class.""" + + def __init__( + self, + use_checkpointing, + checkpoint_dir, + save_checkpoint_interval, + filename): + if (not use_checkpointing or + checkpoint_dir is None or + save_checkpoint_interval <= 0): + self._checkpoint_enabled = False + return + + self._checkpoint_enabled = True + self._checkpoint_dir = checkpoint_dir + os.makedirs(self._checkpoint_dir, exist_ok=True) + self._filename = filename + self._checkpoint_path = os.path.join(self._checkpoint_dir, filename) + self._last_checkpoint_time = 0 + self._checkpoint_every = save_checkpoint_interval + + def maybe_save_checkpoint( + self, + experiment_state, + step, + rng, + is_final): + """Saves a checkpoint if enough time has passed since the previous one.""" + current_time = time.time() + if (not self._checkpoint_enabled or + jax.host_id() != 0 or # Only checkpoint the first worker. + (not is_final and + current_time - self._last_checkpoint_time < self._checkpoint_every)): + return + checkpoint_data = dict( + experiment_state=jax.tree_map( + lambda x: jax.device_get(x[0]), experiment_state), + step=step, + rng=rng) + with open(self._checkpoint_path + '_tmp', 'wb') as checkpoint_file: + dill.dump(checkpoint_data, checkpoint_file, protocol=2) + try: + os.rename(self._checkpoint_path, self._checkpoint_path + '_old') + remove_old = True + except FileNotFoundError: + remove_old = False # No previous checkpoint to remove + os.rename(self._checkpoint_path + '_tmp', self._checkpoint_path) + if remove_old: + os.remove(self._checkpoint_path + '_old') + self._last_checkpoint_time = current_time + + def maybe_load_checkpoint( + self): + """Loads a checkpoint if any is found.""" + checkpoint_data = load_checkpoint(self._checkpoint_path) + if checkpoint_data is None: + logging.info('No existing checkpoint found at %s', self._checkpoint_path) + return None + step = checkpoint_data['step'] + rng = checkpoint_data['rng'] + experiment_state = jax.tree_map( + helpers.bcast_local_devices, checkpoint_data['experiment_state']) + del checkpoint_data + return experiment_state, step, rng + + +def load_checkpoint(checkpoint_path): + try: + with open(checkpoint_path, 'rb') as checkpoint_file: + checkpoint_data = dill.load(checkpoint_file) + logging.info('Loading checkpoint from %s, saved at step %d', + checkpoint_path, checkpoint_data['step']) + return checkpoint_data + except FileNotFoundError: + return None diff --git a/byol/utils/dataset.py b/byol/utils/dataset.py new file mode 100644 index 0000000..cf58622 --- /dev/null +++ b/byol/utils/dataset.py @@ -0,0 +1,266 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ImageNet dataset with typical pre-processing.""" + +import enum +from typing import Generator, Mapping, Optional, Sequence, Text, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow.compat.v2 as tf +import tensorflow_datasets as tfds + +Batch = Mapping[Text, np.ndarray] + + +class Split(enum.Enum): + """Imagenet dataset split.""" + TRAIN = 1 + TRAIN_AND_VALID = 2 + VALID = 3 + TEST = 4 + + @classmethod + def from_string(cls, name): + return { + 'TRAIN': Split.TRAIN, + 'TRAIN_AND_VALID': Split.TRAIN_AND_VALID, + 'VALID': Split.VALID, + 'VALIDATION': Split.VALID, + 'TEST': Split.TEST + }[name.upper()] + + @property + def num_examples(self): + return { + Split.TRAIN_AND_VALID: 1281167, + Split.TRAIN: 1271167, + Split.VALID: 10000, + Split.TEST: 50000 + }[self] + + +class PreprocessMode(enum.Enum): + """Preprocessing modes for the dataset.""" + PRETRAIN = 1 # Generates two augmented views (random crop + augmentations). + LINEAR_TRAIN = 2 # Generates a single random crop. + EVAL = 3 # Generates a single center crop. + + +def normalize_images(images): + """Normalize the image using ImageNet statistics.""" + mean_rgb = (0.485, 0.456, 0.406) + stddev_rgb = (0.229, 0.224, 0.225) + normed_images = images - jnp.array(mean_rgb).reshape((1, 1, 1, 3)) + normed_images = normed_images / jnp.array(stddev_rgb).reshape((1, 1, 1, 3)) + return normed_images + + +def load(split, + *, + preprocess_mode, + batch_dims, + transpose = False, + allow_caching = False): + """Loads the given split of the dataset.""" + start, end = _shard(split, jax.host_id(), jax.host_count()) + + total_batch_size = np.prod(batch_dims) + + tfds_split = tfds.core.ReadInstruction( + _to_tfds_split(split), from_=start, to=end, unit='abs') + ds = tfds.load( + 'imagenet2012:5.*.*', + split=tfds_split, + decoders={'image': tfds.decode.SkipDecoding()}) + + options = ds.options() + options.experimental_threading.private_threadpool_size = 48 + options.experimental_threading.max_intra_op_parallelism = 1 + + if preprocess_mode is not PreprocessMode.EVAL: + options.experimental_deterministic = False + if jax.host_count() > 1 and allow_caching: + # Only cache if we are reading a subset of the dataset. + ds = ds.cache() + ds = ds.repeat() + ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0) + + else: + if split.num_examples % total_batch_size != 0: + raise ValueError(f'Test/valid must be divisible by {total_batch_size}') + + def preprocess_pretrain(example): + view1 = _preprocess_image(example['image'], mode=preprocess_mode) + view2 = _preprocess_image(example['image'], mode=preprocess_mode) + label = tf.cast(example['label'], tf.int32) + return {'view1': view1, 'view2': view2, 'labels': label} + + def preprocess_linear_train(example): + image = _preprocess_image(example['image'], mode=preprocess_mode) + label = tf.cast(example['label'], tf.int32) + return {'images': image, 'labels': label} + + def preprocess_eval(example): + image = _preprocess_image(example['image'], mode=preprocess_mode) + label = tf.cast(example['label'], tf.int32) + return {'images': image, 'labels': label} + + if preprocess_mode is PreprocessMode.PRETRAIN: + ds = ds.map( + preprocess_pretrain, num_parallel_calls=tf.data.experimental.AUTOTUNE) + elif preprocess_mode is PreprocessMode.LINEAR_TRAIN: + ds = ds.map( + preprocess_linear_train, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + else: + ds = ds.map( + preprocess_eval, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + def transpose_fn(batch): + # We use the double-transpose-trick to improve performance for TPUs. Note + # that this (typically) requires a matching HWCN->NHWC transpose in your + # model code. The compiler cannot make this optimization for us since our + # data pipeline and model are compiled separately. + batch = dict(**batch) + if preprocess_mode is PreprocessMode.PRETRAIN: + batch['view1'] = tf.transpose(batch['view1'], (1, 2, 3, 0)) + batch['view2'] = tf.transpose(batch['view2'], (1, 2, 3, 0)) + else: + batch['images'] = tf.transpose(batch['images'], (1, 2, 3, 0)) + return batch + + for i, batch_size in enumerate(reversed(batch_dims)): + ds = ds.batch(batch_size) + if i == 0 and transpose: + ds = ds.map(transpose_fn) # NHWC -> HWCN + + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) + + yield from tfds.as_numpy(ds) + + +def _to_tfds_split(split): + """Returns the TFDS split appropriately sharded.""" + # NOTE: Imagenet did not release labels for the test split used in the + # competition, we consider the VALID split the TEST split and reserve + # 10k images from TRAIN for VALID. + if split in (Split.TRAIN, Split.TRAIN_AND_VALID, Split.VALID): + return tfds.Split.TRAIN + else: + assert split == Split.TEST + return tfds.Split.VALIDATION + + +def _shard(split, shard_index, num_shards): + """Returns [start, end) for the given shard index.""" + assert shard_index < num_shards + arange = np.arange(split.num_examples) + shard_range = np.array_split(arange, num_shards)[shard_index] + start, end = shard_range[0], (shard_range[-1] + 1) + if split == Split.TRAIN: + # Note that our TRAIN=TFDS_TRAIN[10000:] and VALID=TFDS_TRAIN[:10000]. + offset = Split.VALID.num_examples + start += offset + end += offset + return start, end + + +def _preprocess_image( + image_bytes, + mode, +): + """Returns processed and resized images.""" + if mode is PreprocessMode.PRETRAIN: + image = _decode_and_random_crop(image_bytes) + # Random horizontal flipping is optionally done in augmentations.preprocess. + elif mode is PreprocessMode.LINEAR_TRAIN: + image = _decode_and_random_crop(image_bytes) + image = tf.image.random_flip_left_right(image) + else: + image = _decode_and_center_crop(image_bytes) + # NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without + # clamping overshoots. This means values returned will be outside the range + # [0.0, 255.0] (e.g. we have observed outputs in the range [-51.1, 336.6]). + assert image.dtype == tf.uint8 + image = tf.image.resize(image, [224, 224], tf.image.ResizeMethod.BICUBIC) + image = tf.clip_by_value(image / 255., 0., 1.) + return image + + +def _decode_and_random_crop(image_bytes): + """Make a random crop of 224.""" + img_size = tf.image.extract_jpeg_shape(image_bytes) + area = tf.cast(img_size[1] * img_size[0], tf.float32) + target_area = tf.random.uniform([], 0.08, 1.0, dtype=tf.float32) * area + + log_ratio = (tf.math.log(3 / 4), tf.math.log(4 / 3)) + aspect_ratio = tf.math.exp( + tf.random.uniform([], *log_ratio, dtype=tf.float32)) + + w = tf.cast(tf.round(tf.sqrt(target_area * aspect_ratio)), tf.int32) + h = tf.cast(tf.round(tf.sqrt(target_area / aspect_ratio)), tf.int32) + + w = tf.minimum(w, img_size[1]) + h = tf.minimum(h, img_size[0]) + + offset_w = tf.random.uniform((), + minval=0, + maxval=img_size[1] - w + 1, + dtype=tf.int32) + offset_h = tf.random.uniform((), + minval=0, + maxval=img_size[0] - h + 1, + dtype=tf.int32) + + crop_window = tf.stack([offset_h, offset_w, h, w]) + image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + return image + + +def transpose_images(batch): + """Transpose images for TPU training..""" + new_batch = dict(batch) # Avoid mutating in place. + if 'images' in batch: + new_batch['images'] = jnp.transpose(batch['images'], (3, 0, 1, 2)) + else: + new_batch['view1'] = jnp.transpose(batch['view1'], (3, 0, 1, 2)) + new_batch['view2'] = jnp.transpose(batch['view2'], (3, 0, 1, 2)) + return new_batch + + +def _decode_and_center_crop( + image_bytes, + jpeg_shape = None, +): + """Crops to center of image with padding then scales.""" + if jpeg_shape is None: + jpeg_shape = tf.image.extract_jpeg_shape(image_bytes) + image_height = jpeg_shape[0] + image_width = jpeg_shape[1] + + padded_center_crop_size = tf.cast( + ((224 / (224 + 32)) * + tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + crop_window = tf.stack([ + offset_height, offset_width, padded_center_crop_size, + padded_center_crop_size + ]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + return image diff --git a/byol/utils/helpers.py b/byol/utils/helpers.py new file mode 100644 index 0000000..18d2cc2 --- /dev/null +++ b/byol/utils/helpers.py @@ -0,0 +1,131 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions.""" + +from typing import Optional, Text +from absl import logging +import jax +import jax.numpy as jnp + + +def topk_accuracy( + logits, + labels, + topk, + ignore_label_above = None, +): + """Top-num_codes accuracy.""" + assert len(labels.shape) == 1, 'topk expects 1d int labels.' + assert len(logits.shape) == 2, 'topk expects 2d logits.' + + if ignore_label_above is not None: + logits = logits[labels < ignore_label_above, :] + labels = labels[labels < ignore_label_above] + + prds = jnp.argsort(logits, axis=1)[:, ::-1] + prds = prds[:, :topk] + total = jnp.any(prds == jnp.tile(labels[:, jnp.newaxis], [1, topk]), axis=1) + + return total + + +def softmax_cross_entropy( + logits, + labels, + reduction = 'mean', +): + """Computes softmax cross entropy given logits and one-hot class labels. + + Args: + logits: Logit output values. + labels: Ground truth one-hot-encoded labels. + reduction: Type of reduction to apply to loss. + + Returns: + Loss value. If `reduction` is `none`, this has the same shape as `labels`; + otherwise, it is scalar. + + Raises: + ValueError: If the type of `reduction` is unsupported. + """ + loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1) + if reduction == 'sum': + return jnp.sum(loss) + elif reduction == 'mean': + return jnp.mean(loss) + elif reduction == 'none' or reduction is None: + return loss + else: + raise ValueError(f'Incorrect reduction mode {reduction}') + + +def l2_normalize( + x, + axis = None, + epsilon = 1e-12, +): + """l2 normalize a tensor on an axis with numerical stability.""" + square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True) + x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon)) + return x * x_inv_norm + + +def l2_weight_regularizer(params): + """Helper to do lasso on weights. + + Args: + params: the entire param set. + + Returns: + Scalar of the l2 norm of the weights. + """ + l2_norm = 0. + for mod_name, mod_params in params.items(): + if 'norm' not in mod_name: + for param_k, param_v in mod_params.items(): + if param_k != 'b' not in param_k: # Filter out biases + l2_norm += jnp.sum(jnp.square(param_v)) + else: + logging.warning('Excluding %s/%s from optimizer weight decay!', + mod_name, param_k) + else: + logging.warning('Excluding %s from optimizer weight decay!', mod_name) + + return 0.5 * l2_norm + + +def regression_loss(x, y): + """Byol's regression loss. This is a simple cosine similarity.""" + normed_x, normed_y = l2_normalize(x, axis=-1), l2_normalize(y, axis=-1) + return jnp.sum((normed_x - normed_y)**2, axis=-1) + + +def bcast_local_devices(value): + """Broadcasts an object to all local devices.""" + devices = jax.local_devices() + + def _replicate(x): + """Replicate an object on each device.""" + x = jnp.array(x) + aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype) + buffers = [jax.interpreters.xla.device_put(x, d) for d in devices] + return jax.pxla.ShardedDeviceArray(aval, buffers) + + return jax.tree_util.tree_map(_replicate, value) + + +def get_first(xs): + """Gets values from the first device.""" + return jax.tree_map(lambda x: x[0], xs) diff --git a/byol/utils/networks.py b/byol/utils/networks.py new file mode 100644 index 0000000..72c8be7 --- /dev/null +++ b/byol/utils/networks.py @@ -0,0 +1,319 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Networks used in BYOL.""" + +from typing import Any, Mapping, Optional, Sequence, Text + +import haiku as hk +import jax +import jax.numpy as jnp + + +class MLP(hk.Module): + """One hidden layer perceptron, with normalization.""" + + def __init__( + self, + name, + hidden_size, + output_size, + bn_config, + ): + super().__init__(name=name) + self._hidden_size = hidden_size + self._output_size = output_size + self._bn_config = bn_config + + def __call__(self, inputs, is_training): + out = hk.Linear(output_size=self._hidden_size, with_bias=True)(inputs) + out = hk.BatchNorm(**self._bn_config)(out, is_training=is_training) + out = jax.nn.relu(out) + out = hk.Linear(output_size=self._output_size, with_bias=False)(out) + return out + + +class ResNetTorso(hk.nets.ResNet): + """ResNet model.""" + + def __init__( + self, + blocks_per_group, + num_classes = None, + bn_config = None, + resnet_v2 = False, + bottleneck = True, + channels_per_group = (256, 512, 1024, 2048), + use_projection = (True, True, True, True), + width_multiplier = 1, + name = None, + ): + """Constructs a ResNet model. + + Args: + blocks_per_group: A sequence of length 4 that indicates the number of + blocks created in each group. + num_classes: The number of classes to classify the inputs into. + bn_config: A dictionary of three elements, `decay_rate`, `eps`, and + `cross_replica_axis`, to be passed on to the `BatchNorm` layers. By + default the `decay_rate` is `0.9` and `eps` is `1e-5`, and the axis is + `None`. + resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to + False. + bottleneck: Whether the block should bottleneck or not. Defaults to True. + channels_per_group: A sequence of length 4 that indicates the number + of channels used for each block in each group. + use_projection: A sequence of length 4 that indicates whether each + residual block should use projection. + width_multiplier: An integer multiplying the number of channels per group. + name: Name of the module. + """ + channels_per_group = [width_multiplier * channel + for channel in channels_per_group] + super().__init__( + blocks_per_group=blocks_per_group, + num_classes=num_classes, + bn_config=bn_config, + resnet_v2=resnet_v2, + bottleneck=bottleneck, + channels_per_group=channels_per_group, + name=name) + + def __call__(self, inputs, is_training, test_local_stats=False): + out = inputs + out = self.initial_conv(out) + if not self.resnet_v2: + out = self.initial_batchnorm(out, is_training, test_local_stats) + out = jax.nn.relu(out) + + out = hk.max_pool(out, + window_shape=(1, 3, 3, 1), + strides=(1, 2, 2, 1), + padding='SAME') + + for block_group in self.block_groups: + out = block_group(out, is_training, test_local_stats) + + if self.resnet_v2: + out = self.final_batchnorm(out, is_training, test_local_stats) + out = jax.nn.relu(out) + out = jnp.mean(out, axis=[1, 2]) + return out + + +class TinyResNet(ResNetTorso): + """Tiny resnet for local runs and tests.""" + + def __init__(self, + num_classes = None, + bn_config = None, + resnet_v2 = False, + width_multiplier = 1, + name = None): + """Constructs a ResNet model. + + Args: + num_classes: The number of classes to classify the inputs into. + bn_config: A dictionary of two elements, `decay_rate` and `eps` to be + passed on to the `BatchNorm` layers. + resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults + to False. + width_multiplier: An integer multiplying the number of channels per group. + name: Name of the module. + """ + super().__init__(blocks_per_group=(1, 1, 1, 1), + channels_per_group=(8, 8, 8, 8), + num_classes=num_classes, + bn_config=bn_config, + resnet_v2=resnet_v2, + bottleneck=False, + width_multiplier=width_multiplier, + name=name) + + +class ResNet18(ResNetTorso): + """ResNet18.""" + + def __init__(self, + num_classes = None, + bn_config = None, + resnet_v2 = False, + width_multiplier = 1, + name = None): + """Constructs a ResNet model. + + Args: + num_classes: The number of classes to classify the inputs into. + bn_config: A dictionary of two elements, `decay_rate` and `eps` to be + passed on to the `BatchNorm` layers. + resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults + to False. + width_multiplier: An integer multiplying the number of channels per group. + name: Name of the module. + """ + super().__init__(blocks_per_group=(2, 2, 2, 2), + num_classes=num_classes, + bn_config=bn_config, + resnet_v2=resnet_v2, + bottleneck=False, + channels_per_group=(64, 128, 256, 512), + width_multiplier=width_multiplier, + name=name) + + +class ResNet34(ResNetTorso): + """ResNet34.""" + + def __init__(self, + num_classes, + bn_config = None, + resnet_v2 = False, + width_multiplier = 1, + name = None): + """Constructs a ResNet model. + + Args: + num_classes: The number of classes to classify the inputs into. + bn_config: A dictionary of two elements, `decay_rate` and `eps` to be + passed on to the `BatchNorm` layers. + resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults + to False. + width_multiplier: An integer multiplying the number of channels per group. + name: Name of the module. + """ + super().__init__(blocks_per_group=(3, 4, 6, 3), + num_classes=num_classes, + bn_config=bn_config, + resnet_v2=resnet_v2, + bottleneck=False, + channels_per_group=(64, 128, 256, 512), + width_multiplier=width_multiplier, + name=name) + + +class ResNet50(ResNetTorso): + """ResNet50.""" + + def __init__(self, + num_classes = None, + bn_config = None, + resnet_v2 = False, + width_multiplier = 1, + name = None): + """Constructs a ResNet model. + + Args: + num_classes: The number of classes to classify the inputs into. + bn_config: A dictionary of two elements, `decay_rate` and `eps` to be + passed on to the `BatchNorm` layers. + resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults + to False. + width_multiplier: An integer multiplying the number of channels per group. + name: Name of the module. + """ + super().__init__(blocks_per_group=(3, 4, 6, 3), + num_classes=num_classes, + bn_config=bn_config, + resnet_v2=resnet_v2, + bottleneck=True, + width_multiplier=width_multiplier, + name=name) + + +class ResNet101(ResNetTorso): + """ResNet101.""" + + def __init__(self, + num_classes, + bn_config = None, + resnet_v2 = False, + width_multiplier = 1, + name = None): + """Constructs a ResNet model. + + Args: + num_classes: The number of classes to classify the inputs into. + bn_config: A dictionary of two elements, `decay_rate` and `eps` to be + passed on to the `BatchNorm` layers. + resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults + to False. + width_multiplier: An integer multiplying the number of channels per group. + name: Name of the module. + """ + super().__init__(blocks_per_group=(3, 4, 23, 3), + num_classes=num_classes, + bn_config=bn_config, + resnet_v2=resnet_v2, + bottleneck=True, + width_multiplier=width_multiplier, + name=name) + + +class ResNet152(ResNetTorso): + """ResNet152.""" + + def __init__(self, + num_classes, + bn_config = None, + resnet_v2 = False, + width_multiplier = 1, + name = None): + """Constructs a ResNet model. + + Args: + num_classes: The number of classes to classify the inputs into. + bn_config: A dictionary of two elements, `decay_rate` and `eps` to be + passed on to the `BatchNorm` layers. + resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults + to False. + width_multiplier: An integer multiplying the number of channels per group. + name: Name of the module. + """ + super().__init__(blocks_per_group=(3, 8, 36, 3), + num_classes=num_classes, + bn_config=bn_config, + resnet_v2=resnet_v2, + bottleneck=True, + width_multiplier=width_multiplier, + name=name) + + +class ResNet200(ResNetTorso): + """ResNet200.""" + + def __init__(self, + num_classes, + bn_config = None, + resnet_v2 = False, + width_multiplier = 1, + name = None): + """Constructs a ResNet model. + + Args: + num_classes: The number of classes to classify the inputs into. + bn_config: A dictionary of two elements, `decay_rate` and `eps` to be + passed on to the `BatchNorm` layers. + resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults + to False. + width_multiplier: An integer multiplying the number of channels per group. + name: Name of the module. + """ + super().__init__(blocks_per_group=(3, 24, 36, 3), + num_classes=num_classes, + bn_config=bn_config, + resnet_v2=resnet_v2, + bottleneck=True, + width_multiplier=width_multiplier, + name=name) diff --git a/byol/utils/optimizers.py b/byol/utils/optimizers.py new file mode 100644 index 0000000..bc8f039 --- /dev/null +++ b/byol/utils/optimizers.py @@ -0,0 +1,188 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of LARS Optimizer with optax.""" + +from typing import Any, Callable, List, NamedTuple, Optional, Tuple + +import jax +import jax.numpy as jnp +import optax +import tree as nest + +# A filter function takes a path and a value as input and outputs True for +# variable to apply update and False not to apply the update +FilterFn = Callable[[Tuple[Any], jnp.ndarray], jnp.ndarray] + + +def exclude_bias_and_norm(path, val): + """Filter to exclude biaises and normalizations weights.""" + del val + if path[-1] == "b" or "norm" in path[-2]: + return False + return True + + +def _partial_update(updates, + new_updates, + params, + filter_fn = None): + """Returns new_update for params which filter_fn is True else updates.""" + + if filter_fn is None: + return new_updates + + wrapped_filter_fn = lambda x, y: jnp.array(filter_fn(x, y)) + params_to_filter = nest.map_structure_with_path(wrapped_filter_fn, params) + + def _update_fn(g, t, m): + m = m.astype(g.dtype) + return g * (1. - m) + t * m + + return jax.tree_multimap(_update_fn, updates, new_updates, params_to_filter) + + +class ScaleByLarsState(NamedTuple): + mu: jnp.ndarray + + +def scale_by_lars( + momentum = 0.9, + eta = 0.001, + filter_fn = None): + """Rescales updates according to the LARS algorithm. + + Does not include weight decay. + References: + [You et al, 2017](https://arxiv.org/abs/1708.03888) + + Args: + momentum: momentum coeficient. + eta: LARS coefficient. + filter_fn: an optional filter function. + + Returns: + An (init_fn, update_fn) tuple. + """ + + def init_fn(params): + mu = jax.tree_multimap(jnp.zeros_like, params) # momentum + return ScaleByLarsState(mu=mu) + + def update_fn(updates, state, + params): + + def lars_adaptation( + update, + param, + ): + param_norm = jnp.linalg.norm(param) + update_norm = jnp.linalg.norm(update) + return update * jnp.where( + param_norm > 0., + jnp.where(update_norm > 0, + (eta * param_norm / update_norm), 1.0), 1.0) + + adapted_updates = jax.tree_multimap(lars_adaptation, updates, params) + adapted_updates = _partial_update(updates, adapted_updates, params, + filter_fn) + mu = jax.tree_multimap(lambda g, t: momentum * g + t, + state.mu, adapted_updates) + return mu, ScaleByLarsState(mu=mu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class AddWeightDecayState(NamedTuple): + """Stateless transformation.""" + + +def add_weight_decay( + weight_decay, + filter_fn = None): + """Adds a weight decay to the update. + + Args: + weight_decay: weight_decay coeficient. + filter_fn: an optional filter function. + + Returns: + An (init_fn, update_fn) tuple. + """ + + def init_fn(_): + return AddWeightDecayState() + + def update_fn( + updates, + state, + params, + ): + new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates, + params) + new_updates = _partial_update(updates, new_updates, params, filter_fn) + return new_updates, state + + return optax.GradientTransformation(init_fn, update_fn) + + +LarsState = List # Type for the lars optimizer + + +def lars( + learning_rate, + weight_decay = 0., + momentum = 0.9, + eta = 0.001, + weight_decay_filter = None, + lars_adaptation_filter = None, +): + """Creates lars optimizer with weight decay. + + References: + [You et al, 2017](https://arxiv.org/abs/1708.03888) + + Args: + learning_rate: learning rate coefficient. + weight_decay: weight decay coefficient. + momentum: momentum coefficient. + eta: LARS coefficient. + weight_decay_filter: optional filter function to only apply the weight + decay on a subset of parameters. The filter function takes as input the + parameter path (as a tuple) and its associated update, and return a True + for params to apply the weight decay and False for params to not apply + the weight decay. When weight_decay_filter is set to None, the weight + decay is not applied to the bias, i.e. when the variable name is 'b', and + the weight decay is not applied to nornalization params, i.e. the + panultimate path contains 'norm'. + lars_adaptation_filter: similar to weight decay filter but for lars + adaptation + + Returns: + An optax.GradientTransformation, i.e. a (init_fn, update_fn) tuple. + """ + + if weight_decay_filter is None: + weight_decay_filter = lambda *_: True + if lars_adaptation_filter is None: + lars_adaptation_filter = lambda *_: True + + return optax.chain( + add_weight_decay( + weight_decay=weight_decay, filter_fn=weight_decay_filter), + scale_by_lars( + momentum=momentum, eta=eta, filter_fn=lars_adaptation_filter), + optax.scale(-learning_rate), + ) diff --git a/byol/utils/schedules.py b/byol/utils/schedules.py new file mode 100644 index 0000000..f8817a0 --- /dev/null +++ b/byol/utils/schedules.py @@ -0,0 +1,53 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Learning rate schedules.""" +import jax.numpy as jnp + + +def target_ema(global_step, + base_ema, + max_steps): + decay = _cosine_decay(global_step, max_steps, 1.) + return 1. - (1. - base_ema) * decay + + +def learning_schedule(global_step, + batch_size, + base_learning_rate, + total_steps, + warmup_steps): + """Cosine learning rate scheduler.""" + # Compute LR & Scaled LR + scaled_lr = base_learning_rate * batch_size / 256. + learning_rate = ( + global_step.astype(jnp.float32) / int(warmup_steps) * + scaled_lr if warmup_steps > 0 else scaled_lr) + + # Cosine schedule after warmup. + return jnp.where( + global_step < warmup_steps, learning_rate, + _cosine_decay(global_step - warmup_steps, total_steps - warmup_steps, + scaled_lr)) + + +def _cosine_decay(global_step, + max_steps, + initial_value): + """Simple implementation of cosine decay from TF1.""" + global_step = jnp.minimum(global_step, max_steps) + cosine_decay_value = 0.5 * (1 + jnp.cos(jnp.pi * global_step / max_steps)) + decayed_learning_rate = initial_value * cosine_decay_value + return decayed_learning_rate