diff --git a/perceiver/README.md b/perceiver/README.md index 7123123..1e21473 100644 --- a/perceiver/README.md +++ b/perceiver/README.md @@ -43,21 +43,68 @@ First, install dependencies following these instructions: 4. Install other dependencies: `pip install -f requirements.txt` After install dependencies, you can open the notebooks in the `colabs` directory -using Jupyter or Colab. +using Jupyter or Colab, and you can run our example training script. +Our colabs and training script assume that you are running from the +`deepmind_research` directory. ### Colabs We provide the following colabs: +* colabs/masked_language_modelling.ipynb: Colab for running a pre-trained + Perceiver masked-language model (Section 4.1 in [2]). * colabs/optical_flow.ipynb: Colab for running a pre-trained optical flow Perceiver model and visualizing the output flow (Section 4.2 in [2]). * colabs/video_autoencoding.ipynb: Colab for running a pre-trained video autoencoding Perceiver model and visualizing video reconstructions (Section 4.3 in [2]). +### Training scripts +We also provide an example training script to train a Perceiver IO model for +ImageNet classification. +The provided hyperparameters are the settings used to train Perceiver IO +with 2D Fourier position encodings, as described in +section 4.5 and supplemental section I.1 of the paper [2]. + +To run the script locally and train a miniature Perceiver model, +use the `./launch_local.sh` script: `perceiver/train/launch_local.sh`. +The script would need to be adapted to run on a distributed training setup +in order to train a full-scale model. + +## Attributions and Disclaimers + +The file `perceiver/train/autoaugment.py` originates from the `tensorflow/tpu` +repository (https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/efficientnet/autoaugment.py), +copyright (c) The Tensorflow Authors. + +Sintel data is provided by and available from Sintel.org (https://durian.blender.org/), +copyright (c) Blender Foundation/www.sintel.org. + +Imagenet data is provided by and available from https://image-net.org/ +(for researchers and educators who wish to use the images for +non-commercial research and/or educational purposes, +see https://image-net.org/about.php for details about access, +conditions and terms). + +Video content may include clips provided as part of the THUMOS Challenge datasets, +which may be accessed at http://crcv.ucf.edu/THUMOS14/download.html, +copyrights held by the creators. + +All data and parameters included with Perceiver are made available +under the terms of the CC BY 4.0 license, +available at https://creativecommons.org/licenses/by/4.0/legalcode. + +This is not an officially supported Google product. + ## References [1] Andrew Jaegle, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals, -Joao Carreira. +João Carreira. *Perceiver: General Perception with Iterative Attention*. ICML 2021. +https://arxiv.org/abs/2103.03206 -[2] TODO: Add citation after paper is published on ArXiv. +[2] Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, +Catalin Ionescu, David Ding, Skanda Koppula, Andrew Brock, Evan Shelhamer, +Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, +João Carreira. +*Perceiver IO: A General Architecture for Structured Inputs & Outputs*. +arXiv, 2021. diff --git a/perceiver/requirements.txt b/perceiver/requirements.txt index b98d6ff..813211e 100644 --- a/perceiver/requirements.txt +++ b/perceiver/requirements.txt @@ -5,12 +5,16 @@ einops==0.3.0 flatbuffers==2.0 imageio==2.9.0 immutabledict==2.0.0 -jax==0.2.16 -jaxlib==0.1.68+cuda111 +jaxline==0.0.3 numpy==1.21.0 opencv-python==4.5.2.54 opt-einsum==3.3.0 +optax==0.0.9 Pillow==8.3.1 scipy==1.7.0 six==1.16.0 tabulate==0.8.9 +tensorflow==2.5.0 +tensorflow-addons==0.13.0 +tensorflow-datasets==4.3.0 +tensorflow-probability==0.13.0 diff --git a/perceiver/train/autoaugment.py b/perceiver/train/autoaugment.py new file mode 100644 index 0000000..39b2075 --- /dev/null +++ b/perceiver/train/autoaugment.py @@ -0,0 +1,716 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""AutoAugment and RandAugment policies for enhanced image preprocessing. + +AutoAugment Reference: https://arxiv.org/abs/1805.09501 +RandAugment Reference: https://arxiv.org/abs/1909.13719 +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect +import math +from ml_collections import config_dict +import tensorflow.compat.v1 as tf +from tensorflow_addons import image as contrib_image +# pylint: disable=deprecated-method + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + + +def policy_v0(): + """Autoaugment policy that was used in AutoAugment Paper.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + return policy + + +def policy_vtest(): + """Autoaugment test policy for debugging.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)], + ] + return policy + + +def blend(image1, image2, factor): + """Blend image1 and image2 using 'factor'. + + Factor can be above 0.0. A value of 0.0 means only image1 is used. + A value of 1.0 means only image2 is used. A value between 0.0 and + 1.0 means we linearly interpolate the pixel values between the two + images. A value greater than 1.0 "extrapolates" the difference + between the two pixel values, and we clip the results to values + between 0 and 255. + + Args: + image1: An image Tensor of type uint8. + image2: An image Tensor of type uint8. + factor: A floating point value above 0.0. + + Returns: + A blended image Tensor of type uint8. + """ + if factor == 0.0: + return tf.convert_to_tensor(image1) + if factor == 1.0: + return tf.convert_to_tensor(image2) + + image1 = tf.to_float(image1) + image2 = tf.to_float(image2) + + difference = image2 - image1 + scaled = factor * difference + + # Do addition in float. + temp = tf.to_float(image1) + scaled + + # Interpolate + if factor > 0.0 and factor < 1.0: + # Interpolation means we always stay within 0 and 255. + return tf.cast(temp, tf.uint8) + + # Extrapolate: + # + # We need to clip and then cast. + return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8) + + +def cutout(image, pad_size, replace=0): + """Apply cutout (https://arxiv.org/abs/1708.04552) to image. + + This operation applies a (2*pad_size x 2*pad_size) mask of zeros to + a random location within `img`. The pixel values filled in will be of the + value `replace`. The located where the mask will be applied is randomly + chosen uniformly over the whole image. + + Args: + image: An image Tensor of type uint8. + pad_size: Specifies how big the zero mask that will be generated is that + is applied to the image. The mask will be of size + (2*pad_size x 2*pad_size). + replace: What pixel value to fill in the image in the area that has + the cutout mask applied to it. + + Returns: + An image Tensor that is of type uint8. + """ + image_height = tf.shape(image)[0] + image_width = tf.shape(image)[1] + + # Sample the center location in the image where the zero mask will be applied. + cutout_center_height = tf.random_uniform( + shape=[], minval=0, maxval=image_height, + dtype=tf.int32) + + cutout_center_width = tf.random_uniform( + shape=[], minval=0, maxval=image_width, + dtype=tf.int32) + + lower_pad = tf.maximum(0, cutout_center_height - pad_size) + upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) + left_pad = tf.maximum(0, cutout_center_width - pad_size) + right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) + + cutout_shape = [image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad)] + padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] + mask = tf.pad( + tf.zeros(cutout_shape, dtype=image.dtype), + padding_dims, constant_values=1) + mask = tf.expand_dims(mask, -1) + mask = tf.tile(mask, [1, 1, 3]) + image = tf.where( + tf.equal(mask, 0), + tf.ones_like(image, dtype=image.dtype) * replace, + image) + return image + + +def solarize(image, threshold=128): + # For each pixel in the image, select the pixel + # if the value is less than the threshold. + # Otherwise, subtract the pixel from 255. + return tf.where(image < threshold, image, 255 - image) + + +def solarize_add(image, addition=0, threshold=128): + # For each pixel in the image less than threshold + # we add 'addition' amount to it and then clip the + # pixel value to be between 0 and 255. The value + # of 'addition' is between -128 and 128. + added_image = tf.cast(image, tf.int64) + addition + added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8) + return tf.where(image < threshold, added_image, image) + + +def color(image, factor): + """Equivalent of PIL Color.""" + degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) + return blend(degenerate, image, factor) + + +def contrast(image, factor): + """Equivalent of PIL Contrast.""" + degenerate = tf.image.rgb_to_grayscale(image) + # Cast before calling tf.histogram. + degenerate = tf.cast(degenerate, tf.int32) + + # Compute the grayscale histogram, then compute the mean pixel value, + # and create a constant image size of that value. Use that as the + # blending degenerate target of the original image. + hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) + mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 + degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8)) + return blend(degenerate, image, factor) + + +def brightness(image, factor): + """Equivalent of PIL Brightness.""" + degenerate = tf.zeros_like(image) + return blend(degenerate, image, factor) + + +def posterize(image, bits): + """Equivalent of PIL Posterize.""" + shift = 8 - bits + return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) + + +def rotate(image, degrees, replace): + """Rotates the image by degrees either clockwise or counterclockwise. + + Args: + image: An image Tensor of type uint8. + degrees: Float, a scalar angle in degrees to rotate all images by. If + degrees is positive the image will be rotated clockwise otherwise it will + be rotated counterclockwise. + replace: A one or three value 1D tensor to fill empty pixels caused by + the rotate operation. + + Returns: + The rotated version of image. + """ + # Convert from degrees to radians. + degrees_to_radians = math.pi / 180.0 + radians = degrees * degrees_to_radians + + # In practice, we should randomize the rotation degrees by flipping + # it negatively half the time, but that's done on 'degrees' outside + # of the function. + image = contrib_image.rotate(wrap(image), radians) + return unwrap(image, replace) + + +def translate_x(image, pixels, replace): + """Equivalent of PIL Translate in X dimension.""" + image = contrib_image.translate(wrap(image), [-pixels, 0]) + return unwrap(image, replace) + + +def translate_y(image, pixels, replace): + """Equivalent of PIL Translate in Y dimension.""" + image = contrib_image.translate(wrap(image), [0, -pixels]) + return unwrap(image, replace) + + +def shear_x(image, level, replace): + """Equivalent of PIL Shearing in X dimension.""" + # Shear parallel to x axis is a projective transform + # with a matrix form of: + # [1 level + # 0 1]. + image = contrib_image.transform( + wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) + return unwrap(image, replace) + + +def shear_y(image, level, replace): + """Equivalent of PIL Shearing in Y dimension.""" + # Shear parallel to y axis is a projective transform + # with a matrix form of: + # [1 0 + # level 1]. + image = contrib_image.transform( + wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) + return unwrap(image, replace) + + +def autocontrast(image): + """Implements Autocontrast function from PIL using TF ops. + + Args: + image: A 3D uint8 tensor. + + Returns: + The image after it has had autocontrast applied to it and will be of type + uint8. + """ + + def scale_channel(image): + """Scale the 2D image using the autocontrast rule.""" + # A possibly cheaper version can be done using cumsum/unique_with_counts + # over the histogram values, rather than iterating over the entire image. + # to compute mins and maxes. + lo = tf.to_float(tf.reduce_min(image)) + hi = tf.to_float(tf.reduce_max(image)) + + # Scale the image, making the lowest value 0 and the highest value 255. + def scale_values(im): + scale = 255.0 / (hi - lo) + offset = -lo * scale + im = tf.to_float(im) * scale + offset + im = tf.clip_by_value(im, 0.0, 255.0) + return tf.cast(im, tf.uint8) + + result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image) + return result + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image[:, :, 0]) + s2 = scale_channel(image[:, :, 1]) + s3 = scale_channel(image[:, :, 2]) + image = tf.stack([s1, s2, s3], 2) + return image + + +def sharpness(image, factor): + """Implements Sharpness function from PIL using TF ops.""" + orig_image = image + image = tf.cast(image, tf.float32) + # Make image 4D for conv operation. + image = tf.expand_dims(image, 0) + # SMOOTH PIL Kernel. + kernel = tf.constant( + [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, + shape=[3, 3, 1, 1]) / 13. + # Tile across channel dimension. + kernel = tf.tile(kernel, [1, 1, 3, 1]) + strides = [1, 1, 1, 1] + with tf.device('/cpu:0'): + # Some augmentation that uses depth-wise conv will cause crashing when + # training on GPU. See (b/156242594) for details. + degenerate = tf.nn.depthwise_conv2d( + image, kernel, strides, padding='VALID', rate=[1, 1]) + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) + + # For the borders of the resulting image, fill in the values of the + # original image. + mask = tf.ones_like(degenerate) + padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) + padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) + result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) + + # Blend the final result. + return blend(result, orig_image, factor) + + +def equalize(image): + """Implements Equalize function from PIL using TF ops.""" + def scale_channel(im, c): + """Scale the data in the channel to implement equalize.""" + im = tf.cast(im[:, :, c], tf.int32) + # Compute the histogram of the image channel. + histo = tf.histogram_fixed_width(im, [0, 255], nbins=256) + + # For the purposes of computing the step, filter out the nonzeros. + nonzero = tf.where(tf.not_equal(histo, 0)) + nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1]) + step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255 + + def build_lut(histo, step): + # Compute the cumulative sum, shifting by step // 2 + # and then normalization by step. + lut = (tf.cumsum(histo) + (step // 2)) // step + # Shift lut, prepending with 0. + lut = tf.concat([[0], lut[:-1]], 0) + # Clip the counts to be in range. This is done + # in the C code for image.point. + return tf.clip_by_value(lut, 0, 255) + + # If step is zero, return the original image. Otherwise, build + # lut from the full histogram and step and then index from it. + result = tf.cond(tf.equal(step, 0), + lambda: im, + lambda: tf.gather(build_lut(histo, step), im)) + + return tf.cast(result, tf.uint8) + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image, 0) + s2 = scale_channel(image, 1) + s3 = scale_channel(image, 2) + image = tf.stack([s1, s2, s3], 2) + return image + + +def invert(image): + """Inverts the image pixels.""" + image = tf.convert_to_tensor(image) + return 255 - image + + +def wrap(image): + """Returns 'image' with an extra channel set to all 1s.""" + shape = tf.shape(image) + extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype) + extended = tf.concat([image, extended_channel], 2) + return extended + + +def unwrap(image, replace): + """Unwraps an image produced by wrap. + + Where there is a 0 in the last channel for every spatial position, + the rest of the three channels in that spatial dimension are grayed + (set to 128). Operations like translate and shear on a wrapped + Tensor will leave 0s in empty locations. Some transformations look + at the intensity of values to do preprocessing, and we want these + empty pixels to assume the 'average' value, rather than pure black. + + + Args: + image: A 3D Image Tensor with 4 channels. + replace: A one or three value 1D tensor to fill empty pixels. + + Returns: + image: A 3D image Tensor with 3 channels. + """ + image_shape = tf.shape(image) + # Flatten the spatial dimensions. + flattened_image = tf.reshape(image, [-1, image_shape[2]]) + + # Find all pixels where the last channel is zero. + alpha_channel = flattened_image[:, 3] + + replace = tf.concat([replace, tf.ones([1], image.dtype)], 0) + + # Where they are zero, fill them in with 'replace'. + flattened_image = tf.where( + tf.equal(alpha_channel, 0), + tf.ones_like(flattened_image, dtype=image.dtype) * replace, + flattened_image) + + image = tf.reshape(flattened_image, image_shape) + image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3]) + return image + + +NAME_TO_FUNC = { + 'AutoContrast': autocontrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x, + 'TranslateY': translate_y, + 'Cutout': cutout, +} + + +def _randomly_negate_tensor(tensor): + """With 50% prob turn the tensor negative.""" + should_flip = tf.cast(tf.floor(tf.random_uniform([]) + 0.5), tf.bool) + final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor) + return final_tensor + + +def _rotate_level_to_arg(level): + level = (level/_MAX_LEVEL) * 30. + level = _randomly_negate_tensor(level) + return (level,) + + +def _shrink_level_to_arg(level): + """Converts level to ratio by which we shrink the image content.""" + if level == 0: + return (1.0,) # if level is zero, do not shrink the image + # Maximum shrinking ratio is 2.9. + level = 2. / (_MAX_LEVEL / level) + 0.9 + return (level,) + + +def _enhance_level_to_arg(level): + return ((level/_MAX_LEVEL) * 1.8 + 0.1,) + + +def _shear_level_to_arg(level): + level = (level/_MAX_LEVEL) * 0.3 + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level,) + + +def _translate_level_to_arg(level, translate_const): + level = (level/_MAX_LEVEL) * float(translate_const) + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level,) + + +def level_to_arg(hparams): + return { + 'AutoContrast': lambda level: (), + 'Equalize': lambda level: (), + 'Invert': lambda level: (), + 'Rotate': _rotate_level_to_arg, + 'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4),), + 'Solarize': lambda level: (int((level/_MAX_LEVEL) * 256),), + 'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110),), + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams.cutout_const),), + # pylint:disable=g-long-lambda + 'TranslateX': lambda level: _translate_level_to_arg( + level, hparams.translate_const), + 'TranslateY': lambda level: _translate_level_to_arg( + level, hparams.translate_const), + # pylint:enable=g-long-lambda + } + + +def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams): + """Return the function that corresponds to `name` and update `level` param.""" + func = NAME_TO_FUNC[name] + args = level_to_arg(augmentation_hparams)[name](level) + + # Check to see if prob is passed into function. This is used for operations + # where we alter bboxes independently. + # pytype:disable=wrong-arg-types + if 'prob' in inspect.getargspec(func)[0]: + args = tuple([prob] + list(args)) + # pytype:enable=wrong-arg-types + + # Add in replace arg if it is required for the function that is being called. + # pytype:disable=wrong-arg-types + if 'replace' in inspect.getargspec(func)[0]: + # Make sure replace is the final argument + assert 'replace' == inspect.getargspec(func)[0][-1] + args = tuple(list(args) + [replace_value]) + # pytype:enable=wrong-arg-types + + return (func, prob, args) + + +def _apply_func_with_prob(func, image, args, prob): + """Apply `func` to image w/ `args` as input with probability `prob`.""" + assert isinstance(args, tuple) + + # If prob is a function argument, then this randomness is being handled + # inside the function, so make sure it is always called. + # pytype:disable=wrong-arg-types + if 'prob' in inspect.getargspec(func)[0]: + prob = 1.0 + # pytype:enable=wrong-arg-types + + # Apply the function with probability `prob`. + should_apply_op = tf.cast( + tf.floor(tf.random_uniform([], dtype=tf.float32) + prob), tf.bool) + augmented_image = tf.cond( + should_apply_op, + lambda: func(image, *args), + lambda: image) + return augmented_image + + +def select_and_apply_random_policy(policies, image): + """Select a random policy from `policies` and apply it to `image`.""" + policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32) + # Note that using tf.case instead of tf.conds would result in significantly + # larger graphs and would even break export for some larger policies. + for (i, policy) in enumerate(policies): + image = tf.cond( + tf.equal(i, policy_to_select), + lambda selected_policy=policy: selected_policy(image), + lambda: image) + return image + + +def build_and_apply_nas_policy(policies, image, + augmentation_hparams): + """Build a policy from the given policies passed in and apply to image. + + Args: + policies: list of lists of tuples in the form `(func, prob, level)`, `func` + is a string name of the augmentation function, `prob` is the probability + of applying the `func` operation, `level` is the input argument for + `func`. + image: tf.Tensor that the resulting policy will be applied to. + augmentation_hparams: Hparams associated with the NAS learned policy. + + Returns: + A version of image that now has data augmentation applied to it based on + the `policies` pass into the function. + """ + replace_value = [128, 128, 128] + + # func is the string name of the augmentation function, prob is the + # probability of applying the operation and level is the parameter associated + # with the tf op. + + # tf_policies are functions that take in an image and return an augmented + # image. + tf_policies = [] + for policy in policies: + tf_policy = [] + # Link string name to the correct python function and make sure the correct + # argument is passed into that function. + for policy_info in policy: + policy_info = list(policy_info) + [replace_value, augmentation_hparams] + + tf_policy.append(_parse_policy_info(*policy_info)) + # Now build the tf policy that will apply the augmentation procedue + # on image. + def make_final_policy(tf_policy_): + def final_policy(image_): + for func, prob, args in tf_policy_: + image_ = _apply_func_with_prob( + func, image_, args, prob) + return image_ + return final_policy + tf_policies.append(make_final_policy(tf_policy)) + + augmented_image = select_and_apply_random_policy( + tf_policies, image) + return augmented_image + + +def distort_image_with_autoaugment(image, augmentation_name): + """Applies the AutoAugment policy to `image`. + + AutoAugment is from the paper: https://arxiv.org/abs/1805.09501. + + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + augmentation_name: The name of the AutoAugment policy to use. The available + options are `v0` and `test`. `v0` is the policy used for + all of the results in the paper and was found to achieve the best results + on the COCO dataset. `v1`, `v2` and `v3` are additional good policies + found on the COCO dataset that have slight variation in what operations + were used during the search procedure along with how many operations are + applied in parallel to a single image (2 vs 3). + + Returns: + A tuple containing the augmented versions of `image`. + """ + available_policies = {'v0': policy_v0, + 'test': policy_vtest} + if augmentation_name not in available_policies: + raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name)) + + policy = available_policies[augmentation_name]() + # Hparams that will be used for AutoAugment. + augmentation_hparams = config_dict.ConfigDict(dict( + cutout_const=100, translate_const=250)) + + return build_and_apply_nas_policy(policy, image, augmentation_hparams) + + +def distort_image_with_randaugment(image, num_layers, magnitude): + """Applies the RandAugment policy to `image`. + + RandAugment is from the paper https://arxiv.org/abs/1909.13719, + + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + num_layers: Integer, the number of augmentation transformations to apply + sequentially to an image. Represented as (N) in the paper. Usually best + values will be in the range [1, 3]. + magnitude: Integer, shared magnitude across all augmentation operations. + Represented as (M) in the paper. Usually best values are in the range + [5, 30]. + + Returns: + The augmented version of `image`. + """ + replace_value = [128] * 3 + tf.logging.info('Using RandAug.') + augmentation_hparams = config_dict.ConfigDict(dict( + cutout_const=40, translate_const=100)) + available_ops = [ + 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', + 'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness', + 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'] + + for layer_num in range(num_layers): + op_to_select = tf.random_uniform( + [], maxval=len(available_ops), dtype=tf.int32) + random_magnitude = float(magnitude) + with tf.name_scope('randaug_layer_{}'.format(layer_num)): + for (i, op_name) in enumerate(available_ops): + prob = tf.random_uniform([], minval=0.2, maxval=0.8, dtype=tf.float32) + func, _, args = _parse_policy_info(op_name, prob, random_magnitude, + replace_value, augmentation_hparams) + image = tf.cond( + tf.equal(i, op_to_select), + # pylint:disable=g-long-lambda + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args), + # pylint:enable=g-long-lambda + lambda: image) + return image diff --git a/perceiver/train/dataset.py b/perceiver/train/dataset.py new file mode 100644 index 0000000..489b07c --- /dev/null +++ b/perceiver/train/dataset.py @@ -0,0 +1,423 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ImageNet dataset with pre-processing and augmentation. + +Deng, et al CVPR 2009 - ImageNet: A large-scale hierarchical image database. +https://image-net.org/ +""" + +import enum +from typing import Any, Generator, Mapping, Optional, Sequence, Text, Tuple + +import jax +import numpy as np +import tensorflow.compat.v2 as tf +import tensorflow_datasets as tfds +import tensorflow_probability as tfp + +from perceiver.train import autoaugment + + +Batch = Mapping[Text, np.ndarray] +MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255) +STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255) +AUTOTUNE = tf.data.experimental.AUTOTUNE + +INPUT_DIM = 224 # The number of pixels in the image resize. + + +class Split(enum.Enum): + """ImageNet dataset split.""" + TRAIN = 1 + TRAIN_AND_VALID = 2 + VALID = 3 + TEST = 4 + + @classmethod + def from_string(cls, name: Text) -> 'Split': + return {'TRAIN': Split.TRAIN, 'TRAIN_AND_VALID': Split.TRAIN_AND_VALID, + 'VALID': Split.VALID, 'VALIDATION': Split.VALID, + 'TEST': Split.TEST}[name.upper()] + + @property + def num_examples(self): + return {Split.TRAIN_AND_VALID: 1281167, Split.TRAIN: 1271167, + Split.VALID: 10000, Split.TEST: 50000}[self] + + +def load( + split: Split, + *, + is_training: bool, + # batch_dims should be: + # [device_count, per_device_batch_size] or [total_batch_size] + batch_dims: Sequence[int], + augmentation_settings: Mapping[str, Any], + # The shape to which images are resized. + im_dim: int = INPUT_DIM, + threadpool_size: int = 48, + max_intra_op_parallelism: int = 1, +) -> Generator[Batch, None, None]: + """Loads the given split of the dataset.""" + start, end = _shard(split, jax.host_id(), jax.host_count()) + + im_size = (im_dim, im_dim) + + total_batch_size = np.prod(batch_dims) + + tfds_split = tfds.core.ReadInstruction(_to_tfds_split(split), + from_=start, to=end, unit='abs') + + ds = tfds.load('imagenet2012:5.*.*', split=tfds_split, + decoders={'image': tfds.decode.SkipDecoding()}) + + options = tf.data.Options() + options.experimental_threading.private_threadpool_size = threadpool_size + options.experimental_threading.max_intra_op_parallelism = ( + max_intra_op_parallelism) + options.experimental_optimization.map_parallelization = True + if is_training: + options.experimental_deterministic = False + ds = ds.with_options(options) + + if is_training: + if jax.host_count() > 1: + # Only cache if we are reading a subset of the dataset. + ds = ds.cache() + ds = ds.repeat() + ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0) + + else: + if split.num_examples % total_batch_size != 0: + raise ValueError(f'Test/valid must be divisible by {total_batch_size}') + + def crop_augment_preprocess(example): + image, _ = _preprocess_image( + example['image'], is_training, im_size, augmentation_settings) + + label = tf.cast(example['label'], tf.int32) + + out = {'images': image, 'labels': label} + + if is_training: + if augmentation_settings['cutmix']: + out['mask'] = cutmix_padding(*im_size) + out['cutmix_ratio'] = tf.reduce_mean(out['mask']) + if augmentation_settings['mixup_alpha'] is not None: + beta = tfp.distributions.Beta( + augmentation_settings['mixup_alpha'], + augmentation_settings['mixup_alpha']) + out['mixup_ratio'] = beta.sample() + return out + + ds = ds.map(crop_augment_preprocess, num_parallel_calls=AUTOTUNE) + + # Mixup/cutmix by temporarily batching (using the per-device batch size): + use_cutmix = augmentation_settings['cutmix'] + use_mixup = augmentation_settings['mixup_alpha'] is not None + if is_training and (use_cutmix or use_mixup): + inner_batch_size = batch_dims[-1] + # Apply mixup, cutmix, or mixup + cutmix on batched data. + # We use data from 2 batches to produce 1 mixed batch. + ds = ds.batch(inner_batch_size * 2) + if not use_cutmix and use_mixup: + ds = ds.map(my_mixup, num_parallel_calls=AUTOTUNE) + elif use_cutmix and not use_mixup: + ds = ds.map(my_cutmix, num_parallel_calls=AUTOTUNE) + elif use_cutmix and use_mixup: + ds = ds.map(my_mixup_cutmix, num_parallel_calls=AUTOTUNE) + + # Unbatch for further processing. + ds = ds.unbatch() + + for batch_size in reversed(batch_dims): + ds = ds.batch(batch_size) + + ds = ds.prefetch(AUTOTUNE) + + yield from tfds.as_numpy(ds) + + +# cutmix_padding, my_cutmix, my_mixup, and my_mixup_cutmix taken from: +# https://github.com/deepmind/deepmind-research/blob/master/nfnets/dataset.py +def cutmix_padding(h, w): + """Returns image mask for CutMix. + + Taken from (https://github.com/google/edward2/blob/master/experimental + /marginalization_mixup/data_utils.py#L367) + Args: + h: image height. + w: image width. + """ + r_x = tf.random.uniform([], 0, w, tf.int32) + r_y = tf.random.uniform([], 0, h, tf.int32) + + # Beta dist in paper, but they used Beta(1,1) which is just uniform. + image1_proportion = tf.random.uniform([]) + patch_length_ratio = tf.math.sqrt(1 - image1_proportion) + r_w = tf.cast(patch_length_ratio * tf.cast(w, tf.float32), tf.int32) + r_h = tf.cast(patch_length_ratio * tf.cast(h, tf.float32), tf.int32) + bbx1 = tf.clip_by_value(tf.cast(r_x - r_w // 2, tf.int32), 0, w) + bby1 = tf.clip_by_value(tf.cast(r_y - r_h // 2, tf.int32), 0, h) + bbx2 = tf.clip_by_value(tf.cast(r_x + r_w // 2, tf.int32), 0, w) + bby2 = tf.clip_by_value(tf.cast(r_y + r_h // 2, tf.int32), 0, h) + + # Create the binary mask. + pad_left = bbx1 + pad_top = bby1 + pad_right = tf.maximum(w - bbx2, 0) + pad_bottom = tf.maximum(h - bby2, 0) + r_h = bby2 - bby1 + r_w = bbx2 - bbx1 + + mask = tf.pad( + tf.ones((r_h, r_w)), + paddings=[[pad_top, pad_bottom], [pad_left, pad_right]], + mode='CONSTANT', + constant_values=0) + mask.set_shape((h, w)) + return mask[..., None] # Add channel dim. + + +def my_cutmix(batch): + """Apply CutMix: https://arxiv.org/abs/1905.04899.""" + batch = dict(**batch) + bs = tf.shape(batch['images'])[0] // 2 + mask = batch['mask'][:bs] + images = (mask * batch['images'][:bs] + (1.0 - mask) * batch['images'][bs:]) + mix_labels = batch['labels'][bs:] + labels = batch['labels'][:bs] + ratio = batch['cutmix_ratio'][:bs] + return {'images': images, 'labels': labels, + 'mix_labels': mix_labels, 'ratio': ratio} + + +def my_mixup(batch): + """Apply mixup: https://arxiv.org/abs/1710.09412.""" + batch = dict(**batch) + bs = tf.shape(batch['images'])[0] // 2 + ratio = batch['mixup_ratio'][:bs, None, None, None] + images = (ratio * batch['images'][:bs] + (1.0 - ratio) * batch['images'][bs:]) + mix_labels = batch['labels'][bs:] + labels = batch['labels'][:bs] + ratio = ratio[..., 0, 0, 0] # Unsqueeze + return {'images': images, 'labels': labels, + 'mix_labels': mix_labels, 'ratio': ratio} + + +def my_mixup_cutmix(batch): + """Apply mixup to half the batch, and cutmix to the other.""" + batch = dict(**batch) + bs = tf.shape(batch['images'])[0] // 4 + mixup_ratio = batch['mixup_ratio'][:bs, None, None, None] + mixup_images = (mixup_ratio * batch['images'][:bs] + + (1.0 - mixup_ratio) * batch['images'][bs:2*bs]) + mixup_labels = batch['labels'][:bs] + mixup_mix_labels = batch['labels'][bs:2*bs] + + cutmix_mask = batch['mask'][2*bs:3*bs] + cutmix_images = (cutmix_mask * batch['images'][2*bs:3*bs] + + (1.0 - cutmix_mask) * batch['images'][-bs:]) + cutmix_labels = batch['labels'][2*bs:3*bs] + cutmix_mix_labels = batch['labels'][-bs:] + cutmix_ratio = batch['cutmix_ratio'][2*bs : 3*bs] + + return {'images': tf.concat([mixup_images, cutmix_images], axis=0), + 'labels': tf.concat([mixup_labels, cutmix_labels], axis=0), + 'mix_labels': tf.concat([mixup_mix_labels, cutmix_mix_labels], 0), + 'ratio': tf.concat([mixup_ratio[..., 0, 0, 0], cutmix_ratio], axis=0)} + + +def _to_tfds_split(split: Split) -> tfds.Split: + """Returns the TFDS split appropriately sharded.""" + # NOTE: Imagenet did not release labels for the test split used in the + # competition, so it has been typical at DeepMind to consider the VALID + # split the TEST split and to reserve 10k images from TRAIN for VALID. + if split in ( + Split.TRAIN, Split.TRAIN_AND_VALID, Split.VALID): + return tfds.Split.TRAIN + else: + assert split == Split.TEST + return tfds.Split.VALIDATION + + +def _shard( + split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]: + """Returns [start, end) for the given shard index.""" + assert shard_index < num_shards + arange = np.arange(split.num_examples) + shard_range = np.array_split(arange, num_shards)[shard_index] + start, end = shard_range[0], (shard_range[-1] + 1) + if split == Split.TRAIN: + # Note that our TRAIN=TFDS_TRAIN[10000:] and VALID=TFDS_TRAIN[:10000]. + offset = Split.VALID.num_examples + start += offset + end += offset + return start, end + + +def _preprocess_image( + image_bytes: tf.Tensor, + is_training: bool, + image_size: Sequence[int], + augmentation_settings: Mapping[str, Any], +) -> Tuple[tf.Tensor, tf.Tensor]: + """Returns processed and resized images.""" + + # Get the image crop. + if is_training: + image, im_shape = _decode_and_random_crop(image_bytes) + image = tf.image.random_flip_left_right(image) + else: + image, im_shape = _decode_and_center_crop(image_bytes) + assert image.dtype == tf.uint8 + + # Optionally apply RandAugment: https://arxiv.org/abs/1909.13719 + if is_training: + if augmentation_settings['randaugment'] is not None: + # Input and output images are dtype uint8. + image = autoaugment.distort_image_with_randaugment( + image, + num_layers=augmentation_settings['randaugment']['num_layers'], + magnitude=augmentation_settings['randaugment']['magnitude']) + + # Resize and normalize the image crop. + # NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without + # clamping overshoots. This means values returned will be outside the range + # [0.0, 255.0] (e.g. we have observed outputs in the range [-51.1, 336.6]). + image = tf.image.resize( + image, image_size, tf.image.ResizeMethod.BICUBIC) + image = _normalize_image(image) + + return image, im_shape + + +def _normalize_image(image: tf.Tensor) -> tf.Tensor: + """Normalize the image to zero mean and unit variance.""" + image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype) + image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype) + return image + + +def _distorted_bounding_box_crop( + image_bytes: tf.Tensor, + *, + jpeg_shape: tf.Tensor, + bbox: tf.Tensor, + min_object_covered: float, + aspect_ratio_range: Tuple[float, float], + area_range: Tuple[float, float], + max_attempts: int, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Generates cropped_image using one of the bboxes randomly distorted.""" + bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box( + jpeg_shape, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True) + + # Crop the image to the specified bounding box. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + crop_window = [offset_y, offset_x, target_height, target_width] + + if image_bytes.dtype == tf.dtypes.string: + image = tf.image.decode_and_crop_jpeg(image_bytes, + tf.stack(crop_window), + channels=3) + else: + image = tf.image.crop_to_bounding_box(image_bytes, *crop_window) + + im_shape = tf.stack([target_height, target_width]) + return image, im_shape + + +def _decode_whole_image(image_bytes: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + image = tf.io.decode_jpeg(image_bytes, channels=3) + im_shape = tf.io.extract_jpeg_shape(image_bytes, output_type=tf.int32) + return image, im_shape + + +def _decode_and_random_crop( + image_bytes: tf.Tensor +) -> Tuple[tf.Tensor, tf.Tensor]: + """Make a random crop of INPUT_DIM.""" + + if image_bytes.dtype == tf.dtypes.string: + jpeg_shape = tf.image.extract_jpeg_shape(image_bytes) + else: + jpeg_shape = tf.shape(image_bytes) + + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + image, im_shape = _distorted_bounding_box_crop( + image_bytes, + jpeg_shape=jpeg_shape, + bbox=bbox, + min_object_covered=0.1, + aspect_ratio_range=(3 / 4, 4 / 3), + area_range=(0.08, 1.0), + max_attempts=10) + + if tf.reduce_all(tf.equal(jpeg_shape, tf.shape(image))): + # If the random crop failed fall back to center crop. + image, im_shape = _decode_and_center_crop(image_bytes, jpeg_shape) + return image, im_shape + + +def _center_crop(image, crop_dim): + """Center crops an image to a target dimension.""" + image_height = image.shape[0] + image_width = image.shape[1] + offset_height = ((image_height - crop_dim) + 1) // 2 + offset_width = ((image_width - crop_dim) + 1) // 2 + return tf.image.crop_to_bounding_box( + image, offset_height, offset_width, crop_dim, crop_dim) + + +def _decode_and_center_crop( + image_bytes: tf.Tensor, + jpeg_shape: Optional[tf.Tensor] = None, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Crops to center of image with padding then scales.""" + if jpeg_shape is None: + if image_bytes.dtype == tf.dtypes.string: + jpeg_shape = tf.image.extract_jpeg_shape(image_bytes) + else: + jpeg_shape = tf.shape(image_bytes) + + image_height = jpeg_shape[0] + image_width = jpeg_shape[1] + + padded_center_crop_size = tf.cast( + ((INPUT_DIM / (INPUT_DIM + 32)) * + tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + crop_window = [offset_height, offset_width, + padded_center_crop_size, padded_center_crop_size] + + if image_bytes.dtype == tf.dtypes.string: + image = tf.image.decode_and_crop_jpeg(image_bytes, + tf.stack(crop_window), + channels=3) + else: + image = tf.image.crop_to_bounding_box(image_bytes, *crop_window) + + im_shape = tf.stack([padded_center_crop_size, padded_center_crop_size]) + return image, im_shape diff --git a/perceiver/train/experiment.py b/perceiver/train/experiment.py new file mode 100644 index 0000000..591a3bb --- /dev/null +++ b/perceiver/train/experiment.py @@ -0,0 +1,538 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A reference training pipeline for Perceiver/Perceiver IO on ImageNet. + +We use the Jaxline (https://github.com/deepmind/jaxline) training framework. +Two sets of hyperparameters are provided, the hyperparameters we used for the +Perceiver IO paper, and scaled-down hyperparameters for local testing. +This script should run out-of-the-box with the local hyper parameters. +The scaled-up hyperparameters requires a distributed learning setup to run, +and this script will need to be adapted to your specific setup. +""" + +import functools +from typing import Generator, Mapping, Text, Tuple + +from absl import app +from absl import flags +from absl import logging +import haiku as hk +import jax +import jax.numpy as jnp +from jaxline import base_config +from jaxline import experiment +from jaxline import platform +from jaxline import utils as jl_utils +from ml_collections import config_dict +import numpy as np +import optax + + +from perceiver import io_processors +from perceiver import perceiver +from perceiver.train import dataset +from perceiver.train import utils + +FLAGS = flags.FLAGS + +OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState] +Scalars = Mapping[Text, jnp.ndarray] + + +N_TRAIN_EXAMPLES = dataset.Split.TRAIN_AND_VALID.num_examples +N_CLASSES = 1000 +# Only local/debug parameters are supported out of the box. +# To use the scaled-up hyperparameters, please adapt this script to your +# training setup and set this flag to False +IS_LOCAL = True + + +def get_training_steps(batch_size, n_epochs): + return (N_TRAIN_EXAMPLES * n_epochs) // batch_size + + +def get_config(): + """Return config object for training.""" + use_debug_settings = IS_LOCAL + config = base_config.get_base_config() + + # Experiment config. + local_batch_size = 2 + # Modify this to adapt to your custom distributed learning setup + num_devices = 1 + config.train_batch_size = local_batch_size * num_devices + config.n_epochs = 110 + + def _default_or_debug(default_value, debug_value): + return debug_value if use_debug_settings else default_value + + n_train_examples = N_TRAIN_EXAMPLES + num_classes = N_CLASSES + + config.experiment_kwargs = config_dict.ConfigDict( + dict( + config=dict( + optimizer=dict( + base_lr=5e-4, + max_norm=10.0, # < 0 to turn off. + schedule_type='constant_cosine', + weight_decay=1e-1, + decay_pos_embs=True, + scale_by_batch=True, + cosine_decay_kwargs=dict( + init_value=0.0, + warmup_epochs=0, + end_value=0.0, + ), + step_decay_kwargs=dict( + decay_boundaries=[0.5, 0.8, 0.95], + decay_rate=0.1, + ), + constant_cosine_decay_kwargs=dict( + constant_fraction=0.5, + end_value=0.0, + ), + optimizer='lamb', + # Optimizer-specific kwargs: + adam_kwargs=dict( + b1=0.9, + b2=0.999, + eps=1e-8, + ), + lamb_kwargs=dict( + b1=0.9, + b2=0.999, + eps=1e-6, + ), + ), + # Don't specify output_channels - it's not used for + # classifiers. + model=dict( + perceiver_kwargs=dict( + input_preprocessor=dict( + prep_type='pixels', + # Channels for conv/conv1x1 preprocessing: + num_channels=64, + # ------------------------- + # Position encoding arguments: + # ------------------------- + position_encoding_type='fourier', + concat_or_add_pos='concat', + spatial_downsample=1, + # If >0, project position to this size: + project_pos_dim=-1, + trainable_position_encoding_kwargs=dict( + num_channels=258, # Match default # for Fourier. + init_scale=0.02, + ), + fourier_position_encoding_kwargs=dict( + num_bands=64, + max_resolution=(224, 224), + sine_only=False, + concat_pos=True, + ), + ), + encoder=dict( + num_self_attends_per_block=_default_or_debug(6, 2), + # Weights won't be shared if num_blocks is set to 1. + num_blocks=_default_or_debug(8, 2), + z_index_dim=512, + num_z_channels=1024, + num_cross_attend_heads=1, + num_self_attend_heads=8, + cross_attend_widening_factor=1, + self_attend_widening_factor=1, + dropout_prob=0.0, + # Position encoding for the latent array. + z_pos_enc_init_scale=0.02, + cross_attention_shape_for_attn='kv', + use_query_residual=True, + ), + decoder=dict( + num_z_channels=1024, + use_query_residual=True, + # Position encoding for the output logits. + position_encoding_type='trainable', + trainable_position_encoding_kwargs=dict( + num_channels=1024, + init_scale=0.02, + ), + ), + ), + ), + training=dict( + images_per_epoch=n_train_examples, + label_smoothing=0.1, + n_epochs=config.get_oneway_ref('n_epochs'), + batch_size=config.get_oneway_ref('train_batch_size') + ), + data=dict( + num_classes=num_classes, + # Run on smaller images to debug. + im_dim=_default_or_debug(224, 32), + augmentation=dict( + # Typical randaug params: + # num_layers in [1, 3] + # magnitude in [5, 30] + # Set randaugment to None to disable. + randaugment=dict( + num_layers=4, + magnitude=5), + cutmix=True, + # Mixup alpha should be in [0, 1]. + # Set to None to disable. + mixup_alpha=0.2, + ), + ), + evaluation=dict( + subset='test', + batch_size=2, + ), + ) + ) + ) + + # Training loop config. + config.training_steps = get_training_steps( + config.get_oneway_ref('train_batch_size'), + config.get_oneway_ref('n_epochs')) + config.log_train_data_interval = 60 + config.log_tensors_interval = 60 + config.save_checkpoint_interval = 300 + config.eval_specific_checkpoint_dir = '' + config.best_model_eval_metric = 'eval_top_1_acc' + config.checkpoint_dir = '/tmp/perceiver_imagnet_checkpoints' + config.train_checkpoint_all_hosts = False + + # Prevents accidentally setting keys that aren't recognized (e.g. in tests). + config.lock() + + return config + + +class Experiment(experiment.AbstractExperiment): + """ImageNet experiment.""" + + # A map from object properties that will be checkpointed to their name + # in a checkpoint. Currently we assume that these are all sharded + # device arrays. + CHECKPOINT_ATTRS = { + '_params': 'params', + '_state': 'state', + '_opt_state': 'opt_state', + } + + def __init__(self, mode, init_rng, config): + """Initializes experiment.""" + + super(Experiment, self).__init__(mode=mode, init_rng=init_rng) + + self.mode = mode + self.init_rng = init_rng + self.config = config + + # Checkpointed experiment state. + self._params = None + self._state = None + self._opt_state = None + + # Input pipelines. + self._train_input = None + self._eval_input = None + + self.forward = hk.transform_with_state(self._forward_fn) + + # NOTE: We "donate" the `params, state, opt_state` arguments which allows + # JAX (on some backends) to reuse the device memory associated with these + # inputs to store the outputs of our function (which also start with + # `params, state, opt_state`). + self._update_func = jax.pmap(self._update_func, axis_name='i', + donate_argnums=(0, 1, 2)) + self._eval_batch = jax.jit(self._eval_batch) + + def _forward_fn( + self, + inputs: dataset.Batch, + is_training: bool, + ) -> jnp.ndarray: + + images = inputs['images'] + + perceiver_kwargs = self.config.model.perceiver_kwargs + input_preprocessor = io_processors.ImagePreprocessor( + **perceiver_kwargs['input_preprocessor']) + encoder = perceiver.PerceiverEncoder(**perceiver_kwargs['encoder']) + decoder = perceiver.ClassificationDecoder( + self.config.data.num_classes, + **perceiver_kwargs['decoder']) + model = perceiver.Perceiver( + encoder=encoder, + decoder=decoder, + input_preprocessor=input_preprocessor) + + return model(images, is_training=is_training) + + # _ _ + # | |_ _ __ __ _(_)_ __ + # | __| '__/ _` | | '_ \ + # | |_| | | (_| | | | | | + # \__|_| \__,_|_|_| |_| + # + + def step(self, global_step: int, rng: jnp.ndarray, + *unused_args, **unused_kwargs): + """See base class.""" + + if self._train_input is None: + self._initialize_train() + + inputs = next(self._train_input) + + self._params, self._state, self._opt_state, scalars = ( + self._update_func( + self._params, self._state, self._opt_state, inputs, rng, global_step + )) + + scalars = jl_utils.get_first(scalars) + return scalars + + def _initialize_train(self): + self._train_input = jl_utils.py_prefetch(self._build_train_input) + + total_batch_size = self.config.training.batch_size + steps_per_epoch = ( + self.config.training.images_per_epoch / self.config.training.batch_size) + total_steps = self.config.training.n_epochs * steps_per_epoch + # Scale by the (negative) learning rate. + self._lr_schedule = utils.get_learning_rate_schedule( + total_batch_size, steps_per_epoch, total_steps, self.config.optimizer) + + self._optimizer = utils.make_optimizer( + self.config.optimizer, + self._lr_schedule) + + # Check we haven't already restored params + if self._params is None: + logging.info('Initializing parameters.') + + inputs = next(self._train_input) + + init_net = jax.pmap(lambda *a: self.forward.init(*a, is_training=True)) + init_opt = jax.pmap(self._optimizer.init) + + # Init uses the same RNG key on all hosts+devices to ensure everyone + # computes the same initial state. + init_rng = jl_utils.bcast_local_devices(self.init_rng) + + self._params, self._state = init_net(init_rng, inputs) + self._opt_state = init_opt(self._params) + + def _load_data(self, split, is_training, batch_dims): + """Wrapper for dataset loading.""" + + return dataset.load( + split=split, + is_training=is_training, + batch_dims=batch_dims, + im_dim=self.config.data.im_dim, + augmentation_settings=self.config.data.augmentation, + ) + + def _build_train_input(self) -> Generator[dataset.Batch, None, None]: + """See base class.""" + num_devices = jax.device_count() + global_batch_size = self.config.training.batch_size + per_device_batch_size, ragged = divmod(global_batch_size, num_devices) + + if ragged: + raise ValueError( + f'Global batch size {global_batch_size} must be divisible by ' + f'num devices {num_devices}') + + split = dataset.Split.TRAIN_AND_VALID + + return self._load_data( + split=split, + is_training=True, + batch_dims=[jax.local_device_count(), per_device_batch_size]) + + def _one_hot(self, value): + """One-hot encoding potentially over a sequence of labels.""" + y = jax.nn.one_hot(value, self.config.data.num_classes) + return y + + def _loss_fn( + self, + params: hk.Params, + state: hk.State, + inputs: dataset.Batch, + rng: jnp.ndarray, + ) -> Tuple[jnp.ndarray, Tuple[Scalars, hk.State]]: + logits, state = self.forward.apply( + params, state, rng, inputs, is_training=True) + + label = self._one_hot(inputs['labels']) + # Handle cutmix/mixup label mixing: + if 'mix_labels' in inputs: + logging.info('Using mixup or cutmix!') + mix_label = self._one_hot(inputs['mix_labels']) + mix_ratio = inputs['ratio'][:, None] + label = mix_ratio * label + (1. - mix_ratio) * mix_label + + # Apply label-smoothing to one-hot labels. + label_smoothing = self.config.training.label_smoothing + if not (label_smoothing >= 0. and label_smoothing < 1.): + raise ValueError( + f"'label_smoothing is {label_smoothing} and should be in [0, 1)") + if label_smoothing > 0: + smooth_positives = 1. - label_smoothing + smooth_negatives = label_smoothing / self.config.data.num_classes + label = smooth_positives * label + smooth_negatives + + loss_w_batch = utils.softmax_cross_entropy(logits, label) + loss = jnp.mean(loss_w_batch, dtype=loss_w_batch.dtype) + scaled_loss = loss / jax.device_count() + + metrics = utils.topk_correct(logits, inputs['labels'], prefix='') + metrics = jax.tree_map(jnp.mean, metrics) + + top_1_acc = metrics['top_1_acc'] + top_5_acc = metrics['top_5_acc'] + + loss_scalars = dict( + loss=loss, + top_1_acc=top_1_acc, + top_5_acc=top_5_acc, + ) + + return scaled_loss, (loss_scalars, state) + + def _update_func( + self, + params: hk.Params, + state: hk.State, + opt_state: OptState, + inputs: dataset.Batch, + rng: jnp.ndarray, + global_step: int, + ) -> Tuple[hk.Params, hk.State, OptState, Scalars]: + """Applies an update to parameters and returns new state.""" + # This function computes the gradient of the first output of loss_fn and + # passes through the other arguments unchanged. + grad_loss_fn = jax.grad(self._loss_fn, has_aux=True) + scaled_grads, (loss_scalars, state) = grad_loss_fn( + params, state, inputs, rng) + grads = jax.lax.psum(scaled_grads, axis_name='i') + + # Grab the learning rate to log before performing the step. + learning_rate = self._lr_schedule(global_step) + + # Compute and apply updates via our optimizer. + updates, opt_state = self._optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + + n_params = 0 + for k in params.keys(): + for l in params[k]: + n_params = n_params + np.prod(params[k][l].shape) + + # Scalars to log (note: we log the mean across all hosts/devices). + scalars = {'learning_rate': learning_rate, + 'n_params (M)': float(n_params/1e6), + 'global_gradient_norm': optax.global_norm(grads)} + loss_scalars = {f'train_{k}': v for k, v in loss_scalars.items()} + scalars.update(loss_scalars) + scalars = jax.lax.pmean(scalars, axis_name='i') + + return params, state, opt_state, scalars + + # _ + # _____ ____ _| | + # / _ \ \ / / _` | | + # | __/\ V / (_| | | + # \___| \_/ \__,_|_| + # + + def evaluate(self, global_step, rng, **unused_args): + """See base class.""" + global_step = np.array(jl_utils.get_first(global_step)) + scalars = jax.device_get(self._eval_epoch(jl_utils.get_first(rng))) + + logging.info('[Step %d] Eval scalars: %s', global_step, scalars) + return scalars + + def _eval_batch( + self, + params: hk.Params, + state: hk.State, + inputs: dataset.Batch, + rng: jnp.ndarray, + ) -> Scalars: + """Evaluates a batch.""" + logits, _ = self.forward.apply( + params, state, rng, inputs, is_training=False) + + labels = self._one_hot(inputs['labels']) + loss = utils.softmax_cross_entropy(logits, labels) + + metrics = utils.topk_correct(logits, inputs['labels'], prefix='') + metrics = jax.tree_map(jnp.mean, metrics) + top_1_acc = metrics['top_1_acc'] + top_5_acc = metrics['top_5_acc'] + + bs = logits.shape[0] + + top_1_acc = jnp.expand_dims(top_1_acc, axis=0) * bs + top_5_acc = jnp.expand_dims(top_5_acc, axis=0) * bs + + # NOTE: Returned values will be summed and finally divided by num_samples. + return { + 'eval_loss': loss, + 'eval_top_1_acc': top_1_acc, 'eval_top_5_acc': top_5_acc} + + def _build_eval_input(self) -> Generator[dataset.Batch, None, None]: + split = dataset.Split.from_string(self.config.evaluation.subset) + + return self._load_data( + split=split, + is_training=False, + batch_dims=[self.config.evaluation.batch_size]) + + def _eval_epoch(self, rng): + """Evaluates an epoch.""" + num_samples = 0. + summed_scalars = None + + params = jl_utils.get_first(self._params) + state = jl_utils.get_first(self._state) + + for inputs in self._build_eval_input(): + num_samples += inputs['labels'].shape[0] + scalars = self._eval_batch(params, state, inputs, rng) + + # Accumulate the sum of scalars for each step. + scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars) + if summed_scalars is None: + summed_scalars = scalars + else: + summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars) + + mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars) + return mean_scalars + + +if __name__ == '__main__': + flags.mark_flag_as_required('config') + app.run(functools.partial(platform.main, Experiment)) diff --git a/perceiver/train/launch_local.sh b/perceiver/train/launch_local.sh new file mode 100755 index 0000000..555aef2 --- /dev/null +++ b/perceiver/train/launch_local.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Command for running training script. + +PYTHONPATH=.::$PYTHONPATH python perceiver/train/experiment.py \ + --config=perceiver/train/experiment.py --logtostderr diff --git a/perceiver/train/utils.py b/perceiver/train/utils.py new file mode 100644 index 0000000..9935101 --- /dev/null +++ b/perceiver/train/utils.py @@ -0,0 +1,242 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities.""" + +from typing import Callable, List, Mapping, NamedTuple, Optional, Tuple, Union + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import optax + + +Batch = Mapping[str, np.ndarray] +OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState] +Scalars = Mapping[str, jnp.ndarray] +ParamsOrState = Union[hk.Params, hk.State] + + +NORM_NAMES = ['layer_norm', 'batchnorm'] + + +# any_in and topk_correct taken from +# https://github.com/deepmind/deepmind-research/blob/master/nfnets/utils.py +@jax.vmap +def any_in(prediction, target): + """For each row in a and b, checks if any element of a is in b.""" + return jnp.isin(prediction, target) + + +def topk_correct(logits, labels, mask=None, prefix='', topk=(1, 5)): + """Calculate top-k error for multiple k values.""" + metrics = {} + argsorted_logits = jnp.argsort(logits) + for k in topk: + pred_labels = argsorted_logits[..., -k:] + # Get the number of examples where the label is in the top-k predictions + correct = any_in(pred_labels, labels).any(axis=-1).astype(jnp.float32) + if mask is not None: + correct *= mask + metrics[f'{prefix}top_{k}_acc'] = correct + return metrics + + +def softmax_cross_entropy(logits, labels): + """Computes softmax cross entropy given logits and one-hot class labels. + + Args: + logits: Logit output values. + labels: Ground truth one-hot-encoded labels. + + Returns: + Loss value with the same shape as `labels`; + """ + return jnp.asarray(optax.softmax_cross_entropy(logits, labels)) + + +def _get_batch_scaled_lr(total_batch_size, lr, scale_by_batch=True): + # This is the linear scaling rule in Section 5.1 of + # https://arxiv.org/pdf/1706.02677.pdf. + + if scale_by_batch: + lr = (lr * total_batch_size) / 256 + + return lr + + +def get_learning_rate_schedule( + total_batch_size, steps_per_epoch, total_steps, optimizer_config): + """Build the learning rate schedule function.""" + base_lr = _get_batch_scaled_lr(total_batch_size, optimizer_config.base_lr, + optimizer_config.scale_by_batch) + + schedule_type = optimizer_config.schedule_type + if schedule_type == 'steps': + boundaries = optimizer_config.step_decay_kwargs.decay_boundaries + boundaries.sort() + + decay_rate = optimizer_config.step_decay_kwargs.decay_rate + boundaries_and_scales = { + int(boundary * total_steps): decay_rate for boundary in boundaries} + schedule_fn = optax.piecewise_constant_schedule( + init_value=base_lr, boundaries_and_scales=boundaries_and_scales) + elif schedule_type == 'cosine': + warmup_steps = (optimizer_config.cosine_decay_kwargs.warmup_epochs + * steps_per_epoch) + # Batch scale the other lr values as well: + init_value = _get_batch_scaled_lr( + total_batch_size, + optimizer_config.cosine_decay_kwargs.init_value, + optimizer_config.scale_by_batch) + end_value = _get_batch_scaled_lr( + total_batch_size, + optimizer_config.cosine_decay_kwargs.end_value, + optimizer_config.scale_by_batch) + + schedule_fn = optax.warmup_cosine_decay_schedule( + init_value=init_value, + peak_value=base_lr, + warmup_steps=warmup_steps, + decay_steps=total_steps, + end_value=end_value) + elif schedule_type == 'constant_cosine': + # Convert end_value to alpha, used by cosine_decay_schedule. + alpha = optimizer_config.constant_cosine_decay_kwargs.end_value / base_lr + + # Number of steps spent in constant phase. + constant_steps = int( + optimizer_config.constant_cosine_decay_kwargs.constant_fraction + * total_steps) + decay_steps = total_steps - constant_steps + + constant_phase = optax.constant_schedule(value=base_lr) + decay_phase = optax.cosine_decay_schedule( + init_value=base_lr, + decay_steps=decay_steps, + alpha=alpha) + schedule_fn = optax.join_schedules( + schedules=[constant_phase, decay_phase], + boundaries=[constant_steps]) + else: + raise ValueError(f'Unknown learning rate schedule: {schedule_type}') + + return schedule_fn + + +def _weight_decay_exclude( + exclude_names: Optional[List[str]] = None +) -> Callable[[str, str, jnp.ndarray], bool]: + """Logic for deciding which parameters to include for weight decay.. + + Args: + exclude_names: an optional list of names to include for weight_decay. ['w'] + by default. + + Returns: + A predicate that returns True for params that need to be excluded from + weight_decay. + """ + # By default weight_decay the weights but not the biases. + if not exclude_names: + exclude_names = ['b'] + + def exclude(module_name: str, name: str, value: jnp.array): + del value + # Do not weight decay the parameters of normalization blocks. + if any([norm_name in module_name for norm_name in NORM_NAMES]): + return True + else: + return name in exclude_names + + return exclude + + +class AddWeightDecayState(NamedTuple): + """Stateless transformation.""" + + +def add_weight_decay( + weight_decay: float, + exclude_names: Optional[List[str]] = None) -> optax.GradientTransformation: + """Add parameter scaled by `weight_decay` to the `updates`. + + Same as optax.add_decayed_weights but can exclude parameters by name. + + Args: + weight_decay: weight_decay coefficient. + exclude_names: an optional list of names to exclude for weight_decay. ['b'] + by default. + + Returns: + An (init_fn, update_fn) tuple. + """ + + def init_fn(_): + return AddWeightDecayState() + + def update_fn(updates, state, params): + exclude = _weight_decay_exclude(exclude_names=exclude_names) + + u_ex, u_in = hk.data_structures.partition(exclude, updates) + _, p_in = hk.data_structures.partition(exclude, params) + u_in = jax.tree_multimap(lambda g, p: g + weight_decay * p, u_in, p_in) + updates = hk.data_structures.merge(u_ex, u_in) + return updates, state + + return optax.GradientTransformation(init_fn, update_fn) + + +def make_optimizer(optimizer_config, lr_schedule): + """Construct the optax optimizer with given LR schedule.""" + if (optimizer_config.get('decay_pos_embs') is None or + optimizer_config.decay_pos_embs): + # Decay learned position embeddings by default. + weight_decay_exclude_names = ['b'] + else: + weight_decay_exclude_names = ['pos_embs', 'b'] + + optax_chain = [] + if optimizer_config.max_norm > 0: + optax_chain.append( + optax.clip_by_global_norm(optimizer_config.max_norm)) + + if optimizer_config.optimizer == 'adam': + # See: https://arxiv.org/abs/1412.6980 + optax_chain.extend([ + optax.scale_by_adam(**optimizer_config.adam_kwargs), + add_weight_decay( + optimizer_config.weight_decay, + exclude_names=weight_decay_exclude_names) + ]) + elif optimizer_config.optimizer == 'lamb': + # See: https://arxiv.org/abs/1904.00962 + optax_chain.extend([ + optax.scale_by_adam(**optimizer_config.lamb_kwargs), + add_weight_decay( + optimizer_config.weight_decay, + exclude_names=weight_decay_exclude_names), + optax.scale_by_trust_ratio() + ]) + else: + raise ValueError(f'Undefined optimizer {optimizer_config.optimizer}') + + # Scale by the (negative) learning rate. + optax_chain.extend([ + optax.scale_by_schedule(lr_schedule), + optax.scale(-1), + ]) + + return optax.chain(*optax_chain)