This notebook illustrates the CIFAR-10 experiments in the paper:

PiperOrigin-RevId: 326141025
This commit is contained in:
Clara Huiyi Hu
2020-08-12 00:58:22 +00:00
committed by Louise Deason
parent 84321ad894
commit 923ad3cff0
15 changed files with 3095 additions and 0 deletions
+159
View File
@@ -0,0 +1,159 @@
# Bootstrap Your Own Latent
This is the implementation of the pre-training and linear evaluation pipeline of
BYOL - https://arxiv.org/abs/2006.07733.
Using this implementation should achieve a top-1 accuracy on Imagenet between
74.0% and 74.5% after about 8h of training using 512 Cloud TPU v3.
The main pretraining module is `byol_experiment.py`. By default it uses BYOL to
pretrain a Resnet-50 on Imagenet. In parallel, we train a classifier on top of
the representation to assess its performance during training. This classifier
does not back-propagate any gradient to the ResNet-50.
The evaluation module is `eval_experiment.py`. It evaluates the performance of
the representation learnt by BYOL (using a given checkpoint).
## Setup
To set up a Python virtual environment with the required dependencies, run:
```shell
python3 -m venv byol_env
source byol_env/bin/activate
pip install --upgrade pip
pip install -r byol/requirements.txt
```
The code uses `tensorflow_datasets` to load the Imagenet dataset. Manual
download may be required; see https://www.tensorflow.org/datasets/catalog/imagenet2012
for details.
## Running a short pre-training locally
To run a short (40 epochs) pre-training experiment on a local machine, use:
```shell
mkdir /tmp/byol_checkpoints
python -m byol.main_loop \
--experiment_mode='pretrain' \
--worker_mode='train' \
--checkpoint_root='/tmp/byol_checkpoints' \
--pretrain_epochs=40
```
## Full pipeline and presets
The various parts of the pipeline can be run using:
```shell
python -m byol.main_loop \
--experiment_mode=<'pretrain' or 'linear-eval'> \
--worker_mode=<'train' or 'eval'> \
--checkpoint_root=</path/to/the/checkpointing/folder> \
--pretrain_epochs=<40, 100, 300 or 1000>
```
### Pretraining
Setting `--experiment_mode=pretrain` will configure the main loop for
pretraining; we provide presets for 40, 100, 300 and 1000 epochs.
Use `--worker_mode=train` for a training job, which will regularly save
checkpoints under `<checkpoint_root>/pretrain.pkl`. To monitor the progress of
the pretraining, you can run a second worker (using `--worker_mode=eval`) with
the same `checkpoint_root` setting. This worker will regularly load the
checkpoint and evaluate the performance of a linear classifier (trained by the
pretraining `train` worker) on the `TEST` set.
Note that the default settings are set for large-scale training on Cloud TPUs,
with a total batch size of 4096. To avoid the need to re-run the full
experiment, we will make a pre-trained checkpoint available on GCP.
### Linear evaluation
Setting `--experiment_mode=linear-eval` will configure the main loop for
linear evaluation; we provide presets for 80 epochs.
Use `--worker_mode=train` for a training job, which will load the encoder
weights from an existing checkpoint (form a pretrain experiment) located at
`<checkpoint_root>/pretrain.pkl`, and train a linear classifier on top of this
encoder. The weights from the linear classifier trained in the pretraining phase
will be discarded.
The training job will regularly save checkpoints under
`<checkpoint_root>/linear-eval.pkl`. You can run a second worker
(using `--worker_mode=eval`) with the same `checkpoint_root` setting to
regularly load the checkpoint and evaluate the performance of the classifier
(trained by the linear-eval `train` worker) on the test set.
Note that the above will run a simplified version of the linear evaluation
pipeline described in the paper, with a single value of the base learning rate
and without using a validation set. To fully reproduce the results from the
paper, one should run 5 instances of the linear-eval `train` worker (using only
the `TRAIN` subset, and each using a different `checkpoint_root`), run the
`eval` worker (using only the `VALID` subset) for each checkpoint, then run a
final `eval` worker on the `TEST` set.
## Running on GCP
Notice: we currently do not recommend running the full experiment on public
Cloud TPUs. We provide an alternative small-scale GPU setup in the next section.
The experiments from the paper were run on Google's internal infrastructure.
Unfortunately, training JAX models on publicly available Cloud TPUs is currently
in its early stages; in particular, we found the learning to be heavily
bottlenecked by the data loading pipeline, resulting in a significantly slower
training. We will update the code and instructions below once the limitation is
addressed.
To train the model on TPU using Google Compute Engine, please follow the
instructions at https://cloud.google.com/tpu/docs/imagenet-setup to first
download and preprocess the ImageNet dataset and upload it to a
Cloud Storage Bucket. From a GCE VM, you can check that `tensorflow_datasets`
can correctly load the dataset by running:
```python
import tensorflow_datasets as tfds
tfds.load('imagenet2012', data_dir='gs://<your-bucket-name>')
```
To learn how to use JAX with Cloud TPUs, please follow the instructions here:
https://github.com/google/jax/tree/master/cloud_tpu_colabs.
## Setup for fast iteration
In order to make reproduction easier, it is possible to change the training
setup to use the smaller [imagenette](https://github.com/fastai/imagenette)
dataset (9469 training images with 10 classes). The following setup and
hyperparameters can be used on a machine with a single V100 GPU:
- in `utils/dataset.py`:
- update `Split.num_examples` with the figures from [tfds](https://www.tensorflow.org/datasets/catalog/imagenette) (with `Split.VALID: 0`)
- use `imagenette/160px-v2` in the call to `tfds.load`
- use 128x128 px images (_i.e._, replace all instances of `224` by `128`)
- it doesn't seem necessary to change the color normalization (make sure to
not replace the value `0.224` by mistake in the previous step).
- in `configs/byol.py`, use:
- `num_classes`: `10`
- `network_config.encoder`: `ResNet18`
- `optimizer_config.weight_decay`: `1e-6`
- `lr_schedule_config.base_learning_rate`: `2.0`
- `evaluation_config.batch_size`: `25`
- other parameters unchanged.
You can then run:
```shell
mkdir /tmp/byol_checkpoints
python -m byol.main_loop \
--experiment_mode='pretrain' \
--worker_mode='train' \
--checkpoint_root='/tmp/byol_checkpoints' \
--batch_size=256 \
--pretrain_epochs=1000
```
With these settings, BYOL should achieve ~92.3% top-1 accuracy (for the
*online* classifier) in roughly 4 hours. Note that the above parameters were not
finely tuned and may not be optimal.
File diff suppressed because it is too large Load Diff
+78
View File
@@ -0,0 +1,78 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Config file for BYOL experiment."""
from byol.utils import dataset
# Preset values for certain number of training epochs.
_LR_PRESETS = {40: 0.45, 100: 0.45, 300: 0.3, 1000: 0.2}
_WD_PRESETS = {40: 1e-6, 100: 1e-6, 300: 1e-6, 1000: 1.5e-6}
_EMA_PRESETS = {40: 0.97, 100: 0.99, 300: 0.99, 1000: 0.996}
def get_config(num_epochs, batch_size):
"""Return config object, containing all hyperparameters for training."""
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
assert num_epochs in [40, 100, 300, 1000]
config = dict(
random_seed=0,
num_classes=1000,
batch_size=batch_size,
max_steps=num_epochs * train_images_per_epoch // batch_size,
enable_double_transpose=True,
base_target_ema=_EMA_PRESETS[num_epochs],
network_config=dict(
projector_hidden_size=4096,
projector_output_size=256,
predictor_hidden_size=4096,
encoder_class='ResNet50', # Should match a class in utils/networks.
encoder_config=dict(
resnet_v2=False,
width_multiplier=1),
bn_config={
'decay_rate': .9,
'eps': 1e-5,
# Accumulate batchnorm statistics across devices.
# This should be equal to the `axis_name` argument passed
# to jax.pmap.
'cross_replica_axis': 'i',
'create_scale': True,
'create_offset': True,
}),
optimizer_config=dict(
weight_decay=_WD_PRESETS[num_epochs],
eta=1e-3,
momentum=.9,
),
lr_schedule_config=dict(
base_learning_rate=_LR_PRESETS[num_epochs],
warmup_steps=10 * train_images_per_epoch // batch_size,
),
evaluation_config=dict(
subset='test',
batch_size=100,
),
checkpointing_config=dict(
use_checkpointing=True,
checkpoint_dir='/tmp/byol',
save_checkpoint_interval=300,
filename='pretrain.pkl'
),
)
return config
+67
View File
@@ -0,0 +1,67 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Config file for evaluation experiment."""
from typing import Text
from byol.utils import dataset
def get_config(checkpoint_to_evaluate, batch_size):
"""Return config object for training."""
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
config = dict(
random_seed=0,
enable_double_transpose=True,
max_steps=80 * train_images_per_epoch // batch_size,
num_classes=1000,
batch_size=batch_size,
checkpoint_to_evaluate=checkpoint_to_evaluate,
# If True, allows training without loading a checkpoint.
allow_train_from_scratch=False,
# Whether the backbone should be frozen (linear evaluation) or
# trainable (fine-tuning).
freeze_backbone=True,
optimizer_config=dict(
momentum=0.9,
nesterov=True,
),
lr_schedule_config=dict(
base_learning_rate=0.2,
warmup_steps=0,
),
network_config=dict( # Should match the evaluated checkpoint
encoder_class='ResNet50', # Should match a class in utils/networks.
encoder_config=dict(
resnet_v2=False,
width_multiplier=1),
bn_decay_rate=0.9,
),
evaluation_config=dict(
subset='test',
batch_size=100,
),
checkpointing_config=dict(
use_checkpointing=True,
checkpoint_dir='/tmp/byol',
save_checkpoint_interval=300,
filename='linear-eval.pkl'
),
)
return config
+477
View File
@@ -0,0 +1,477 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Linear evaluation or fine-tuning pipeline.
Use this experiment to evaluate a checkpoint from byol_experiment.
"""
import functools
from typing import Any, Generator, Mapping, NamedTuple, Optional, Text, Tuple, Union
from absl import logging
from acme.jax import utils as acme_utils
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
from byol.utils import checkpointing
from byol.utils import dataset
from byol.utils import helpers
from byol.utils import networks
from byol.utils import schedules
# Type declarations.
OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState]
LogsDict = Mapping[Text, jnp.ndarray]
class _EvalExperimentState(NamedTuple):
backbone_params: hk.Params
classif_params: hk.Params
backbone_state: hk.State
backbone_opt_state: Union[None, OptState]
classif_opt_state: OptState
class EvalExperiment:
"""Linear evaluation experiment."""
def __init__(
self,
random_seed,
num_classes,
batch_size,
max_steps,
enable_double_transpose,
checkpoint_to_evaluate,
allow_train_from_scratch,
freeze_backbone,
network_config,
optimizer_config,
lr_schedule_config,
evaluation_config,
checkpointing_config):
"""Constructs the experiment.
Args:
random_seed: the random seed to use when initializing network weights.
num_classes: the number of classes; used for the online evaluation.
batch_size: the total batch size; should be a multiple of the number of
available accelerators.
max_steps: the number of training steps; used for the lr/target network
ema schedules.
enable_double_transpose: see dataset.py; only has effect on TPU.
checkpoint_to_evaluate: the path to the checkpoint to evaluate.
allow_train_from_scratch: whether to allow training without specifying a
checkpoint to evaluate (training from scratch).
freeze_backbone: whether the backbone resnet should remain frozen (linear
evaluation) or be trainable (fine-tuning).
network_config: the configuration for the network.
optimizer_config: the configuration for the optimizer.
lr_schedule_config: the configuration for the learning rate schedule.
evaluation_config: the evaluation configuration.
checkpointing_config: the configuration for checkpointing.
"""
self._random_seed = random_seed
self._enable_double_transpose = enable_double_transpose
self._num_classes = num_classes
self._lr_schedule_config = lr_schedule_config
self._batch_size = batch_size
self._max_steps = max_steps
self._checkpoint_to_evaluate = checkpoint_to_evaluate
self._allow_train_from_scratch = allow_train_from_scratch
self._freeze_backbone = freeze_backbone
self._optimizer_config = optimizer_config
self._evaluation_config = evaluation_config
# Checkpointed experiment state.
self._experiment_state = None
# Input pipelines.
self._train_input = None
self._eval_input = None
backbone_fn = functools.partial(self._backbone_fn, **network_config)
self.forward_backbone = hk.without_apply_rng(
hk.transform_with_state(backbone_fn))
self.forward_classif = hk.without_apply_rng(hk.transform(self._classif_fn))
self.update_pmap = jax.pmap(self._update_func, axis_name='i')
self.eval_batch_jit = jax.jit(self._eval_batch)
self._is_backbone_training = not self._freeze_backbone
self._checkpointer = checkpointing.Checkpointer(**checkpointing_config)
def _should_transpose_images(self):
"""Should we transpose images (saves host-to-device time on TPUs)."""
return (self._enable_double_transpose and
jax.local_devices()[0].platform == 'tpu')
def _backbone_fn(
self,
inputs,
encoder_class,
encoder_config,
bn_decay_rate,
is_training,
):
"""Forward of the encoder (backbone)."""
bn_config = {'decay_rate': bn_decay_rate}
encoder = getattr(networks, encoder_class)
model = encoder(
None,
bn_config=bn_config,
**encoder_config)
if self._should_transpose_images():
inputs = dataset.transpose_images(inputs)
images = dataset.normalize_images(inputs['images'])
return model(images, is_training=is_training)
def _classif_fn(
self,
embeddings,
):
classifier = hk.Linear(output_size=self._num_classes)
return classifier(embeddings)
# _ _
# | |_ _ __ __ _(_)_ __
# | __| '__/ _` | | '_ \
# | |_| | | (_| | | | | |
# \__|_| \__,_|_|_| |_|
#
def step(self, *,
global_step,
rng):
"""Performs a single training step."""
if self._train_input is None:
self._initialize_train(rng)
inputs = next(self._train_input)
self._experiment_state, scalars = self.update_pmap(
self._experiment_state, global_step, inputs)
scalars = helpers.get_first(scalars)
return scalars
def save_checkpoint(self, step, rng):
self._checkpointer.maybe_save_checkpoint(
self._experiment_state, step=step, rng=rng,
is_final=step >= self._max_steps)
def load_checkpoint(self):
checkpoint_data = self._checkpointer.maybe_load_checkpoint()
if checkpoint_data is None:
return None
self._experiment_state, step, rng = checkpoint_data
return step, rng
def _initialize_train(self, rng):
"""BYOL's _ExperimentState initialization.
Args:
rng: random number generator used to initialize parameters. If working in
a multi device setup, this need to be a ShardedArray.
dummy_input: a dummy image, used to compute intermediate outputs shapes.
Returns:
Initial EvalExperiment state.
Raises:
RuntimeError: invalid or empty checkpoint.
"""
self._train_input = acme_utils.prefetch(self._build_train_input())
# Check we haven't already restored params
if self._experiment_state is None:
inputs = next(self._train_input)
if self._checkpoint_to_evaluate is not None:
# Load params from checkpoint
checkpoint_data = checkpointing.load_checkpoint(
self._checkpoint_to_evaluate)
if checkpoint_data is None:
raise RuntimeError('Invalid checkpoint.')
backbone_params = checkpoint_data['experiment_state'].online_params
backbone_state = checkpoint_data['experiment_state'].online_state
backbone_params = helpers.bcast_local_devices(backbone_params)
backbone_state = helpers.bcast_local_devices(backbone_state)
else:
if not self._allow_train_from_scratch:
raise ValueError(
'No checkpoint specified, but `allow_train_from_scratch` '
'set to False')
# Initialize with random parameters
logging.info(
'No checkpoint specified, initializing the networks from scratch '
'(dry run mode)')
backbone_params, backbone_state = jax.pmap(
functools.partial(self.forward_backbone.init, is_training=True),
axis_name='i')(rng=rng, inputs=inputs)
init_experiment = jax.pmap(self._make_initial_state, axis_name='i')
# Init uses the same RNG key on all hosts+devices to ensure everyone
# computes the same initial state and parameters.
init_rng = jax.random.PRNGKey(self._random_seed)
init_rng = helpers.bcast_local_devices(init_rng)
self._experiment_state = init_experiment(
rng=init_rng,
dummy_input=inputs,
backbone_params=backbone_params,
backbone_state=backbone_state)
# Clear the backbone optimizer's state when the backbone is frozen.
if self._freeze_backbone:
self._experiment_state = _EvalExperimentState(
backbone_params=self._experiment_state.backbone_params,
classif_params=self._experiment_state.classif_params,
backbone_state=self._experiment_state.backbone_state,
backbone_opt_state=None,
classif_opt_state=self._experiment_state.classif_opt_state,
)
def _make_initial_state(
self,
rng,
dummy_input,
backbone_params,
backbone_state,
):
"""_EvalExperimentState initialization."""
# Initialize the backbone params
# Always create the batchnorm weights (is_training=True), they will be
# overwritten when loading the checkpoint.
embeddings, _ = self.forward_backbone.apply(
backbone_params, backbone_state, dummy_input, is_training=True)
backbone_opt_state = self._optimizer(0.).init(backbone_params)
# Initialize the classifier params and optimizer_state
classif_params = self.forward_classif.init(rng, embeddings)
classif_opt_state = self._optimizer(0.).init(classif_params)
return _EvalExperimentState(
backbone_params=backbone_params,
classif_params=classif_params,
backbone_state=backbone_state,
backbone_opt_state=backbone_opt_state,
classif_opt_state=classif_opt_state,
)
def _build_train_input(self):
"""See base class."""
num_devices = jax.device_count()
global_batch_size = self._batch_size
per_device_batch_size, ragged = divmod(global_batch_size, num_devices)
if ragged:
raise ValueError(
f'Global batch size {global_batch_size} must be divisible by '
f'num devices {num_devices}')
return dataset.load(
dataset.Split.TRAIN_AND_VALID,
preprocess_mode=dataset.PreprocessMode.LINEAR_TRAIN,
transpose=self._should_transpose_images(),
batch_dims=[jax.local_device_count(), per_device_batch_size])
def _optimizer(self, learning_rate):
"""Build optimizer from config."""
return optax.sgd(learning_rate, **self._optimizer_config)
def _loss_fn(
self,
backbone_params,
classif_params,
backbone_state,
inputs,
):
"""Compute the classification loss function.
Args:
backbone_params: parameters of the encoder network.
classif_params: parameters of the linear classifier.
backbone_state: internal state of encoder network.
inputs: inputs, containing `images` and `labels`.
Returns:
The classification loss and various logs.
"""
embeddings, backbone_state = self.forward_backbone.apply(
backbone_params,
backbone_state,
inputs,
is_training=not self._freeze_backbone)
logits = self.forward_classif.apply(classif_params, embeddings)
labels = hk.one_hot(inputs['labels'], self._num_classes)
loss = helpers.softmax_cross_entropy(logits, labels, reduction='mean')
scaled_loss = loss / jax.device_count()
return scaled_loss, (loss, backbone_state)
def _update_func(
self,
experiment_state,
global_step,
inputs,
):
"""Applies an update to parameters and returns new state."""
# This function computes the gradient of the first output of loss_fn and
# passes through the other arguments unchanged.
# Gradient of the first output of _loss_fn wrt the backbone (arg 0) and the
# classifier parameters (arg 1). The auxiliary outputs are returned as-is.
grad_loss_fn = jax.grad(self._loss_fn, has_aux=True, argnums=(0, 1))
grads, aux_outputs = grad_loss_fn(
experiment_state.backbone_params,
experiment_state.classif_params,
experiment_state.backbone_state,
inputs,
)
backbone_grads, classifier_grads = grads
train_loss, new_backbone_state = aux_outputs
classifier_grads = jax.lax.psum(classifier_grads, axis_name='i')
# Compute the decayed learning rate
learning_rate = schedules.learning_schedule(
global_step,
batch_size=self._batch_size,
total_steps=self._max_steps,
**self._lr_schedule_config)
# Compute and apply updates via our optimizer.
classif_updates, new_classif_opt_state = \
self._optimizer(learning_rate).update(
classifier_grads,
experiment_state.classif_opt_state)
new_classif_params = optax.apply_updates(experiment_state.classif_params,
classif_updates)
if self._freeze_backbone:
del backbone_grads, new_backbone_state # Unused
# The backbone is not updated.
new_backbone_params = experiment_state.backbone_params
new_backbone_opt_state = None
new_backbone_state = experiment_state.backbone_state
else:
backbone_grads = jax.lax.psum(backbone_grads, axis_name='i')
# Compute and apply updates via our optimizer.
backbone_updates, new_backbone_opt_state = \
self._optimizer(learning_rate).update(
backbone_grads,
experiment_state.backbone_opt_state)
new_backbone_params = optax.apply_updates(
experiment_state.backbone_params, backbone_updates)
experiment_state = _EvalExperimentState(
new_backbone_params,
new_classif_params,
new_backbone_state,
new_backbone_opt_state,
new_classif_opt_state,
)
# Scalars to log (note: we log the mean across all hosts/devices).
scalars = {'train_loss': train_loss}
scalars = jax.lax.pmean(scalars, axis_name='i')
return experiment_state, scalars
# _
# _____ ____ _| |
# / _ \ \ / / _` | |
# | __/\ V / (_| | |
# \___| \_/ \__,_|_|
#
def evaluate(self, global_step, **unused_args):
"""See base class."""
global_step = np.array(helpers.get_first(global_step))
scalars = jax.device_get(self._eval_epoch(**self._evaluation_config))
logging.info('[Step %d] Eval scalars: %s', global_step, scalars)
return scalars
def _eval_batch(
self,
backbone_params,
classif_params,
backbone_state,
inputs,
):
"""Evaluates a batch."""
embeddings, backbone_state = self.forward_backbone.apply(
backbone_params, backbone_state, inputs, is_training=False)
logits = self.forward_classif.apply(classif_params, embeddings)
labels = hk.one_hot(inputs['labels'], self._num_classes)
loss = helpers.softmax_cross_entropy(logits, labels, reduction=None)
top1_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=1)
top5_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=5)
# NOTE: Returned values will be summed and finally divided by num_samples.
return {
'eval_loss': loss,
'top1_accuracy': top1_correct,
'top5_accuracy': top5_correct
}
def _eval_epoch(self, subset, batch_size):
"""Evaluates an epoch."""
num_samples = 0.
summed_scalars = None
backbone_params = helpers.get_first(self._experiment_state.backbone_params)
classif_params = helpers.get_first(self._experiment_state.classif_params)
backbone_state = helpers.get_first(self._experiment_state.backbone_state)
split = dataset.Split.from_string(subset)
dataset_iterator = dataset.load(
split,
preprocess_mode=dataset.PreprocessMode.EVAL,
transpose=self._should_transpose_images(),
batch_dims=[batch_size])
for inputs in dataset_iterator:
num_samples += inputs['labels'].shape[0]
scalars = self.eval_batch_jit(
backbone_params,
classif_params,
backbone_state,
inputs,
)
# Accumulate the sum of scalars for each step.
scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars)
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars
+157
View File
@@ -0,0 +1,157 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training and evaluation loops for an experiment."""
import time
from typing import Any, Mapping, Text, Type, Union
from absl import app
from absl import flags
from absl import logging
import jax
import numpy as np
from byol import byol_experiment
from byol import eval_experiment
from byol.configs import byol as byol_config
from byol.configs import eval as eval_config
flags.DEFINE_string('experiment_mode',
'pretrain', 'The experiment, pretrain or linear-eval')
flags.DEFINE_string('worker_mode', 'train', 'The mode, train or eval')
flags.DEFINE_string('worker_tpu_driver', '', 'The tpu driver to use')
flags.DEFINE_integer('pretrain_epochs', 1000, 'Number of pre-training epochs')
flags.DEFINE_integer('batch_size', 4096, 'Total batch size')
flags.DEFINE_string('checkpoint_root', '/tmp/byol',
'The directory to save checkpoints to.')
flags.DEFINE_integer('log_tensors_interval', 60, 'Log tensors every n seconds.')
FLAGS = flags.FLAGS
Experiment = Union[
Type[byol_experiment.ByolExperiment],
Type[eval_experiment.EvalExperiment]]
def train_loop(experiment_class, config):
"""The main training loop.
This loop periodically saves a checkpoint to be evaluated in the eval_loop.
Args:
experiment_class: the constructor for the experiment (either byol_experiment
or eval_experiment).
config: the experiment config.
"""
experiment = experiment_class(**config)
rng = jax.random.PRNGKey(0)
step = 0
host_id = jax.host_id()
last_logging = time.time()
if config['checkpointing_config']['use_checkpointing']:
checkpoint_data = experiment.load_checkpoint()
if checkpoint_data is None:
step = 0
else:
step, rng = checkpoint_data
local_device_count = jax.local_device_count()
while step < config['max_steps']:
step_rng, rng = tuple(jax.random.split(rng))
# Broadcast the random seeds across the devices
step_rng_device = jax.random.split(step_rng, num=jax.device_count())
step_rng_device = step_rng_device[
host_id * local_device_count:(host_id + 1) * local_device_count]
step_device = np.broadcast_to(step, [local_device_count])
# Perform a training step and get scalars to log.
scalars = experiment.step(global_step=step_device, rng=step_rng_device)
# Checkpointing and logging.
if config['checkpointing_config']['use_checkpointing']:
experiment.save_checkpoint(step, rng)
current_time = time.time()
if current_time - last_logging > FLAGS.log_tensors_interval:
logging.info('Step %d: %s', step, scalars)
last_logging = current_time
step += 1
logging.info('Saving final checkpoint')
logging.info('Step %d: %s', step, scalars)
experiment.save_checkpoint(step, rng)
def eval_loop(experiment_class, config):
"""The main evaluation loop.
This loop periodically loads a checkpoint and evaluates its performance on the
test set, by calling experiment.evaluate.
Args:
experiment_class: the constructor for the experiment (either byol_experiment
or eval_experiment).
config: the experiment config.
"""
experiment = experiment_class(**config)
last_evaluated_step = -1
while True:
checkpoint_data = experiment.load_checkpoint()
if checkpoint_data is None:
logging.info('No checkpoint found. Waiting for 10s.')
time.sleep(10)
continue
step, _ = checkpoint_data
if step <= last_evaluated_step:
logging.info('Checkpoint at step %d already evaluated, waiting.', step)
time.sleep(10)
continue
host_id = jax.host_id()
local_device_count = jax.local_device_count()
step_device = np.broadcast_to(step, [local_device_count])
scalars = experiment.evaluate(global_step=step_device)
if host_id == 0: # Only perform logging in one host.
logging.info('Evaluation at step %d: %s', step, scalars)
last_evaluated_step = step
if last_evaluated_step >= config['max_steps']:
return
def main(_):
if FLAGS.worker_tpu_driver:
jax.config.update('jax_xla_backend', 'tpu_driver')
jax.config.update('jax_backend_target', FLAGS.worker_tpu_driver)
logging.info('Backend: %s %r', FLAGS.worker_tpu_driver, jax.devices())
if FLAGS.experiment_mode == 'pretrain':
experiment_class = byol_experiment.ByolExperiment
config = byol_config.get_config(FLAGS.pretrain_epochs, FLAGS.batch_size)
elif FLAGS.experiment_mode == 'linear-eval':
experiment_class = eval_experiment.EvalExperiment
config = eval_config.get_config(f'{FLAGS.checkpoint_root}/pretrain.pkl',
FLAGS.batch_size)
else:
raise ValueError(f'Unknown experiment mode: {FLAGS.experiment_mode}')
config['checkpointing_config']['checkpoint_dir'] = FLAGS.checkpoint_root
if FLAGS.worker_mode == 'train':
train_loop(experiment_class, config)
elif FLAGS.worker_mode == 'eval':
eval_loop(experiment_class, config)
if __name__ == '__main__':
app.run(main)
+97
View File
@@ -0,0 +1,97 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for BYOL's main training loop."""
from absl import flags
from absl.testing import absltest
import tensorflow_datasets as tfds
from byol import byol_experiment
from byol import eval_experiment
from byol import main_loop
from byol.configs import byol as byol_config
from byol.configs import eval as eval_config
FLAGS = flags.FLAGS
class MainLoopTest(absltest.TestCase):
def test_pretrain(self):
config = byol_config.get_config(num_epochs=40, batch_size=4)
temp_dir = self.create_tempdir().full_path
# Override some config fields to make test lighter.
config['network_config']['encoder_class'] = 'TinyResNet'
config['network_config']['projector_hidden_size'] = 256
config['network_config']['predictor_hidden_size'] = 256
config['checkpointing_config']['checkpoint_dir'] = temp_dir
config['evaluation_config']['batch_size'] = 16
config['max_steps'] = 16
with tfds.testing.mock_data(num_examples=64):
experiment_class = byol_experiment.ByolExperiment
main_loop.train_loop(experiment_class, config)
main_loop.eval_loop(experiment_class, config)
def test_linear_eval(self):
config = eval_config.get_config(checkpoint_to_evaluate=None, batch_size=4)
temp_dir = self.create_tempdir().full_path
# Override some config fields to make test lighter.
config['network_config']['encoder_class'] = 'TinyResNet'
config['allow_train_from_scratch'] = True
config['checkpointing_config']['checkpoint_dir'] = temp_dir
config['evaluation_config']['batch_size'] = 16
config['max_steps'] = 16
with tfds.testing.mock_data(num_examples=64):
experiment_class = eval_experiment.EvalExperiment
main_loop.train_loop(experiment_class, config)
main_loop.eval_loop(experiment_class, config)
def test_pipeline(self):
b_config = byol_config.get_config(num_epochs=40, batch_size=4)
temp_dir = self.create_tempdir().full_path
# Override some config fields to make test lighter.
b_config['network_config']['encoder_class'] = 'TinyResNet'
b_config['network_config']['projector_hidden_size'] = 256
b_config['network_config']['predictor_hidden_size'] = 256
b_config['checkpointing_config']['checkpoint_dir'] = temp_dir
b_config['evaluation_config']['batch_size'] = 16
b_config['max_steps'] = 16
with tfds.testing.mock_data(num_examples=64):
main_loop.train_loop(byol_experiment.ByolExperiment, b_config)
e_config = eval_config.get_config(
checkpoint_to_evaluate=f'{temp_dir}/pretrain.pkl',
batch_size=4)
# Override some config fields to make test lighter.
e_config['network_config']['encoder_class'] = 'TinyResNet'
e_config['allow_train_from_scratch'] = True
e_config['checkpointing_config']['checkpoint_dir'] = temp_dir
e_config['evaluation_config']['batch_size'] = 16
e_config['max_steps'] = 16
with tfds.testing.mock_data(num_examples=64):
main_loop.train_loop(eval_experiment.EvalExperiment, e_config)
if __name__ == '__main__':
absltest.main()
+8
View File
@@ -0,0 +1,8 @@
dm-acme
dm-haiku
dm-tree
jax
jaxlib
numpy>=1.16
tensorflow
tensorflow_datasets
+457
View File
@@ -0,0 +1,457 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data preprocessing and augmentation."""
import functools
from typing import Any, Mapping, Text
import jax
import jax.numpy as jnp
# typing
JaxBatch = Mapping[Text, jnp.ndarray]
ConfigDict = Mapping[Text, Any]
augment_config = dict(
view1=dict(
random_flip=True, # Random left/right flip
color_transform=dict(
apply_prob=1.0,
# Range of jittering
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1,
# Probability of applying color jittering
color_jitter_prob=0.8,
# Probability of converting to grayscale
to_grayscale_prob=0.2,
# Shuffle the order of color transforms
shuffle=True),
gaussian_blur=dict(
apply_prob=1.0,
# Kernel size ~ image_size / blur_divider
blur_divider=10.,
# Kernel distribution
sigma_min=0.1,
sigma_max=2.0),
solarize=dict(apply_prob=0.0, threshold=0.5),
),
view2=dict(
random_flip=True,
color_transform=dict(
apply_prob=1.0,
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1,
color_jitter_prob=0.8,
to_grayscale_prob=0.2,
shuffle=True),
gaussian_blur=dict(
apply_prob=0.1, blur_divider=10., sigma_min=0.1, sigma_max=2.0),
solarize=dict(apply_prob=0.2, threshold=0.5),
))
def postprocess(inputs, rng):
"""Apply the image augmentations to crops in inputs (view1 and view2)."""
def _postprocess_image(
images,
rng,
presets,
):
"""Applies augmentations in post-processing.
Args:
images: an NHWC tensor (with C=3), with float values in [0, 1].
rng: a single PRNGKey.
presets: a dict of presets for the augmentations.
Returns:
A batch of augmented images with shape NHWC, with keys view1, view2
and labels.
"""
flip_rng, color_rng, blur_rng, solarize_rng = jax.random.split(rng, 4)
out = images
if presets['random_flip']:
out = random_flip(out, flip_rng)
if presets['color_transform']['apply_prob'] > 0:
out = color_transform(out, color_rng, **presets['color_transform'])
if presets['gaussian_blur']['apply_prob'] > 0:
out = gaussian_blur(out, blur_rng, **presets['gaussian_blur'])
if presets['solarize']['apply_prob'] > 0:
out = solarize(out, solarize_rng, **presets['solarize'])
out = jnp.clip(out, 0., 1.)
return jax.lax.stop_gradient(out)
rng1, rng2 = jax.random.split(rng, num=2)
view1 = _postprocess_image(inputs['view1'], rng1, augment_config['view1'])
view2 = _postprocess_image(inputs['view2'], rng2, augment_config['view2'])
return dict(view1=view1, view2=view2, labels=inputs['labels'])
def _maybe_apply(apply_fn, inputs, rng, apply_prob):
should_apply = jax.random.uniform(rng, shape=()) <= apply_prob
return jax.lax.cond(should_apply, inputs, apply_fn, inputs, lambda x: x)
def _depthwise_conv2d(inputs, kernel, strides, padding):
"""Computes a depthwise conv2d in Jax.
Args:
inputs: an NHWC tensor with N=1.
kernel: a [H", W", 1, C] tensor.
strides: a 2d tensor.
padding: "SAME" or "VALID".
Returns:
The depthwise convolution of inputs with kernel, as [H, W, C].
"""
return jax.lax.conv_general_dilated(
inputs,
kernel,
strides,
padding,
feature_group_count=inputs.shape[-1],
dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
def _gaussian_blur_single_image(image, kernel_size, padding, sigma):
"""Applies gaussian blur to a single image, given as NHWC with N=1."""
radius = int(kernel_size / 2)
kernel_size_ = 2 * radius + 1
x = jnp.arange(-radius, radius + 1).astype(jnp.float32)
blur_filter = jnp.exp(-x**2 / (2. * sigma**2))
blur_filter = blur_filter / jnp.sum(blur_filter)
blur_v = jnp.reshape(blur_filter, [kernel_size_, 1, 1, 1])
blur_h = jnp.reshape(blur_filter, [1, kernel_size_, 1, 1])
num_channels = image.shape[-1]
blur_h = jnp.tile(blur_h, [1, 1, 1, num_channels])
blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels])
expand_batch_dim = len(image.shape) == 3
if expand_batch_dim:
image = image[jnp.newaxis, Ellipsis]
blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding)
blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding)
blurred = jnp.squeeze(blurred, axis=0)
return blurred
def _random_gaussian_blur(image, rng, kernel_size, padding, sigma_min,
sigma_max, apply_prob):
"""Applies a random gaussian blur."""
apply_rng, transform_rng = jax.random.split(rng)
def _apply(image):
sigma_rng, = jax.random.split(transform_rng, 1)
sigma = jax.random.uniform(
sigma_rng,
shape=(),
minval=sigma_min,
maxval=sigma_max,
dtype=jnp.float32)
return _gaussian_blur_single_image(image, kernel_size, padding, sigma)
return _maybe_apply(_apply, image, apply_rng, apply_prob)
def rgb_to_hsv(r, g, b):
"""Converts R, G, B values to H, S, V values.
Reference TF implementation:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
Only input values between 0 and 1 are guaranteed to work properly, but this
function complies with the TF implementation outside of this range.
Args:
r: A tensor representing the red color component as floats.
g: A tensor representing the green color component as floats.
b: A tensor representing the blue color component as floats.
Returns:
H, S, V values, each as tensors of shape [...] (same as the input without
the last dimension).
"""
vv = jnp.maximum(jnp.maximum(r, g), b)
range_ = vv - jnp.minimum(jnp.minimum(r, g), b)
sat = jnp.where(vv > 0, range_ / vv, 0.)
norm = jnp.where(range_ != 0, 1. / (6. * range_), 1e9)
hr = norm * (g - b)
hg = norm * (b - r) + 2. / 6.
hb = norm * (r - g) + 4. / 6.
hue = jnp.where(r == vv, hr, jnp.where(g == vv, hg, hb))
hue = hue * (range_ > 0)
hue = hue + (hue < 0)
return hue, sat, vv
def hsv_to_rgb(h, s, v):
"""Converts H, S, V values to an R, G, B tuple.
Reference TF implementation:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
Only input values between 0 and 1 are guaranteed to work properly, but this
function complies with the TF implementation outside of this range.
Args:
h: A float tensor of arbitrary shape for the hue (0-1 values).
s: A float tensor of the same shape for the saturation (0-1 values).
v: A float tensor of the same shape for the value channel (0-1 values).
Returns:
An (r, g, b) tuple, each with the same dimension as the inputs.
"""
c = s * v
m = v - c
dh = (h % 1.) * 6.
fmodu = dh % 2.
x = c * (1 - jnp.abs(fmodu - 1))
hcat = jnp.floor(dh).astype(jnp.int32)
rr = jnp.where(
(hcat == 0) | (hcat == 5), c, jnp.where(
(hcat == 1) | (hcat == 4), x, 0)) + m
gg = jnp.where(
(hcat == 1) | (hcat == 2), c, jnp.where(
(hcat == 0) | (hcat == 3), x, 0)) + m
bb = jnp.where(
(hcat == 3) | (hcat == 4), c, jnp.where(
(hcat == 2) | (hcat == 5), x, 0)) + m
return rr, gg, bb
def adjust_brightness(rgb_tuple, delta):
return jax.tree_map(lambda x: x + delta, rgb_tuple)
def adjust_contrast(image, factor):
def _adjust_contrast_channel(channel):
mean = jnp.mean(channel, axis=(-2, -1), keepdims=True)
return factor * (channel - mean) + mean
return jax.tree_map(_adjust_contrast_channel, image)
def adjust_saturation(h, s, v, factor):
return h, jnp.clip(s * factor, 0., 1.), v
def adjust_hue(h, s, v, delta):
# Note: this method exactly matches TF"s adjust_hue (combined with the hsv/rgb
# conversions) when running on GPU. When running on CPU, the results will be
# different if all RGB values for a pixel are outside of the [0, 1] range.
return (h + delta) % 1.0, s, v
def _random_brightness(rgb_tuple, rng, max_delta):
delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
return adjust_brightness(rgb_tuple, delta)
def _random_contrast(rgb_tuple, rng, max_delta):
factor = jax.random.uniform(
rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta)
return adjust_contrast(rgb_tuple, factor)
def _random_saturation(rgb_tuple, rng, max_delta):
h, s, v = rgb_to_hsv(*rgb_tuple)
factor = jax.random.uniform(
rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta)
return hsv_to_rgb(*adjust_saturation(h, s, v, factor))
def _random_hue(rgb_tuple, rng, max_delta):
h, s, v = rgb_to_hsv(*rgb_tuple)
delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
return hsv_to_rgb(*adjust_hue(h, s, v, delta))
def _to_grayscale(image):
rgb_weights = jnp.array([0.2989, 0.5870, 0.1140])
grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[Ellipsis, jnp.newaxis]
return jnp.tile(grayscale, (1, 1, 3)) # Back to 3 channels.
def _color_transform_single_image(image, rng, brightness, contrast, saturation,
hue, to_grayscale_prob, color_jitter_prob,
apply_prob, shuffle):
"""Applies color jittering to a single image."""
apply_rng, transform_rng = jax.random.split(rng)
perm_rng, b_rng, c_rng, s_rng, h_rng, cj_rng, gs_rng = jax.random.split(
transform_rng, 7)
# Whether the transform should be applied at all.
should_apply = jax.random.uniform(apply_rng, shape=()) <= apply_prob
# Whether to apply grayscale transform.
should_apply_gs = jax.random.uniform(gs_rng, shape=()) <= to_grayscale_prob
# Whether to apply color jittering.
should_apply_color = jax.random.uniform(cj_rng, shape=()) <= color_jitter_prob
# Decorator to conditionally apply fn based on an index.
def _make_cond(fn, idx):
def identity_fn(x, unused_rng, unused_param):
return x
def cond_fn(args, i):
def clip(args):
return jax.tree_map(lambda arg: jnp.clip(arg, 0., 1.), args)
out = jax.lax.cond(should_apply & should_apply_color & (i == idx), args,
lambda a: clip(fn(*a)), args,
lambda a: identity_fn(*a))
return jax.lax.stop_gradient(out)
return cond_fn
random_brightness_cond = _make_cond(_random_brightness, idx=0)
random_contrast_cond = _make_cond(_random_contrast, idx=1)
random_saturation_cond = _make_cond(_random_saturation, idx=2)
random_hue_cond = _make_cond(_random_hue, idx=3)
def _color_jitter(x):
rgb_tuple = tuple(jax.tree_map(jnp.squeeze, jnp.split(x, 3, axis=-1)))
if shuffle:
order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32))
else:
order = range(4)
for idx in order:
if brightness > 0:
rgb_tuple = random_brightness_cond((rgb_tuple, b_rng, brightness), idx)
if contrast > 0:
rgb_tuple = random_contrast_cond((rgb_tuple, c_rng, contrast), idx)
if saturation > 0:
rgb_tuple = random_saturation_cond((rgb_tuple, s_rng, saturation), idx)
if hue > 0:
rgb_tuple = random_hue_cond((rgb_tuple, h_rng, hue), idx)
return jnp.stack(rgb_tuple, axis=-1)
out_apply = _color_jitter(image)
out_apply = jax.lax.cond(should_apply & should_apply_gs, out_apply,
_to_grayscale, out_apply, lambda x: x)
return jnp.clip(out_apply, 0., 1.)
def _random_flip_single_image(image, rng):
_, flip_rng = jax.random.split(rng)
should_flip_lr = jax.random.uniform(flip_rng, shape=()) <= 0.5
image = jax.lax.cond(should_flip_lr, image, jnp.fliplr, image, lambda x: x)
return image
def random_flip(images, rng):
rngs = jax.random.split(rng, images.shape[0])
return jax.vmap(_random_flip_single_image)(images, rngs)
def color_transform(images,
rng,
brightness=0.8,
contrast=0.8,
saturation=0.8,
hue=0.2,
color_jitter_prob=0.8,
to_grayscale_prob=0.2,
apply_prob=1.0,
shuffle=True):
"""Applies color jittering and/or grayscaling to a batch of images.
Args:
images: an NHWC tensor, with C=3.
rng: a single PRNGKey.
brightness: the range of jitter on brightness.
contrast: the range of jitter on contrast.
saturation: the range of jitter on saturation.
hue: the range of jitter on hue.
color_jitter_prob: the probability of applying color jittering.
to_grayscale_prob: the probability of converting the image to grayscale.
apply_prob: the probability of applying the transform to a batch element.
shuffle: whether to apply the transforms in a random order.
Returns:
A NHWC tensor of the transformed images.
"""
rngs = jax.random.split(rng, images.shape[0])
jitter_fn = functools.partial(
_color_transform_single_image,
brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue,
color_jitter_prob=color_jitter_prob,
to_grayscale_prob=to_grayscale_prob,
apply_prob=apply_prob,
shuffle=shuffle)
return jax.vmap(jitter_fn)(images, rngs)
def gaussian_blur(images,
rng,
blur_divider=10.,
sigma_min=0.1,
sigma_max=2.0,
apply_prob=1.0):
"""Applies gaussian blur to a batch of images.
Args:
images: an NHWC tensor, with C=3.
rng: a single PRNGKey.
blur_divider: the blurring kernel will have size H / blur_divider.
sigma_min: the minimum value for sigma in the blurring kernel.
sigma_max: the maximum value for sigma in the blurring kernel.
apply_prob: the probability of applying the transform to a batch element.
Returns:
A NHWC tensor of the blurred images.
"""
rngs = jax.random.split(rng, images.shape[0])
kernel_size = images.shape[1] / blur_divider
blur_fn = functools.partial(
_random_gaussian_blur,
kernel_size=kernel_size,
padding='SAME',
sigma_min=sigma_min,
sigma_max=sigma_max,
apply_prob=apply_prob)
return jax.vmap(blur_fn)(images, rngs)
def _solarize_single_image(image, rng, threshold, apply_prob):
def _apply(image):
return jnp.where(image < threshold, image, 1. - image)
return _maybe_apply(_apply, image, rng, apply_prob)
def solarize(images, rng, threshold=0.5, apply_prob=1.0):
"""Applies solarization.
Args:
images: an NHWC tensor (with C=3).
rng: a single PRNGKey.
threshold: the solarization threshold.
apply_prob: the probability of applying the transform to a batch element.
Returns:
A NHWC tensor of the transformed images.
"""
rngs = jax.random.split(rng, images.shape[0])
solarize_fn = functools.partial(
_solarize_single_image, threshold=threshold, apply_prob=apply_prob)
return jax.vmap(solarize_fn)(images, rngs)
+105
View File
@@ -0,0 +1,105 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint saving and restoring utilities."""
import os
import time
from typing import Mapping, Text, Tuple, Union
from absl import logging
import dill
import jax
import jax.numpy as jnp
from byol.utils import helpers
class Checkpointer:
"""A checkpoint saving and loading class."""
def __init__(
self,
use_checkpointing,
checkpoint_dir,
save_checkpoint_interval,
filename):
if (not use_checkpointing or
checkpoint_dir is None or
save_checkpoint_interval <= 0):
self._checkpoint_enabled = False
return
self._checkpoint_enabled = True
self._checkpoint_dir = checkpoint_dir
os.makedirs(self._checkpoint_dir, exist_ok=True)
self._filename = filename
self._checkpoint_path = os.path.join(self._checkpoint_dir, filename)
self._last_checkpoint_time = 0
self._checkpoint_every = save_checkpoint_interval
def maybe_save_checkpoint(
self,
experiment_state,
step,
rng,
is_final):
"""Saves a checkpoint if enough time has passed since the previous one."""
current_time = time.time()
if (not self._checkpoint_enabled or
jax.host_id() != 0 or # Only checkpoint the first worker.
(not is_final and
current_time - self._last_checkpoint_time < self._checkpoint_every)):
return
checkpoint_data = dict(
experiment_state=jax.tree_map(
lambda x: jax.device_get(x[0]), experiment_state),
step=step,
rng=rng)
with open(self._checkpoint_path + '_tmp', 'wb') as checkpoint_file:
dill.dump(checkpoint_data, checkpoint_file, protocol=2)
try:
os.rename(self._checkpoint_path, self._checkpoint_path + '_old')
remove_old = True
except FileNotFoundError:
remove_old = False # No previous checkpoint to remove
os.rename(self._checkpoint_path + '_tmp', self._checkpoint_path)
if remove_old:
os.remove(self._checkpoint_path + '_old')
self._last_checkpoint_time = current_time
def maybe_load_checkpoint(
self):
"""Loads a checkpoint if any is found."""
checkpoint_data = load_checkpoint(self._checkpoint_path)
if checkpoint_data is None:
logging.info('No existing checkpoint found at %s', self._checkpoint_path)
return None
step = checkpoint_data['step']
rng = checkpoint_data['rng']
experiment_state = jax.tree_map(
helpers.bcast_local_devices, checkpoint_data['experiment_state'])
del checkpoint_data
return experiment_state, step, rng
def load_checkpoint(checkpoint_path):
try:
with open(checkpoint_path, 'rb') as checkpoint_file:
checkpoint_data = dill.load(checkpoint_file)
logging.info('Loading checkpoint from %s, saved at step %d',
checkpoint_path, checkpoint_data['step'])
return checkpoint_data
except FileNotFoundError:
return None
+266
View File
@@ -0,0 +1,266 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ImageNet dataset with typical pre-processing."""
import enum
from typing import Generator, Mapping, Optional, Sequence, Text, Tuple
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
Batch = Mapping[Text, np.ndarray]
class Split(enum.Enum):
"""Imagenet dataset split."""
TRAIN = 1
TRAIN_AND_VALID = 2
VALID = 3
TEST = 4
@classmethod
def from_string(cls, name):
return {
'TRAIN': Split.TRAIN,
'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
'VALID': Split.VALID,
'VALIDATION': Split.VALID,
'TEST': Split.TEST
}[name.upper()]
@property
def num_examples(self):
return {
Split.TRAIN_AND_VALID: 1281167,
Split.TRAIN: 1271167,
Split.VALID: 10000,
Split.TEST: 50000
}[self]
class PreprocessMode(enum.Enum):
"""Preprocessing modes for the dataset."""
PRETRAIN = 1 # Generates two augmented views (random crop + augmentations).
LINEAR_TRAIN = 2 # Generates a single random crop.
EVAL = 3 # Generates a single center crop.
def normalize_images(images):
"""Normalize the image using ImageNet statistics."""
mean_rgb = (0.485, 0.456, 0.406)
stddev_rgb = (0.229, 0.224, 0.225)
normed_images = images - jnp.array(mean_rgb).reshape((1, 1, 1, 3))
normed_images = normed_images / jnp.array(stddev_rgb).reshape((1, 1, 1, 3))
return normed_images
def load(split,
*,
preprocess_mode,
batch_dims,
transpose = False,
allow_caching = False):
"""Loads the given split of the dataset."""
start, end = _shard(split, jax.host_id(), jax.host_count())
total_batch_size = np.prod(batch_dims)
tfds_split = tfds.core.ReadInstruction(
_to_tfds_split(split), from_=start, to=end, unit='abs')
ds = tfds.load(
'imagenet2012:5.*.*',
split=tfds_split,
decoders={'image': tfds.decode.SkipDecoding()})
options = ds.options()
options.experimental_threading.private_threadpool_size = 48
options.experimental_threading.max_intra_op_parallelism = 1
if preprocess_mode is not PreprocessMode.EVAL:
options.experimental_deterministic = False
if jax.host_count() > 1 and allow_caching:
# Only cache if we are reading a subset of the dataset.
ds = ds.cache()
ds = ds.repeat()
ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0)
else:
if split.num_examples % total_batch_size != 0:
raise ValueError(f'Test/valid must be divisible by {total_batch_size}')
def preprocess_pretrain(example):
view1 = _preprocess_image(example['image'], mode=preprocess_mode)
view2 = _preprocess_image(example['image'], mode=preprocess_mode)
label = tf.cast(example['label'], tf.int32)
return {'view1': view1, 'view2': view2, 'labels': label}
def preprocess_linear_train(example):
image = _preprocess_image(example['image'], mode=preprocess_mode)
label = tf.cast(example['label'], tf.int32)
return {'images': image, 'labels': label}
def preprocess_eval(example):
image = _preprocess_image(example['image'], mode=preprocess_mode)
label = tf.cast(example['label'], tf.int32)
return {'images': image, 'labels': label}
if preprocess_mode is PreprocessMode.PRETRAIN:
ds = ds.map(
preprocess_pretrain, num_parallel_calls=tf.data.experimental.AUTOTUNE)
elif preprocess_mode is PreprocessMode.LINEAR_TRAIN:
ds = ds.map(
preprocess_linear_train,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
else:
ds = ds.map(
preprocess_eval, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def transpose_fn(batch):
# We use the double-transpose-trick to improve performance for TPUs. Note
# that this (typically) requires a matching HWCN->NHWC transpose in your
# model code. The compiler cannot make this optimization for us since our
# data pipeline and model are compiled separately.
batch = dict(**batch)
if preprocess_mode is PreprocessMode.PRETRAIN:
batch['view1'] = tf.transpose(batch['view1'], (1, 2, 3, 0))
batch['view2'] = tf.transpose(batch['view2'], (1, 2, 3, 0))
else:
batch['images'] = tf.transpose(batch['images'], (1, 2, 3, 0))
return batch
for i, batch_size in enumerate(reversed(batch_dims)):
ds = ds.batch(batch_size)
if i == 0 and transpose:
ds = ds.map(transpose_fn) # NHWC -> HWCN
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
yield from tfds.as_numpy(ds)
def _to_tfds_split(split):
"""Returns the TFDS split appropriately sharded."""
# NOTE: Imagenet did not release labels for the test split used in the
# competition, we consider the VALID split the TEST split and reserve
# 10k images from TRAIN for VALID.
if split in (Split.TRAIN, Split.TRAIN_AND_VALID, Split.VALID):
return tfds.Split.TRAIN
else:
assert split == Split.TEST
return tfds.Split.VALIDATION
def _shard(split, shard_index, num_shards):
"""Returns [start, end) for the given shard index."""
assert shard_index < num_shards
arange = np.arange(split.num_examples)
shard_range = np.array_split(arange, num_shards)[shard_index]
start, end = shard_range[0], (shard_range[-1] + 1)
if split == Split.TRAIN:
# Note that our TRAIN=TFDS_TRAIN[10000:] and VALID=TFDS_TRAIN[:10000].
offset = Split.VALID.num_examples
start += offset
end += offset
return start, end
def _preprocess_image(
image_bytes,
mode,
):
"""Returns processed and resized images."""
if mode is PreprocessMode.PRETRAIN:
image = _decode_and_random_crop(image_bytes)
# Random horizontal flipping is optionally done in augmentations.preprocess.
elif mode is PreprocessMode.LINEAR_TRAIN:
image = _decode_and_random_crop(image_bytes)
image = tf.image.random_flip_left_right(image)
else:
image = _decode_and_center_crop(image_bytes)
# NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without
# clamping overshoots. This means values returned will be outside the range
# [0.0, 255.0] (e.g. we have observed outputs in the range [-51.1, 336.6]).
assert image.dtype == tf.uint8
image = tf.image.resize(image, [224, 224], tf.image.ResizeMethod.BICUBIC)
image = tf.clip_by_value(image / 255., 0., 1.)
return image
def _decode_and_random_crop(image_bytes):
"""Make a random crop of 224."""
img_size = tf.image.extract_jpeg_shape(image_bytes)
area = tf.cast(img_size[1] * img_size[0], tf.float32)
target_area = tf.random.uniform([], 0.08, 1.0, dtype=tf.float32) * area
log_ratio = (tf.math.log(3 / 4), tf.math.log(4 / 3))
aspect_ratio = tf.math.exp(
tf.random.uniform([], *log_ratio, dtype=tf.float32))
w = tf.cast(tf.round(tf.sqrt(target_area * aspect_ratio)), tf.int32)
h = tf.cast(tf.round(tf.sqrt(target_area / aspect_ratio)), tf.int32)
w = tf.minimum(w, img_size[1])
h = tf.minimum(h, img_size[0])
offset_w = tf.random.uniform((),
minval=0,
maxval=img_size[1] - w + 1,
dtype=tf.int32)
offset_h = tf.random.uniform((),
minval=0,
maxval=img_size[0] - h + 1,
dtype=tf.int32)
crop_window = tf.stack([offset_h, offset_w, h, w])
image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
return image
def transpose_images(batch):
"""Transpose images for TPU training.."""
new_batch = dict(batch) # Avoid mutating in place.
if 'images' in batch:
new_batch['images'] = jnp.transpose(batch['images'], (3, 0, 1, 2))
else:
new_batch['view1'] = jnp.transpose(batch['view1'], (3, 0, 1, 2))
new_batch['view2'] = jnp.transpose(batch['view2'], (3, 0, 1, 2))
return new_batch
def _decode_and_center_crop(
image_bytes,
jpeg_shape = None,
):
"""Crops to center of image with padding then scales."""
if jpeg_shape is None:
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
image_height = jpeg_shape[0]
image_width = jpeg_shape[1]
padded_center_crop_size = tf.cast(
((224 / (224 + 32)) *
tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
crop_window = tf.stack([
offset_height, offset_width, padded_center_crop_size,
padded_center_crop_size
])
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
return image
+131
View File
@@ -0,0 +1,131 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions."""
from typing import Optional, Text
from absl import logging
import jax
import jax.numpy as jnp
def topk_accuracy(
logits,
labels,
topk,
ignore_label_above = None,
):
"""Top-num_codes accuracy."""
assert len(labels.shape) == 1, 'topk expects 1d int labels.'
assert len(logits.shape) == 2, 'topk expects 2d logits.'
if ignore_label_above is not None:
logits = logits[labels < ignore_label_above, :]
labels = labels[labels < ignore_label_above]
prds = jnp.argsort(logits, axis=1)[:, ::-1]
prds = prds[:, :topk]
total = jnp.any(prds == jnp.tile(labels[:, jnp.newaxis], [1, topk]), axis=1)
return total
def softmax_cross_entropy(
logits,
labels,
reduction = 'mean',
):
"""Computes softmax cross entropy given logits and one-hot class labels.
Args:
logits: Logit output values.
labels: Ground truth one-hot-encoded labels.
reduction: Type of reduction to apply to loss.
Returns:
Loss value. If `reduction` is `none`, this has the same shape as `labels`;
otherwise, it is scalar.
Raises:
ValueError: If the type of `reduction` is unsupported.
"""
loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
if reduction == 'sum':
return jnp.sum(loss)
elif reduction == 'mean':
return jnp.mean(loss)
elif reduction == 'none' or reduction is None:
return loss
else:
raise ValueError(f'Incorrect reduction mode {reduction}')
def l2_normalize(
x,
axis = None,
epsilon = 1e-12,
):
"""l2 normalize a tensor on an axis with numerical stability."""
square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)
x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))
return x * x_inv_norm
def l2_weight_regularizer(params):
"""Helper to do lasso on weights.
Args:
params: the entire param set.
Returns:
Scalar of the l2 norm of the weights.
"""
l2_norm = 0.
for mod_name, mod_params in params.items():
if 'norm' not in mod_name:
for param_k, param_v in mod_params.items():
if param_k != 'b' not in param_k: # Filter out biases
l2_norm += jnp.sum(jnp.square(param_v))
else:
logging.warning('Excluding %s/%s from optimizer weight decay!',
mod_name, param_k)
else:
logging.warning('Excluding %s from optimizer weight decay!', mod_name)
return 0.5 * l2_norm
def regression_loss(x, y):
"""Byol's regression loss. This is a simple cosine similarity."""
normed_x, normed_y = l2_normalize(x, axis=-1), l2_normalize(y, axis=-1)
return jnp.sum((normed_x - normed_y)**2, axis=-1)
def bcast_local_devices(value):
"""Broadcasts an object to all local devices."""
devices = jax.local_devices()
def _replicate(x):
"""Replicate an object on each device."""
x = jnp.array(x)
aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype)
buffers = [jax.interpreters.xla.device_put(x, d) for d in devices]
return jax.pxla.ShardedDeviceArray(aval, buffers)
return jax.tree_util.tree_map(_replicate, value)
def get_first(xs):
"""Gets values from the first device."""
return jax.tree_map(lambda x: x[0], xs)
+319
View File
@@ -0,0 +1,319 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Networks used in BYOL."""
from typing import Any, Mapping, Optional, Sequence, Text
import haiku as hk
import jax
import jax.numpy as jnp
class MLP(hk.Module):
"""One hidden layer perceptron, with normalization."""
def __init__(
self,
name,
hidden_size,
output_size,
bn_config,
):
super().__init__(name=name)
self._hidden_size = hidden_size
self._output_size = output_size
self._bn_config = bn_config
def __call__(self, inputs, is_training):
out = hk.Linear(output_size=self._hidden_size, with_bias=True)(inputs)
out = hk.BatchNorm(**self._bn_config)(out, is_training=is_training)
out = jax.nn.relu(out)
out = hk.Linear(output_size=self._output_size, with_bias=False)(out)
return out
class ResNetTorso(hk.nets.ResNet):
"""ResNet model."""
def __init__(
self,
blocks_per_group,
num_classes = None,
bn_config = None,
resnet_v2 = False,
bottleneck = True,
channels_per_group = (256, 512, 1024, 2048),
use_projection = (True, True, True, True),
width_multiplier = 1,
name = None,
):
"""Constructs a ResNet model.
Args:
blocks_per_group: A sequence of length 4 that indicates the number of
blocks created in each group.
num_classes: The number of classes to classify the inputs into.
bn_config: A dictionary of three elements, `decay_rate`, `eps`, and
`cross_replica_axis`, to be passed on to the `BatchNorm` layers. By
default the `decay_rate` is `0.9` and `eps` is `1e-5`, and the axis is
`None`.
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
False.
bottleneck: Whether the block should bottleneck or not. Defaults to True.
channels_per_group: A sequence of length 4 that indicates the number
of channels used for each block in each group.
use_projection: A sequence of length 4 that indicates whether each
residual block should use projection.
width_multiplier: An integer multiplying the number of channels per group.
name: Name of the module.
"""
channels_per_group = [width_multiplier * channel
for channel in channels_per_group]
super().__init__(
blocks_per_group=blocks_per_group,
num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=bottleneck,
channels_per_group=channels_per_group,
name=name)
def __call__(self, inputs, is_training, test_local_stats=False):
out = inputs
out = self.initial_conv(out)
if not self.resnet_v2:
out = self.initial_batchnorm(out, is_training, test_local_stats)
out = jax.nn.relu(out)
out = hk.max_pool(out,
window_shape=(1, 3, 3, 1),
strides=(1, 2, 2, 1),
padding='SAME')
for block_group in self.block_groups:
out = block_group(out, is_training, test_local_stats)
if self.resnet_v2:
out = self.final_batchnorm(out, is_training, test_local_stats)
out = jax.nn.relu(out)
out = jnp.mean(out, axis=[1, 2])
return out
class TinyResNet(ResNetTorso):
"""Tiny resnet for local runs and tests."""
def __init__(self,
num_classes = None,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
"""Constructs a ResNet model.
Args:
num_classes: The number of classes to classify the inputs into.
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
passed on to the `BatchNorm` layers.
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
to False.
width_multiplier: An integer multiplying the number of channels per group.
name: Name of the module.
"""
super().__init__(blocks_per_group=(1, 1, 1, 1),
channels_per_group=(8, 8, 8, 8),
num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=False,
width_multiplier=width_multiplier,
name=name)
class ResNet18(ResNetTorso):
"""ResNet18."""
def __init__(self,
num_classes = None,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
"""Constructs a ResNet model.
Args:
num_classes: The number of classes to classify the inputs into.
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
passed on to the `BatchNorm` layers.
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
to False.
width_multiplier: An integer multiplying the number of channels per group.
name: Name of the module.
"""
super().__init__(blocks_per_group=(2, 2, 2, 2),
num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=False,
channels_per_group=(64, 128, 256, 512),
width_multiplier=width_multiplier,
name=name)
class ResNet34(ResNetTorso):
"""ResNet34."""
def __init__(self,
num_classes,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
"""Constructs a ResNet model.
Args:
num_classes: The number of classes to classify the inputs into.
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
passed on to the `BatchNorm` layers.
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
to False.
width_multiplier: An integer multiplying the number of channels per group.
name: Name of the module.
"""
super().__init__(blocks_per_group=(3, 4, 6, 3),
num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=False,
channels_per_group=(64, 128, 256, 512),
width_multiplier=width_multiplier,
name=name)
class ResNet50(ResNetTorso):
"""ResNet50."""
def __init__(self,
num_classes = None,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
"""Constructs a ResNet model.
Args:
num_classes: The number of classes to classify the inputs into.
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
passed on to the `BatchNorm` layers.
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
to False.
width_multiplier: An integer multiplying the number of channels per group.
name: Name of the module.
"""
super().__init__(blocks_per_group=(3, 4, 6, 3),
num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=True,
width_multiplier=width_multiplier,
name=name)
class ResNet101(ResNetTorso):
"""ResNet101."""
def __init__(self,
num_classes,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
"""Constructs a ResNet model.
Args:
num_classes: The number of classes to classify the inputs into.
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
passed on to the `BatchNorm` layers.
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
to False.
width_multiplier: An integer multiplying the number of channels per group.
name: Name of the module.
"""
super().__init__(blocks_per_group=(3, 4, 23, 3),
num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=True,
width_multiplier=width_multiplier,
name=name)
class ResNet152(ResNetTorso):
"""ResNet152."""
def __init__(self,
num_classes,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
"""Constructs a ResNet model.
Args:
num_classes: The number of classes to classify the inputs into.
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
passed on to the `BatchNorm` layers.
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
to False.
width_multiplier: An integer multiplying the number of channels per group.
name: Name of the module.
"""
super().__init__(blocks_per_group=(3, 8, 36, 3),
num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=True,
width_multiplier=width_multiplier,
name=name)
class ResNet200(ResNetTorso):
"""ResNet200."""
def __init__(self,
num_classes,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
"""Constructs a ResNet model.
Args:
num_classes: The number of classes to classify the inputs into.
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
passed on to the `BatchNorm` layers.
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
to False.
width_multiplier: An integer multiplying the number of channels per group.
name: Name of the module.
"""
super().__init__(blocks_per_group=(3, 24, 36, 3),
num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=True,
width_multiplier=width_multiplier,
name=name)
+188
View File
@@ -0,0 +1,188 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of LARS Optimizer with optax."""
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
import jax
import jax.numpy as jnp
import optax
import tree as nest
# A filter function takes a path and a value as input and outputs True for
# variable to apply update and False not to apply the update
FilterFn = Callable[[Tuple[Any], jnp.ndarray], jnp.ndarray]
def exclude_bias_and_norm(path, val):
"""Filter to exclude biaises and normalizations weights."""
del val
if path[-1] == "b" or "norm" in path[-2]:
return False
return True
def _partial_update(updates,
new_updates,
params,
filter_fn = None):
"""Returns new_update for params which filter_fn is True else updates."""
if filter_fn is None:
return new_updates
wrapped_filter_fn = lambda x, y: jnp.array(filter_fn(x, y))
params_to_filter = nest.map_structure_with_path(wrapped_filter_fn, params)
def _update_fn(g, t, m):
m = m.astype(g.dtype)
return g * (1. - m) + t * m
return jax.tree_multimap(_update_fn, updates, new_updates, params_to_filter)
class ScaleByLarsState(NamedTuple):
mu: jnp.ndarray
def scale_by_lars(
momentum = 0.9,
eta = 0.001,
filter_fn = None):
"""Rescales updates according to the LARS algorithm.
Does not include weight decay.
References:
[You et al, 2017](https://arxiv.org/abs/1708.03888)
Args:
momentum: momentum coeficient.
eta: LARS coefficient.
filter_fn: an optional filter function.
Returns:
An (init_fn, update_fn) tuple.
"""
def init_fn(params):
mu = jax.tree_multimap(jnp.zeros_like, params) # momentum
return ScaleByLarsState(mu=mu)
def update_fn(updates, state,
params):
def lars_adaptation(
update,
param,
):
param_norm = jnp.linalg.norm(param)
update_norm = jnp.linalg.norm(update)
return update * jnp.where(
param_norm > 0.,
jnp.where(update_norm > 0,
(eta * param_norm / update_norm), 1.0), 1.0)
adapted_updates = jax.tree_multimap(lars_adaptation, updates, params)
adapted_updates = _partial_update(updates, adapted_updates, params,
filter_fn)
mu = jax.tree_multimap(lambda g, t: momentum * g + t,
state.mu, adapted_updates)
return mu, ScaleByLarsState(mu=mu)
return optax.GradientTransformation(init_fn, update_fn)
class AddWeightDecayState(NamedTuple):
"""Stateless transformation."""
def add_weight_decay(
weight_decay,
filter_fn = None):
"""Adds a weight decay to the update.
Args:
weight_decay: weight_decay coeficient.
filter_fn: an optional filter function.
Returns:
An (init_fn, update_fn) tuple.
"""
def init_fn(_):
return AddWeightDecayState()
def update_fn(
updates,
state,
params,
):
new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
params)
new_updates = _partial_update(updates, new_updates, params, filter_fn)
return new_updates, state
return optax.GradientTransformation(init_fn, update_fn)
LarsState = List # Type for the lars optimizer
def lars(
learning_rate,
weight_decay = 0.,
momentum = 0.9,
eta = 0.001,
weight_decay_filter = None,
lars_adaptation_filter = None,
):
"""Creates lars optimizer with weight decay.
References:
[You et al, 2017](https://arxiv.org/abs/1708.03888)
Args:
learning_rate: learning rate coefficient.
weight_decay: weight decay coefficient.
momentum: momentum coefficient.
eta: LARS coefficient.
weight_decay_filter: optional filter function to only apply the weight
decay on a subset of parameters. The filter function takes as input the
parameter path (as a tuple) and its associated update, and return a True
for params to apply the weight decay and False for params to not apply
the weight decay. When weight_decay_filter is set to None, the weight
decay is not applied to the bias, i.e. when the variable name is 'b', and
the weight decay is not applied to nornalization params, i.e. the
panultimate path contains 'norm'.
lars_adaptation_filter: similar to weight decay filter but for lars
adaptation
Returns:
An optax.GradientTransformation, i.e. a (init_fn, update_fn) tuple.
"""
if weight_decay_filter is None:
weight_decay_filter = lambda *_: True
if lars_adaptation_filter is None:
lars_adaptation_filter = lambda *_: True
return optax.chain(
add_weight_decay(
weight_decay=weight_decay, filter_fn=weight_decay_filter),
scale_by_lars(
momentum=momentum, eta=eta, filter_fn=lars_adaptation_filter),
optax.scale(-learning_rate),
)
+53
View File
@@ -0,0 +1,53 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Learning rate schedules."""
import jax.numpy as jnp
def target_ema(global_step,
base_ema,
max_steps):
decay = _cosine_decay(global_step, max_steps, 1.)
return 1. - (1. - base_ema) * decay
def learning_schedule(global_step,
batch_size,
base_learning_rate,
total_steps,
warmup_steps):
"""Cosine learning rate scheduler."""
# Compute LR & Scaled LR
scaled_lr = base_learning_rate * batch_size / 256.
learning_rate = (
global_step.astype(jnp.float32) / int(warmup_steps) *
scaled_lr if warmup_steps > 0 else scaled_lr)
# Cosine schedule after warmup.
return jnp.where(
global_step < warmup_steps, learning_rate,
_cosine_decay(global_step - warmup_steps, total_steps - warmup_steps,
scaled_lr))
def _cosine_decay(global_step,
max_steps,
initial_value):
"""Simple implementation of cosine decay from TF1."""
global_step = jnp.minimum(global_step, max_steps)
cosine_decay_value = 0.5 * (1 + jnp.cos(jnp.pi * global_step / max_steps))
decayed_learning_rate = initial_value * cosine_decay_value
return decayed_learning_rate