Example training pipeline for perceiver.

PiperOrigin-RevId: 387673066
This commit is contained in:
David Ding
2021-07-29 22:45:43 +01:00
committed by Diego de Las Casas
parent 8c431d40ea
commit 12971d9a42
7 changed files with 1995 additions and 5 deletions
+50 -3
View File
@@ -43,21 +43,68 @@ First, install dependencies following these instructions:
4. Install other dependencies: `pip install -f requirements.txt` 4. Install other dependencies: `pip install -f requirements.txt`
After install dependencies, you can open the notebooks in the `colabs` directory After install dependencies, you can open the notebooks in the `colabs` directory
using Jupyter or Colab. using Jupyter or Colab, and you can run our example training script.
Our colabs and training script assume that you are running from the
`deepmind_research` directory.
### Colabs ### Colabs
We provide the following colabs: We provide the following colabs:
* colabs/masked_language_modelling.ipynb: Colab for running a pre-trained
Perceiver masked-language model (Section 4.1 in [2]).
* colabs/optical_flow.ipynb: Colab for running a pre-trained optical flow * colabs/optical_flow.ipynb: Colab for running a pre-trained optical flow
Perceiver model and visualizing the output flow (Section 4.2 in [2]). Perceiver model and visualizing the output flow (Section 4.2 in [2]).
* colabs/video_autoencoding.ipynb: Colab for running a pre-trained * colabs/video_autoencoding.ipynb: Colab for running a pre-trained
video autoencoding Perceiver model and visualizing video reconstructions video autoencoding Perceiver model and visualizing video reconstructions
(Section 4.3 in [2]). (Section 4.3 in [2]).
### Training scripts
We also provide an example training script to train a Perceiver IO model for
ImageNet classification.
The provided hyperparameters are the settings used to train Perceiver IO
with 2D Fourier position encodings, as described in
section 4.5 and supplemental section I.1 of the paper [2].
To run the script locally and train a miniature Perceiver model,
use the `./launch_local.sh` script: `perceiver/train/launch_local.sh`.
The script would need to be adapted to run on a distributed training setup
in order to train a full-scale model.
## Attributions and Disclaimers
The file `perceiver/train/autoaugment.py` originates from the `tensorflow/tpu`
repository (https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/efficientnet/autoaugment.py),
copyright (c) The Tensorflow Authors.
Sintel data is provided by and available from Sintel.org (https://durian.blender.org/),
copyright (c) Blender Foundation/www.sintel.org.
Imagenet data is provided by and available from https://image-net.org/
(for researchers and educators who wish to use the images for
non-commercial research and/or educational purposes,
see https://image-net.org/about.php for details about access,
conditions and terms).
Video content may include clips provided as part of the THUMOS Challenge datasets,
which may be accessed at http://crcv.ucf.edu/THUMOS14/download.html,
copyrights held by the creators.
All data and parameters included with Perceiver are made available
under the terms of the CC BY 4.0 license,
available at https://creativecommons.org/licenses/by/4.0/legalcode.
This is not an officially supported Google product.
## References ## References
[1] Andrew Jaegle, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals, [1] Andrew Jaegle, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals,
Joao Carreira. João Carreira.
*Perceiver: General Perception with Iterative Attention*. ICML 2021. *Perceiver: General Perception with Iterative Attention*. ICML 2021.
https://arxiv.org/abs/2103.03206
[2] TODO: Add citation after paper is published on ArXiv. [2] Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch,
Catalin Ionescu, David Ding, Skanda Koppula, Andrew Brock, Evan Shelhamer,
Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals,
João Carreira.
*Perceiver IO: A General Architecture for Structured Inputs & Outputs*.
arXiv, 2021.
+6 -2
View File
@@ -5,12 +5,16 @@ einops==0.3.0
flatbuffers==2.0 flatbuffers==2.0
imageio==2.9.0 imageio==2.9.0
immutabledict==2.0.0 immutabledict==2.0.0
jax==0.2.16 jaxline==0.0.3
jaxlib==0.1.68+cuda111
numpy==1.21.0 numpy==1.21.0
opencv-python==4.5.2.54 opencv-python==4.5.2.54
opt-einsum==3.3.0 opt-einsum==3.3.0
optax==0.0.9
Pillow==8.3.1 Pillow==8.3.1
scipy==1.7.0 scipy==1.7.0
six==1.16.0 six==1.16.0
tabulate==0.8.9 tabulate==0.8.9
tensorflow==2.5.0
tensorflow-addons==0.13.0
tensorflow-datasets==4.3.0
tensorflow-probability==0.13.0
File diff suppressed because it is too large Load Diff
+423
View File
@@ -0,0 +1,423 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ImageNet dataset with pre-processing and augmentation.
Deng, et al CVPR 2009 - ImageNet: A large-scale hierarchical image database.
https://image-net.org/
"""
import enum
from typing import Any, Generator, Mapping, Optional, Sequence, Text, Tuple
import jax
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
from perceiver.train import autoaugment
Batch = Mapping[Text, np.ndarray]
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
AUTOTUNE = tf.data.experimental.AUTOTUNE
INPUT_DIM = 224 # The number of pixels in the image resize.
class Split(enum.Enum):
"""ImageNet dataset split."""
TRAIN = 1
TRAIN_AND_VALID = 2
VALID = 3
TEST = 4
@classmethod
def from_string(cls, name: Text) -> 'Split':
return {'TRAIN': Split.TRAIN, 'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
'VALID': Split.VALID, 'VALIDATION': Split.VALID,
'TEST': Split.TEST}[name.upper()]
@property
def num_examples(self):
return {Split.TRAIN_AND_VALID: 1281167, Split.TRAIN: 1271167,
Split.VALID: 10000, Split.TEST: 50000}[self]
def load(
split: Split,
*,
is_training: bool,
# batch_dims should be:
# [device_count, per_device_batch_size] or [total_batch_size]
batch_dims: Sequence[int],
augmentation_settings: Mapping[str, Any],
# The shape to which images are resized.
im_dim: int = INPUT_DIM,
threadpool_size: int = 48,
max_intra_op_parallelism: int = 1,
) -> Generator[Batch, None, None]:
"""Loads the given split of the dataset."""
start, end = _shard(split, jax.host_id(), jax.host_count())
im_size = (im_dim, im_dim)
total_batch_size = np.prod(batch_dims)
tfds_split = tfds.core.ReadInstruction(_to_tfds_split(split),
from_=start, to=end, unit='abs')
ds = tfds.load('imagenet2012:5.*.*', split=tfds_split,
decoders={'image': tfds.decode.SkipDecoding()})
options = tf.data.Options()
options.experimental_threading.private_threadpool_size = threadpool_size
options.experimental_threading.max_intra_op_parallelism = (
max_intra_op_parallelism)
options.experimental_optimization.map_parallelization = True
if is_training:
options.experimental_deterministic = False
ds = ds.with_options(options)
if is_training:
if jax.host_count() > 1:
# Only cache if we are reading a subset of the dataset.
ds = ds.cache()
ds = ds.repeat()
ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0)
else:
if split.num_examples % total_batch_size != 0:
raise ValueError(f'Test/valid must be divisible by {total_batch_size}')
def crop_augment_preprocess(example):
image, _ = _preprocess_image(
example['image'], is_training, im_size, augmentation_settings)
label = tf.cast(example['label'], tf.int32)
out = {'images': image, 'labels': label}
if is_training:
if augmentation_settings['cutmix']:
out['mask'] = cutmix_padding(*im_size)
out['cutmix_ratio'] = tf.reduce_mean(out['mask'])
if augmentation_settings['mixup_alpha'] is not None:
beta = tfp.distributions.Beta(
augmentation_settings['mixup_alpha'],
augmentation_settings['mixup_alpha'])
out['mixup_ratio'] = beta.sample()
return out
ds = ds.map(crop_augment_preprocess, num_parallel_calls=AUTOTUNE)
# Mixup/cutmix by temporarily batching (using the per-device batch size):
use_cutmix = augmentation_settings['cutmix']
use_mixup = augmentation_settings['mixup_alpha'] is not None
if is_training and (use_cutmix or use_mixup):
inner_batch_size = batch_dims[-1]
# Apply mixup, cutmix, or mixup + cutmix on batched data.
# We use data from 2 batches to produce 1 mixed batch.
ds = ds.batch(inner_batch_size * 2)
if not use_cutmix and use_mixup:
ds = ds.map(my_mixup, num_parallel_calls=AUTOTUNE)
elif use_cutmix and not use_mixup:
ds = ds.map(my_cutmix, num_parallel_calls=AUTOTUNE)
elif use_cutmix and use_mixup:
ds = ds.map(my_mixup_cutmix, num_parallel_calls=AUTOTUNE)
# Unbatch for further processing.
ds = ds.unbatch()
for batch_size in reversed(batch_dims):
ds = ds.batch(batch_size)
ds = ds.prefetch(AUTOTUNE)
yield from tfds.as_numpy(ds)
# cutmix_padding, my_cutmix, my_mixup, and my_mixup_cutmix taken from:
# https://github.com/deepmind/deepmind-research/blob/master/nfnets/dataset.py
def cutmix_padding(h, w):
"""Returns image mask for CutMix.
Taken from (https://github.com/google/edward2/blob/master/experimental
/marginalization_mixup/data_utils.py#L367)
Args:
h: image height.
w: image width.
"""
r_x = tf.random.uniform([], 0, w, tf.int32)
r_y = tf.random.uniform([], 0, h, tf.int32)
# Beta dist in paper, but they used Beta(1,1) which is just uniform.
image1_proportion = tf.random.uniform([])
patch_length_ratio = tf.math.sqrt(1 - image1_proportion)
r_w = tf.cast(patch_length_ratio * tf.cast(w, tf.float32), tf.int32)
r_h = tf.cast(patch_length_ratio * tf.cast(h, tf.float32), tf.int32)
bbx1 = tf.clip_by_value(tf.cast(r_x - r_w // 2, tf.int32), 0, w)
bby1 = tf.clip_by_value(tf.cast(r_y - r_h // 2, tf.int32), 0, h)
bbx2 = tf.clip_by_value(tf.cast(r_x + r_w // 2, tf.int32), 0, w)
bby2 = tf.clip_by_value(tf.cast(r_y + r_h // 2, tf.int32), 0, h)
# Create the binary mask.
pad_left = bbx1
pad_top = bby1
pad_right = tf.maximum(w - bbx2, 0)
pad_bottom = tf.maximum(h - bby2, 0)
r_h = bby2 - bby1
r_w = bbx2 - bbx1
mask = tf.pad(
tf.ones((r_h, r_w)),
paddings=[[pad_top, pad_bottom], [pad_left, pad_right]],
mode='CONSTANT',
constant_values=0)
mask.set_shape((h, w))
return mask[..., None] # Add channel dim.
def my_cutmix(batch):
"""Apply CutMix: https://arxiv.org/abs/1905.04899."""
batch = dict(**batch)
bs = tf.shape(batch['images'])[0] // 2
mask = batch['mask'][:bs]
images = (mask * batch['images'][:bs] + (1.0 - mask) * batch['images'][bs:])
mix_labels = batch['labels'][bs:]
labels = batch['labels'][:bs]
ratio = batch['cutmix_ratio'][:bs]
return {'images': images, 'labels': labels,
'mix_labels': mix_labels, 'ratio': ratio}
def my_mixup(batch):
"""Apply mixup: https://arxiv.org/abs/1710.09412."""
batch = dict(**batch)
bs = tf.shape(batch['images'])[0] // 2
ratio = batch['mixup_ratio'][:bs, None, None, None]
images = (ratio * batch['images'][:bs] + (1.0 - ratio) * batch['images'][bs:])
mix_labels = batch['labels'][bs:]
labels = batch['labels'][:bs]
ratio = ratio[..., 0, 0, 0] # Unsqueeze
return {'images': images, 'labels': labels,
'mix_labels': mix_labels, 'ratio': ratio}
def my_mixup_cutmix(batch):
"""Apply mixup to half the batch, and cutmix to the other."""
batch = dict(**batch)
bs = tf.shape(batch['images'])[0] // 4
mixup_ratio = batch['mixup_ratio'][:bs, None, None, None]
mixup_images = (mixup_ratio * batch['images'][:bs]
+ (1.0 - mixup_ratio) * batch['images'][bs:2*bs])
mixup_labels = batch['labels'][:bs]
mixup_mix_labels = batch['labels'][bs:2*bs]
cutmix_mask = batch['mask'][2*bs:3*bs]
cutmix_images = (cutmix_mask * batch['images'][2*bs:3*bs]
+ (1.0 - cutmix_mask) * batch['images'][-bs:])
cutmix_labels = batch['labels'][2*bs:3*bs]
cutmix_mix_labels = batch['labels'][-bs:]
cutmix_ratio = batch['cutmix_ratio'][2*bs : 3*bs]
return {'images': tf.concat([mixup_images, cutmix_images], axis=0),
'labels': tf.concat([mixup_labels, cutmix_labels], axis=0),
'mix_labels': tf.concat([mixup_mix_labels, cutmix_mix_labels], 0),
'ratio': tf.concat([mixup_ratio[..., 0, 0, 0], cutmix_ratio], axis=0)}
def _to_tfds_split(split: Split) -> tfds.Split:
"""Returns the TFDS split appropriately sharded."""
# NOTE: Imagenet did not release labels for the test split used in the
# competition, so it has been typical at DeepMind to consider the VALID
# split the TEST split and to reserve 10k images from TRAIN for VALID.
if split in (
Split.TRAIN, Split.TRAIN_AND_VALID, Split.VALID):
return tfds.Split.TRAIN
else:
assert split == Split.TEST
return tfds.Split.VALIDATION
def _shard(
split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]:
"""Returns [start, end) for the given shard index."""
assert shard_index < num_shards
arange = np.arange(split.num_examples)
shard_range = np.array_split(arange, num_shards)[shard_index]
start, end = shard_range[0], (shard_range[-1] + 1)
if split == Split.TRAIN:
# Note that our TRAIN=TFDS_TRAIN[10000:] and VALID=TFDS_TRAIN[:10000].
offset = Split.VALID.num_examples
start += offset
end += offset
return start, end
def _preprocess_image(
image_bytes: tf.Tensor,
is_training: bool,
image_size: Sequence[int],
augmentation_settings: Mapping[str, Any],
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Returns processed and resized images."""
# Get the image crop.
if is_training:
image, im_shape = _decode_and_random_crop(image_bytes)
image = tf.image.random_flip_left_right(image)
else:
image, im_shape = _decode_and_center_crop(image_bytes)
assert image.dtype == tf.uint8
# Optionally apply RandAugment: https://arxiv.org/abs/1909.13719
if is_training:
if augmentation_settings['randaugment'] is not None:
# Input and output images are dtype uint8.
image = autoaugment.distort_image_with_randaugment(
image,
num_layers=augmentation_settings['randaugment']['num_layers'],
magnitude=augmentation_settings['randaugment']['magnitude'])
# Resize and normalize the image crop.
# NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without
# clamping overshoots. This means values returned will be outside the range
# [0.0, 255.0] (e.g. we have observed outputs in the range [-51.1, 336.6]).
image = tf.image.resize(
image, image_size, tf.image.ResizeMethod.BICUBIC)
image = _normalize_image(image)
return image, im_shape
def _normalize_image(image: tf.Tensor) -> tf.Tensor:
"""Normalize the image to zero mean and unit variance."""
image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype)
return image
def _distorted_bounding_box_crop(
image_bytes: tf.Tensor,
*,
jpeg_shape: tf.Tensor,
bbox: tf.Tensor,
min_object_covered: float,
aspect_ratio_range: Tuple[float, float],
area_range: Tuple[float, float],
max_attempts: int,
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Generates cropped_image using one of the bboxes randomly distorted."""
bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
jpeg_shape,
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
max_attempts=max_attempts,
use_image_if_no_bounding_boxes=True)
# Crop the image to the specified bounding box.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
crop_window = [offset_y, offset_x, target_height, target_width]
if image_bytes.dtype == tf.dtypes.string:
image = tf.image.decode_and_crop_jpeg(image_bytes,
tf.stack(crop_window),
channels=3)
else:
image = tf.image.crop_to_bounding_box(image_bytes, *crop_window)
im_shape = tf.stack([target_height, target_width])
return image, im_shape
def _decode_whole_image(image_bytes: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
image = tf.io.decode_jpeg(image_bytes, channels=3)
im_shape = tf.io.extract_jpeg_shape(image_bytes, output_type=tf.int32)
return image, im_shape
def _decode_and_random_crop(
image_bytes: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Make a random crop of INPUT_DIM."""
if image_bytes.dtype == tf.dtypes.string:
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
else:
jpeg_shape = tf.shape(image_bytes)
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
image, im_shape = _distorted_bounding_box_crop(
image_bytes,
jpeg_shape=jpeg_shape,
bbox=bbox,
min_object_covered=0.1,
aspect_ratio_range=(3 / 4, 4 / 3),
area_range=(0.08, 1.0),
max_attempts=10)
if tf.reduce_all(tf.equal(jpeg_shape, tf.shape(image))):
# If the random crop failed fall back to center crop.
image, im_shape = _decode_and_center_crop(image_bytes, jpeg_shape)
return image, im_shape
def _center_crop(image, crop_dim):
"""Center crops an image to a target dimension."""
image_height = image.shape[0]
image_width = image.shape[1]
offset_height = ((image_height - crop_dim) + 1) // 2
offset_width = ((image_width - crop_dim) + 1) // 2
return tf.image.crop_to_bounding_box(
image, offset_height, offset_width, crop_dim, crop_dim)
def _decode_and_center_crop(
image_bytes: tf.Tensor,
jpeg_shape: Optional[tf.Tensor] = None,
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Crops to center of image with padding then scales."""
if jpeg_shape is None:
if image_bytes.dtype == tf.dtypes.string:
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
else:
jpeg_shape = tf.shape(image_bytes)
image_height = jpeg_shape[0]
image_width = jpeg_shape[1]
padded_center_crop_size = tf.cast(
((INPUT_DIM / (INPUT_DIM + 32)) *
tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
crop_window = [offset_height, offset_width,
padded_center_crop_size, padded_center_crop_size]
if image_bytes.dtype == tf.dtypes.string:
image = tf.image.decode_and_crop_jpeg(image_bytes,
tf.stack(crop_window),
channels=3)
else:
image = tf.image.crop_to_bounding_box(image_bytes, *crop_window)
im_shape = tf.stack([padded_center_crop_size, padded_center_crop_size])
return image, im_shape
File diff suppressed because it is too large Load Diff
+20
View File
@@ -0,0 +1,20 @@
#!/bin/bash
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Command for running training script.
PYTHONPATH=.::$PYTHONPATH python perceiver/train/experiment.py \
--config=perceiver/train/experiment.py --logtostderr
+242
View File
@@ -0,0 +1,242 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities."""
from typing import Callable, List, Mapping, NamedTuple, Optional, Tuple, Union
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
Batch = Mapping[str, np.ndarray]
OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState]
Scalars = Mapping[str, jnp.ndarray]
ParamsOrState = Union[hk.Params, hk.State]
NORM_NAMES = ['layer_norm', 'batchnorm']
# any_in and topk_correct taken from
# https://github.com/deepmind/deepmind-research/blob/master/nfnets/utils.py
@jax.vmap
def any_in(prediction, target):
"""For each row in a and b, checks if any element of a is in b."""
return jnp.isin(prediction, target)
def topk_correct(logits, labels, mask=None, prefix='', topk=(1, 5)):
"""Calculate top-k error for multiple k values."""
metrics = {}
argsorted_logits = jnp.argsort(logits)
for k in topk:
pred_labels = argsorted_logits[..., -k:]
# Get the number of examples where the label is in the top-k predictions
correct = any_in(pred_labels, labels).any(axis=-1).astype(jnp.float32)
if mask is not None:
correct *= mask
metrics[f'{prefix}top_{k}_acc'] = correct
return metrics
def softmax_cross_entropy(logits, labels):
"""Computes softmax cross entropy given logits and one-hot class labels.
Args:
logits: Logit output values.
labels: Ground truth one-hot-encoded labels.
Returns:
Loss value with the same shape as `labels`;
"""
return jnp.asarray(optax.softmax_cross_entropy(logits, labels))
def _get_batch_scaled_lr(total_batch_size, lr, scale_by_batch=True):
# This is the linear scaling rule in Section 5.1 of
# https://arxiv.org/pdf/1706.02677.pdf.
if scale_by_batch:
lr = (lr * total_batch_size) / 256
return lr
def get_learning_rate_schedule(
total_batch_size, steps_per_epoch, total_steps, optimizer_config):
"""Build the learning rate schedule function."""
base_lr = _get_batch_scaled_lr(total_batch_size, optimizer_config.base_lr,
optimizer_config.scale_by_batch)
schedule_type = optimizer_config.schedule_type
if schedule_type == 'steps':
boundaries = optimizer_config.step_decay_kwargs.decay_boundaries
boundaries.sort()
decay_rate = optimizer_config.step_decay_kwargs.decay_rate
boundaries_and_scales = {
int(boundary * total_steps): decay_rate for boundary in boundaries}
schedule_fn = optax.piecewise_constant_schedule(
init_value=base_lr, boundaries_and_scales=boundaries_and_scales)
elif schedule_type == 'cosine':
warmup_steps = (optimizer_config.cosine_decay_kwargs.warmup_epochs
* steps_per_epoch)
# Batch scale the other lr values as well:
init_value = _get_batch_scaled_lr(
total_batch_size,
optimizer_config.cosine_decay_kwargs.init_value,
optimizer_config.scale_by_batch)
end_value = _get_batch_scaled_lr(
total_batch_size,
optimizer_config.cosine_decay_kwargs.end_value,
optimizer_config.scale_by_batch)
schedule_fn = optax.warmup_cosine_decay_schedule(
init_value=init_value,
peak_value=base_lr,
warmup_steps=warmup_steps,
decay_steps=total_steps,
end_value=end_value)
elif schedule_type == 'constant_cosine':
# Convert end_value to alpha, used by cosine_decay_schedule.
alpha = optimizer_config.constant_cosine_decay_kwargs.end_value / base_lr
# Number of steps spent in constant phase.
constant_steps = int(
optimizer_config.constant_cosine_decay_kwargs.constant_fraction
* total_steps)
decay_steps = total_steps - constant_steps
constant_phase = optax.constant_schedule(value=base_lr)
decay_phase = optax.cosine_decay_schedule(
init_value=base_lr,
decay_steps=decay_steps,
alpha=alpha)
schedule_fn = optax.join_schedules(
schedules=[constant_phase, decay_phase],
boundaries=[constant_steps])
else:
raise ValueError(f'Unknown learning rate schedule: {schedule_type}')
return schedule_fn
def _weight_decay_exclude(
exclude_names: Optional[List[str]] = None
) -> Callable[[str, str, jnp.ndarray], bool]:
"""Logic for deciding which parameters to include for weight decay..
Args:
exclude_names: an optional list of names to include for weight_decay. ['w']
by default.
Returns:
A predicate that returns True for params that need to be excluded from
weight_decay.
"""
# By default weight_decay the weights but not the biases.
if not exclude_names:
exclude_names = ['b']
def exclude(module_name: str, name: str, value: jnp.array):
del value
# Do not weight decay the parameters of normalization blocks.
if any([norm_name in module_name for norm_name in NORM_NAMES]):
return True
else:
return name in exclude_names
return exclude
class AddWeightDecayState(NamedTuple):
"""Stateless transformation."""
def add_weight_decay(
weight_decay: float,
exclude_names: Optional[List[str]] = None) -> optax.GradientTransformation:
"""Add parameter scaled by `weight_decay` to the `updates`.
Same as optax.add_decayed_weights but can exclude parameters by name.
Args:
weight_decay: weight_decay coefficient.
exclude_names: an optional list of names to exclude for weight_decay. ['b']
by default.
Returns:
An (init_fn, update_fn) tuple.
"""
def init_fn(_):
return AddWeightDecayState()
def update_fn(updates, state, params):
exclude = _weight_decay_exclude(exclude_names=exclude_names)
u_ex, u_in = hk.data_structures.partition(exclude, updates)
_, p_in = hk.data_structures.partition(exclude, params)
u_in = jax.tree_multimap(lambda g, p: g + weight_decay * p, u_in, p_in)
updates = hk.data_structures.merge(u_ex, u_in)
return updates, state
return optax.GradientTransformation(init_fn, update_fn)
def make_optimizer(optimizer_config, lr_schedule):
"""Construct the optax optimizer with given LR schedule."""
if (optimizer_config.get('decay_pos_embs') is None or
optimizer_config.decay_pos_embs):
# Decay learned position embeddings by default.
weight_decay_exclude_names = ['b']
else:
weight_decay_exclude_names = ['pos_embs', 'b']
optax_chain = []
if optimizer_config.max_norm > 0:
optax_chain.append(
optax.clip_by_global_norm(optimizer_config.max_norm))
if optimizer_config.optimizer == 'adam':
# See: https://arxiv.org/abs/1412.6980
optax_chain.extend([
optax.scale_by_adam(**optimizer_config.adam_kwargs),
add_weight_decay(
optimizer_config.weight_decay,
exclude_names=weight_decay_exclude_names)
])
elif optimizer_config.optimizer == 'lamb':
# See: https://arxiv.org/abs/1904.00962
optax_chain.extend([
optax.scale_by_adam(**optimizer_config.lamb_kwargs),
add_weight_decay(
optimizer_config.weight_decay,
exclude_names=weight_decay_exclude_names),
optax.scale_by_trust_ratio()
])
else:
raise ValueError(f'Undefined optimizer {optimizer_config.optimizer}')
# Scale by the (negative) learning rate.
optax_chain.extend([
optax.scale_by_schedule(lr_schedule),
optax.scale(-1),
])
return optax.chain(*optax_chain)