mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
8457046b2c
PiperOrigin-RevId: 328023346
267 lines
9.3 KiB
Python
267 lines
9.3 KiB
Python
# Copyright 2020 DeepMind Technologies Limited.
|
|
#
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""ImageNet dataset with typical pre-processing."""
|
|
|
|
import enum
|
|
from typing import Generator, Mapping, Optional, Sequence, Text, Tuple
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
import tensorflow.compat.v2 as tf
|
|
import tensorflow_datasets as tfds
|
|
|
|
Batch = Mapping[Text, np.ndarray]
|
|
|
|
|
|
class Split(enum.Enum):
|
|
"""Imagenet dataset split."""
|
|
TRAIN = 1
|
|
TRAIN_AND_VALID = 2
|
|
VALID = 3
|
|
TEST = 4
|
|
|
|
@classmethod
|
|
def from_string(cls, name: Text) -> 'Split':
|
|
return {
|
|
'TRAIN': Split.TRAIN,
|
|
'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
|
|
'VALID': Split.VALID,
|
|
'VALIDATION': Split.VALID,
|
|
'TEST': Split.TEST
|
|
}[name.upper()]
|
|
|
|
@property
|
|
def num_examples(self):
|
|
return {
|
|
Split.TRAIN_AND_VALID: 1281167,
|
|
Split.TRAIN: 1271167,
|
|
Split.VALID: 10000,
|
|
Split.TEST: 50000
|
|
}[self]
|
|
|
|
|
|
class PreprocessMode(enum.Enum):
|
|
"""Preprocessing modes for the dataset."""
|
|
PRETRAIN = 1 # Generates two augmented views (random crop + augmentations).
|
|
LINEAR_TRAIN = 2 # Generates a single random crop.
|
|
EVAL = 3 # Generates a single center crop.
|
|
|
|
|
|
def normalize_images(images: jnp.ndarray) -> jnp.ndarray:
|
|
"""Normalize the image using ImageNet statistics."""
|
|
mean_rgb = (0.485, 0.456, 0.406)
|
|
stddev_rgb = (0.229, 0.224, 0.225)
|
|
normed_images = images - jnp.array(mean_rgb).reshape((1, 1, 1, 3))
|
|
normed_images = normed_images / jnp.array(stddev_rgb).reshape((1, 1, 1, 3))
|
|
return normed_images
|
|
|
|
|
|
def load(split: Split,
|
|
*,
|
|
preprocess_mode: PreprocessMode,
|
|
batch_dims: Sequence[int],
|
|
transpose: bool = False,
|
|
allow_caching: bool = False) -> Generator[Batch, None, None]:
|
|
"""Loads the given split of the dataset."""
|
|
start, end = _shard(split, jax.host_id(), jax.host_count())
|
|
|
|
total_batch_size = np.prod(batch_dims)
|
|
|
|
tfds_split = tfds.core.ReadInstruction(
|
|
_to_tfds_split(split), from_=start, to=end, unit='abs')
|
|
ds = tfds.load(
|
|
'imagenet2012:5.*.*',
|
|
split=tfds_split,
|
|
decoders={'image': tfds.decode.SkipDecoding()})
|
|
|
|
options = ds.options()
|
|
options.experimental_threading.private_threadpool_size = 48
|
|
options.experimental_threading.max_intra_op_parallelism = 1
|
|
|
|
if preprocess_mode is not PreprocessMode.EVAL:
|
|
options.experimental_deterministic = False
|
|
if jax.host_count() > 1 and allow_caching:
|
|
# Only cache if we are reading a subset of the dataset.
|
|
ds = ds.cache()
|
|
ds = ds.repeat()
|
|
ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0)
|
|
|
|
else:
|
|
if split.num_examples % total_batch_size != 0:
|
|
raise ValueError(f'Test/valid must be divisible by {total_batch_size}')
|
|
|
|
def preprocess_pretrain(example):
|
|
view1 = _preprocess_image(example['image'], mode=preprocess_mode)
|
|
view2 = _preprocess_image(example['image'], mode=preprocess_mode)
|
|
label = tf.cast(example['label'], tf.int32)
|
|
return {'view1': view1, 'view2': view2, 'labels': label}
|
|
|
|
def preprocess_linear_train(example):
|
|
image = _preprocess_image(example['image'], mode=preprocess_mode)
|
|
label = tf.cast(example['label'], tf.int32)
|
|
return {'images': image, 'labels': label}
|
|
|
|
def preprocess_eval(example):
|
|
image = _preprocess_image(example['image'], mode=preprocess_mode)
|
|
label = tf.cast(example['label'], tf.int32)
|
|
return {'images': image, 'labels': label}
|
|
|
|
if preprocess_mode is PreprocessMode.PRETRAIN:
|
|
ds = ds.map(
|
|
preprocess_pretrain, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
|
elif preprocess_mode is PreprocessMode.LINEAR_TRAIN:
|
|
ds = ds.map(
|
|
preprocess_linear_train,
|
|
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
|
else:
|
|
ds = ds.map(
|
|
preprocess_eval, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
|
|
|
def transpose_fn(batch):
|
|
# We use the double-transpose-trick to improve performance for TPUs. Note
|
|
# that this (typically) requires a matching HWCN->NHWC transpose in your
|
|
# model code. The compiler cannot make this optimization for us since our
|
|
# data pipeline and model are compiled separately.
|
|
batch = dict(**batch)
|
|
if preprocess_mode is PreprocessMode.PRETRAIN:
|
|
batch['view1'] = tf.transpose(batch['view1'], (1, 2, 3, 0))
|
|
batch['view2'] = tf.transpose(batch['view2'], (1, 2, 3, 0))
|
|
else:
|
|
batch['images'] = tf.transpose(batch['images'], (1, 2, 3, 0))
|
|
return batch
|
|
|
|
for i, batch_size in enumerate(reversed(batch_dims)):
|
|
ds = ds.batch(batch_size)
|
|
if i == 0 and transpose:
|
|
ds = ds.map(transpose_fn) # NHWC -> HWCN
|
|
|
|
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
|
|
|
|
yield from tfds.as_numpy(ds)
|
|
|
|
|
|
def _to_tfds_split(split: Split) -> tfds.Split:
|
|
"""Returns the TFDS split appropriately sharded."""
|
|
# NOTE: Imagenet did not release labels for the test split used in the
|
|
# competition, we consider the VALID split the TEST split and reserve
|
|
# 10k images from TRAIN for VALID.
|
|
if split in (Split.TRAIN, Split.TRAIN_AND_VALID, Split.VALID):
|
|
return tfds.Split.TRAIN
|
|
else:
|
|
assert split == Split.TEST
|
|
return tfds.Split.VALIDATION
|
|
|
|
|
|
def _shard(split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]:
|
|
"""Returns [start, end) for the given shard index."""
|
|
assert shard_index < num_shards
|
|
arange = np.arange(split.num_examples)
|
|
shard_range = np.array_split(arange, num_shards)[shard_index]
|
|
start, end = shard_range[0], (shard_range[-1] + 1)
|
|
if split == Split.TRAIN:
|
|
# Note that our TRAIN=TFDS_TRAIN[10000:] and VALID=TFDS_TRAIN[:10000].
|
|
offset = Split.VALID.num_examples
|
|
start += offset
|
|
end += offset
|
|
return start, end
|
|
|
|
|
|
def _preprocess_image(
|
|
image_bytes: tf.Tensor,
|
|
mode: PreprocessMode,
|
|
) -> tf.Tensor:
|
|
"""Returns processed and resized images."""
|
|
if mode is PreprocessMode.PRETRAIN:
|
|
image = _decode_and_random_crop(image_bytes)
|
|
# Random horizontal flipping is optionally done in augmentations.preprocess.
|
|
elif mode is PreprocessMode.LINEAR_TRAIN:
|
|
image = _decode_and_random_crop(image_bytes)
|
|
image = tf.image.random_flip_left_right(image)
|
|
else:
|
|
image = _decode_and_center_crop(image_bytes)
|
|
# NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without
|
|
# clamping overshoots. This means values returned will be outside the range
|
|
# [0.0, 255.0] (e.g. we have observed outputs in the range [-51.1, 336.6]).
|
|
assert image.dtype == tf.uint8
|
|
image = tf.image.resize(image, [224, 224], tf.image.ResizeMethod.BICUBIC)
|
|
image = tf.clip_by_value(image / 255., 0., 1.)
|
|
return image
|
|
|
|
|
|
def _decode_and_random_crop(image_bytes: tf.Tensor) -> tf.Tensor:
|
|
"""Make a random crop of 224."""
|
|
img_size = tf.image.extract_jpeg_shape(image_bytes)
|
|
area = tf.cast(img_size[1] * img_size[0], tf.float32)
|
|
target_area = tf.random.uniform([], 0.08, 1.0, dtype=tf.float32) * area
|
|
|
|
log_ratio = (tf.math.log(3 / 4), tf.math.log(4 / 3))
|
|
aspect_ratio = tf.math.exp(
|
|
tf.random.uniform([], *log_ratio, dtype=tf.float32))
|
|
|
|
w = tf.cast(tf.round(tf.sqrt(target_area * aspect_ratio)), tf.int32)
|
|
h = tf.cast(tf.round(tf.sqrt(target_area / aspect_ratio)), tf.int32)
|
|
|
|
w = tf.minimum(w, img_size[1])
|
|
h = tf.minimum(h, img_size[0])
|
|
|
|
offset_w = tf.random.uniform((),
|
|
minval=0,
|
|
maxval=img_size[1] - w + 1,
|
|
dtype=tf.int32)
|
|
offset_h = tf.random.uniform((),
|
|
minval=0,
|
|
maxval=img_size[0] - h + 1,
|
|
dtype=tf.int32)
|
|
|
|
crop_window = tf.stack([offset_h, offset_w, h, w])
|
|
image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
|
|
return image
|
|
|
|
|
|
def transpose_images(batch: Batch):
|
|
"""Transpose images for TPU training.."""
|
|
new_batch = dict(batch) # Avoid mutating in place.
|
|
if 'images' in batch:
|
|
new_batch['images'] = jnp.transpose(batch['images'], (3, 0, 1, 2))
|
|
else:
|
|
new_batch['view1'] = jnp.transpose(batch['view1'], (3, 0, 1, 2))
|
|
new_batch['view2'] = jnp.transpose(batch['view2'], (3, 0, 1, 2))
|
|
return new_batch
|
|
|
|
|
|
def _decode_and_center_crop(
|
|
image_bytes: tf.Tensor,
|
|
jpeg_shape: Optional[tf.Tensor] = None,
|
|
) -> tf.Tensor:
|
|
"""Crops to center of image with padding then scales."""
|
|
if jpeg_shape is None:
|
|
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
|
|
image_height = jpeg_shape[0]
|
|
image_width = jpeg_shape[1]
|
|
|
|
padded_center_crop_size = tf.cast(
|
|
((224 / (224 + 32)) *
|
|
tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
|
|
|
|
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
|
|
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
|
|
crop_window = tf.stack([
|
|
offset_height, offset_width, padded_center_crop_size,
|
|
padded_center_crop_size
|
|
])
|
|
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
|
|
return image
|