mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-29 03:35:21 +08:00
Add checkpoints from the ablation study.
PiperOrigin-RevId: 328023346
This commit is contained in:
committed by
Diego de Las Casas
parent
22c3daff19
commit
8457046b2c
+17
-17
@@ -34,7 +34,7 @@ class Split(enum.Enum):
|
||||
TEST = 4
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, name):
|
||||
def from_string(cls, name: Text) -> 'Split':
|
||||
return {
|
||||
'TRAIN': Split.TRAIN,
|
||||
'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
|
||||
@@ -60,7 +60,7 @@ class PreprocessMode(enum.Enum):
|
||||
EVAL = 3 # Generates a single center crop.
|
||||
|
||||
|
||||
def normalize_images(images):
|
||||
def normalize_images(images: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Normalize the image using ImageNet statistics."""
|
||||
mean_rgb = (0.485, 0.456, 0.406)
|
||||
stddev_rgb = (0.229, 0.224, 0.225)
|
||||
@@ -69,12 +69,12 @@ def normalize_images(images):
|
||||
return normed_images
|
||||
|
||||
|
||||
def load(split,
|
||||
def load(split: Split,
|
||||
*,
|
||||
preprocess_mode,
|
||||
batch_dims,
|
||||
transpose = False,
|
||||
allow_caching = False):
|
||||
preprocess_mode: PreprocessMode,
|
||||
batch_dims: Sequence[int],
|
||||
transpose: bool = False,
|
||||
allow_caching: bool = False) -> Generator[Batch, None, None]:
|
||||
"""Loads the given split of the dataset."""
|
||||
start, end = _shard(split, jax.host_id(), jax.host_count())
|
||||
|
||||
@@ -153,7 +153,7 @@ def load(split,
|
||||
yield from tfds.as_numpy(ds)
|
||||
|
||||
|
||||
def _to_tfds_split(split):
|
||||
def _to_tfds_split(split: Split) -> tfds.Split:
|
||||
"""Returns the TFDS split appropriately sharded."""
|
||||
# NOTE: Imagenet did not release labels for the test split used in the
|
||||
# competition, we consider the VALID split the TEST split and reserve
|
||||
@@ -165,7 +165,7 @@ def _to_tfds_split(split):
|
||||
return tfds.Split.VALIDATION
|
||||
|
||||
|
||||
def _shard(split, shard_index, num_shards):
|
||||
def _shard(split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]:
|
||||
"""Returns [start, end) for the given shard index."""
|
||||
assert shard_index < num_shards
|
||||
arange = np.arange(split.num_examples)
|
||||
@@ -180,9 +180,9 @@ def _shard(split, shard_index, num_shards):
|
||||
|
||||
|
||||
def _preprocess_image(
|
||||
image_bytes,
|
||||
mode,
|
||||
):
|
||||
image_bytes: tf.Tensor,
|
||||
mode: PreprocessMode,
|
||||
) -> tf.Tensor:
|
||||
"""Returns processed and resized images."""
|
||||
if mode is PreprocessMode.PRETRAIN:
|
||||
image = _decode_and_random_crop(image_bytes)
|
||||
@@ -201,7 +201,7 @@ def _preprocess_image(
|
||||
return image
|
||||
|
||||
|
||||
def _decode_and_random_crop(image_bytes):
|
||||
def _decode_and_random_crop(image_bytes: tf.Tensor) -> tf.Tensor:
|
||||
"""Make a random crop of 224."""
|
||||
img_size = tf.image.extract_jpeg_shape(image_bytes)
|
||||
area = tf.cast(img_size[1] * img_size[0], tf.float32)
|
||||
@@ -231,7 +231,7 @@ def _decode_and_random_crop(image_bytes):
|
||||
return image
|
||||
|
||||
|
||||
def transpose_images(batch):
|
||||
def transpose_images(batch: Batch):
|
||||
"""Transpose images for TPU training.."""
|
||||
new_batch = dict(batch) # Avoid mutating in place.
|
||||
if 'images' in batch:
|
||||
@@ -243,9 +243,9 @@ def transpose_images(batch):
|
||||
|
||||
|
||||
def _decode_and_center_crop(
|
||||
image_bytes,
|
||||
jpeg_shape = None,
|
||||
):
|
||||
image_bytes: tf.Tensor,
|
||||
jpeg_shape: Optional[tf.Tensor] = None,
|
||||
) -> tf.Tensor:
|
||||
"""Crops to center of image with padding then scales."""
|
||||
if jpeg_shape is None:
|
||||
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
|
||||
|
||||
Reference in New Issue
Block a user