mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-01 21:56:38 +08:00
Example training pipeline for perceiver.
PiperOrigin-RevId: 387673066
This commit is contained in:
committed by
Diego de Las Casas
parent
8c431d40ea
commit
12971d9a42
+50
-3
@@ -43,21 +43,68 @@ First, install dependencies following these instructions:
|
|||||||
4. Install other dependencies: `pip install -f requirements.txt`
|
4. Install other dependencies: `pip install -f requirements.txt`
|
||||||
|
|
||||||
After install dependencies, you can open the notebooks in the `colabs` directory
|
After install dependencies, you can open the notebooks in the `colabs` directory
|
||||||
using Jupyter or Colab.
|
using Jupyter or Colab, and you can run our example training script.
|
||||||
|
Our colabs and training script assume that you are running from the
|
||||||
|
`deepmind_research` directory.
|
||||||
|
|
||||||
### Colabs
|
### Colabs
|
||||||
We provide the following colabs:
|
We provide the following colabs:
|
||||||
|
|
||||||
|
* colabs/masked_language_modelling.ipynb: Colab for running a pre-trained
|
||||||
|
Perceiver masked-language model (Section 4.1 in [2]).
|
||||||
* colabs/optical_flow.ipynb: Colab for running a pre-trained optical flow
|
* colabs/optical_flow.ipynb: Colab for running a pre-trained optical flow
|
||||||
Perceiver model and visualizing the output flow (Section 4.2 in [2]).
|
Perceiver model and visualizing the output flow (Section 4.2 in [2]).
|
||||||
* colabs/video_autoencoding.ipynb: Colab for running a pre-trained
|
* colabs/video_autoencoding.ipynb: Colab for running a pre-trained
|
||||||
video autoencoding Perceiver model and visualizing video reconstructions
|
video autoencoding Perceiver model and visualizing video reconstructions
|
||||||
(Section 4.3 in [2]).
|
(Section 4.3 in [2]).
|
||||||
|
|
||||||
|
### Training scripts
|
||||||
|
We also provide an example training script to train a Perceiver IO model for
|
||||||
|
ImageNet classification.
|
||||||
|
The provided hyperparameters are the settings used to train Perceiver IO
|
||||||
|
with 2D Fourier position encodings, as described in
|
||||||
|
section 4.5 and supplemental section I.1 of the paper [2].
|
||||||
|
|
||||||
|
To run the script locally and train a miniature Perceiver model,
|
||||||
|
use the `./launch_local.sh` script: `perceiver/train/launch_local.sh`.
|
||||||
|
The script would need to be adapted to run on a distributed training setup
|
||||||
|
in order to train a full-scale model.
|
||||||
|
|
||||||
|
## Attributions and Disclaimers
|
||||||
|
|
||||||
|
The file `perceiver/train/autoaugment.py` originates from the `tensorflow/tpu`
|
||||||
|
repository (https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/efficientnet/autoaugment.py),
|
||||||
|
copyright (c) The Tensorflow Authors.
|
||||||
|
|
||||||
|
Sintel data is provided by and available from Sintel.org (https://durian.blender.org/),
|
||||||
|
copyright (c) Blender Foundation/www.sintel.org.
|
||||||
|
|
||||||
|
Imagenet data is provided by and available from https://image-net.org/
|
||||||
|
(for researchers and educators who wish to use the images for
|
||||||
|
non-commercial research and/or educational purposes,
|
||||||
|
see https://image-net.org/about.php for details about access,
|
||||||
|
conditions and terms).
|
||||||
|
|
||||||
|
Video content may include clips provided as part of the THUMOS Challenge datasets,
|
||||||
|
which may be accessed at http://crcv.ucf.edu/THUMOS14/download.html,
|
||||||
|
copyrights held by the creators.
|
||||||
|
|
||||||
|
All data and parameters included with Perceiver are made available
|
||||||
|
under the terms of the CC BY 4.0 license,
|
||||||
|
available at https://creativecommons.org/licenses/by/4.0/legalcode.
|
||||||
|
|
||||||
|
This is not an officially supported Google product.
|
||||||
|
|
||||||
## References
|
## References
|
||||||
|
|
||||||
[1] Andrew Jaegle, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals,
|
[1] Andrew Jaegle, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals,
|
||||||
Joao Carreira.
|
João Carreira.
|
||||||
*Perceiver: General Perception with Iterative Attention*. ICML 2021.
|
*Perceiver: General Perception with Iterative Attention*. ICML 2021.
|
||||||
|
https://arxiv.org/abs/2103.03206
|
||||||
|
|
||||||
[2] TODO: Add citation after paper is published on ArXiv.
|
[2] Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch,
|
||||||
|
Catalin Ionescu, David Ding, Skanda Koppula, Andrew Brock, Evan Shelhamer,
|
||||||
|
Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals,
|
||||||
|
João Carreira.
|
||||||
|
*Perceiver IO: A General Architecture for Structured Inputs & Outputs*.
|
||||||
|
arXiv, 2021.
|
||||||
|
|||||||
@@ -5,12 +5,16 @@ einops==0.3.0
|
|||||||
flatbuffers==2.0
|
flatbuffers==2.0
|
||||||
imageio==2.9.0
|
imageio==2.9.0
|
||||||
immutabledict==2.0.0
|
immutabledict==2.0.0
|
||||||
jax==0.2.16
|
jaxline==0.0.3
|
||||||
jaxlib==0.1.68+cuda111
|
|
||||||
numpy==1.21.0
|
numpy==1.21.0
|
||||||
opencv-python==4.5.2.54
|
opencv-python==4.5.2.54
|
||||||
opt-einsum==3.3.0
|
opt-einsum==3.3.0
|
||||||
|
optax==0.0.9
|
||||||
Pillow==8.3.1
|
Pillow==8.3.1
|
||||||
scipy==1.7.0
|
scipy==1.7.0
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
tabulate==0.8.9
|
tabulate==0.8.9
|
||||||
|
tensorflow==2.5.0
|
||||||
|
tensorflow-addons==0.13.0
|
||||||
|
tensorflow-datasets==4.3.0
|
||||||
|
tensorflow-probability==0.13.0
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,423 @@
|
|||||||
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""ImageNet dataset with pre-processing and augmentation.
|
||||||
|
|
||||||
|
Deng, et al CVPR 2009 - ImageNet: A large-scale hierarchical image database.
|
||||||
|
https://image-net.org/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
from typing import Any, Generator, Mapping, Optional, Sequence, Text, Tuple
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v2 as tf
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
|
import tensorflow_probability as tfp
|
||||||
|
|
||||||
|
from perceiver.train import autoaugment
|
||||||
|
|
||||||
|
|
||||||
|
Batch = Mapping[Text, np.ndarray]
|
||||||
|
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
|
||||||
|
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
|
||||||
|
AUTOTUNE = tf.data.experimental.AUTOTUNE
|
||||||
|
|
||||||
|
INPUT_DIM = 224 # The number of pixels in the image resize.
|
||||||
|
|
||||||
|
|
||||||
|
class Split(enum.Enum):
|
||||||
|
"""ImageNet dataset split."""
|
||||||
|
TRAIN = 1
|
||||||
|
TRAIN_AND_VALID = 2
|
||||||
|
VALID = 3
|
||||||
|
TEST = 4
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, name: Text) -> 'Split':
|
||||||
|
return {'TRAIN': Split.TRAIN, 'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
|
||||||
|
'VALID': Split.VALID, 'VALIDATION': Split.VALID,
|
||||||
|
'TEST': Split.TEST}[name.upper()]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_examples(self):
|
||||||
|
return {Split.TRAIN_AND_VALID: 1281167, Split.TRAIN: 1271167,
|
||||||
|
Split.VALID: 10000, Split.TEST: 50000}[self]
|
||||||
|
|
||||||
|
|
||||||
|
def load(
|
||||||
|
split: Split,
|
||||||
|
*,
|
||||||
|
is_training: bool,
|
||||||
|
# batch_dims should be:
|
||||||
|
# [device_count, per_device_batch_size] or [total_batch_size]
|
||||||
|
batch_dims: Sequence[int],
|
||||||
|
augmentation_settings: Mapping[str, Any],
|
||||||
|
# The shape to which images are resized.
|
||||||
|
im_dim: int = INPUT_DIM,
|
||||||
|
threadpool_size: int = 48,
|
||||||
|
max_intra_op_parallelism: int = 1,
|
||||||
|
) -> Generator[Batch, None, None]:
|
||||||
|
"""Loads the given split of the dataset."""
|
||||||
|
start, end = _shard(split, jax.host_id(), jax.host_count())
|
||||||
|
|
||||||
|
im_size = (im_dim, im_dim)
|
||||||
|
|
||||||
|
total_batch_size = np.prod(batch_dims)
|
||||||
|
|
||||||
|
tfds_split = tfds.core.ReadInstruction(_to_tfds_split(split),
|
||||||
|
from_=start, to=end, unit='abs')
|
||||||
|
|
||||||
|
ds = tfds.load('imagenet2012:5.*.*', split=tfds_split,
|
||||||
|
decoders={'image': tfds.decode.SkipDecoding()})
|
||||||
|
|
||||||
|
options = tf.data.Options()
|
||||||
|
options.experimental_threading.private_threadpool_size = threadpool_size
|
||||||
|
options.experimental_threading.max_intra_op_parallelism = (
|
||||||
|
max_intra_op_parallelism)
|
||||||
|
options.experimental_optimization.map_parallelization = True
|
||||||
|
if is_training:
|
||||||
|
options.experimental_deterministic = False
|
||||||
|
ds = ds.with_options(options)
|
||||||
|
|
||||||
|
if is_training:
|
||||||
|
if jax.host_count() > 1:
|
||||||
|
# Only cache if we are reading a subset of the dataset.
|
||||||
|
ds = ds.cache()
|
||||||
|
ds = ds.repeat()
|
||||||
|
ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if split.num_examples % total_batch_size != 0:
|
||||||
|
raise ValueError(f'Test/valid must be divisible by {total_batch_size}')
|
||||||
|
|
||||||
|
def crop_augment_preprocess(example):
|
||||||
|
image, _ = _preprocess_image(
|
||||||
|
example['image'], is_training, im_size, augmentation_settings)
|
||||||
|
|
||||||
|
label = tf.cast(example['label'], tf.int32)
|
||||||
|
|
||||||
|
out = {'images': image, 'labels': label}
|
||||||
|
|
||||||
|
if is_training:
|
||||||
|
if augmentation_settings['cutmix']:
|
||||||
|
out['mask'] = cutmix_padding(*im_size)
|
||||||
|
out['cutmix_ratio'] = tf.reduce_mean(out['mask'])
|
||||||
|
if augmentation_settings['mixup_alpha'] is not None:
|
||||||
|
beta = tfp.distributions.Beta(
|
||||||
|
augmentation_settings['mixup_alpha'],
|
||||||
|
augmentation_settings['mixup_alpha'])
|
||||||
|
out['mixup_ratio'] = beta.sample()
|
||||||
|
return out
|
||||||
|
|
||||||
|
ds = ds.map(crop_augment_preprocess, num_parallel_calls=AUTOTUNE)
|
||||||
|
|
||||||
|
# Mixup/cutmix by temporarily batching (using the per-device batch size):
|
||||||
|
use_cutmix = augmentation_settings['cutmix']
|
||||||
|
use_mixup = augmentation_settings['mixup_alpha'] is not None
|
||||||
|
if is_training and (use_cutmix or use_mixup):
|
||||||
|
inner_batch_size = batch_dims[-1]
|
||||||
|
# Apply mixup, cutmix, or mixup + cutmix on batched data.
|
||||||
|
# We use data from 2 batches to produce 1 mixed batch.
|
||||||
|
ds = ds.batch(inner_batch_size * 2)
|
||||||
|
if not use_cutmix and use_mixup:
|
||||||
|
ds = ds.map(my_mixup, num_parallel_calls=AUTOTUNE)
|
||||||
|
elif use_cutmix and not use_mixup:
|
||||||
|
ds = ds.map(my_cutmix, num_parallel_calls=AUTOTUNE)
|
||||||
|
elif use_cutmix and use_mixup:
|
||||||
|
ds = ds.map(my_mixup_cutmix, num_parallel_calls=AUTOTUNE)
|
||||||
|
|
||||||
|
# Unbatch for further processing.
|
||||||
|
ds = ds.unbatch()
|
||||||
|
|
||||||
|
for batch_size in reversed(batch_dims):
|
||||||
|
ds = ds.batch(batch_size)
|
||||||
|
|
||||||
|
ds = ds.prefetch(AUTOTUNE)
|
||||||
|
|
||||||
|
yield from tfds.as_numpy(ds)
|
||||||
|
|
||||||
|
|
||||||
|
# cutmix_padding, my_cutmix, my_mixup, and my_mixup_cutmix taken from:
|
||||||
|
# https://github.com/deepmind/deepmind-research/blob/master/nfnets/dataset.py
|
||||||
|
def cutmix_padding(h, w):
|
||||||
|
"""Returns image mask for CutMix.
|
||||||
|
|
||||||
|
Taken from (https://github.com/google/edward2/blob/master/experimental
|
||||||
|
/marginalization_mixup/data_utils.py#L367)
|
||||||
|
Args:
|
||||||
|
h: image height.
|
||||||
|
w: image width.
|
||||||
|
"""
|
||||||
|
r_x = tf.random.uniform([], 0, w, tf.int32)
|
||||||
|
r_y = tf.random.uniform([], 0, h, tf.int32)
|
||||||
|
|
||||||
|
# Beta dist in paper, but they used Beta(1,1) which is just uniform.
|
||||||
|
image1_proportion = tf.random.uniform([])
|
||||||
|
patch_length_ratio = tf.math.sqrt(1 - image1_proportion)
|
||||||
|
r_w = tf.cast(patch_length_ratio * tf.cast(w, tf.float32), tf.int32)
|
||||||
|
r_h = tf.cast(patch_length_ratio * tf.cast(h, tf.float32), tf.int32)
|
||||||
|
bbx1 = tf.clip_by_value(tf.cast(r_x - r_w // 2, tf.int32), 0, w)
|
||||||
|
bby1 = tf.clip_by_value(tf.cast(r_y - r_h // 2, tf.int32), 0, h)
|
||||||
|
bbx2 = tf.clip_by_value(tf.cast(r_x + r_w // 2, tf.int32), 0, w)
|
||||||
|
bby2 = tf.clip_by_value(tf.cast(r_y + r_h // 2, tf.int32), 0, h)
|
||||||
|
|
||||||
|
# Create the binary mask.
|
||||||
|
pad_left = bbx1
|
||||||
|
pad_top = bby1
|
||||||
|
pad_right = tf.maximum(w - bbx2, 0)
|
||||||
|
pad_bottom = tf.maximum(h - bby2, 0)
|
||||||
|
r_h = bby2 - bby1
|
||||||
|
r_w = bbx2 - bbx1
|
||||||
|
|
||||||
|
mask = tf.pad(
|
||||||
|
tf.ones((r_h, r_w)),
|
||||||
|
paddings=[[pad_top, pad_bottom], [pad_left, pad_right]],
|
||||||
|
mode='CONSTANT',
|
||||||
|
constant_values=0)
|
||||||
|
mask.set_shape((h, w))
|
||||||
|
return mask[..., None] # Add channel dim.
|
||||||
|
|
||||||
|
|
||||||
|
def my_cutmix(batch):
|
||||||
|
"""Apply CutMix: https://arxiv.org/abs/1905.04899."""
|
||||||
|
batch = dict(**batch)
|
||||||
|
bs = tf.shape(batch['images'])[0] // 2
|
||||||
|
mask = batch['mask'][:bs]
|
||||||
|
images = (mask * batch['images'][:bs] + (1.0 - mask) * batch['images'][bs:])
|
||||||
|
mix_labels = batch['labels'][bs:]
|
||||||
|
labels = batch['labels'][:bs]
|
||||||
|
ratio = batch['cutmix_ratio'][:bs]
|
||||||
|
return {'images': images, 'labels': labels,
|
||||||
|
'mix_labels': mix_labels, 'ratio': ratio}
|
||||||
|
|
||||||
|
|
||||||
|
def my_mixup(batch):
|
||||||
|
"""Apply mixup: https://arxiv.org/abs/1710.09412."""
|
||||||
|
batch = dict(**batch)
|
||||||
|
bs = tf.shape(batch['images'])[0] // 2
|
||||||
|
ratio = batch['mixup_ratio'][:bs, None, None, None]
|
||||||
|
images = (ratio * batch['images'][:bs] + (1.0 - ratio) * batch['images'][bs:])
|
||||||
|
mix_labels = batch['labels'][bs:]
|
||||||
|
labels = batch['labels'][:bs]
|
||||||
|
ratio = ratio[..., 0, 0, 0] # Unsqueeze
|
||||||
|
return {'images': images, 'labels': labels,
|
||||||
|
'mix_labels': mix_labels, 'ratio': ratio}
|
||||||
|
|
||||||
|
|
||||||
|
def my_mixup_cutmix(batch):
|
||||||
|
"""Apply mixup to half the batch, and cutmix to the other."""
|
||||||
|
batch = dict(**batch)
|
||||||
|
bs = tf.shape(batch['images'])[0] // 4
|
||||||
|
mixup_ratio = batch['mixup_ratio'][:bs, None, None, None]
|
||||||
|
mixup_images = (mixup_ratio * batch['images'][:bs]
|
||||||
|
+ (1.0 - mixup_ratio) * batch['images'][bs:2*bs])
|
||||||
|
mixup_labels = batch['labels'][:bs]
|
||||||
|
mixup_mix_labels = batch['labels'][bs:2*bs]
|
||||||
|
|
||||||
|
cutmix_mask = batch['mask'][2*bs:3*bs]
|
||||||
|
cutmix_images = (cutmix_mask * batch['images'][2*bs:3*bs]
|
||||||
|
+ (1.0 - cutmix_mask) * batch['images'][-bs:])
|
||||||
|
cutmix_labels = batch['labels'][2*bs:3*bs]
|
||||||
|
cutmix_mix_labels = batch['labels'][-bs:]
|
||||||
|
cutmix_ratio = batch['cutmix_ratio'][2*bs : 3*bs]
|
||||||
|
|
||||||
|
return {'images': tf.concat([mixup_images, cutmix_images], axis=0),
|
||||||
|
'labels': tf.concat([mixup_labels, cutmix_labels], axis=0),
|
||||||
|
'mix_labels': tf.concat([mixup_mix_labels, cutmix_mix_labels], 0),
|
||||||
|
'ratio': tf.concat([mixup_ratio[..., 0, 0, 0], cutmix_ratio], axis=0)}
|
||||||
|
|
||||||
|
|
||||||
|
def _to_tfds_split(split: Split) -> tfds.Split:
|
||||||
|
"""Returns the TFDS split appropriately sharded."""
|
||||||
|
# NOTE: Imagenet did not release labels for the test split used in the
|
||||||
|
# competition, so it has been typical at DeepMind to consider the VALID
|
||||||
|
# split the TEST split and to reserve 10k images from TRAIN for VALID.
|
||||||
|
if split in (
|
||||||
|
Split.TRAIN, Split.TRAIN_AND_VALID, Split.VALID):
|
||||||
|
return tfds.Split.TRAIN
|
||||||
|
else:
|
||||||
|
assert split == Split.TEST
|
||||||
|
return tfds.Split.VALIDATION
|
||||||
|
|
||||||
|
|
||||||
|
def _shard(
|
||||||
|
split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]:
|
||||||
|
"""Returns [start, end) for the given shard index."""
|
||||||
|
assert shard_index < num_shards
|
||||||
|
arange = np.arange(split.num_examples)
|
||||||
|
shard_range = np.array_split(arange, num_shards)[shard_index]
|
||||||
|
start, end = shard_range[0], (shard_range[-1] + 1)
|
||||||
|
if split == Split.TRAIN:
|
||||||
|
# Note that our TRAIN=TFDS_TRAIN[10000:] and VALID=TFDS_TRAIN[:10000].
|
||||||
|
offset = Split.VALID.num_examples
|
||||||
|
start += offset
|
||||||
|
end += offset
|
||||||
|
return start, end
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_image(
|
||||||
|
image_bytes: tf.Tensor,
|
||||||
|
is_training: bool,
|
||||||
|
image_size: Sequence[int],
|
||||||
|
augmentation_settings: Mapping[str, Any],
|
||||||
|
) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||||
|
"""Returns processed and resized images."""
|
||||||
|
|
||||||
|
# Get the image crop.
|
||||||
|
if is_training:
|
||||||
|
image, im_shape = _decode_and_random_crop(image_bytes)
|
||||||
|
image = tf.image.random_flip_left_right(image)
|
||||||
|
else:
|
||||||
|
image, im_shape = _decode_and_center_crop(image_bytes)
|
||||||
|
assert image.dtype == tf.uint8
|
||||||
|
|
||||||
|
# Optionally apply RandAugment: https://arxiv.org/abs/1909.13719
|
||||||
|
if is_training:
|
||||||
|
if augmentation_settings['randaugment'] is not None:
|
||||||
|
# Input and output images are dtype uint8.
|
||||||
|
image = autoaugment.distort_image_with_randaugment(
|
||||||
|
image,
|
||||||
|
num_layers=augmentation_settings['randaugment']['num_layers'],
|
||||||
|
magnitude=augmentation_settings['randaugment']['magnitude'])
|
||||||
|
|
||||||
|
# Resize and normalize the image crop.
|
||||||
|
# NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without
|
||||||
|
# clamping overshoots. This means values returned will be outside the range
|
||||||
|
# [0.0, 255.0] (e.g. we have observed outputs in the range [-51.1, 336.6]).
|
||||||
|
image = tf.image.resize(
|
||||||
|
image, image_size, tf.image.ResizeMethod.BICUBIC)
|
||||||
|
image = _normalize_image(image)
|
||||||
|
|
||||||
|
return image, im_shape
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_image(image: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""Normalize the image to zero mean and unit variance."""
|
||||||
|
image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype)
|
||||||
|
image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def _distorted_bounding_box_crop(
|
||||||
|
image_bytes: tf.Tensor,
|
||||||
|
*,
|
||||||
|
jpeg_shape: tf.Tensor,
|
||||||
|
bbox: tf.Tensor,
|
||||||
|
min_object_covered: float,
|
||||||
|
aspect_ratio_range: Tuple[float, float],
|
||||||
|
area_range: Tuple[float, float],
|
||||||
|
max_attempts: int,
|
||||||
|
) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||||
|
"""Generates cropped_image using one of the bboxes randomly distorted."""
|
||||||
|
bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
|
||||||
|
jpeg_shape,
|
||||||
|
bounding_boxes=bbox,
|
||||||
|
min_object_covered=min_object_covered,
|
||||||
|
aspect_ratio_range=aspect_ratio_range,
|
||||||
|
area_range=area_range,
|
||||||
|
max_attempts=max_attempts,
|
||||||
|
use_image_if_no_bounding_boxes=True)
|
||||||
|
|
||||||
|
# Crop the image to the specified bounding box.
|
||||||
|
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
||||||
|
target_height, target_width, _ = tf.unstack(bbox_size)
|
||||||
|
crop_window = [offset_y, offset_x, target_height, target_width]
|
||||||
|
|
||||||
|
if image_bytes.dtype == tf.dtypes.string:
|
||||||
|
image = tf.image.decode_and_crop_jpeg(image_bytes,
|
||||||
|
tf.stack(crop_window),
|
||||||
|
channels=3)
|
||||||
|
else:
|
||||||
|
image = tf.image.crop_to_bounding_box(image_bytes, *crop_window)
|
||||||
|
|
||||||
|
im_shape = tf.stack([target_height, target_width])
|
||||||
|
return image, im_shape
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_whole_image(image_bytes: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||||
|
image = tf.io.decode_jpeg(image_bytes, channels=3)
|
||||||
|
im_shape = tf.io.extract_jpeg_shape(image_bytes, output_type=tf.int32)
|
||||||
|
return image, im_shape
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_and_random_crop(
|
||||||
|
image_bytes: tf.Tensor
|
||||||
|
) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||||
|
"""Make a random crop of INPUT_DIM."""
|
||||||
|
|
||||||
|
if image_bytes.dtype == tf.dtypes.string:
|
||||||
|
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
|
||||||
|
else:
|
||||||
|
jpeg_shape = tf.shape(image_bytes)
|
||||||
|
|
||||||
|
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
|
||||||
|
image, im_shape = _distorted_bounding_box_crop(
|
||||||
|
image_bytes,
|
||||||
|
jpeg_shape=jpeg_shape,
|
||||||
|
bbox=bbox,
|
||||||
|
min_object_covered=0.1,
|
||||||
|
aspect_ratio_range=(3 / 4, 4 / 3),
|
||||||
|
area_range=(0.08, 1.0),
|
||||||
|
max_attempts=10)
|
||||||
|
|
||||||
|
if tf.reduce_all(tf.equal(jpeg_shape, tf.shape(image))):
|
||||||
|
# If the random crop failed fall back to center crop.
|
||||||
|
image, im_shape = _decode_and_center_crop(image_bytes, jpeg_shape)
|
||||||
|
return image, im_shape
|
||||||
|
|
||||||
|
|
||||||
|
def _center_crop(image, crop_dim):
|
||||||
|
"""Center crops an image to a target dimension."""
|
||||||
|
image_height = image.shape[0]
|
||||||
|
image_width = image.shape[1]
|
||||||
|
offset_height = ((image_height - crop_dim) + 1) // 2
|
||||||
|
offset_width = ((image_width - crop_dim) + 1) // 2
|
||||||
|
return tf.image.crop_to_bounding_box(
|
||||||
|
image, offset_height, offset_width, crop_dim, crop_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_and_center_crop(
|
||||||
|
image_bytes: tf.Tensor,
|
||||||
|
jpeg_shape: Optional[tf.Tensor] = None,
|
||||||
|
) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||||
|
"""Crops to center of image with padding then scales."""
|
||||||
|
if jpeg_shape is None:
|
||||||
|
if image_bytes.dtype == tf.dtypes.string:
|
||||||
|
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
|
||||||
|
else:
|
||||||
|
jpeg_shape = tf.shape(image_bytes)
|
||||||
|
|
||||||
|
image_height = jpeg_shape[0]
|
||||||
|
image_width = jpeg_shape[1]
|
||||||
|
|
||||||
|
padded_center_crop_size = tf.cast(
|
||||||
|
((INPUT_DIM / (INPUT_DIM + 32)) *
|
||||||
|
tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
|
||||||
|
|
||||||
|
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
|
||||||
|
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
|
||||||
|
crop_window = [offset_height, offset_width,
|
||||||
|
padded_center_crop_size, padded_center_crop_size]
|
||||||
|
|
||||||
|
if image_bytes.dtype == tf.dtypes.string:
|
||||||
|
image = tf.image.decode_and_crop_jpeg(image_bytes,
|
||||||
|
tf.stack(crop_window),
|
||||||
|
channels=3)
|
||||||
|
else:
|
||||||
|
image = tf.image.crop_to_bounding_box(image_bytes, *crop_window)
|
||||||
|
|
||||||
|
im_shape = tf.stack([padded_center_crop_size, padded_center_crop_size])
|
||||||
|
return image, im_shape
|
||||||
File diff suppressed because it is too large
Load Diff
Executable
+20
@@ -0,0 +1,20 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Command for running training script.
|
||||||
|
|
||||||
|
PYTHONPATH=.::$PYTHONPATH python perceiver/train/experiment.py \
|
||||||
|
--config=perceiver/train/experiment.py --logtostderr
|
||||||
@@ -0,0 +1,242 @@
|
|||||||
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Utilities."""
|
||||||
|
|
||||||
|
from typing import Callable, List, Mapping, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import haiku as hk
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import optax
|
||||||
|
|
||||||
|
|
||||||
|
Batch = Mapping[str, np.ndarray]
|
||||||
|
OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState]
|
||||||
|
Scalars = Mapping[str, jnp.ndarray]
|
||||||
|
ParamsOrState = Union[hk.Params, hk.State]
|
||||||
|
|
||||||
|
|
||||||
|
NORM_NAMES = ['layer_norm', 'batchnorm']
|
||||||
|
|
||||||
|
|
||||||
|
# any_in and topk_correct taken from
|
||||||
|
# https://github.com/deepmind/deepmind-research/blob/master/nfnets/utils.py
|
||||||
|
@jax.vmap
|
||||||
|
def any_in(prediction, target):
|
||||||
|
"""For each row in a and b, checks if any element of a is in b."""
|
||||||
|
return jnp.isin(prediction, target)
|
||||||
|
|
||||||
|
|
||||||
|
def topk_correct(logits, labels, mask=None, prefix='', topk=(1, 5)):
|
||||||
|
"""Calculate top-k error for multiple k values."""
|
||||||
|
metrics = {}
|
||||||
|
argsorted_logits = jnp.argsort(logits)
|
||||||
|
for k in topk:
|
||||||
|
pred_labels = argsorted_logits[..., -k:]
|
||||||
|
# Get the number of examples where the label is in the top-k predictions
|
||||||
|
correct = any_in(pred_labels, labels).any(axis=-1).astype(jnp.float32)
|
||||||
|
if mask is not None:
|
||||||
|
correct *= mask
|
||||||
|
metrics[f'{prefix}top_{k}_acc'] = correct
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_cross_entropy(logits, labels):
|
||||||
|
"""Computes softmax cross entropy given logits and one-hot class labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Logit output values.
|
||||||
|
labels: Ground truth one-hot-encoded labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss value with the same shape as `labels`;
|
||||||
|
"""
|
||||||
|
return jnp.asarray(optax.softmax_cross_entropy(logits, labels))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_scaled_lr(total_batch_size, lr, scale_by_batch=True):
|
||||||
|
# This is the linear scaling rule in Section 5.1 of
|
||||||
|
# https://arxiv.org/pdf/1706.02677.pdf.
|
||||||
|
|
||||||
|
if scale_by_batch:
|
||||||
|
lr = (lr * total_batch_size) / 256
|
||||||
|
|
||||||
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def get_learning_rate_schedule(
|
||||||
|
total_batch_size, steps_per_epoch, total_steps, optimizer_config):
|
||||||
|
"""Build the learning rate schedule function."""
|
||||||
|
base_lr = _get_batch_scaled_lr(total_batch_size, optimizer_config.base_lr,
|
||||||
|
optimizer_config.scale_by_batch)
|
||||||
|
|
||||||
|
schedule_type = optimizer_config.schedule_type
|
||||||
|
if schedule_type == 'steps':
|
||||||
|
boundaries = optimizer_config.step_decay_kwargs.decay_boundaries
|
||||||
|
boundaries.sort()
|
||||||
|
|
||||||
|
decay_rate = optimizer_config.step_decay_kwargs.decay_rate
|
||||||
|
boundaries_and_scales = {
|
||||||
|
int(boundary * total_steps): decay_rate for boundary in boundaries}
|
||||||
|
schedule_fn = optax.piecewise_constant_schedule(
|
||||||
|
init_value=base_lr, boundaries_and_scales=boundaries_and_scales)
|
||||||
|
elif schedule_type == 'cosine':
|
||||||
|
warmup_steps = (optimizer_config.cosine_decay_kwargs.warmup_epochs
|
||||||
|
* steps_per_epoch)
|
||||||
|
# Batch scale the other lr values as well:
|
||||||
|
init_value = _get_batch_scaled_lr(
|
||||||
|
total_batch_size,
|
||||||
|
optimizer_config.cosine_decay_kwargs.init_value,
|
||||||
|
optimizer_config.scale_by_batch)
|
||||||
|
end_value = _get_batch_scaled_lr(
|
||||||
|
total_batch_size,
|
||||||
|
optimizer_config.cosine_decay_kwargs.end_value,
|
||||||
|
optimizer_config.scale_by_batch)
|
||||||
|
|
||||||
|
schedule_fn = optax.warmup_cosine_decay_schedule(
|
||||||
|
init_value=init_value,
|
||||||
|
peak_value=base_lr,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
decay_steps=total_steps,
|
||||||
|
end_value=end_value)
|
||||||
|
elif schedule_type == 'constant_cosine':
|
||||||
|
# Convert end_value to alpha, used by cosine_decay_schedule.
|
||||||
|
alpha = optimizer_config.constant_cosine_decay_kwargs.end_value / base_lr
|
||||||
|
|
||||||
|
# Number of steps spent in constant phase.
|
||||||
|
constant_steps = int(
|
||||||
|
optimizer_config.constant_cosine_decay_kwargs.constant_fraction
|
||||||
|
* total_steps)
|
||||||
|
decay_steps = total_steps - constant_steps
|
||||||
|
|
||||||
|
constant_phase = optax.constant_schedule(value=base_lr)
|
||||||
|
decay_phase = optax.cosine_decay_schedule(
|
||||||
|
init_value=base_lr,
|
||||||
|
decay_steps=decay_steps,
|
||||||
|
alpha=alpha)
|
||||||
|
schedule_fn = optax.join_schedules(
|
||||||
|
schedules=[constant_phase, decay_phase],
|
||||||
|
boundaries=[constant_steps])
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown learning rate schedule: {schedule_type}')
|
||||||
|
|
||||||
|
return schedule_fn
|
||||||
|
|
||||||
|
|
||||||
|
def _weight_decay_exclude(
|
||||||
|
exclude_names: Optional[List[str]] = None
|
||||||
|
) -> Callable[[str, str, jnp.ndarray], bool]:
|
||||||
|
"""Logic for deciding which parameters to include for weight decay..
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exclude_names: an optional list of names to include for weight_decay. ['w']
|
||||||
|
by default.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A predicate that returns True for params that need to be excluded from
|
||||||
|
weight_decay.
|
||||||
|
"""
|
||||||
|
# By default weight_decay the weights but not the biases.
|
||||||
|
if not exclude_names:
|
||||||
|
exclude_names = ['b']
|
||||||
|
|
||||||
|
def exclude(module_name: str, name: str, value: jnp.array):
|
||||||
|
del value
|
||||||
|
# Do not weight decay the parameters of normalization blocks.
|
||||||
|
if any([norm_name in module_name for norm_name in NORM_NAMES]):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return name in exclude_names
|
||||||
|
|
||||||
|
return exclude
|
||||||
|
|
||||||
|
|
||||||
|
class AddWeightDecayState(NamedTuple):
|
||||||
|
"""Stateless transformation."""
|
||||||
|
|
||||||
|
|
||||||
|
def add_weight_decay(
|
||||||
|
weight_decay: float,
|
||||||
|
exclude_names: Optional[List[str]] = None) -> optax.GradientTransformation:
|
||||||
|
"""Add parameter scaled by `weight_decay` to the `updates`.
|
||||||
|
|
||||||
|
Same as optax.add_decayed_weights but can exclude parameters by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight_decay: weight_decay coefficient.
|
||||||
|
exclude_names: an optional list of names to exclude for weight_decay. ['b']
|
||||||
|
by default.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An (init_fn, update_fn) tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def init_fn(_):
|
||||||
|
return AddWeightDecayState()
|
||||||
|
|
||||||
|
def update_fn(updates, state, params):
|
||||||
|
exclude = _weight_decay_exclude(exclude_names=exclude_names)
|
||||||
|
|
||||||
|
u_ex, u_in = hk.data_structures.partition(exclude, updates)
|
||||||
|
_, p_in = hk.data_structures.partition(exclude, params)
|
||||||
|
u_in = jax.tree_multimap(lambda g, p: g + weight_decay * p, u_in, p_in)
|
||||||
|
updates = hk.data_structures.merge(u_ex, u_in)
|
||||||
|
return updates, state
|
||||||
|
|
||||||
|
return optax.GradientTransformation(init_fn, update_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def make_optimizer(optimizer_config, lr_schedule):
|
||||||
|
"""Construct the optax optimizer with given LR schedule."""
|
||||||
|
if (optimizer_config.get('decay_pos_embs') is None or
|
||||||
|
optimizer_config.decay_pos_embs):
|
||||||
|
# Decay learned position embeddings by default.
|
||||||
|
weight_decay_exclude_names = ['b']
|
||||||
|
else:
|
||||||
|
weight_decay_exclude_names = ['pos_embs', 'b']
|
||||||
|
|
||||||
|
optax_chain = []
|
||||||
|
if optimizer_config.max_norm > 0:
|
||||||
|
optax_chain.append(
|
||||||
|
optax.clip_by_global_norm(optimizer_config.max_norm))
|
||||||
|
|
||||||
|
if optimizer_config.optimizer == 'adam':
|
||||||
|
# See: https://arxiv.org/abs/1412.6980
|
||||||
|
optax_chain.extend([
|
||||||
|
optax.scale_by_adam(**optimizer_config.adam_kwargs),
|
||||||
|
add_weight_decay(
|
||||||
|
optimizer_config.weight_decay,
|
||||||
|
exclude_names=weight_decay_exclude_names)
|
||||||
|
])
|
||||||
|
elif optimizer_config.optimizer == 'lamb':
|
||||||
|
# See: https://arxiv.org/abs/1904.00962
|
||||||
|
optax_chain.extend([
|
||||||
|
optax.scale_by_adam(**optimizer_config.lamb_kwargs),
|
||||||
|
add_weight_decay(
|
||||||
|
optimizer_config.weight_decay,
|
||||||
|
exclude_names=weight_decay_exclude_names),
|
||||||
|
optax.scale_by_trust_ratio()
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Undefined optimizer {optimizer_config.optimizer}')
|
||||||
|
|
||||||
|
# Scale by the (negative) learning rate.
|
||||||
|
optax_chain.extend([
|
||||||
|
optax.scale_by_schedule(lr_schedule),
|
||||||
|
optax.scale(-1),
|
||||||
|
])
|
||||||
|
|
||||||
|
return optax.chain(*optax_chain)
|
||||||
Reference in New Issue
Block a user