mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-22 07:11:25 +08:00
This notebook illustrates the CIFAR-10 experiments in the paper:
PiperOrigin-RevId: 326141025
This commit is contained in:
committed by
Louise Deason
parent
84321ad894
commit
923ad3cff0
+159
@@ -0,0 +1,159 @@
|
||||
# Bootstrap Your Own Latent
|
||||
|
||||
This is the implementation of the pre-training and linear evaluation pipeline of
|
||||
BYOL - https://arxiv.org/abs/2006.07733.
|
||||
|
||||
Using this implementation should achieve a top-1 accuracy on Imagenet between
|
||||
74.0% and 74.5% after about 8h of training using 512 Cloud TPU v3.
|
||||
|
||||
The main pretraining module is `byol_experiment.py`. By default it uses BYOL to
|
||||
pretrain a Resnet-50 on Imagenet. In parallel, we train a classifier on top of
|
||||
the representation to assess its performance during training. This classifier
|
||||
does not back-propagate any gradient to the ResNet-50.
|
||||
|
||||
The evaluation module is `eval_experiment.py`. It evaluates the performance of
|
||||
the representation learnt by BYOL (using a given checkpoint).
|
||||
|
||||
## Setup
|
||||
|
||||
To set up a Python virtual environment with the required dependencies, run:
|
||||
|
||||
```shell
|
||||
python3 -m venv byol_env
|
||||
source byol_env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r byol/requirements.txt
|
||||
```
|
||||
|
||||
The code uses `tensorflow_datasets` to load the Imagenet dataset. Manual
|
||||
download may be required; see https://www.tensorflow.org/datasets/catalog/imagenet2012
|
||||
for details.
|
||||
|
||||
## Running a short pre-training locally
|
||||
|
||||
To run a short (40 epochs) pre-training experiment on a local machine, use:
|
||||
|
||||
```shell
|
||||
mkdir /tmp/byol_checkpoints
|
||||
python -m byol.main_loop \
|
||||
--experiment_mode='pretrain' \
|
||||
--worker_mode='train' \
|
||||
--checkpoint_root='/tmp/byol_checkpoints' \
|
||||
--pretrain_epochs=40
|
||||
```
|
||||
|
||||
## Full pipeline and presets
|
||||
|
||||
The various parts of the pipeline can be run using:
|
||||
|
||||
```shell
|
||||
python -m byol.main_loop \
|
||||
--experiment_mode=<'pretrain' or 'linear-eval'> \
|
||||
--worker_mode=<'train' or 'eval'> \
|
||||
--checkpoint_root=</path/to/the/checkpointing/folder> \
|
||||
--pretrain_epochs=<40, 100, 300 or 1000>
|
||||
```
|
||||
|
||||
### Pretraining
|
||||
Setting `--experiment_mode=pretrain` will configure the main loop for
|
||||
pretraining; we provide presets for 40, 100, 300 and 1000 epochs.
|
||||
|
||||
Use `--worker_mode=train` for a training job, which will regularly save
|
||||
checkpoints under `<checkpoint_root>/pretrain.pkl`. To monitor the progress of
|
||||
the pretraining, you can run a second worker (using `--worker_mode=eval`) with
|
||||
the same `checkpoint_root` setting. This worker will regularly load the
|
||||
checkpoint and evaluate the performance of a linear classifier (trained by the
|
||||
pretraining `train` worker) on the `TEST` set.
|
||||
|
||||
Note that the default settings are set for large-scale training on Cloud TPUs,
|
||||
with a total batch size of 4096. To avoid the need to re-run the full
|
||||
experiment, we will make a pre-trained checkpoint available on GCP.
|
||||
|
||||
### Linear evaluation
|
||||
Setting `--experiment_mode=linear-eval` will configure the main loop for
|
||||
linear evaluation; we provide presets for 80 epochs.
|
||||
|
||||
Use `--worker_mode=train` for a training job, which will load the encoder
|
||||
weights from an existing checkpoint (form a pretrain experiment) located at
|
||||
`<checkpoint_root>/pretrain.pkl`, and train a linear classifier on top of this
|
||||
encoder. The weights from the linear classifier trained in the pretraining phase
|
||||
will be discarded.
|
||||
|
||||
The training job will regularly save checkpoints under
|
||||
`<checkpoint_root>/linear-eval.pkl`. You can run a second worker
|
||||
(using `--worker_mode=eval`) with the same `checkpoint_root` setting to
|
||||
regularly load the checkpoint and evaluate the performance of the classifier
|
||||
(trained by the linear-eval `train` worker) on the test set.
|
||||
|
||||
Note that the above will run a simplified version of the linear evaluation
|
||||
pipeline described in the paper, with a single value of the base learning rate
|
||||
and without using a validation set. To fully reproduce the results from the
|
||||
paper, one should run 5 instances of the linear-eval `train` worker (using only
|
||||
the `TRAIN` subset, and each using a different `checkpoint_root`), run the
|
||||
`eval` worker (using only the `VALID` subset) for each checkpoint, then run a
|
||||
final `eval` worker on the `TEST` set.
|
||||
|
||||
|
||||
## Running on GCP
|
||||
|
||||
Notice: we currently do not recommend running the full experiment on public
|
||||
Cloud TPUs. We provide an alternative small-scale GPU setup in the next section.
|
||||
|
||||
The experiments from the paper were run on Google's internal infrastructure.
|
||||
Unfortunately, training JAX models on publicly available Cloud TPUs is currently
|
||||
in its early stages; in particular, we found the learning to be heavily
|
||||
bottlenecked by the data loading pipeline, resulting in a significantly slower
|
||||
training. We will update the code and instructions below once the limitation is
|
||||
addressed.
|
||||
|
||||
To train the model on TPU using Google Compute Engine, please follow the
|
||||
instructions at https://cloud.google.com/tpu/docs/imagenet-setup to first
|
||||
download and preprocess the ImageNet dataset and upload it to a
|
||||
Cloud Storage Bucket. From a GCE VM, you can check that `tensorflow_datasets`
|
||||
can correctly load the dataset by running:
|
||||
|
||||
```python
|
||||
import tensorflow_datasets as tfds
|
||||
tfds.load('imagenet2012', data_dir='gs://<your-bucket-name>')
|
||||
```
|
||||
|
||||
To learn how to use JAX with Cloud TPUs, please follow the instructions here:
|
||||
https://github.com/google/jax/tree/master/cloud_tpu_colabs.
|
||||
|
||||
|
||||
## Setup for fast iteration
|
||||
|
||||
In order to make reproduction easier, it is possible to change the training
|
||||
setup to use the smaller [imagenette](https://github.com/fastai/imagenette)
|
||||
dataset (9469 training images with 10 classes). The following setup and
|
||||
hyperparameters can be used on a machine with a single V100 GPU:
|
||||
|
||||
- in `utils/dataset.py`:
|
||||
- update `Split.num_examples` with the figures from [tfds](https://www.tensorflow.org/datasets/catalog/imagenette) (with `Split.VALID: 0`)
|
||||
- use `imagenette/160px-v2` in the call to `tfds.load`
|
||||
- use 128x128 px images (_i.e._, replace all instances of `224` by `128`)
|
||||
- it doesn't seem necessary to change the color normalization (make sure to
|
||||
not replace the value `0.224` by mistake in the previous step).
|
||||
- in `configs/byol.py`, use:
|
||||
- `num_classes`: `10`
|
||||
- `network_config.encoder`: `ResNet18`
|
||||
- `optimizer_config.weight_decay`: `1e-6`
|
||||
- `lr_schedule_config.base_learning_rate`: `2.0`
|
||||
- `evaluation_config.batch_size`: `25`
|
||||
- other parameters unchanged.
|
||||
|
||||
You can then run:
|
||||
|
||||
```shell
|
||||
mkdir /tmp/byol_checkpoints
|
||||
python -m byol.main_loop \
|
||||
--experiment_mode='pretrain' \
|
||||
--worker_mode='train' \
|
||||
--checkpoint_root='/tmp/byol_checkpoints' \
|
||||
--batch_size=256 \
|
||||
--pretrain_epochs=1000
|
||||
```
|
||||
|
||||
With these settings, BYOL should achieve ~92.3% top-1 accuracy (for the
|
||||
*online* classifier) in roughly 4 hours. Note that the above parameters were not
|
||||
finely tuned and may not be optimal.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,78 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Config file for BYOL experiment."""
|
||||
|
||||
from byol.utils import dataset
|
||||
|
||||
|
||||
# Preset values for certain number of training epochs.
|
||||
_LR_PRESETS = {40: 0.45, 100: 0.45, 300: 0.3, 1000: 0.2}
|
||||
_WD_PRESETS = {40: 1e-6, 100: 1e-6, 300: 1e-6, 1000: 1.5e-6}
|
||||
_EMA_PRESETS = {40: 0.97, 100: 0.99, 300: 0.99, 1000: 0.996}
|
||||
|
||||
|
||||
def get_config(num_epochs, batch_size):
|
||||
"""Return config object, containing all hyperparameters for training."""
|
||||
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
|
||||
|
||||
assert num_epochs in [40, 100, 300, 1000]
|
||||
|
||||
config = dict(
|
||||
random_seed=0,
|
||||
num_classes=1000,
|
||||
batch_size=batch_size,
|
||||
max_steps=num_epochs * train_images_per_epoch // batch_size,
|
||||
enable_double_transpose=True,
|
||||
base_target_ema=_EMA_PRESETS[num_epochs],
|
||||
network_config=dict(
|
||||
projector_hidden_size=4096,
|
||||
projector_output_size=256,
|
||||
predictor_hidden_size=4096,
|
||||
encoder_class='ResNet50', # Should match a class in utils/networks.
|
||||
encoder_config=dict(
|
||||
resnet_v2=False,
|
||||
width_multiplier=1),
|
||||
bn_config={
|
||||
'decay_rate': .9,
|
||||
'eps': 1e-5,
|
||||
# Accumulate batchnorm statistics across devices.
|
||||
# This should be equal to the `axis_name` argument passed
|
||||
# to jax.pmap.
|
||||
'cross_replica_axis': 'i',
|
||||
'create_scale': True,
|
||||
'create_offset': True,
|
||||
}),
|
||||
optimizer_config=dict(
|
||||
weight_decay=_WD_PRESETS[num_epochs],
|
||||
eta=1e-3,
|
||||
momentum=.9,
|
||||
),
|
||||
lr_schedule_config=dict(
|
||||
base_learning_rate=_LR_PRESETS[num_epochs],
|
||||
warmup_steps=10 * train_images_per_epoch // batch_size,
|
||||
),
|
||||
evaluation_config=dict(
|
||||
subset='test',
|
||||
batch_size=100,
|
||||
),
|
||||
checkpointing_config=dict(
|
||||
use_checkpointing=True,
|
||||
checkpoint_dir='/tmp/byol',
|
||||
save_checkpoint_interval=300,
|
||||
filename='pretrain.pkl'
|
||||
),
|
||||
)
|
||||
|
||||
return config
|
||||
@@ -0,0 +1,67 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Config file for evaluation experiment."""
|
||||
|
||||
from typing import Text
|
||||
|
||||
from byol.utils import dataset
|
||||
|
||||
|
||||
def get_config(checkpoint_to_evaluate, batch_size):
|
||||
"""Return config object for training."""
|
||||
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
|
||||
|
||||
config = dict(
|
||||
random_seed=0,
|
||||
enable_double_transpose=True,
|
||||
max_steps=80 * train_images_per_epoch // batch_size,
|
||||
num_classes=1000,
|
||||
batch_size=batch_size,
|
||||
checkpoint_to_evaluate=checkpoint_to_evaluate,
|
||||
# If True, allows training without loading a checkpoint.
|
||||
allow_train_from_scratch=False,
|
||||
# Whether the backbone should be frozen (linear evaluation) or
|
||||
# trainable (fine-tuning).
|
||||
freeze_backbone=True,
|
||||
optimizer_config=dict(
|
||||
momentum=0.9,
|
||||
nesterov=True,
|
||||
),
|
||||
lr_schedule_config=dict(
|
||||
base_learning_rate=0.2,
|
||||
warmup_steps=0,
|
||||
),
|
||||
network_config=dict( # Should match the evaluated checkpoint
|
||||
encoder_class='ResNet50', # Should match a class in utils/networks.
|
||||
encoder_config=dict(
|
||||
resnet_v2=False,
|
||||
width_multiplier=1),
|
||||
bn_decay_rate=0.9,
|
||||
),
|
||||
evaluation_config=dict(
|
||||
subset='test',
|
||||
batch_size=100,
|
||||
),
|
||||
checkpointing_config=dict(
|
||||
use_checkpointing=True,
|
||||
checkpoint_dir='/tmp/byol',
|
||||
save_checkpoint_interval=300,
|
||||
filename='linear-eval.pkl'
|
||||
),
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@@ -0,0 +1,477 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Linear evaluation or fine-tuning pipeline.
|
||||
|
||||
Use this experiment to evaluate a checkpoint from byol_experiment.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Any, Generator, Mapping, NamedTuple, Optional, Text, Tuple, Union
|
||||
|
||||
from absl import logging
|
||||
from acme.jax import utils as acme_utils
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
|
||||
from byol.utils import checkpointing
|
||||
from byol.utils import dataset
|
||||
from byol.utils import helpers
|
||||
from byol.utils import networks
|
||||
from byol.utils import schedules
|
||||
|
||||
# Type declarations.
|
||||
OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState]
|
||||
LogsDict = Mapping[Text, jnp.ndarray]
|
||||
|
||||
|
||||
class _EvalExperimentState(NamedTuple):
|
||||
backbone_params: hk.Params
|
||||
classif_params: hk.Params
|
||||
backbone_state: hk.State
|
||||
backbone_opt_state: Union[None, OptState]
|
||||
classif_opt_state: OptState
|
||||
|
||||
|
||||
class EvalExperiment:
|
||||
"""Linear evaluation experiment."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
random_seed,
|
||||
num_classes,
|
||||
batch_size,
|
||||
max_steps,
|
||||
enable_double_transpose,
|
||||
checkpoint_to_evaluate,
|
||||
allow_train_from_scratch,
|
||||
freeze_backbone,
|
||||
network_config,
|
||||
optimizer_config,
|
||||
lr_schedule_config,
|
||||
evaluation_config,
|
||||
checkpointing_config):
|
||||
"""Constructs the experiment.
|
||||
|
||||
Args:
|
||||
random_seed: the random seed to use when initializing network weights.
|
||||
num_classes: the number of classes; used for the online evaluation.
|
||||
batch_size: the total batch size; should be a multiple of the number of
|
||||
available accelerators.
|
||||
max_steps: the number of training steps; used for the lr/target network
|
||||
ema schedules.
|
||||
enable_double_transpose: see dataset.py; only has effect on TPU.
|
||||
checkpoint_to_evaluate: the path to the checkpoint to evaluate.
|
||||
allow_train_from_scratch: whether to allow training without specifying a
|
||||
checkpoint to evaluate (training from scratch).
|
||||
freeze_backbone: whether the backbone resnet should remain frozen (linear
|
||||
evaluation) or be trainable (fine-tuning).
|
||||
network_config: the configuration for the network.
|
||||
optimizer_config: the configuration for the optimizer.
|
||||
lr_schedule_config: the configuration for the learning rate schedule.
|
||||
evaluation_config: the evaluation configuration.
|
||||
checkpointing_config: the configuration for checkpointing.
|
||||
"""
|
||||
|
||||
self._random_seed = random_seed
|
||||
self._enable_double_transpose = enable_double_transpose
|
||||
self._num_classes = num_classes
|
||||
self._lr_schedule_config = lr_schedule_config
|
||||
self._batch_size = batch_size
|
||||
self._max_steps = max_steps
|
||||
self._checkpoint_to_evaluate = checkpoint_to_evaluate
|
||||
self._allow_train_from_scratch = allow_train_from_scratch
|
||||
self._freeze_backbone = freeze_backbone
|
||||
self._optimizer_config = optimizer_config
|
||||
self._evaluation_config = evaluation_config
|
||||
|
||||
# Checkpointed experiment state.
|
||||
self._experiment_state = None
|
||||
|
||||
# Input pipelines.
|
||||
self._train_input = None
|
||||
self._eval_input = None
|
||||
|
||||
backbone_fn = functools.partial(self._backbone_fn, **network_config)
|
||||
self.forward_backbone = hk.without_apply_rng(
|
||||
hk.transform_with_state(backbone_fn))
|
||||
self.forward_classif = hk.without_apply_rng(hk.transform(self._classif_fn))
|
||||
self.update_pmap = jax.pmap(self._update_func, axis_name='i')
|
||||
self.eval_batch_jit = jax.jit(self._eval_batch)
|
||||
|
||||
self._is_backbone_training = not self._freeze_backbone
|
||||
|
||||
self._checkpointer = checkpointing.Checkpointer(**checkpointing_config)
|
||||
|
||||
def _should_transpose_images(self):
|
||||
"""Should we transpose images (saves host-to-device time on TPUs)."""
|
||||
return (self._enable_double_transpose and
|
||||
jax.local_devices()[0].platform == 'tpu')
|
||||
|
||||
def _backbone_fn(
|
||||
self,
|
||||
inputs,
|
||||
encoder_class,
|
||||
encoder_config,
|
||||
bn_decay_rate,
|
||||
is_training,
|
||||
):
|
||||
"""Forward of the encoder (backbone)."""
|
||||
bn_config = {'decay_rate': bn_decay_rate}
|
||||
encoder = getattr(networks, encoder_class)
|
||||
model = encoder(
|
||||
None,
|
||||
bn_config=bn_config,
|
||||
**encoder_config)
|
||||
|
||||
if self._should_transpose_images():
|
||||
inputs = dataset.transpose_images(inputs)
|
||||
images = dataset.normalize_images(inputs['images'])
|
||||
return model(images, is_training=is_training)
|
||||
|
||||
def _classif_fn(
|
||||
self,
|
||||
embeddings,
|
||||
):
|
||||
classifier = hk.Linear(output_size=self._num_classes)
|
||||
return classifier(embeddings)
|
||||
|
||||
# _ _
|
||||
# | |_ _ __ __ _(_)_ __
|
||||
# | __| '__/ _` | | '_ \
|
||||
# | |_| | | (_| | | | | |
|
||||
# \__|_| \__,_|_|_| |_|
|
||||
#
|
||||
|
||||
def step(self, *,
|
||||
global_step,
|
||||
rng):
|
||||
"""Performs a single training step."""
|
||||
|
||||
if self._train_input is None:
|
||||
self._initialize_train(rng)
|
||||
|
||||
inputs = next(self._train_input)
|
||||
self._experiment_state, scalars = self.update_pmap(
|
||||
self._experiment_state, global_step, inputs)
|
||||
|
||||
scalars = helpers.get_first(scalars)
|
||||
return scalars
|
||||
|
||||
def save_checkpoint(self, step, rng):
|
||||
self._checkpointer.maybe_save_checkpoint(
|
||||
self._experiment_state, step=step, rng=rng,
|
||||
is_final=step >= self._max_steps)
|
||||
|
||||
def load_checkpoint(self):
|
||||
checkpoint_data = self._checkpointer.maybe_load_checkpoint()
|
||||
if checkpoint_data is None:
|
||||
return None
|
||||
self._experiment_state, step, rng = checkpoint_data
|
||||
return step, rng
|
||||
|
||||
def _initialize_train(self, rng):
|
||||
"""BYOL's _ExperimentState initialization.
|
||||
|
||||
Args:
|
||||
rng: random number generator used to initialize parameters. If working in
|
||||
a multi device setup, this need to be a ShardedArray.
|
||||
dummy_input: a dummy image, used to compute intermediate outputs shapes.
|
||||
|
||||
Returns:
|
||||
Initial EvalExperiment state.
|
||||
|
||||
Raises:
|
||||
RuntimeError: invalid or empty checkpoint.
|
||||
"""
|
||||
self._train_input = acme_utils.prefetch(self._build_train_input())
|
||||
|
||||
# Check we haven't already restored params
|
||||
if self._experiment_state is None:
|
||||
|
||||
inputs = next(self._train_input)
|
||||
|
||||
if self._checkpoint_to_evaluate is not None:
|
||||
# Load params from checkpoint
|
||||
checkpoint_data = checkpointing.load_checkpoint(
|
||||
self._checkpoint_to_evaluate)
|
||||
if checkpoint_data is None:
|
||||
raise RuntimeError('Invalid checkpoint.')
|
||||
backbone_params = checkpoint_data['experiment_state'].online_params
|
||||
backbone_state = checkpoint_data['experiment_state'].online_state
|
||||
backbone_params = helpers.bcast_local_devices(backbone_params)
|
||||
backbone_state = helpers.bcast_local_devices(backbone_state)
|
||||
else:
|
||||
if not self._allow_train_from_scratch:
|
||||
raise ValueError(
|
||||
'No checkpoint specified, but `allow_train_from_scratch` '
|
||||
'set to False')
|
||||
# Initialize with random parameters
|
||||
logging.info(
|
||||
'No checkpoint specified, initializing the networks from scratch '
|
||||
'(dry run mode)')
|
||||
backbone_params, backbone_state = jax.pmap(
|
||||
functools.partial(self.forward_backbone.init, is_training=True),
|
||||
axis_name='i')(rng=rng, inputs=inputs)
|
||||
|
||||
init_experiment = jax.pmap(self._make_initial_state, axis_name='i')
|
||||
|
||||
# Init uses the same RNG key on all hosts+devices to ensure everyone
|
||||
# computes the same initial state and parameters.
|
||||
init_rng = jax.random.PRNGKey(self._random_seed)
|
||||
init_rng = helpers.bcast_local_devices(init_rng)
|
||||
self._experiment_state = init_experiment(
|
||||
rng=init_rng,
|
||||
dummy_input=inputs,
|
||||
backbone_params=backbone_params,
|
||||
backbone_state=backbone_state)
|
||||
|
||||
# Clear the backbone optimizer's state when the backbone is frozen.
|
||||
if self._freeze_backbone:
|
||||
self._experiment_state = _EvalExperimentState(
|
||||
backbone_params=self._experiment_state.backbone_params,
|
||||
classif_params=self._experiment_state.classif_params,
|
||||
backbone_state=self._experiment_state.backbone_state,
|
||||
backbone_opt_state=None,
|
||||
classif_opt_state=self._experiment_state.classif_opt_state,
|
||||
)
|
||||
|
||||
def _make_initial_state(
|
||||
self,
|
||||
rng,
|
||||
dummy_input,
|
||||
backbone_params,
|
||||
backbone_state,
|
||||
):
|
||||
"""_EvalExperimentState initialization."""
|
||||
|
||||
# Initialize the backbone params
|
||||
# Always create the batchnorm weights (is_training=True), they will be
|
||||
# overwritten when loading the checkpoint.
|
||||
embeddings, _ = self.forward_backbone.apply(
|
||||
backbone_params, backbone_state, dummy_input, is_training=True)
|
||||
backbone_opt_state = self._optimizer(0.).init(backbone_params)
|
||||
|
||||
# Initialize the classifier params and optimizer_state
|
||||
classif_params = self.forward_classif.init(rng, embeddings)
|
||||
classif_opt_state = self._optimizer(0.).init(classif_params)
|
||||
|
||||
return _EvalExperimentState(
|
||||
backbone_params=backbone_params,
|
||||
classif_params=classif_params,
|
||||
backbone_state=backbone_state,
|
||||
backbone_opt_state=backbone_opt_state,
|
||||
classif_opt_state=classif_opt_state,
|
||||
)
|
||||
|
||||
def _build_train_input(self):
|
||||
"""See base class."""
|
||||
num_devices = jax.device_count()
|
||||
global_batch_size = self._batch_size
|
||||
per_device_batch_size, ragged = divmod(global_batch_size, num_devices)
|
||||
|
||||
if ragged:
|
||||
raise ValueError(
|
||||
f'Global batch size {global_batch_size} must be divisible by '
|
||||
f'num devices {num_devices}')
|
||||
|
||||
return dataset.load(
|
||||
dataset.Split.TRAIN_AND_VALID,
|
||||
preprocess_mode=dataset.PreprocessMode.LINEAR_TRAIN,
|
||||
transpose=self._should_transpose_images(),
|
||||
batch_dims=[jax.local_device_count(), per_device_batch_size])
|
||||
|
||||
def _optimizer(self, learning_rate):
|
||||
"""Build optimizer from config."""
|
||||
return optax.sgd(learning_rate, **self._optimizer_config)
|
||||
|
||||
def _loss_fn(
|
||||
self,
|
||||
backbone_params,
|
||||
classif_params,
|
||||
backbone_state,
|
||||
inputs,
|
||||
):
|
||||
"""Compute the classification loss function.
|
||||
|
||||
Args:
|
||||
backbone_params: parameters of the encoder network.
|
||||
classif_params: parameters of the linear classifier.
|
||||
backbone_state: internal state of encoder network.
|
||||
inputs: inputs, containing `images` and `labels`.
|
||||
|
||||
Returns:
|
||||
The classification loss and various logs.
|
||||
"""
|
||||
embeddings, backbone_state = self.forward_backbone.apply(
|
||||
backbone_params,
|
||||
backbone_state,
|
||||
inputs,
|
||||
is_training=not self._freeze_backbone)
|
||||
|
||||
logits = self.forward_classif.apply(classif_params, embeddings)
|
||||
labels = hk.one_hot(inputs['labels'], self._num_classes)
|
||||
loss = helpers.softmax_cross_entropy(logits, labels, reduction='mean')
|
||||
scaled_loss = loss / jax.device_count()
|
||||
|
||||
return scaled_loss, (loss, backbone_state)
|
||||
|
||||
def _update_func(
|
||||
self,
|
||||
experiment_state,
|
||||
global_step,
|
||||
inputs,
|
||||
):
|
||||
"""Applies an update to parameters and returns new state."""
|
||||
# This function computes the gradient of the first output of loss_fn and
|
||||
# passes through the other arguments unchanged.
|
||||
|
||||
# Gradient of the first output of _loss_fn wrt the backbone (arg 0) and the
|
||||
# classifier parameters (arg 1). The auxiliary outputs are returned as-is.
|
||||
grad_loss_fn = jax.grad(self._loss_fn, has_aux=True, argnums=(0, 1))
|
||||
|
||||
grads, aux_outputs = grad_loss_fn(
|
||||
experiment_state.backbone_params,
|
||||
experiment_state.classif_params,
|
||||
experiment_state.backbone_state,
|
||||
inputs,
|
||||
)
|
||||
backbone_grads, classifier_grads = grads
|
||||
train_loss, new_backbone_state = aux_outputs
|
||||
classifier_grads = jax.lax.psum(classifier_grads, axis_name='i')
|
||||
|
||||
# Compute the decayed learning rate
|
||||
learning_rate = schedules.learning_schedule(
|
||||
global_step,
|
||||
batch_size=self._batch_size,
|
||||
total_steps=self._max_steps,
|
||||
**self._lr_schedule_config)
|
||||
|
||||
# Compute and apply updates via our optimizer.
|
||||
classif_updates, new_classif_opt_state = \
|
||||
self._optimizer(learning_rate).update(
|
||||
classifier_grads,
|
||||
experiment_state.classif_opt_state)
|
||||
|
||||
new_classif_params = optax.apply_updates(experiment_state.classif_params,
|
||||
classif_updates)
|
||||
|
||||
if self._freeze_backbone:
|
||||
del backbone_grads, new_backbone_state # Unused
|
||||
# The backbone is not updated.
|
||||
new_backbone_params = experiment_state.backbone_params
|
||||
new_backbone_opt_state = None
|
||||
new_backbone_state = experiment_state.backbone_state
|
||||
else:
|
||||
backbone_grads = jax.lax.psum(backbone_grads, axis_name='i')
|
||||
|
||||
# Compute and apply updates via our optimizer.
|
||||
backbone_updates, new_backbone_opt_state = \
|
||||
self._optimizer(learning_rate).update(
|
||||
backbone_grads,
|
||||
experiment_state.backbone_opt_state)
|
||||
|
||||
new_backbone_params = optax.apply_updates(
|
||||
experiment_state.backbone_params, backbone_updates)
|
||||
|
||||
experiment_state = _EvalExperimentState(
|
||||
new_backbone_params,
|
||||
new_classif_params,
|
||||
new_backbone_state,
|
||||
new_backbone_opt_state,
|
||||
new_classif_opt_state,
|
||||
)
|
||||
|
||||
# Scalars to log (note: we log the mean across all hosts/devices).
|
||||
scalars = {'train_loss': train_loss}
|
||||
scalars = jax.lax.pmean(scalars, axis_name='i')
|
||||
|
||||
return experiment_state, scalars
|
||||
|
||||
# _
|
||||
# _____ ____ _| |
|
||||
# / _ \ \ / / _` | |
|
||||
# | __/\ V / (_| | |
|
||||
# \___| \_/ \__,_|_|
|
||||
#
|
||||
|
||||
def evaluate(self, global_step, **unused_args):
|
||||
"""See base class."""
|
||||
|
||||
global_step = np.array(helpers.get_first(global_step))
|
||||
scalars = jax.device_get(self._eval_epoch(**self._evaluation_config))
|
||||
|
||||
logging.info('[Step %d] Eval scalars: %s', global_step, scalars)
|
||||
return scalars
|
||||
|
||||
def _eval_batch(
|
||||
self,
|
||||
backbone_params,
|
||||
classif_params,
|
||||
backbone_state,
|
||||
inputs,
|
||||
):
|
||||
"""Evaluates a batch."""
|
||||
embeddings, backbone_state = self.forward_backbone.apply(
|
||||
backbone_params, backbone_state, inputs, is_training=False)
|
||||
logits = self.forward_classif.apply(classif_params, embeddings)
|
||||
labels = hk.one_hot(inputs['labels'], self._num_classes)
|
||||
loss = helpers.softmax_cross_entropy(logits, labels, reduction=None)
|
||||
top1_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=1)
|
||||
top5_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=5)
|
||||
# NOTE: Returned values will be summed and finally divided by num_samples.
|
||||
return {
|
||||
'eval_loss': loss,
|
||||
'top1_accuracy': top1_correct,
|
||||
'top5_accuracy': top5_correct
|
||||
}
|
||||
|
||||
def _eval_epoch(self, subset, batch_size):
|
||||
"""Evaluates an epoch."""
|
||||
num_samples = 0.
|
||||
summed_scalars = None
|
||||
|
||||
backbone_params = helpers.get_first(self._experiment_state.backbone_params)
|
||||
classif_params = helpers.get_first(self._experiment_state.classif_params)
|
||||
backbone_state = helpers.get_first(self._experiment_state.backbone_state)
|
||||
split = dataset.Split.from_string(subset)
|
||||
|
||||
dataset_iterator = dataset.load(
|
||||
split,
|
||||
preprocess_mode=dataset.PreprocessMode.EVAL,
|
||||
transpose=self._should_transpose_images(),
|
||||
batch_dims=[batch_size])
|
||||
|
||||
for inputs in dataset_iterator:
|
||||
num_samples += inputs['labels'].shape[0]
|
||||
scalars = self.eval_batch_jit(
|
||||
backbone_params,
|
||||
classif_params,
|
||||
backbone_state,
|
||||
inputs,
|
||||
)
|
||||
|
||||
# Accumulate the sum of scalars for each step.
|
||||
scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars)
|
||||
if summed_scalars is None:
|
||||
summed_scalars = scalars
|
||||
else:
|
||||
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
|
||||
|
||||
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
|
||||
return mean_scalars
|
||||
@@ -0,0 +1,157 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Training and evaluation loops for an experiment."""
|
||||
|
||||
import time
|
||||
from typing import Any, Mapping, Text, Type, Union
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from byol import byol_experiment
|
||||
from byol import eval_experiment
|
||||
from byol.configs import byol as byol_config
|
||||
from byol.configs import eval as eval_config
|
||||
|
||||
flags.DEFINE_string('experiment_mode',
|
||||
'pretrain', 'The experiment, pretrain or linear-eval')
|
||||
flags.DEFINE_string('worker_mode', 'train', 'The mode, train or eval')
|
||||
flags.DEFINE_string('worker_tpu_driver', '', 'The tpu driver to use')
|
||||
flags.DEFINE_integer('pretrain_epochs', 1000, 'Number of pre-training epochs')
|
||||
flags.DEFINE_integer('batch_size', 4096, 'Total batch size')
|
||||
flags.DEFINE_string('checkpoint_root', '/tmp/byol',
|
||||
'The directory to save checkpoints to.')
|
||||
flags.DEFINE_integer('log_tensors_interval', 60, 'Log tensors every n seconds.')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
Experiment = Union[
|
||||
Type[byol_experiment.ByolExperiment],
|
||||
Type[eval_experiment.EvalExperiment]]
|
||||
|
||||
|
||||
def train_loop(experiment_class, config):
|
||||
"""The main training loop.
|
||||
|
||||
This loop periodically saves a checkpoint to be evaluated in the eval_loop.
|
||||
|
||||
Args:
|
||||
experiment_class: the constructor for the experiment (either byol_experiment
|
||||
or eval_experiment).
|
||||
config: the experiment config.
|
||||
"""
|
||||
experiment = experiment_class(**config)
|
||||
rng = jax.random.PRNGKey(0)
|
||||
step = 0
|
||||
|
||||
host_id = jax.host_id()
|
||||
last_logging = time.time()
|
||||
if config['checkpointing_config']['use_checkpointing']:
|
||||
checkpoint_data = experiment.load_checkpoint()
|
||||
if checkpoint_data is None:
|
||||
step = 0
|
||||
else:
|
||||
step, rng = checkpoint_data
|
||||
|
||||
local_device_count = jax.local_device_count()
|
||||
while step < config['max_steps']:
|
||||
step_rng, rng = tuple(jax.random.split(rng))
|
||||
# Broadcast the random seeds across the devices
|
||||
step_rng_device = jax.random.split(step_rng, num=jax.device_count())
|
||||
step_rng_device = step_rng_device[
|
||||
host_id * local_device_count:(host_id + 1) * local_device_count]
|
||||
step_device = np.broadcast_to(step, [local_device_count])
|
||||
|
||||
# Perform a training step and get scalars to log.
|
||||
scalars = experiment.step(global_step=step_device, rng=step_rng_device)
|
||||
|
||||
# Checkpointing and logging.
|
||||
if config['checkpointing_config']['use_checkpointing']:
|
||||
experiment.save_checkpoint(step, rng)
|
||||
current_time = time.time()
|
||||
if current_time - last_logging > FLAGS.log_tensors_interval:
|
||||
logging.info('Step %d: %s', step, scalars)
|
||||
last_logging = current_time
|
||||
step += 1
|
||||
logging.info('Saving final checkpoint')
|
||||
logging.info('Step %d: %s', step, scalars)
|
||||
experiment.save_checkpoint(step, rng)
|
||||
|
||||
|
||||
def eval_loop(experiment_class, config):
|
||||
"""The main evaluation loop.
|
||||
|
||||
This loop periodically loads a checkpoint and evaluates its performance on the
|
||||
test set, by calling experiment.evaluate.
|
||||
|
||||
Args:
|
||||
experiment_class: the constructor for the experiment (either byol_experiment
|
||||
or eval_experiment).
|
||||
config: the experiment config.
|
||||
"""
|
||||
experiment = experiment_class(**config)
|
||||
last_evaluated_step = -1
|
||||
while True:
|
||||
checkpoint_data = experiment.load_checkpoint()
|
||||
if checkpoint_data is None:
|
||||
logging.info('No checkpoint found. Waiting for 10s.')
|
||||
time.sleep(10)
|
||||
continue
|
||||
step, _ = checkpoint_data
|
||||
if step <= last_evaluated_step:
|
||||
logging.info('Checkpoint at step %d already evaluated, waiting.', step)
|
||||
time.sleep(10)
|
||||
continue
|
||||
host_id = jax.host_id()
|
||||
local_device_count = jax.local_device_count()
|
||||
step_device = np.broadcast_to(step, [local_device_count])
|
||||
scalars = experiment.evaluate(global_step=step_device)
|
||||
if host_id == 0: # Only perform logging in one host.
|
||||
logging.info('Evaluation at step %d: %s', step, scalars)
|
||||
last_evaluated_step = step
|
||||
if last_evaluated_step >= config['max_steps']:
|
||||
return
|
||||
|
||||
|
||||
def main(_):
|
||||
if FLAGS.worker_tpu_driver:
|
||||
jax.config.update('jax_xla_backend', 'tpu_driver')
|
||||
jax.config.update('jax_backend_target', FLAGS.worker_tpu_driver)
|
||||
logging.info('Backend: %s %r', FLAGS.worker_tpu_driver, jax.devices())
|
||||
|
||||
if FLAGS.experiment_mode == 'pretrain':
|
||||
experiment_class = byol_experiment.ByolExperiment
|
||||
config = byol_config.get_config(FLAGS.pretrain_epochs, FLAGS.batch_size)
|
||||
elif FLAGS.experiment_mode == 'linear-eval':
|
||||
experiment_class = eval_experiment.EvalExperiment
|
||||
config = eval_config.get_config(f'{FLAGS.checkpoint_root}/pretrain.pkl',
|
||||
FLAGS.batch_size)
|
||||
else:
|
||||
raise ValueError(f'Unknown experiment mode: {FLAGS.experiment_mode}')
|
||||
config['checkpointing_config']['checkpoint_dir'] = FLAGS.checkpoint_root
|
||||
|
||||
if FLAGS.worker_mode == 'train':
|
||||
train_loop(experiment_class, config)
|
||||
elif FLAGS.worker_mode == 'eval':
|
||||
eval_loop(experiment_class, config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,97 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for BYOL's main training loop."""
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from byol import byol_experiment
|
||||
from byol import eval_experiment
|
||||
from byol import main_loop
|
||||
from byol.configs import byol as byol_config
|
||||
from byol.configs import eval as eval_config
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class MainLoopTest(absltest.TestCase):
|
||||
|
||||
def test_pretrain(self):
|
||||
config = byol_config.get_config(num_epochs=40, batch_size=4)
|
||||
temp_dir = self.create_tempdir().full_path
|
||||
|
||||
# Override some config fields to make test lighter.
|
||||
config['network_config']['encoder_class'] = 'TinyResNet'
|
||||
config['network_config']['projector_hidden_size'] = 256
|
||||
config['network_config']['predictor_hidden_size'] = 256
|
||||
config['checkpointing_config']['checkpoint_dir'] = temp_dir
|
||||
config['evaluation_config']['batch_size'] = 16
|
||||
config['max_steps'] = 16
|
||||
|
||||
with tfds.testing.mock_data(num_examples=64):
|
||||
experiment_class = byol_experiment.ByolExperiment
|
||||
main_loop.train_loop(experiment_class, config)
|
||||
main_loop.eval_loop(experiment_class, config)
|
||||
|
||||
def test_linear_eval(self):
|
||||
config = eval_config.get_config(checkpoint_to_evaluate=None, batch_size=4)
|
||||
temp_dir = self.create_tempdir().full_path
|
||||
|
||||
# Override some config fields to make test lighter.
|
||||
config['network_config']['encoder_class'] = 'TinyResNet'
|
||||
config['allow_train_from_scratch'] = True
|
||||
config['checkpointing_config']['checkpoint_dir'] = temp_dir
|
||||
config['evaluation_config']['batch_size'] = 16
|
||||
config['max_steps'] = 16
|
||||
|
||||
with tfds.testing.mock_data(num_examples=64):
|
||||
experiment_class = eval_experiment.EvalExperiment
|
||||
main_loop.train_loop(experiment_class, config)
|
||||
main_loop.eval_loop(experiment_class, config)
|
||||
|
||||
def test_pipeline(self):
|
||||
b_config = byol_config.get_config(num_epochs=40, batch_size=4)
|
||||
temp_dir = self.create_tempdir().full_path
|
||||
|
||||
# Override some config fields to make test lighter.
|
||||
b_config['network_config']['encoder_class'] = 'TinyResNet'
|
||||
b_config['network_config']['projector_hidden_size'] = 256
|
||||
b_config['network_config']['predictor_hidden_size'] = 256
|
||||
b_config['checkpointing_config']['checkpoint_dir'] = temp_dir
|
||||
b_config['evaluation_config']['batch_size'] = 16
|
||||
b_config['max_steps'] = 16
|
||||
|
||||
with tfds.testing.mock_data(num_examples=64):
|
||||
main_loop.train_loop(byol_experiment.ByolExperiment, b_config)
|
||||
|
||||
e_config = eval_config.get_config(
|
||||
checkpoint_to_evaluate=f'{temp_dir}/pretrain.pkl',
|
||||
batch_size=4)
|
||||
|
||||
# Override some config fields to make test lighter.
|
||||
e_config['network_config']['encoder_class'] = 'TinyResNet'
|
||||
e_config['allow_train_from_scratch'] = True
|
||||
e_config['checkpointing_config']['checkpoint_dir'] = temp_dir
|
||||
e_config['evaluation_config']['batch_size'] = 16
|
||||
e_config['max_steps'] = 16
|
||||
|
||||
with tfds.testing.mock_data(num_examples=64):
|
||||
main_loop.train_loop(eval_experiment.EvalExperiment, e_config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
@@ -0,0 +1,8 @@
|
||||
dm-acme
|
||||
dm-haiku
|
||||
dm-tree
|
||||
jax
|
||||
jaxlib
|
||||
numpy>=1.16
|
||||
tensorflow
|
||||
tensorflow_datasets
|
||||
@@ -0,0 +1,457 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Data preprocessing and augmentation."""
|
||||
|
||||
import functools
|
||||
from typing import Any, Mapping, Text
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
# typing
|
||||
JaxBatch = Mapping[Text, jnp.ndarray]
|
||||
ConfigDict = Mapping[Text, Any]
|
||||
|
||||
augment_config = dict(
|
||||
view1=dict(
|
||||
random_flip=True, # Random left/right flip
|
||||
color_transform=dict(
|
||||
apply_prob=1.0,
|
||||
# Range of jittering
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.2,
|
||||
hue=0.1,
|
||||
# Probability of applying color jittering
|
||||
color_jitter_prob=0.8,
|
||||
# Probability of converting to grayscale
|
||||
to_grayscale_prob=0.2,
|
||||
# Shuffle the order of color transforms
|
||||
shuffle=True),
|
||||
gaussian_blur=dict(
|
||||
apply_prob=1.0,
|
||||
# Kernel size ~ image_size / blur_divider
|
||||
blur_divider=10.,
|
||||
# Kernel distribution
|
||||
sigma_min=0.1,
|
||||
sigma_max=2.0),
|
||||
solarize=dict(apply_prob=0.0, threshold=0.5),
|
||||
),
|
||||
view2=dict(
|
||||
random_flip=True,
|
||||
color_transform=dict(
|
||||
apply_prob=1.0,
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.2,
|
||||
hue=0.1,
|
||||
color_jitter_prob=0.8,
|
||||
to_grayscale_prob=0.2,
|
||||
shuffle=True),
|
||||
gaussian_blur=dict(
|
||||
apply_prob=0.1, blur_divider=10., sigma_min=0.1, sigma_max=2.0),
|
||||
solarize=dict(apply_prob=0.2, threshold=0.5),
|
||||
))
|
||||
|
||||
|
||||
def postprocess(inputs, rng):
|
||||
"""Apply the image augmentations to crops in inputs (view1 and view2)."""
|
||||
|
||||
def _postprocess_image(
|
||||
images,
|
||||
rng,
|
||||
presets,
|
||||
):
|
||||
"""Applies augmentations in post-processing.
|
||||
|
||||
Args:
|
||||
images: an NHWC tensor (with C=3), with float values in [0, 1].
|
||||
rng: a single PRNGKey.
|
||||
presets: a dict of presets for the augmentations.
|
||||
|
||||
Returns:
|
||||
A batch of augmented images with shape NHWC, with keys view1, view2
|
||||
and labels.
|
||||
"""
|
||||
flip_rng, color_rng, blur_rng, solarize_rng = jax.random.split(rng, 4)
|
||||
out = images
|
||||
if presets['random_flip']:
|
||||
out = random_flip(out, flip_rng)
|
||||
if presets['color_transform']['apply_prob'] > 0:
|
||||
out = color_transform(out, color_rng, **presets['color_transform'])
|
||||
if presets['gaussian_blur']['apply_prob'] > 0:
|
||||
out = gaussian_blur(out, blur_rng, **presets['gaussian_blur'])
|
||||
if presets['solarize']['apply_prob'] > 0:
|
||||
out = solarize(out, solarize_rng, **presets['solarize'])
|
||||
out = jnp.clip(out, 0., 1.)
|
||||
return jax.lax.stop_gradient(out)
|
||||
|
||||
rng1, rng2 = jax.random.split(rng, num=2)
|
||||
view1 = _postprocess_image(inputs['view1'], rng1, augment_config['view1'])
|
||||
view2 = _postprocess_image(inputs['view2'], rng2, augment_config['view2'])
|
||||
return dict(view1=view1, view2=view2, labels=inputs['labels'])
|
||||
|
||||
|
||||
def _maybe_apply(apply_fn, inputs, rng, apply_prob):
|
||||
should_apply = jax.random.uniform(rng, shape=()) <= apply_prob
|
||||
return jax.lax.cond(should_apply, inputs, apply_fn, inputs, lambda x: x)
|
||||
|
||||
|
||||
def _depthwise_conv2d(inputs, kernel, strides, padding):
|
||||
"""Computes a depthwise conv2d in Jax.
|
||||
|
||||
Args:
|
||||
inputs: an NHWC tensor with N=1.
|
||||
kernel: a [H", W", 1, C] tensor.
|
||||
strides: a 2d tensor.
|
||||
padding: "SAME" or "VALID".
|
||||
|
||||
Returns:
|
||||
The depthwise convolution of inputs with kernel, as [H, W, C].
|
||||
"""
|
||||
return jax.lax.conv_general_dilated(
|
||||
inputs,
|
||||
kernel,
|
||||
strides,
|
||||
padding,
|
||||
feature_group_count=inputs.shape[-1],
|
||||
dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
|
||||
|
||||
|
||||
def _gaussian_blur_single_image(image, kernel_size, padding, sigma):
|
||||
"""Applies gaussian blur to a single image, given as NHWC with N=1."""
|
||||
radius = int(kernel_size / 2)
|
||||
kernel_size_ = 2 * radius + 1
|
||||
x = jnp.arange(-radius, radius + 1).astype(jnp.float32)
|
||||
blur_filter = jnp.exp(-x**2 / (2. * sigma**2))
|
||||
blur_filter = blur_filter / jnp.sum(blur_filter)
|
||||
blur_v = jnp.reshape(blur_filter, [kernel_size_, 1, 1, 1])
|
||||
blur_h = jnp.reshape(blur_filter, [1, kernel_size_, 1, 1])
|
||||
num_channels = image.shape[-1]
|
||||
blur_h = jnp.tile(blur_h, [1, 1, 1, num_channels])
|
||||
blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels])
|
||||
expand_batch_dim = len(image.shape) == 3
|
||||
if expand_batch_dim:
|
||||
image = image[jnp.newaxis, Ellipsis]
|
||||
blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding)
|
||||
blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding)
|
||||
blurred = jnp.squeeze(blurred, axis=0)
|
||||
return blurred
|
||||
|
||||
|
||||
def _random_gaussian_blur(image, rng, kernel_size, padding, sigma_min,
|
||||
sigma_max, apply_prob):
|
||||
"""Applies a random gaussian blur."""
|
||||
apply_rng, transform_rng = jax.random.split(rng)
|
||||
|
||||
def _apply(image):
|
||||
sigma_rng, = jax.random.split(transform_rng, 1)
|
||||
sigma = jax.random.uniform(
|
||||
sigma_rng,
|
||||
shape=(),
|
||||
minval=sigma_min,
|
||||
maxval=sigma_max,
|
||||
dtype=jnp.float32)
|
||||
return _gaussian_blur_single_image(image, kernel_size, padding, sigma)
|
||||
|
||||
return _maybe_apply(_apply, image, apply_rng, apply_prob)
|
||||
|
||||
|
||||
def rgb_to_hsv(r, g, b):
|
||||
"""Converts R, G, B values to H, S, V values.
|
||||
|
||||
Reference TF implementation:
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
|
||||
Only input values between 0 and 1 are guaranteed to work properly, but this
|
||||
function complies with the TF implementation outside of this range.
|
||||
|
||||
Args:
|
||||
r: A tensor representing the red color component as floats.
|
||||
g: A tensor representing the green color component as floats.
|
||||
b: A tensor representing the blue color component as floats.
|
||||
|
||||
Returns:
|
||||
H, S, V values, each as tensors of shape [...] (same as the input without
|
||||
the last dimension).
|
||||
"""
|
||||
vv = jnp.maximum(jnp.maximum(r, g), b)
|
||||
range_ = vv - jnp.minimum(jnp.minimum(r, g), b)
|
||||
sat = jnp.where(vv > 0, range_ / vv, 0.)
|
||||
norm = jnp.where(range_ != 0, 1. / (6. * range_), 1e9)
|
||||
|
||||
hr = norm * (g - b)
|
||||
hg = norm * (b - r) + 2. / 6.
|
||||
hb = norm * (r - g) + 4. / 6.
|
||||
|
||||
hue = jnp.where(r == vv, hr, jnp.where(g == vv, hg, hb))
|
||||
hue = hue * (range_ > 0)
|
||||
hue = hue + (hue < 0)
|
||||
|
||||
return hue, sat, vv
|
||||
|
||||
|
||||
def hsv_to_rgb(h, s, v):
|
||||
"""Converts H, S, V values to an R, G, B tuple.
|
||||
|
||||
Reference TF implementation:
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
|
||||
Only input values between 0 and 1 are guaranteed to work properly, but this
|
||||
function complies with the TF implementation outside of this range.
|
||||
|
||||
Args:
|
||||
h: A float tensor of arbitrary shape for the hue (0-1 values).
|
||||
s: A float tensor of the same shape for the saturation (0-1 values).
|
||||
v: A float tensor of the same shape for the value channel (0-1 values).
|
||||
|
||||
Returns:
|
||||
An (r, g, b) tuple, each with the same dimension as the inputs.
|
||||
"""
|
||||
c = s * v
|
||||
m = v - c
|
||||
dh = (h % 1.) * 6.
|
||||
fmodu = dh % 2.
|
||||
x = c * (1 - jnp.abs(fmodu - 1))
|
||||
hcat = jnp.floor(dh).astype(jnp.int32)
|
||||
rr = jnp.where(
|
||||
(hcat == 0) | (hcat == 5), c, jnp.where(
|
||||
(hcat == 1) | (hcat == 4), x, 0)) + m
|
||||
gg = jnp.where(
|
||||
(hcat == 1) | (hcat == 2), c, jnp.where(
|
||||
(hcat == 0) | (hcat == 3), x, 0)) + m
|
||||
bb = jnp.where(
|
||||
(hcat == 3) | (hcat == 4), c, jnp.where(
|
||||
(hcat == 2) | (hcat == 5), x, 0)) + m
|
||||
return rr, gg, bb
|
||||
|
||||
|
||||
def adjust_brightness(rgb_tuple, delta):
|
||||
return jax.tree_map(lambda x: x + delta, rgb_tuple)
|
||||
|
||||
|
||||
def adjust_contrast(image, factor):
|
||||
def _adjust_contrast_channel(channel):
|
||||
mean = jnp.mean(channel, axis=(-2, -1), keepdims=True)
|
||||
return factor * (channel - mean) + mean
|
||||
return jax.tree_map(_adjust_contrast_channel, image)
|
||||
|
||||
|
||||
def adjust_saturation(h, s, v, factor):
|
||||
return h, jnp.clip(s * factor, 0., 1.), v
|
||||
|
||||
|
||||
def adjust_hue(h, s, v, delta):
|
||||
# Note: this method exactly matches TF"s adjust_hue (combined with the hsv/rgb
|
||||
# conversions) when running on GPU. When running on CPU, the results will be
|
||||
# different if all RGB values for a pixel are outside of the [0, 1] range.
|
||||
return (h + delta) % 1.0, s, v
|
||||
|
||||
|
||||
def _random_brightness(rgb_tuple, rng, max_delta):
|
||||
delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
|
||||
return adjust_brightness(rgb_tuple, delta)
|
||||
|
||||
|
||||
def _random_contrast(rgb_tuple, rng, max_delta):
|
||||
factor = jax.random.uniform(
|
||||
rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta)
|
||||
return adjust_contrast(rgb_tuple, factor)
|
||||
|
||||
|
||||
def _random_saturation(rgb_tuple, rng, max_delta):
|
||||
h, s, v = rgb_to_hsv(*rgb_tuple)
|
||||
factor = jax.random.uniform(
|
||||
rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta)
|
||||
return hsv_to_rgb(*adjust_saturation(h, s, v, factor))
|
||||
|
||||
|
||||
def _random_hue(rgb_tuple, rng, max_delta):
|
||||
h, s, v = rgb_to_hsv(*rgb_tuple)
|
||||
delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
|
||||
return hsv_to_rgb(*adjust_hue(h, s, v, delta))
|
||||
|
||||
|
||||
def _to_grayscale(image):
|
||||
rgb_weights = jnp.array([0.2989, 0.5870, 0.1140])
|
||||
grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[Ellipsis, jnp.newaxis]
|
||||
return jnp.tile(grayscale, (1, 1, 3)) # Back to 3 channels.
|
||||
|
||||
|
||||
def _color_transform_single_image(image, rng, brightness, contrast, saturation,
|
||||
hue, to_grayscale_prob, color_jitter_prob,
|
||||
apply_prob, shuffle):
|
||||
"""Applies color jittering to a single image."""
|
||||
apply_rng, transform_rng = jax.random.split(rng)
|
||||
perm_rng, b_rng, c_rng, s_rng, h_rng, cj_rng, gs_rng = jax.random.split(
|
||||
transform_rng, 7)
|
||||
|
||||
# Whether the transform should be applied at all.
|
||||
should_apply = jax.random.uniform(apply_rng, shape=()) <= apply_prob
|
||||
# Whether to apply grayscale transform.
|
||||
should_apply_gs = jax.random.uniform(gs_rng, shape=()) <= to_grayscale_prob
|
||||
# Whether to apply color jittering.
|
||||
should_apply_color = jax.random.uniform(cj_rng, shape=()) <= color_jitter_prob
|
||||
|
||||
# Decorator to conditionally apply fn based on an index.
|
||||
def _make_cond(fn, idx):
|
||||
|
||||
def identity_fn(x, unused_rng, unused_param):
|
||||
return x
|
||||
|
||||
def cond_fn(args, i):
|
||||
def clip(args):
|
||||
return jax.tree_map(lambda arg: jnp.clip(arg, 0., 1.), args)
|
||||
out = jax.lax.cond(should_apply & should_apply_color & (i == idx), args,
|
||||
lambda a: clip(fn(*a)), args,
|
||||
lambda a: identity_fn(*a))
|
||||
return jax.lax.stop_gradient(out)
|
||||
|
||||
return cond_fn
|
||||
|
||||
random_brightness_cond = _make_cond(_random_brightness, idx=0)
|
||||
random_contrast_cond = _make_cond(_random_contrast, idx=1)
|
||||
random_saturation_cond = _make_cond(_random_saturation, idx=2)
|
||||
random_hue_cond = _make_cond(_random_hue, idx=3)
|
||||
|
||||
def _color_jitter(x):
|
||||
rgb_tuple = tuple(jax.tree_map(jnp.squeeze, jnp.split(x, 3, axis=-1)))
|
||||
if shuffle:
|
||||
order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32))
|
||||
else:
|
||||
order = range(4)
|
||||
for idx in order:
|
||||
if brightness > 0:
|
||||
rgb_tuple = random_brightness_cond((rgb_tuple, b_rng, brightness), idx)
|
||||
if contrast > 0:
|
||||
rgb_tuple = random_contrast_cond((rgb_tuple, c_rng, contrast), idx)
|
||||
if saturation > 0:
|
||||
rgb_tuple = random_saturation_cond((rgb_tuple, s_rng, saturation), idx)
|
||||
if hue > 0:
|
||||
rgb_tuple = random_hue_cond((rgb_tuple, h_rng, hue), idx)
|
||||
return jnp.stack(rgb_tuple, axis=-1)
|
||||
|
||||
out_apply = _color_jitter(image)
|
||||
out_apply = jax.lax.cond(should_apply & should_apply_gs, out_apply,
|
||||
_to_grayscale, out_apply, lambda x: x)
|
||||
return jnp.clip(out_apply, 0., 1.)
|
||||
|
||||
|
||||
def _random_flip_single_image(image, rng):
|
||||
_, flip_rng = jax.random.split(rng)
|
||||
should_flip_lr = jax.random.uniform(flip_rng, shape=()) <= 0.5
|
||||
image = jax.lax.cond(should_flip_lr, image, jnp.fliplr, image, lambda x: x)
|
||||
return image
|
||||
|
||||
|
||||
def random_flip(images, rng):
|
||||
rngs = jax.random.split(rng, images.shape[0])
|
||||
return jax.vmap(_random_flip_single_image)(images, rngs)
|
||||
|
||||
|
||||
def color_transform(images,
|
||||
rng,
|
||||
brightness=0.8,
|
||||
contrast=0.8,
|
||||
saturation=0.8,
|
||||
hue=0.2,
|
||||
color_jitter_prob=0.8,
|
||||
to_grayscale_prob=0.2,
|
||||
apply_prob=1.0,
|
||||
shuffle=True):
|
||||
"""Applies color jittering and/or grayscaling to a batch of images.
|
||||
|
||||
Args:
|
||||
images: an NHWC tensor, with C=3.
|
||||
rng: a single PRNGKey.
|
||||
brightness: the range of jitter on brightness.
|
||||
contrast: the range of jitter on contrast.
|
||||
saturation: the range of jitter on saturation.
|
||||
hue: the range of jitter on hue.
|
||||
color_jitter_prob: the probability of applying color jittering.
|
||||
to_grayscale_prob: the probability of converting the image to grayscale.
|
||||
apply_prob: the probability of applying the transform to a batch element.
|
||||
shuffle: whether to apply the transforms in a random order.
|
||||
|
||||
Returns:
|
||||
A NHWC tensor of the transformed images.
|
||||
"""
|
||||
rngs = jax.random.split(rng, images.shape[0])
|
||||
jitter_fn = functools.partial(
|
||||
_color_transform_single_image,
|
||||
brightness=brightness,
|
||||
contrast=contrast,
|
||||
saturation=saturation,
|
||||
hue=hue,
|
||||
color_jitter_prob=color_jitter_prob,
|
||||
to_grayscale_prob=to_grayscale_prob,
|
||||
apply_prob=apply_prob,
|
||||
shuffle=shuffle)
|
||||
return jax.vmap(jitter_fn)(images, rngs)
|
||||
|
||||
|
||||
def gaussian_blur(images,
|
||||
rng,
|
||||
blur_divider=10.,
|
||||
sigma_min=0.1,
|
||||
sigma_max=2.0,
|
||||
apply_prob=1.0):
|
||||
"""Applies gaussian blur to a batch of images.
|
||||
|
||||
Args:
|
||||
images: an NHWC tensor, with C=3.
|
||||
rng: a single PRNGKey.
|
||||
blur_divider: the blurring kernel will have size H / blur_divider.
|
||||
sigma_min: the minimum value for sigma in the blurring kernel.
|
||||
sigma_max: the maximum value for sigma in the blurring kernel.
|
||||
apply_prob: the probability of applying the transform to a batch element.
|
||||
|
||||
Returns:
|
||||
A NHWC tensor of the blurred images.
|
||||
"""
|
||||
rngs = jax.random.split(rng, images.shape[0])
|
||||
kernel_size = images.shape[1] / blur_divider
|
||||
blur_fn = functools.partial(
|
||||
_random_gaussian_blur,
|
||||
kernel_size=kernel_size,
|
||||
padding='SAME',
|
||||
sigma_min=sigma_min,
|
||||
sigma_max=sigma_max,
|
||||
apply_prob=apply_prob)
|
||||
return jax.vmap(blur_fn)(images, rngs)
|
||||
|
||||
|
||||
def _solarize_single_image(image, rng, threshold, apply_prob):
|
||||
|
||||
def _apply(image):
|
||||
return jnp.where(image < threshold, image, 1. - image)
|
||||
|
||||
return _maybe_apply(_apply, image, rng, apply_prob)
|
||||
|
||||
|
||||
def solarize(images, rng, threshold=0.5, apply_prob=1.0):
|
||||
"""Applies solarization.
|
||||
|
||||
Args:
|
||||
images: an NHWC tensor (with C=3).
|
||||
rng: a single PRNGKey.
|
||||
threshold: the solarization threshold.
|
||||
apply_prob: the probability of applying the transform to a batch element.
|
||||
|
||||
Returns:
|
||||
A NHWC tensor of the transformed images.
|
||||
"""
|
||||
rngs = jax.random.split(rng, images.shape[0])
|
||||
solarize_fn = functools.partial(
|
||||
_solarize_single_image, threshold=threshold, apply_prob=apply_prob)
|
||||
return jax.vmap(solarize_fn)(images, rngs)
|
||||
@@ -0,0 +1,105 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Checkpoint saving and restoring utilities."""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Mapping, Text, Tuple, Union
|
||||
|
||||
from absl import logging
|
||||
import dill
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from byol.utils import helpers
|
||||
|
||||
|
||||
class Checkpointer:
|
||||
"""A checkpoint saving and loading class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_checkpointing,
|
||||
checkpoint_dir,
|
||||
save_checkpoint_interval,
|
||||
filename):
|
||||
if (not use_checkpointing or
|
||||
checkpoint_dir is None or
|
||||
save_checkpoint_interval <= 0):
|
||||
self._checkpoint_enabled = False
|
||||
return
|
||||
|
||||
self._checkpoint_enabled = True
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
os.makedirs(self._checkpoint_dir, exist_ok=True)
|
||||
self._filename = filename
|
||||
self._checkpoint_path = os.path.join(self._checkpoint_dir, filename)
|
||||
self._last_checkpoint_time = 0
|
||||
self._checkpoint_every = save_checkpoint_interval
|
||||
|
||||
def maybe_save_checkpoint(
|
||||
self,
|
||||
experiment_state,
|
||||
step,
|
||||
rng,
|
||||
is_final):
|
||||
"""Saves a checkpoint if enough time has passed since the previous one."""
|
||||
current_time = time.time()
|
||||
if (not self._checkpoint_enabled or
|
||||
jax.host_id() != 0 or # Only checkpoint the first worker.
|
||||
(not is_final and
|
||||
current_time - self._last_checkpoint_time < self._checkpoint_every)):
|
||||
return
|
||||
checkpoint_data = dict(
|
||||
experiment_state=jax.tree_map(
|
||||
lambda x: jax.device_get(x[0]), experiment_state),
|
||||
step=step,
|
||||
rng=rng)
|
||||
with open(self._checkpoint_path + '_tmp', 'wb') as checkpoint_file:
|
||||
dill.dump(checkpoint_data, checkpoint_file, protocol=2)
|
||||
try:
|
||||
os.rename(self._checkpoint_path, self._checkpoint_path + '_old')
|
||||
remove_old = True
|
||||
except FileNotFoundError:
|
||||
remove_old = False # No previous checkpoint to remove
|
||||
os.rename(self._checkpoint_path + '_tmp', self._checkpoint_path)
|
||||
if remove_old:
|
||||
os.remove(self._checkpoint_path + '_old')
|
||||
self._last_checkpoint_time = current_time
|
||||
|
||||
def maybe_load_checkpoint(
|
||||
self):
|
||||
"""Loads a checkpoint if any is found."""
|
||||
checkpoint_data = load_checkpoint(self._checkpoint_path)
|
||||
if checkpoint_data is None:
|
||||
logging.info('No existing checkpoint found at %s', self._checkpoint_path)
|
||||
return None
|
||||
step = checkpoint_data['step']
|
||||
rng = checkpoint_data['rng']
|
||||
experiment_state = jax.tree_map(
|
||||
helpers.bcast_local_devices, checkpoint_data['experiment_state'])
|
||||
del checkpoint_data
|
||||
return experiment_state, step, rng
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_path):
|
||||
try:
|
||||
with open(checkpoint_path, 'rb') as checkpoint_file:
|
||||
checkpoint_data = dill.load(checkpoint_file)
|
||||
logging.info('Loading checkpoint from %s, saved at step %d',
|
||||
checkpoint_path, checkpoint_data['step'])
|
||||
return checkpoint_data
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
@@ -0,0 +1,266 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""ImageNet dataset with typical pre-processing."""
|
||||
|
||||
import enum
|
||||
from typing import Generator, Mapping, Optional, Sequence, Text, Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tensorflow.compat.v2 as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
Batch = Mapping[Text, np.ndarray]
|
||||
|
||||
|
||||
class Split(enum.Enum):
|
||||
"""Imagenet dataset split."""
|
||||
TRAIN = 1
|
||||
TRAIN_AND_VALID = 2
|
||||
VALID = 3
|
||||
TEST = 4
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, name):
|
||||
return {
|
||||
'TRAIN': Split.TRAIN,
|
||||
'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
|
||||
'VALID': Split.VALID,
|
||||
'VALIDATION': Split.VALID,
|
||||
'TEST': Split.TEST
|
||||
}[name.upper()]
|
||||
|
||||
@property
|
||||
def num_examples(self):
|
||||
return {
|
||||
Split.TRAIN_AND_VALID: 1281167,
|
||||
Split.TRAIN: 1271167,
|
||||
Split.VALID: 10000,
|
||||
Split.TEST: 50000
|
||||
}[self]
|
||||
|
||||
|
||||
class PreprocessMode(enum.Enum):
|
||||
"""Preprocessing modes for the dataset."""
|
||||
PRETRAIN = 1 # Generates two augmented views (random crop + augmentations).
|
||||
LINEAR_TRAIN = 2 # Generates a single random crop.
|
||||
EVAL = 3 # Generates a single center crop.
|
||||
|
||||
|
||||
def normalize_images(images):
|
||||
"""Normalize the image using ImageNet statistics."""
|
||||
mean_rgb = (0.485, 0.456, 0.406)
|
||||
stddev_rgb = (0.229, 0.224, 0.225)
|
||||
normed_images = images - jnp.array(mean_rgb).reshape((1, 1, 1, 3))
|
||||
normed_images = normed_images / jnp.array(stddev_rgb).reshape((1, 1, 1, 3))
|
||||
return normed_images
|
||||
|
||||
|
||||
def load(split,
|
||||
*,
|
||||
preprocess_mode,
|
||||
batch_dims,
|
||||
transpose = False,
|
||||
allow_caching = False):
|
||||
"""Loads the given split of the dataset."""
|
||||
start, end = _shard(split, jax.host_id(), jax.host_count())
|
||||
|
||||
total_batch_size = np.prod(batch_dims)
|
||||
|
||||
tfds_split = tfds.core.ReadInstruction(
|
||||
_to_tfds_split(split), from_=start, to=end, unit='abs')
|
||||
ds = tfds.load(
|
||||
'imagenet2012:5.*.*',
|
||||
split=tfds_split,
|
||||
decoders={'image': tfds.decode.SkipDecoding()})
|
||||
|
||||
options = ds.options()
|
||||
options.experimental_threading.private_threadpool_size = 48
|
||||
options.experimental_threading.max_intra_op_parallelism = 1
|
||||
|
||||
if preprocess_mode is not PreprocessMode.EVAL:
|
||||
options.experimental_deterministic = False
|
||||
if jax.host_count() > 1 and allow_caching:
|
||||
# Only cache if we are reading a subset of the dataset.
|
||||
ds = ds.cache()
|
||||
ds = ds.repeat()
|
||||
ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0)
|
||||
|
||||
else:
|
||||
if split.num_examples % total_batch_size != 0:
|
||||
raise ValueError(f'Test/valid must be divisible by {total_batch_size}')
|
||||
|
||||
def preprocess_pretrain(example):
|
||||
view1 = _preprocess_image(example['image'], mode=preprocess_mode)
|
||||
view2 = _preprocess_image(example['image'], mode=preprocess_mode)
|
||||
label = tf.cast(example['label'], tf.int32)
|
||||
return {'view1': view1, 'view2': view2, 'labels': label}
|
||||
|
||||
def preprocess_linear_train(example):
|
||||
image = _preprocess_image(example['image'], mode=preprocess_mode)
|
||||
label = tf.cast(example['label'], tf.int32)
|
||||
return {'images': image, 'labels': label}
|
||||
|
||||
def preprocess_eval(example):
|
||||
image = _preprocess_image(example['image'], mode=preprocess_mode)
|
||||
label = tf.cast(example['label'], tf.int32)
|
||||
return {'images': image, 'labels': label}
|
||||
|
||||
if preprocess_mode is PreprocessMode.PRETRAIN:
|
||||
ds = ds.map(
|
||||
preprocess_pretrain, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
elif preprocess_mode is PreprocessMode.LINEAR_TRAIN:
|
||||
ds = ds.map(
|
||||
preprocess_linear_train,
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
else:
|
||||
ds = ds.map(
|
||||
preprocess_eval, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
|
||||
def transpose_fn(batch):
|
||||
# We use the double-transpose-trick to improve performance for TPUs. Note
|
||||
# that this (typically) requires a matching HWCN->NHWC transpose in your
|
||||
# model code. The compiler cannot make this optimization for us since our
|
||||
# data pipeline and model are compiled separately.
|
||||
batch = dict(**batch)
|
||||
if preprocess_mode is PreprocessMode.PRETRAIN:
|
||||
batch['view1'] = tf.transpose(batch['view1'], (1, 2, 3, 0))
|
||||
batch['view2'] = tf.transpose(batch['view2'], (1, 2, 3, 0))
|
||||
else:
|
||||
batch['images'] = tf.transpose(batch['images'], (1, 2, 3, 0))
|
||||
return batch
|
||||
|
||||
for i, batch_size in enumerate(reversed(batch_dims)):
|
||||
ds = ds.batch(batch_size)
|
||||
if i == 0 and transpose:
|
||||
ds = ds.map(transpose_fn) # NHWC -> HWCN
|
||||
|
||||
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
|
||||
|
||||
yield from tfds.as_numpy(ds)
|
||||
|
||||
|
||||
def _to_tfds_split(split):
|
||||
"""Returns the TFDS split appropriately sharded."""
|
||||
# NOTE: Imagenet did not release labels for the test split used in the
|
||||
# competition, we consider the VALID split the TEST split and reserve
|
||||
# 10k images from TRAIN for VALID.
|
||||
if split in (Split.TRAIN, Split.TRAIN_AND_VALID, Split.VALID):
|
||||
return tfds.Split.TRAIN
|
||||
else:
|
||||
assert split == Split.TEST
|
||||
return tfds.Split.VALIDATION
|
||||
|
||||
|
||||
def _shard(split, shard_index, num_shards):
|
||||
"""Returns [start, end) for the given shard index."""
|
||||
assert shard_index < num_shards
|
||||
arange = np.arange(split.num_examples)
|
||||
shard_range = np.array_split(arange, num_shards)[shard_index]
|
||||
start, end = shard_range[0], (shard_range[-1] + 1)
|
||||
if split == Split.TRAIN:
|
||||
# Note that our TRAIN=TFDS_TRAIN[10000:] and VALID=TFDS_TRAIN[:10000].
|
||||
offset = Split.VALID.num_examples
|
||||
start += offset
|
||||
end += offset
|
||||
return start, end
|
||||
|
||||
|
||||
def _preprocess_image(
|
||||
image_bytes,
|
||||
mode,
|
||||
):
|
||||
"""Returns processed and resized images."""
|
||||
if mode is PreprocessMode.PRETRAIN:
|
||||
image = _decode_and_random_crop(image_bytes)
|
||||
# Random horizontal flipping is optionally done in augmentations.preprocess.
|
||||
elif mode is PreprocessMode.LINEAR_TRAIN:
|
||||
image = _decode_and_random_crop(image_bytes)
|
||||
image = tf.image.random_flip_left_right(image)
|
||||
else:
|
||||
image = _decode_and_center_crop(image_bytes)
|
||||
# NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without
|
||||
# clamping overshoots. This means values returned will be outside the range
|
||||
# [0.0, 255.0] (e.g. we have observed outputs in the range [-51.1, 336.6]).
|
||||
assert image.dtype == tf.uint8
|
||||
image = tf.image.resize(image, [224, 224], tf.image.ResizeMethod.BICUBIC)
|
||||
image = tf.clip_by_value(image / 255., 0., 1.)
|
||||
return image
|
||||
|
||||
|
||||
def _decode_and_random_crop(image_bytes):
|
||||
"""Make a random crop of 224."""
|
||||
img_size = tf.image.extract_jpeg_shape(image_bytes)
|
||||
area = tf.cast(img_size[1] * img_size[0], tf.float32)
|
||||
target_area = tf.random.uniform([], 0.08, 1.0, dtype=tf.float32) * area
|
||||
|
||||
log_ratio = (tf.math.log(3 / 4), tf.math.log(4 / 3))
|
||||
aspect_ratio = tf.math.exp(
|
||||
tf.random.uniform([], *log_ratio, dtype=tf.float32))
|
||||
|
||||
w = tf.cast(tf.round(tf.sqrt(target_area * aspect_ratio)), tf.int32)
|
||||
h = tf.cast(tf.round(tf.sqrt(target_area / aspect_ratio)), tf.int32)
|
||||
|
||||
w = tf.minimum(w, img_size[1])
|
||||
h = tf.minimum(h, img_size[0])
|
||||
|
||||
offset_w = tf.random.uniform((),
|
||||
minval=0,
|
||||
maxval=img_size[1] - w + 1,
|
||||
dtype=tf.int32)
|
||||
offset_h = tf.random.uniform((),
|
||||
minval=0,
|
||||
maxval=img_size[0] - h + 1,
|
||||
dtype=tf.int32)
|
||||
|
||||
crop_window = tf.stack([offset_h, offset_w, h, w])
|
||||
image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
|
||||
return image
|
||||
|
||||
|
||||
def transpose_images(batch):
|
||||
"""Transpose images for TPU training.."""
|
||||
new_batch = dict(batch) # Avoid mutating in place.
|
||||
if 'images' in batch:
|
||||
new_batch['images'] = jnp.transpose(batch['images'], (3, 0, 1, 2))
|
||||
else:
|
||||
new_batch['view1'] = jnp.transpose(batch['view1'], (3, 0, 1, 2))
|
||||
new_batch['view2'] = jnp.transpose(batch['view2'], (3, 0, 1, 2))
|
||||
return new_batch
|
||||
|
||||
|
||||
def _decode_and_center_crop(
|
||||
image_bytes,
|
||||
jpeg_shape = None,
|
||||
):
|
||||
"""Crops to center of image with padding then scales."""
|
||||
if jpeg_shape is None:
|
||||
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
|
||||
image_height = jpeg_shape[0]
|
||||
image_width = jpeg_shape[1]
|
||||
|
||||
padded_center_crop_size = tf.cast(
|
||||
((224 / (224 + 32)) *
|
||||
tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
|
||||
|
||||
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
|
||||
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
|
||||
crop_window = tf.stack([
|
||||
offset_height, offset_width, padded_center_crop_size,
|
||||
padded_center_crop_size
|
||||
])
|
||||
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
|
||||
return image
|
||||
@@ -0,0 +1,131 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utility functions."""
|
||||
|
||||
from typing import Optional, Text
|
||||
from absl import logging
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def topk_accuracy(
|
||||
logits,
|
||||
labels,
|
||||
topk,
|
||||
ignore_label_above = None,
|
||||
):
|
||||
"""Top-num_codes accuracy."""
|
||||
assert len(labels.shape) == 1, 'topk expects 1d int labels.'
|
||||
assert len(logits.shape) == 2, 'topk expects 2d logits.'
|
||||
|
||||
if ignore_label_above is not None:
|
||||
logits = logits[labels < ignore_label_above, :]
|
||||
labels = labels[labels < ignore_label_above]
|
||||
|
||||
prds = jnp.argsort(logits, axis=1)[:, ::-1]
|
||||
prds = prds[:, :topk]
|
||||
total = jnp.any(prds == jnp.tile(labels[:, jnp.newaxis], [1, topk]), axis=1)
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def softmax_cross_entropy(
|
||||
logits,
|
||||
labels,
|
||||
reduction = 'mean',
|
||||
):
|
||||
"""Computes softmax cross entropy given logits and one-hot class labels.
|
||||
|
||||
Args:
|
||||
logits: Logit output values.
|
||||
labels: Ground truth one-hot-encoded labels.
|
||||
reduction: Type of reduction to apply to loss.
|
||||
|
||||
Returns:
|
||||
Loss value. If `reduction` is `none`, this has the same shape as `labels`;
|
||||
otherwise, it is scalar.
|
||||
|
||||
Raises:
|
||||
ValueError: If the type of `reduction` is unsupported.
|
||||
"""
|
||||
loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
|
||||
if reduction == 'sum':
|
||||
return jnp.sum(loss)
|
||||
elif reduction == 'mean':
|
||||
return jnp.mean(loss)
|
||||
elif reduction == 'none' or reduction is None:
|
||||
return loss
|
||||
else:
|
||||
raise ValueError(f'Incorrect reduction mode {reduction}')
|
||||
|
||||
|
||||
def l2_normalize(
|
||||
x,
|
||||
axis = None,
|
||||
epsilon = 1e-12,
|
||||
):
|
||||
"""l2 normalize a tensor on an axis with numerical stability."""
|
||||
square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)
|
||||
x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))
|
||||
return x * x_inv_norm
|
||||
|
||||
|
||||
def l2_weight_regularizer(params):
|
||||
"""Helper to do lasso on weights.
|
||||
|
||||
Args:
|
||||
params: the entire param set.
|
||||
|
||||
Returns:
|
||||
Scalar of the l2 norm of the weights.
|
||||
"""
|
||||
l2_norm = 0.
|
||||
for mod_name, mod_params in params.items():
|
||||
if 'norm' not in mod_name:
|
||||
for param_k, param_v in mod_params.items():
|
||||
if param_k != 'b' not in param_k: # Filter out biases
|
||||
l2_norm += jnp.sum(jnp.square(param_v))
|
||||
else:
|
||||
logging.warning('Excluding %s/%s from optimizer weight decay!',
|
||||
mod_name, param_k)
|
||||
else:
|
||||
logging.warning('Excluding %s from optimizer weight decay!', mod_name)
|
||||
|
||||
return 0.5 * l2_norm
|
||||
|
||||
|
||||
def regression_loss(x, y):
|
||||
"""Byol's regression loss. This is a simple cosine similarity."""
|
||||
normed_x, normed_y = l2_normalize(x, axis=-1), l2_normalize(y, axis=-1)
|
||||
return jnp.sum((normed_x - normed_y)**2, axis=-1)
|
||||
|
||||
|
||||
def bcast_local_devices(value):
|
||||
"""Broadcasts an object to all local devices."""
|
||||
devices = jax.local_devices()
|
||||
|
||||
def _replicate(x):
|
||||
"""Replicate an object on each device."""
|
||||
x = jnp.array(x)
|
||||
aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype)
|
||||
buffers = [jax.interpreters.xla.device_put(x, d) for d in devices]
|
||||
return jax.pxla.ShardedDeviceArray(aval, buffers)
|
||||
|
||||
return jax.tree_util.tree_map(_replicate, value)
|
||||
|
||||
|
||||
def get_first(xs):
|
||||
"""Gets values from the first device."""
|
||||
return jax.tree_map(lambda x: x[0], xs)
|
||||
@@ -0,0 +1,319 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Networks used in BYOL."""
|
||||
|
||||
from typing import Any, Mapping, Optional, Sequence, Text
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class MLP(hk.Module):
|
||||
"""One hidden layer perceptron, with normalization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
hidden_size,
|
||||
output_size,
|
||||
bn_config,
|
||||
):
|
||||
super().__init__(name=name)
|
||||
self._hidden_size = hidden_size
|
||||
self._output_size = output_size
|
||||
self._bn_config = bn_config
|
||||
|
||||
def __call__(self, inputs, is_training):
|
||||
out = hk.Linear(output_size=self._hidden_size, with_bias=True)(inputs)
|
||||
out = hk.BatchNorm(**self._bn_config)(out, is_training=is_training)
|
||||
out = jax.nn.relu(out)
|
||||
out = hk.Linear(output_size=self._output_size, with_bias=False)(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNetTorso(hk.nets.ResNet):
|
||||
"""ResNet model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
blocks_per_group,
|
||||
num_classes = None,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
bottleneck = True,
|
||||
channels_per_group = (256, 512, 1024, 2048),
|
||||
use_projection = (True, True, True, True),
|
||||
width_multiplier = 1,
|
||||
name = None,
|
||||
):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
blocks_per_group: A sequence of length 4 that indicates the number of
|
||||
blocks created in each group.
|
||||
num_classes: The number of classes to classify the inputs into.
|
||||
bn_config: A dictionary of three elements, `decay_rate`, `eps`, and
|
||||
`cross_replica_axis`, to be passed on to the `BatchNorm` layers. By
|
||||
default the `decay_rate` is `0.9` and `eps` is `1e-5`, and the axis is
|
||||
`None`.
|
||||
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
|
||||
False.
|
||||
bottleneck: Whether the block should bottleneck or not. Defaults to True.
|
||||
channels_per_group: A sequence of length 4 that indicates the number
|
||||
of channels used for each block in each group.
|
||||
use_projection: A sequence of length 4 that indicates whether each
|
||||
residual block should use projection.
|
||||
width_multiplier: An integer multiplying the number of channels per group.
|
||||
name: Name of the module.
|
||||
"""
|
||||
channels_per_group = [width_multiplier * channel
|
||||
for channel in channels_per_group]
|
||||
super().__init__(
|
||||
blocks_per_group=blocks_per_group,
|
||||
num_classes=num_classes,
|
||||
bn_config=bn_config,
|
||||
resnet_v2=resnet_v2,
|
||||
bottleneck=bottleneck,
|
||||
channels_per_group=channels_per_group,
|
||||
name=name)
|
||||
|
||||
def __call__(self, inputs, is_training, test_local_stats=False):
|
||||
out = inputs
|
||||
out = self.initial_conv(out)
|
||||
if not self.resnet_v2:
|
||||
out = self.initial_batchnorm(out, is_training, test_local_stats)
|
||||
out = jax.nn.relu(out)
|
||||
|
||||
out = hk.max_pool(out,
|
||||
window_shape=(1, 3, 3, 1),
|
||||
strides=(1, 2, 2, 1),
|
||||
padding='SAME')
|
||||
|
||||
for block_group in self.block_groups:
|
||||
out = block_group(out, is_training, test_local_stats)
|
||||
|
||||
if self.resnet_v2:
|
||||
out = self.final_batchnorm(out, is_training, test_local_stats)
|
||||
out = jax.nn.relu(out)
|
||||
out = jnp.mean(out, axis=[1, 2])
|
||||
return out
|
||||
|
||||
|
||||
class TinyResNet(ResNetTorso):
|
||||
"""Tiny resnet for local runs and tests."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes = None,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
num_classes: The number of classes to classify the inputs into.
|
||||
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
|
||||
passed on to the `BatchNorm` layers.
|
||||
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
|
||||
to False.
|
||||
width_multiplier: An integer multiplying the number of channels per group.
|
||||
name: Name of the module.
|
||||
"""
|
||||
super().__init__(blocks_per_group=(1, 1, 1, 1),
|
||||
channels_per_group=(8, 8, 8, 8),
|
||||
num_classes=num_classes,
|
||||
bn_config=bn_config,
|
||||
resnet_v2=resnet_v2,
|
||||
bottleneck=False,
|
||||
width_multiplier=width_multiplier,
|
||||
name=name)
|
||||
|
||||
|
||||
class ResNet18(ResNetTorso):
|
||||
"""ResNet18."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes = None,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
num_classes: The number of classes to classify the inputs into.
|
||||
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
|
||||
passed on to the `BatchNorm` layers.
|
||||
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
|
||||
to False.
|
||||
width_multiplier: An integer multiplying the number of channels per group.
|
||||
name: Name of the module.
|
||||
"""
|
||||
super().__init__(blocks_per_group=(2, 2, 2, 2),
|
||||
num_classes=num_classes,
|
||||
bn_config=bn_config,
|
||||
resnet_v2=resnet_v2,
|
||||
bottleneck=False,
|
||||
channels_per_group=(64, 128, 256, 512),
|
||||
width_multiplier=width_multiplier,
|
||||
name=name)
|
||||
|
||||
|
||||
class ResNet34(ResNetTorso):
|
||||
"""ResNet34."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
num_classes: The number of classes to classify the inputs into.
|
||||
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
|
||||
passed on to the `BatchNorm` layers.
|
||||
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
|
||||
to False.
|
||||
width_multiplier: An integer multiplying the number of channels per group.
|
||||
name: Name of the module.
|
||||
"""
|
||||
super().__init__(blocks_per_group=(3, 4, 6, 3),
|
||||
num_classes=num_classes,
|
||||
bn_config=bn_config,
|
||||
resnet_v2=resnet_v2,
|
||||
bottleneck=False,
|
||||
channels_per_group=(64, 128, 256, 512),
|
||||
width_multiplier=width_multiplier,
|
||||
name=name)
|
||||
|
||||
|
||||
class ResNet50(ResNetTorso):
|
||||
"""ResNet50."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes = None,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
num_classes: The number of classes to classify the inputs into.
|
||||
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
|
||||
passed on to the `BatchNorm` layers.
|
||||
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
|
||||
to False.
|
||||
width_multiplier: An integer multiplying the number of channels per group.
|
||||
name: Name of the module.
|
||||
"""
|
||||
super().__init__(blocks_per_group=(3, 4, 6, 3),
|
||||
num_classes=num_classes,
|
||||
bn_config=bn_config,
|
||||
resnet_v2=resnet_v2,
|
||||
bottleneck=True,
|
||||
width_multiplier=width_multiplier,
|
||||
name=name)
|
||||
|
||||
|
||||
class ResNet101(ResNetTorso):
|
||||
"""ResNet101."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
num_classes: The number of classes to classify the inputs into.
|
||||
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
|
||||
passed on to the `BatchNorm` layers.
|
||||
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
|
||||
to False.
|
||||
width_multiplier: An integer multiplying the number of channels per group.
|
||||
name: Name of the module.
|
||||
"""
|
||||
super().__init__(blocks_per_group=(3, 4, 23, 3),
|
||||
num_classes=num_classes,
|
||||
bn_config=bn_config,
|
||||
resnet_v2=resnet_v2,
|
||||
bottleneck=True,
|
||||
width_multiplier=width_multiplier,
|
||||
name=name)
|
||||
|
||||
|
||||
class ResNet152(ResNetTorso):
|
||||
"""ResNet152."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
num_classes: The number of classes to classify the inputs into.
|
||||
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
|
||||
passed on to the `BatchNorm` layers.
|
||||
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
|
||||
to False.
|
||||
width_multiplier: An integer multiplying the number of channels per group.
|
||||
name: Name of the module.
|
||||
"""
|
||||
super().__init__(blocks_per_group=(3, 8, 36, 3),
|
||||
num_classes=num_classes,
|
||||
bn_config=bn_config,
|
||||
resnet_v2=resnet_v2,
|
||||
bottleneck=True,
|
||||
width_multiplier=width_multiplier,
|
||||
name=name)
|
||||
|
||||
|
||||
class ResNet200(ResNetTorso):
|
||||
"""ResNet200."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
num_classes: The number of classes to classify the inputs into.
|
||||
bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
|
||||
passed on to the `BatchNorm` layers.
|
||||
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
|
||||
to False.
|
||||
width_multiplier: An integer multiplying the number of channels per group.
|
||||
name: Name of the module.
|
||||
"""
|
||||
super().__init__(blocks_per_group=(3, 24, 36, 3),
|
||||
num_classes=num_classes,
|
||||
bn_config=bn_config,
|
||||
resnet_v2=resnet_v2,
|
||||
bottleneck=True,
|
||||
width_multiplier=width_multiplier,
|
||||
name=name)
|
||||
@@ -0,0 +1,188 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Implementation of LARS Optimizer with optax."""
|
||||
|
||||
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import tree as nest
|
||||
|
||||
# A filter function takes a path and a value as input and outputs True for
|
||||
# variable to apply update and False not to apply the update
|
||||
FilterFn = Callable[[Tuple[Any], jnp.ndarray], jnp.ndarray]
|
||||
|
||||
|
||||
def exclude_bias_and_norm(path, val):
|
||||
"""Filter to exclude biaises and normalizations weights."""
|
||||
del val
|
||||
if path[-1] == "b" or "norm" in path[-2]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _partial_update(updates,
|
||||
new_updates,
|
||||
params,
|
||||
filter_fn = None):
|
||||
"""Returns new_update for params which filter_fn is True else updates."""
|
||||
|
||||
if filter_fn is None:
|
||||
return new_updates
|
||||
|
||||
wrapped_filter_fn = lambda x, y: jnp.array(filter_fn(x, y))
|
||||
params_to_filter = nest.map_structure_with_path(wrapped_filter_fn, params)
|
||||
|
||||
def _update_fn(g, t, m):
|
||||
m = m.astype(g.dtype)
|
||||
return g * (1. - m) + t * m
|
||||
|
||||
return jax.tree_multimap(_update_fn, updates, new_updates, params_to_filter)
|
||||
|
||||
|
||||
class ScaleByLarsState(NamedTuple):
|
||||
mu: jnp.ndarray
|
||||
|
||||
|
||||
def scale_by_lars(
|
||||
momentum = 0.9,
|
||||
eta = 0.001,
|
||||
filter_fn = None):
|
||||
"""Rescales updates according to the LARS algorithm.
|
||||
|
||||
Does not include weight decay.
|
||||
References:
|
||||
[You et al, 2017](https://arxiv.org/abs/1708.03888)
|
||||
|
||||
Args:
|
||||
momentum: momentum coeficient.
|
||||
eta: LARS coefficient.
|
||||
filter_fn: an optional filter function.
|
||||
|
||||
Returns:
|
||||
An (init_fn, update_fn) tuple.
|
||||
"""
|
||||
|
||||
def init_fn(params):
|
||||
mu = jax.tree_multimap(jnp.zeros_like, params) # momentum
|
||||
return ScaleByLarsState(mu=mu)
|
||||
|
||||
def update_fn(updates, state,
|
||||
params):
|
||||
|
||||
def lars_adaptation(
|
||||
update,
|
||||
param,
|
||||
):
|
||||
param_norm = jnp.linalg.norm(param)
|
||||
update_norm = jnp.linalg.norm(update)
|
||||
return update * jnp.where(
|
||||
param_norm > 0.,
|
||||
jnp.where(update_norm > 0,
|
||||
(eta * param_norm / update_norm), 1.0), 1.0)
|
||||
|
||||
adapted_updates = jax.tree_multimap(lars_adaptation, updates, params)
|
||||
adapted_updates = _partial_update(updates, adapted_updates, params,
|
||||
filter_fn)
|
||||
mu = jax.tree_multimap(lambda g, t: momentum * g + t,
|
||||
state.mu, adapted_updates)
|
||||
return mu, ScaleByLarsState(mu=mu)
|
||||
|
||||
return optax.GradientTransformation(init_fn, update_fn)
|
||||
|
||||
|
||||
class AddWeightDecayState(NamedTuple):
|
||||
"""Stateless transformation."""
|
||||
|
||||
|
||||
def add_weight_decay(
|
||||
weight_decay,
|
||||
filter_fn = None):
|
||||
"""Adds a weight decay to the update.
|
||||
|
||||
Args:
|
||||
weight_decay: weight_decay coeficient.
|
||||
filter_fn: an optional filter function.
|
||||
|
||||
Returns:
|
||||
An (init_fn, update_fn) tuple.
|
||||
"""
|
||||
|
||||
def init_fn(_):
|
||||
return AddWeightDecayState()
|
||||
|
||||
def update_fn(
|
||||
updates,
|
||||
state,
|
||||
params,
|
||||
):
|
||||
new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
|
||||
params)
|
||||
new_updates = _partial_update(updates, new_updates, params, filter_fn)
|
||||
return new_updates, state
|
||||
|
||||
return optax.GradientTransformation(init_fn, update_fn)
|
||||
|
||||
|
||||
LarsState = List # Type for the lars optimizer
|
||||
|
||||
|
||||
def lars(
|
||||
learning_rate,
|
||||
weight_decay = 0.,
|
||||
momentum = 0.9,
|
||||
eta = 0.001,
|
||||
weight_decay_filter = None,
|
||||
lars_adaptation_filter = None,
|
||||
):
|
||||
"""Creates lars optimizer with weight decay.
|
||||
|
||||
References:
|
||||
[You et al, 2017](https://arxiv.org/abs/1708.03888)
|
||||
|
||||
Args:
|
||||
learning_rate: learning rate coefficient.
|
||||
weight_decay: weight decay coefficient.
|
||||
momentum: momentum coefficient.
|
||||
eta: LARS coefficient.
|
||||
weight_decay_filter: optional filter function to only apply the weight
|
||||
decay on a subset of parameters. The filter function takes as input the
|
||||
parameter path (as a tuple) and its associated update, and return a True
|
||||
for params to apply the weight decay and False for params to not apply
|
||||
the weight decay. When weight_decay_filter is set to None, the weight
|
||||
decay is not applied to the bias, i.e. when the variable name is 'b', and
|
||||
the weight decay is not applied to nornalization params, i.e. the
|
||||
panultimate path contains 'norm'.
|
||||
lars_adaptation_filter: similar to weight decay filter but for lars
|
||||
adaptation
|
||||
|
||||
Returns:
|
||||
An optax.GradientTransformation, i.e. a (init_fn, update_fn) tuple.
|
||||
"""
|
||||
|
||||
if weight_decay_filter is None:
|
||||
weight_decay_filter = lambda *_: True
|
||||
if lars_adaptation_filter is None:
|
||||
lars_adaptation_filter = lambda *_: True
|
||||
|
||||
return optax.chain(
|
||||
add_weight_decay(
|
||||
weight_decay=weight_decay, filter_fn=weight_decay_filter),
|
||||
scale_by_lars(
|
||||
momentum=momentum, eta=eta, filter_fn=lars_adaptation_filter),
|
||||
optax.scale(-learning_rate),
|
||||
)
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Learning rate schedules."""
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def target_ema(global_step,
|
||||
base_ema,
|
||||
max_steps):
|
||||
decay = _cosine_decay(global_step, max_steps, 1.)
|
||||
return 1. - (1. - base_ema) * decay
|
||||
|
||||
|
||||
def learning_schedule(global_step,
|
||||
batch_size,
|
||||
base_learning_rate,
|
||||
total_steps,
|
||||
warmup_steps):
|
||||
"""Cosine learning rate scheduler."""
|
||||
# Compute LR & Scaled LR
|
||||
scaled_lr = base_learning_rate * batch_size / 256.
|
||||
learning_rate = (
|
||||
global_step.astype(jnp.float32) / int(warmup_steps) *
|
||||
scaled_lr if warmup_steps > 0 else scaled_lr)
|
||||
|
||||
# Cosine schedule after warmup.
|
||||
return jnp.where(
|
||||
global_step < warmup_steps, learning_rate,
|
||||
_cosine_decay(global_step - warmup_steps, total_steps - warmup_steps,
|
||||
scaled_lr))
|
||||
|
||||
|
||||
def _cosine_decay(global_step,
|
||||
max_steps,
|
||||
initial_value):
|
||||
"""Simple implementation of cosine decay from TF1."""
|
||||
global_step = jnp.minimum(global_step, max_steps)
|
||||
cosine_decay_value = 0.5 * (1 + jnp.cos(jnp.pi * global_step / max_steps))
|
||||
decayed_learning_rate = initial_value * cosine_decay_value
|
||||
return decayed_learning_rate
|
||||
Reference in New Issue
Block a user