Add checkpoints from the ablation study.

PiperOrigin-RevId: 328023346
This commit is contained in:
Florent Altché
2020-08-23 14:26:26 +01:00
committed by Diego de Las Casas
parent 22c3daff19
commit 8457046b2c
33 changed files with 397 additions and 363 deletions
+17 -17
View File
@@ -34,7 +34,7 @@ class Split(enum.Enum):
TEST = 4
@classmethod
def from_string(cls, name):
def from_string(cls, name: Text) -> 'Split':
return {
'TRAIN': Split.TRAIN,
'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
@@ -60,7 +60,7 @@ class PreprocessMode(enum.Enum):
EVAL = 3 # Generates a single center crop.
def normalize_images(images):
def normalize_images(images: jnp.ndarray) -> jnp.ndarray:
"""Normalize the image using ImageNet statistics."""
mean_rgb = (0.485, 0.456, 0.406)
stddev_rgb = (0.229, 0.224, 0.225)
@@ -69,12 +69,12 @@ def normalize_images(images):
return normed_images
def load(split,
def load(split: Split,
*,
preprocess_mode,
batch_dims,
transpose = False,
allow_caching = False):
preprocess_mode: PreprocessMode,
batch_dims: Sequence[int],
transpose: bool = False,
allow_caching: bool = False) -> Generator[Batch, None, None]:
"""Loads the given split of the dataset."""
start, end = _shard(split, jax.host_id(), jax.host_count())
@@ -153,7 +153,7 @@ def load(split,
yield from tfds.as_numpy(ds)
def _to_tfds_split(split):
def _to_tfds_split(split: Split) -> tfds.Split:
"""Returns the TFDS split appropriately sharded."""
# NOTE: Imagenet did not release labels for the test split used in the
# competition, we consider the VALID split the TEST split and reserve
@@ -165,7 +165,7 @@ def _to_tfds_split(split):
return tfds.Split.VALIDATION
def _shard(split, shard_index, num_shards):
def _shard(split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]:
"""Returns [start, end) for the given shard index."""
assert shard_index < num_shards
arange = np.arange(split.num_examples)
@@ -180,9 +180,9 @@ def _shard(split, shard_index, num_shards):
def _preprocess_image(
image_bytes,
mode,
):
image_bytes: tf.Tensor,
mode: PreprocessMode,
) -> tf.Tensor:
"""Returns processed and resized images."""
if mode is PreprocessMode.PRETRAIN:
image = _decode_and_random_crop(image_bytes)
@@ -201,7 +201,7 @@ def _preprocess_image(
return image
def _decode_and_random_crop(image_bytes):
def _decode_and_random_crop(image_bytes: tf.Tensor) -> tf.Tensor:
"""Make a random crop of 224."""
img_size = tf.image.extract_jpeg_shape(image_bytes)
area = tf.cast(img_size[1] * img_size[0], tf.float32)
@@ -231,7 +231,7 @@ def _decode_and_random_crop(image_bytes):
return image
def transpose_images(batch):
def transpose_images(batch: Batch):
"""Transpose images for TPU training.."""
new_batch = dict(batch) # Avoid mutating in place.
if 'images' in batch:
@@ -243,9 +243,9 @@ def transpose_images(batch):
def _decode_and_center_crop(
image_bytes,
jpeg_shape = None,
):
image_bytes: tf.Tensor,
jpeg_shape: Optional[tf.Tensor] = None,
) -> tf.Tensor:
"""Crops to center of image with padding then scales."""
if jpeg_shape is None:
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)