Add checkpoints from the ablation study.

PiperOrigin-RevId: 328023346
This commit is contained in:
Florent Altché
2020-08-23 14:26:26 +01:00
committed by Diego de Las Casas
parent 22c3daff19
commit 8457046b2c
33 changed files with 397 additions and 363 deletions
+34
View File
@@ -176,3 +176,37 @@ python -m byol.main_loop \
With these settings, BYOL should achieve ~92.3% top-1 accuracy (for the
*online* classifier) in roughly 4 hours. Note that the above parameters were not
finely tuned and may not be optimal.
## Additional checkpoints
Alongside with the pretrained ResNet-50 and ResNet-200 2x, we provide the
following checkpoints from our ablation study. They all correspond to a
ResNet-50 1x pre-trained over 300 epochs and were randomly selected within the
three seeds; file size is roughly 640MB each.
- [Baseline](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_baseline.pkl)
- Smaller batch sizes (figure 3a):
- [Batch size 2048](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_2048.pkl)
- [Batch size 1024](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_1024.pkl)
- [Batch size 512](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_512.pkl)
- [Batch size 256](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_256.pkl)
- [Batch size 128](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_128.pkl)
- [Batch size 64](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_64.pkl)
- Ablation on transformations (figure 3b):
- [Remove grayscale](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_no_grayscale.pkl)
- [Remove color](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_no_color.pkl)
- [Crop and blur only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_and_blur_only.pkl)
- [Crop only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_only.pkl)
- (from Table 18) [Crop and color only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_and_color_only.pkl)
## License
While the code is licensed under the Apache 2.0 License, the checkpoints weights
are made available for non-commercial use only under the terms of the
Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
license. You can find details at:
https://creativecommons.org/licenses/by-nc/4.0/legalcode.
+47 -47
View File
@@ -56,17 +56,17 @@ class ByolExperiment:
def __init__(
self,
random_seed,
num_classes,
batch_size,
max_steps,
enable_double_transpose,
base_target_ema,
network_config,
optimizer_config,
lr_schedule_config,
evaluation_config,
checkpointing_config):
random_seed: int,
num_classes: int,
batch_size: int,
max_steps: int,
enable_double_transpose: bool,
base_target_ema: float,
network_config: Mapping[Text, Any],
optimizer_config: Mapping[Text, Any],
lr_schedule_config: Mapping[Text, Any],
evaluation_config: Mapping[Text, Any],
checkpointing_config: Mapping[Text, Any]):
"""Constructs the experiment.
Args:
@@ -115,15 +115,15 @@ class ByolExperiment:
def _forward(
self,
inputs,
projector_hidden_size,
projector_output_size,
predictor_hidden_size,
encoder_class,
encoder_config,
bn_config,
is_training,
):
inputs: dataset.Batch,
projector_hidden_size: int,
projector_output_size: int,
predictor_hidden_size: int,
encoder_class: Text,
encoder_config: Mapping[Text, Any],
bn_config: Mapping[Text, Any],
is_training: bool,
) -> Mapping[Text, jnp.ndarray]:
"""Forward application of byol's architecture.
Args:
@@ -163,7 +163,7 @@ class ByolExperiment:
classifier = hk.Linear(
output_size=self._num_classes, name='classifier')
def apply_once_fn(images, suffix = ''):
def apply_once_fn(images: jnp.ndarray, suffix: Text = ''):
images = dataset.normalize_images(images)
embedding = net(images, is_training=is_training)
@@ -186,7 +186,7 @@ class ByolExperiment:
else:
return apply_once_fn(inputs['images'], '')
def _optimizer(self, learning_rate):
def _optimizer(self, learning_rate: float) -> optax.GradientTransformation:
"""Build optimizer from config."""
return optimizers.lars(
learning_rate,
@@ -196,13 +196,13 @@ class ByolExperiment:
def loss_fn(
self,
online_params,
target_params,
online_state,
target_state,
rng,
inputs,
):
online_params: hk.Params,
target_params: hk.Params,
online_state: hk.State,
target_state: hk.Params,
rng: jnp.ndarray,
inputs: dataset.Batch,
) -> Tuple[jnp.ndarray, Tuple[Mapping[Text, hk.State], LogsDict]]:
"""Compute BYOL's loss function.
Args:
@@ -292,11 +292,11 @@ class ByolExperiment:
def _update_fn(
self,
byol_state,
global_step,
rng,
inputs,
):
byol_state: _ByolExperimentState,
global_step: jnp.ndarray,
rng: jnp.ndarray,
inputs: dataset.Batch,
) -> Tuple[_ByolExperimentState, LogsDict]:
"""Update online and target parameters.
Args:
@@ -352,9 +352,9 @@ class ByolExperiment:
def _make_initial_state(
self,
rng,
dummy_input,
):
rng: jnp.ndarray,
dummy_input: dataset.Batch,
) -> _ByolExperimentState:
"""BYOL's _ByolExperimentState initialization.
Args:
@@ -393,8 +393,8 @@ class ByolExperiment:
)
def step(self, *,
global_step,
rng):
global_step: jnp.ndarray,
rng: jnp.ndarray) -> Mapping[Text, np.ndarray]:
"""Performs a single training step."""
if self._train_input is None:
self._initialize_train()
@@ -410,11 +410,11 @@ class ByolExperiment:
return helpers.get_first(scalars)
def save_checkpoint(self, step, rng):
def save_checkpoint(self, step: int, rng: jnp.ndarray):
self._checkpointer.maybe_save_checkpoint(
self._byol_state, step=step, rng=rng, is_final=step >= self._max_steps)
def load_checkpoint(self):
def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]:
checkpoint_data = self._checkpointer.maybe_load_checkpoint()
if checkpoint_data is None:
return None
@@ -444,7 +444,7 @@ class ByolExperiment:
self._byol_state = init_byol(rng=init_rng, dummy_input=inputs)
def _build_train_input(self):
def _build_train_input(self) -> Generator[dataset.Batch, None, None]:
"""Loads the (infinitely looping) dataset iterator."""
num_devices = jax.device_count()
global_batch_size = self._batch_size
@@ -463,10 +463,10 @@ class ByolExperiment:
def _eval_batch(
self,
params,
state,
batch,
):
params: hk.Params,
state: hk.State,
batch: dataset.Batch,
) -> Mapping[Text, jnp.ndarray]:
"""Evaluates a batch.
Args:
@@ -494,7 +494,7 @@ class ByolExperiment:
'top5_accuracy': top5_correct,
}
def _eval_epoch(self, subset, batch_size):
def _eval_epoch(self, subset: Text, batch_size: int):
"""Evaluates an epoch."""
num_samples = 0.
summed_scalars = None
+1 -1
View File
@@ -23,7 +23,7 @@ _WD_PRESETS = {40: 1e-6, 100: 1e-6, 300: 1e-6, 1000: 1.5e-6}
_EMA_PRESETS = {40: 0.97, 100: 0.99, 300: 0.99, 1000: 0.996}
def get_config(num_epochs, batch_size):
def get_config(num_epochs: int, batch_size: int):
"""Return config object, containing all hyperparameters for training."""
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
+1 -1
View File
@@ -19,7 +19,7 @@ from typing import Text
from byol.utils import dataset
def get_config(checkpoint_to_evaluate, batch_size):
def get_config(checkpoint_to_evaluate: Text, batch_size: int):
"""Return config object for training."""
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
+47 -47
View File
@@ -53,19 +53,19 @@ class EvalExperiment:
def __init__(
self,
random_seed,
num_classes,
batch_size,
max_steps,
enable_double_transpose,
checkpoint_to_evaluate,
allow_train_from_scratch,
freeze_backbone,
network_config,
optimizer_config,
lr_schedule_config,
evaluation_config,
checkpointing_config):
random_seed: int,
num_classes: int,
batch_size: int,
max_steps: int,
enable_double_transpose: bool,
checkpoint_to_evaluate: Optional[Text],
allow_train_from_scratch: bool,
freeze_backbone: bool,
network_config: Mapping[Text, Any],
optimizer_config: Mapping[Text, Any],
lr_schedule_config: Mapping[Text, Any],
evaluation_config: Mapping[Text, Any],
checkpointing_config: Mapping[Text, Any]):
"""Constructs the experiment.
Args:
@@ -125,12 +125,12 @@ class EvalExperiment:
def _backbone_fn(
self,
inputs,
encoder_class,
encoder_config,
bn_decay_rate,
is_training,
):
inputs: dataset.Batch,
encoder_class: Text,
encoder_config: Mapping[Text, Any],
bn_decay_rate: float,
is_training: bool,
) -> jnp.ndarray:
"""Forward of the encoder (backbone)."""
bn_config = {'decay_rate': bn_decay_rate}
encoder = getattr(networks, encoder_class)
@@ -146,8 +146,8 @@ class EvalExperiment:
def _classif_fn(
self,
embeddings,
):
embeddings: jnp.ndarray,
) -> jnp.ndarray:
classifier = hk.Linear(output_size=self._num_classes)
return classifier(embeddings)
@@ -159,8 +159,8 @@ class EvalExperiment:
#
def step(self, *,
global_step,
rng):
global_step: jnp.ndarray,
rng: jnp.ndarray) -> Mapping[Text, np.ndarray]:
"""Performs a single training step."""
if self._train_input is None:
@@ -173,12 +173,12 @@ class EvalExperiment:
scalars = helpers.get_first(scalars)
return scalars
def save_checkpoint(self, step, rng):
def save_checkpoint(self, step: int, rng: jnp.ndarray):
self._checkpointer.maybe_save_checkpoint(
self._experiment_state, step=step, rng=rng,
is_final=step >= self._max_steps)
def load_checkpoint(self):
def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]:
checkpoint_data = self._checkpointer.maybe_load_checkpoint()
if checkpoint_data is None:
return None
@@ -253,11 +253,11 @@ class EvalExperiment:
def _make_initial_state(
self,
rng,
dummy_input,
backbone_params,
backbone_state,
):
rng: jnp.ndarray,
dummy_input: dataset.Batch,
backbone_params: hk.Params,
backbone_state: hk.Params,
) -> _EvalExperimentState:
"""_EvalExperimentState initialization."""
# Initialize the backbone params
@@ -279,7 +279,7 @@ class EvalExperiment:
classif_opt_state=classif_opt_state,
)
def _build_train_input(self):
def _build_train_input(self) -> Generator[dataset.Batch, None, None]:
"""See base class."""
num_devices = jax.device_count()
global_batch_size = self._batch_size
@@ -296,17 +296,17 @@ class EvalExperiment:
transpose=self._should_transpose_images(),
batch_dims=[jax.local_device_count(), per_device_batch_size])
def _optimizer(self, learning_rate):
def _optimizer(self, learning_rate: float):
"""Build optimizer from config."""
return optax.sgd(learning_rate, **self._optimizer_config)
def _loss_fn(
self,
backbone_params,
classif_params,
backbone_state,
inputs,
):
backbone_params: hk.Params,
classif_params: hk.Params,
backbone_state: hk.State,
inputs: dataset.Batch,
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, hk.State]]:
"""Compute the classification loss function.
Args:
@@ -333,10 +333,10 @@ class EvalExperiment:
def _update_func(
self,
experiment_state,
global_step,
inputs,
):
experiment_state: _EvalExperimentState,
global_step: jnp.ndarray,
inputs: dataset.Batch,
) -> Tuple[_EvalExperimentState, LogsDict]:
"""Applies an update to parameters and returns new state."""
# This function computes the gradient of the first output of loss_fn and
# passes through the other arguments unchanged.
@@ -421,11 +421,11 @@ class EvalExperiment:
def _eval_batch(
self,
backbone_params,
classif_params,
backbone_state,
inputs,
):
backbone_params: hk.Params,
classif_params: hk.Params,
backbone_state: hk.State,
inputs: dataset.Batch,
) -> LogsDict:
"""Evaluates a batch."""
embeddings, backbone_state = self.forward_backbone.apply(
backbone_params, backbone_state, inputs, is_training=False)
@@ -441,7 +441,7 @@ class EvalExperiment:
'top5_accuracy': top5_correct
}
def _eval_epoch(self, subset, batch_size):
def _eval_epoch(self, subset: Text, batch_size: int):
"""Evaluates an epoch."""
num_samples = 0.
summed_scalars = None
+2 -2
View File
@@ -47,7 +47,7 @@ Experiment = Union[
Type[eval_experiment.EvalExperiment]]
def train_loop(experiment_class, config):
def train_loop(experiment_class: Experiment, config: Mapping[Text, Any]):
"""The main training loop.
This loop periodically saves a checkpoint to be evaluated in the eval_loop.
@@ -95,7 +95,7 @@ def train_loop(experiment_class, config):
experiment.save_checkpoint(step, rng)
def eval_loop(experiment_class, config):
def eval_loop(experiment_class: Experiment, config: Mapping[Text, Any]):
"""The main evaluation loop.
This loop periodically loads a checkpoint and evaluates its performance on the
+7 -7
View File
@@ -66,14 +66,14 @@ augment_config = dict(
))
def postprocess(inputs, rng):
def postprocess(inputs: JaxBatch, rng: jnp.ndarray):
"""Apply the image augmentations to crops in inputs (view1 and view2)."""
def _postprocess_image(
images,
rng,
presets,
):
images: jnp.ndarray,
rng: jnp.ndarray,
presets: ConfigDict,
) -> JaxBatch:
"""Applies augmentations in post-processing.
Args:
@@ -144,7 +144,7 @@ def _gaussian_blur_single_image(image, kernel_size, padding, sigma):
blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels])
expand_batch_dim = len(image.shape) == 3
if expand_batch_dim:
image = image[jnp.newaxis, Ellipsis]
image = image[jnp.newaxis, ...]
blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding)
blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding)
blurred = jnp.squeeze(blurred, axis=0)
@@ -284,7 +284,7 @@ def _random_hue(rgb_tuple, rng, max_delta):
def _to_grayscale(image):
rgb_weights = jnp.array([0.2989, 0.5870, 0.1140])
grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[Ellipsis, jnp.newaxis]
grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[..., jnp.newaxis]
return jnp.tile(grayscale, (1, 1, 3)) # Back to 3 channels.
+9 -9
View File
@@ -31,10 +31,10 @@ class Checkpointer:
def __init__(
self,
use_checkpointing,
checkpoint_dir,
save_checkpoint_interval,
filename):
use_checkpointing: bool,
checkpoint_dir: Text,
save_checkpoint_interval: int,
filename: Text):
if (not use_checkpointing or
checkpoint_dir is None or
save_checkpoint_interval <= 0):
@@ -51,10 +51,10 @@ class Checkpointer:
def maybe_save_checkpoint(
self,
experiment_state,
step,
rng,
is_final):
experiment_state: Mapping[Text, jnp.ndarray],
step: int,
rng: jnp.ndarray,
is_final: bool):
"""Saves a checkpoint if enough time has passed since the previous one."""
current_time = time.time()
if (not self._checkpoint_enabled or
@@ -80,7 +80,7 @@ class Checkpointer:
self._last_checkpoint_time = current_time
def maybe_load_checkpoint(
self):
self) -> Union[Tuple[Mapping[Text, jnp.ndarray], int, jnp.ndarray], None]:
"""Loads a checkpoint if any is found."""
checkpoint_data = load_checkpoint(self._checkpoint_path)
if checkpoint_data is None:
+17 -17
View File
@@ -34,7 +34,7 @@ class Split(enum.Enum):
TEST = 4
@classmethod
def from_string(cls, name):
def from_string(cls, name: Text) -> 'Split':
return {
'TRAIN': Split.TRAIN,
'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
@@ -60,7 +60,7 @@ class PreprocessMode(enum.Enum):
EVAL = 3 # Generates a single center crop.
def normalize_images(images):
def normalize_images(images: jnp.ndarray) -> jnp.ndarray:
"""Normalize the image using ImageNet statistics."""
mean_rgb = (0.485, 0.456, 0.406)
stddev_rgb = (0.229, 0.224, 0.225)
@@ -69,12 +69,12 @@ def normalize_images(images):
return normed_images
def load(split,
def load(split: Split,
*,
preprocess_mode,
batch_dims,
transpose = False,
allow_caching = False):
preprocess_mode: PreprocessMode,
batch_dims: Sequence[int],
transpose: bool = False,
allow_caching: bool = False) -> Generator[Batch, None, None]:
"""Loads the given split of the dataset."""
start, end = _shard(split, jax.host_id(), jax.host_count())
@@ -153,7 +153,7 @@ def load(split,
yield from tfds.as_numpy(ds)
def _to_tfds_split(split):
def _to_tfds_split(split: Split) -> tfds.Split:
"""Returns the TFDS split appropriately sharded."""
# NOTE: Imagenet did not release labels for the test split used in the
# competition, we consider the VALID split the TEST split and reserve
@@ -165,7 +165,7 @@ def _to_tfds_split(split):
return tfds.Split.VALIDATION
def _shard(split, shard_index, num_shards):
def _shard(split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]:
"""Returns [start, end) for the given shard index."""
assert shard_index < num_shards
arange = np.arange(split.num_examples)
@@ -180,9 +180,9 @@ def _shard(split, shard_index, num_shards):
def _preprocess_image(
image_bytes,
mode,
):
image_bytes: tf.Tensor,
mode: PreprocessMode,
) -> tf.Tensor:
"""Returns processed and resized images."""
if mode is PreprocessMode.PRETRAIN:
image = _decode_and_random_crop(image_bytes)
@@ -201,7 +201,7 @@ def _preprocess_image(
return image
def _decode_and_random_crop(image_bytes):
def _decode_and_random_crop(image_bytes: tf.Tensor) -> tf.Tensor:
"""Make a random crop of 224."""
img_size = tf.image.extract_jpeg_shape(image_bytes)
area = tf.cast(img_size[1] * img_size[0], tf.float32)
@@ -231,7 +231,7 @@ def _decode_and_random_crop(image_bytes):
return image
def transpose_images(batch):
def transpose_images(batch: Batch):
"""Transpose images for TPU training.."""
new_batch = dict(batch) # Avoid mutating in place.
if 'images' in batch:
@@ -243,9 +243,9 @@ def transpose_images(batch):
def _decode_and_center_crop(
image_bytes,
jpeg_shape = None,
):
image_bytes: tf.Tensor,
jpeg_shape: Optional[tf.Tensor] = None,
) -> tf.Tensor:
"""Crops to center of image with padding then scales."""
if jpeg_shape is None:
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
+14 -14
View File
@@ -21,11 +21,11 @@ import jax.numpy as jnp
def topk_accuracy(
logits,
labels,
topk,
ignore_label_above = None,
):
logits: jnp.ndarray,
labels: jnp.ndarray,
topk: int,
ignore_label_above: Optional[int] = None,
) -> jnp.ndarray:
"""Top-num_codes accuracy."""
assert len(labels.shape) == 1, 'topk expects 1d int labels.'
assert len(logits.shape) == 2, 'topk expects 2d logits.'
@@ -42,10 +42,10 @@ def topk_accuracy(
def softmax_cross_entropy(
logits,
labels,
reduction = 'mean',
):
logits: jnp.ndarray,
labels: jnp.ndarray,
reduction: Optional[Text] = 'mean',
) -> jnp.ndarray:
"""Computes softmax cross entropy given logits and one-hot class labels.
Args:
@@ -72,10 +72,10 @@ def softmax_cross_entropy(
def l2_normalize(
x,
axis = None,
epsilon = 1e-12,
):
x: jnp.ndarray,
axis: Optional[int] = None,
epsilon: float = 1e-12,
) -> jnp.ndarray:
"""l2 normalize a tensor on an axis with numerical stability."""
square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)
x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))
@@ -106,7 +106,7 @@ def l2_weight_regularizer(params):
return 0.5 * l2_norm
def regression_loss(x, y):
def regression_loss(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Byol's regression loss. This is a simple cosine similarity."""
normed_x, normed_y = l2_normalize(x, axis=-1), l2_normalize(y, axis=-1)
return jnp.sum((normed_x - normed_y)**2, axis=-1)
+49 -49
View File
@@ -27,17 +27,17 @@ class MLP(hk.Module):
def __init__(
self,
name,
hidden_size,
output_size,
bn_config,
name: Text,
hidden_size: int,
output_size: int,
bn_config: Mapping[Text, Any],
):
super().__init__(name=name)
self._hidden_size = hidden_size
self._output_size = output_size
self._bn_config = bn_config
def __call__(self, inputs, is_training):
def __call__(self, inputs: jnp.ndarray, is_training: bool) -> jnp.ndarray:
out = hk.Linear(output_size=self._hidden_size, with_bias=True)(inputs)
out = hk.BatchNorm(**self._bn_config)(out, is_training=is_training)
out = jax.nn.relu(out)
@@ -55,15 +55,15 @@ class ResNetTorso(hk.Module):
def __init__(
self,
blocks_per_group,
num_classes = None,
bn_config = None,
resnet_v2 = False,
bottleneck = True,
channels_per_group = (256, 512, 1024, 2048),
use_projection = (True, True, True, True),
width_multiplier = 1,
name = None,
blocks_per_group: Sequence[int],
num_classes: int = None,
bn_config: Optional[Mapping[str, float]] = None,
resnet_v2: bool = False,
bottleneck: bool = True,
channels_per_group: Sequence[int] = (256, 512, 1024, 2048),
use_projection: Sequence[bool] = (True, True, True, True),
width_multiplier: int = 1,
name: Optional[str] = None,
):
"""Constructs a ResNet model.
@@ -155,11 +155,11 @@ class TinyResNet(ResNetTorso):
"""Tiny resnet for local runs and tests."""
def __init__(self,
num_classes = None,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
num_classes: Optional[int] = None,
bn_config: Optional[Mapping[str, float]] = None,
resnet_v2: bool = False,
width_multiplier: int = 1,
name: Optional[str] = None):
"""Constructs a ResNet model.
Args:
@@ -185,11 +185,11 @@ class ResNet18(ResNetTorso):
"""ResNet18."""
def __init__(self,
num_classes = None,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
num_classes: Optional[int] = None,
bn_config: Optional[Mapping[str, float]] = None,
resnet_v2: bool = False,
width_multiplier: int = 1,
name: Optional[str] = None):
"""Constructs a ResNet model.
Args:
@@ -215,11 +215,11 @@ class ResNet34(ResNetTorso):
"""ResNet34."""
def __init__(self,
num_classes,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
num_classes: Optional[int],
bn_config: Optional[Mapping[str, float]] = None,
resnet_v2: bool = False,
width_multiplier: int = 1,
name: Optional[str] = None):
"""Constructs a ResNet model.
Args:
@@ -245,11 +245,11 @@ class ResNet50(ResNetTorso):
"""ResNet50."""
def __init__(self,
num_classes = None,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
num_classes: Optional[int] = None,
bn_config: Optional[Mapping[str, float]] = None,
resnet_v2: bool = False,
width_multiplier: int = 1,
name: Optional[str] = None):
"""Constructs a ResNet model.
Args:
@@ -274,11 +274,11 @@ class ResNet101(ResNetTorso):
"""ResNet101."""
def __init__(self,
num_classes,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
num_classes: Optional[int],
bn_config: Optional[Mapping[str, float]] = None,
resnet_v2: bool = False,
width_multiplier: int = 1,
name: Optional[str] = None):
"""Constructs a ResNet model.
Args:
@@ -303,11 +303,11 @@ class ResNet152(ResNetTorso):
"""ResNet152."""
def __init__(self,
num_classes,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
num_classes: Optional[int],
bn_config: Optional[Mapping[str, float]] = None,
resnet_v2: bool = False,
width_multiplier: int = 1,
name: Optional[str] = None):
"""Constructs a ResNet model.
Args:
@@ -332,11 +332,11 @@ class ResNet200(ResNetTorso):
"""ResNet200."""
def __init__(self,
num_classes,
bn_config = None,
resnet_v2 = False,
width_multiplier = 1,
name = None):
num_classes: Optional[int],
bn_config: Optional[Mapping[str, float]] = None,
resnet_v2: bool = False,
width_multiplier: int = 1,
name: Optional[str] = None):
"""Constructs a ResNet model.
Args:
+29 -29
View File
@@ -27,7 +27,7 @@ import tree as nest
FilterFn = Callable[[Tuple[Any], jnp.ndarray], jnp.ndarray]
def exclude_bias_and_norm(path, val):
def exclude_bias_and_norm(path: Tuple[Any], val: jnp.ndarray) -> jnp.ndarray:
"""Filter to exclude biaises and normalizations weights."""
del val
if path[-1] == "b" or "norm" in path[-2]:
@@ -35,10 +35,10 @@ def exclude_bias_and_norm(path, val):
return True
def _partial_update(updates,
new_updates,
params,
filter_fn = None):
def _partial_update(updates: optax.Updates,
new_updates: optax.Updates,
params: optax.Params,
filter_fn: Optional[FilterFn] = None) -> optax.Updates:
"""Returns new_update for params which filter_fn is True else updates."""
if filter_fn is None:
@@ -47,7 +47,7 @@ def _partial_update(updates,
wrapped_filter_fn = lambda x, y: jnp.array(filter_fn(x, y))
params_to_filter = nest.map_structure_with_path(wrapped_filter_fn, params)
def _update_fn(g, t, m):
def _update_fn(g: jnp.ndarray, t: jnp.ndarray, m: jnp.ndarray) -> jnp.ndarray:
m = m.astype(g.dtype)
return g * (1. - m) + t * m
@@ -59,9 +59,9 @@ class ScaleByLarsState(NamedTuple):
def scale_by_lars(
momentum = 0.9,
eta = 0.001,
filter_fn = None):
momentum: float = 0.9,
eta: float = 0.001,
filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation:
"""Rescales updates according to the LARS algorithm.
Does not include weight decay.
@@ -77,17 +77,17 @@ def scale_by_lars(
An (init_fn, update_fn) tuple.
"""
def init_fn(params):
def init_fn(params: optax.Params) -> ScaleByLarsState:
mu = jax.tree_multimap(jnp.zeros_like, params) # momentum
return ScaleByLarsState(mu=mu)
def update_fn(updates, state,
params):
def update_fn(updates: optax.Updates, state: ScaleByLarsState,
params: optax.Params) -> Tuple[optax.Updates, ScaleByLarsState]:
def lars_adaptation(
update,
param,
):
update: jnp.ndarray,
param: jnp.ndarray,
) -> jnp.ndarray:
param_norm = jnp.linalg.norm(param)
update_norm = jnp.linalg.norm(update)
return update * jnp.where(
@@ -110,8 +110,8 @@ class AddWeightDecayState(NamedTuple):
def add_weight_decay(
weight_decay,
filter_fn = None):
weight_decay: float,
filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation:
"""Adds a weight decay to the update.
Args:
@@ -122,14 +122,14 @@ def add_weight_decay(
An (init_fn, update_fn) tuple.
"""
def init_fn(_):
def init_fn(_) -> AddWeightDecayState:
return AddWeightDecayState()
def update_fn(
updates,
state,
params,
):
updates: optax.Updates,
state: AddWeightDecayState,
params: optax.Params,
) -> Tuple[optax.Updates, AddWeightDecayState]:
new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
params)
new_updates = _partial_update(updates, new_updates, params, filter_fn)
@@ -142,13 +142,13 @@ LarsState = List # Type for the lars optimizer
def lars(
learning_rate,
weight_decay = 0.,
momentum = 0.9,
eta = 0.001,
weight_decay_filter = None,
lars_adaptation_filter = None,
):
learning_rate: float,
weight_decay: float = 0.,
momentum: float = 0.9,
eta: float = 0.001,
weight_decay_filter: Optional[FilterFn] = None,
lars_adaptation_filter: Optional[FilterFn] = None,
) -> optax.GradientTransformation:
"""Creates lars optimizer with weight decay.
References:
+11 -11
View File
@@ -17,18 +17,18 @@
import jax.numpy as jnp
def target_ema(global_step,
base_ema,
max_steps):
def target_ema(global_step: jnp.ndarray,
base_ema: float,
max_steps: int) -> jnp.ndarray:
decay = _cosine_decay(global_step, max_steps, 1.)
return 1. - (1. - base_ema) * decay
def learning_schedule(global_step,
batch_size,
base_learning_rate,
total_steps,
warmup_steps):
def learning_schedule(global_step: jnp.ndarray,
batch_size: int,
base_learning_rate: float,
total_steps: int,
warmup_steps: int) -> float:
"""Cosine learning rate scheduler."""
# Compute LR & Scaled LR
scaled_lr = base_learning_rate * batch_size / 256.
@@ -43,9 +43,9 @@ def learning_schedule(global_step,
scaled_lr))
def _cosine_decay(global_step,
max_steps,
initial_value):
def _cosine_decay(global_step: jnp.ndarray,
max_steps: int,
initial_value: float) -> jnp.ndarray:
"""Simple implementation of cosine decay from TF1."""
global_step = jnp.minimum(global_step, max_steps)
cosine_decay_value = 0.5 * (1 + jnp.cos(jnp.pi * global_step / max_steps))
+2 -2
View File
@@ -46,7 +46,7 @@ def make_so_tangent(q):
for j in range(i+1, n):
a[i, j] = 1
a[j, i] = -1
dq[Ellipsis, ii] = a @ q # tangent vectors are skew-symmetric matrix times Q
dq[..., ii] = a @ q # tangent vectors are skew-symmetric matrix times Q
a[i, j] = 0
a[j, i] = 0
ii += 1
@@ -106,7 +106,7 @@ def make_product_manifold(specification, npts):
spec_array[1, i] = dim
latent_dim += dim
dat = np.random.randn(npts, dim+1)
dat /= np.tile(np.sqrt(np.sum(dat**2, axis=1)[Ellipsis, None]),
dat /= np.tile(np.sqrt(np.sum(dat**2, axis=1)[..., None]),
[1, dim+1])
elif so_spec is not None:
dim = int(so_spec.group(1))
+1 -1
View File
@@ -16,7 +16,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function
import os
+10 -10
View File
@@ -20,7 +20,7 @@ The architecture and performance of this model is described in our publication:
from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function
import functools
@@ -35,10 +35,10 @@ from typing import Any, Dict, Text, Tuple, Optional
def make_graph_from_static_structure(
positions,
types,
box,
edge_threshold):
positions: tf.Tensor,
types: tf.Tensor,
box: tf.Tensor,
edge_threshold: float) -> graphs.GraphsTuple:
"""Returns graph representing the static structure of the glass.
Each particle is represented by a node in the graph. The particle type is
@@ -81,7 +81,7 @@ def make_graph_from_static_structure(
)
def apply_random_rotation(graph):
def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple:
"""Returns randomly rotated graph representation.
The rotation is an element of O(3) with rotation angles multiple of pi/2.
@@ -118,9 +118,9 @@ class GraphBasedModel(snt.AbstractModule):
"""
def __init__(self,
n_recurrences,
mlp_sizes,
mlp_kwargs = None,
n_recurrences: int,
mlp_sizes: Tuple[int],
mlp_kwargs: Optional[Dict[Text, Any]] = None,
name='Graph'):
"""Creates a new GraphBasedModel object.
@@ -168,7 +168,7 @@ class GraphBasedModel(snt.AbstractModule):
node_model_fn=final_model_fn,
edge_model_fn=model_fn)
def _build(self, graphs_tuple):
def _build(self, graphs_tuple: graphs.GraphsTuple) -> tf.Tensor:
"""Connects the model into the tensorflow graph.
Args:
+32 -32
View File
@@ -16,7 +16,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function
import collections
@@ -53,8 +53,8 @@ class ParticleType(enum.IntEnum):
def get_targets(
initial_positions,
trajectory_target_positions):
initial_positions: np.ndarray,
trajectory_target_positions: Sequence[np.ndarray]) -> np.ndarray:
"""Returns the averaged particle mobilities from the sampled trajectories.
Args:
@@ -70,9 +70,9 @@ def get_targets(
def load_data(
file_pattern,
time_index,
max_files_to_load = None):
file_pattern: Text,
time_index: int,
max_files_to_load: Optional[int] = None) -> List[GlassSimulationData]:
"""Returns a dictionary containing the training or test dataset.
The dictionary contains:
@@ -108,9 +108,9 @@ def load_data(
def get_loss_ops(
prediction,
target,
types):
prediction: tf.Tensor,
target: tf.Tensor,
types: tf.Tensor) -> LossCollection:
"""Returns L1/L2 loss and correlation for type A particles.
Args:
@@ -132,9 +132,9 @@ def get_loss_ops(
def get_minimize_op(
loss,
learning_rate,
grad_clip = None):
loss: tf.Tensor,
learning_rate: float,
grad_clip: Optional[float] = None) -> tf.Tensor:
"""Returns minimization operation.
Args:
@@ -152,8 +152,8 @@ def get_minimize_op(
def _log_stats_and_return_mean_correlation(
label,
stats):
label: Text,
stats: Sequence[LossCollection]) -> float:
"""Logs performance statistics and returns mean correlation.
Args:
@@ -171,20 +171,20 @@ def _log_stats_and_return_mean_correlation(
return np.mean([s.correlation for s in stats])
def train_model(train_file_pattern,
test_file_pattern,
max_files_to_load = None,
n_epochs = 1000,
time_index = 9,
augment_data_using_rotations = True,
learning_rate = 1e-4,
grad_clip = 1.0,
n_recurrences = 7,
mlp_sizes = (64, 64),
mlp_kwargs = None,
edge_threshold = 2.0,
measurement_store_interval = 1000,
checkpoint_path = None):
def train_model(train_file_pattern: Text,
test_file_pattern: Text,
max_files_to_load: Optional[int] = None,
n_epochs: int = 1000,
time_index: int = 9,
augment_data_using_rotations: bool = True,
learning_rate: float = 1e-4,
grad_clip: Optional[float] = 1.0,
n_recurrences: int = 7,
mlp_sizes: Tuple[int] = (64, 64),
mlp_kwargs: Optional[Dict[Text, Any]] = None,
edge_threshold: float = 2.0,
measurement_store_interval: int = 1000,
checkpoint_path: Optional[Text] = None) -> float:
"""Trains GraphModel using tensorflow.
Args:
@@ -325,10 +325,10 @@ def train_model(train_file_pattern,
return best_so_far
def apply_model(checkpoint_path,
file_pattern,
max_files_to_load = None,
time_index = 9):
def apply_model(checkpoint_path: Text,
file_pattern: Text,
max_files_to_load: Optional[int] = None,
time_index: int = 9) -> List[np.ndarray]:
"""Applies trained GraphModel using tensorflow.
Args:
+1 -1
View File
@@ -16,7 +16,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function
import os
+2 -2
View File
@@ -145,8 +145,8 @@ class _HierarchicalCore(snt.AbstractModule):
regularizers=self._regularizers,
)(decoder_features)
mu = mu_logsigma[Ellipsis, :latent_dim]
logsigma = mu_logsigma[Ellipsis, latent_dim:]
mu = mu_logsigma[..., :latent_dim]
logsigma = mu_logsigma[..., latent_dim:]
dist = tfd.MultivariateNormalDiag(loc=mu, scale_diag=tf.exp(logsigma))
distributions.append(dist)
+2 -2
View File
@@ -37,8 +37,8 @@ class ComponentDecoder(snt.AbstractModule):
pixel_params = self._pixel_decoder(z_flat).params
self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp")
mask_params = pixel_params[Ellipsis, 0:1]
pixel_params = pixel_params[Ellipsis, 1:]
mask_params = pixel_params[..., 0:1]
pixel_params = pixel_params[..., 1:]
output = MixtureParameters(
pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"),
+2 -2
View File
@@ -134,8 +134,8 @@ class LocScaleDistribution(DistributionModule):
n_channels = params.get_shape().as_list()[-1]
assert n_channels % 2 == 0
assert n_channels // 2 == self.output_shape[-1]
loc = params[Ellipsis, :n_channels // 2]
scale = params[Ellipsis, n_channels // 2:]
loc = params[..., :n_channels // 2]
scale = params[..., n_channels // 2:]
# apply activation functions
if self._scale != "fixed":
+3 -3
View File
@@ -84,7 +84,7 @@ class FactorRegressor(snt.AbstractModule):
for m in self._mapping:
with tf.name_scope(m.name):
assert m.name in latent, "{} not in {}".format(m.name, latent.keys())
pred = all_preds[Ellipsis, idx:idx + m.size]
pred = all_preds[..., idx:idx + m.size]
predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size))
idx += m.size
@@ -165,7 +165,7 @@ class FactorRegressor(snt.AbstractModule):
@staticmethod
def one_hot(f, nr_categories):
return tf.one_hot(tf.cast(f[Ellipsis, 0], tf.int32), depth=nr_categories)
return tf.one_hot(tf.cast(f[..., 0], tf.int32), depth=nr_categories)
@staticmethod
def angle_to_vector(theta):
@@ -194,7 +194,7 @@ def accuracy(labels, logits, assignment, mean_var_tot, num_vis):
pred = tf.argmax(logits, axis=-1, output_type=tf.int32)
labels = tf.argmax(labels, axis=-1, output_type=tf.int32)
correct = tf.cast(tf.equal(labels, pred), tf.float32)
return tf.reduce_sum(correct * assignment[Ellipsis, 0]) / num_vis
return tf.reduce_sum(correct * assignment[..., 0]) / num_vis
def r2(labels, pred, assignment, mean_var_tot, num_vis):
+10 -10
View File
@@ -325,12 +325,12 @@ class IODINE(snt.AbstractModule):
[get_components(xd) for xd in iterations["x_dist"]])
# metrics
tm = tf.transpose(true_mask[Ellipsis, 0], [0, 1, 3, 4, 2])
tm = tf.transpose(true_mask[..., 0], [0, 1, 3, 4, 2])
tm = tf.reshape(tf.tile(tm, sg["1, T, 1, 1, 1"]), sg["B * T, H * W, L"])
pm = tf.transpose(pred_mask[Ellipsis, 0], [0, 1, 3, 4, 2])
pm = tf.transpose(pred_mask[..., 0], [0, 1, 3, 4, 2])
pm = tf.reshape(pm, sg["B * T, H * W, K"])
ari = tf.reshape(adjusted_rand_index(tm, pm), sg["B, T"])
ari_nobg = tf.reshape(adjusted_rand_index(tm[Ellipsis, 1:], pm), sg["B, T"])
ari_nobg = tf.reshape(adjusted_rand_index(tm[..., 1:], pm), sg["B, T"])
mse = tf.reduce_mean(tf.square(recons - image[:, None]), axis=[2, 3, 4, 5])
@@ -387,7 +387,7 @@ class IODINE(snt.AbstractModule):
factor_info["assignment"].append(fass)
for k in fpred:
factor_info["predictions"][k].append(
tf.reduce_sum(fpred[k] * fass[Ellipsis, None], axis=2))
tf.reduce_sum(fpred[k] * fass[..., None], axis=2))
factor_info["metrics"][k].append(fscalars[k])
info["losses"]["factor"] = sg.guard(tf.stack(factor_info["loss"]), "T")
@@ -496,7 +496,7 @@ class IODINE(snt.AbstractModule):
@staticmethod
def _get_mask_posterior(out_dist, img):
p_comp = out_dist.components_distribution.prob(img[Ellipsis, tf.newaxis, :])
p_comp = out_dist.components_distribution.prob(img[..., tf.newaxis, :])
posterior = p_comp / (tf.reduce_sum(p_comp, axis=-1, keepdims=True) + 1e-6)
return tf.transpose(posterior, [0, 4, 2, 3, 1])
@@ -506,7 +506,7 @@ class IODINE(snt.AbstractModule):
dzp, dxp, dmp = tf.gradients(loss, [zp, out_params.pixel, out_params.mask])
log_prob = sg.guard(
out_dist.log_prob(img)[Ellipsis, tf.newaxis], "B, 1, H, W, 1")
out_dist.log_prob(img)[..., tf.newaxis], "B, 1, H, W, 1")
counterfactual_log_probs = []
for k in range(0, self.num_components):
@@ -515,7 +515,7 @@ class IODINE(snt.AbstractModule):
pixel = tf.concat([out_params.pixel[:, :k], out_params.pixel[:, k + 1:]],
axis=1)
out_dist_k = self.output_dist(pixel, mask)
log_prob_k = out_dist_k.log_prob(img)[Ellipsis, tf.newaxis]
log_prob_k = out_dist_k.log_prob(img)[..., tf.newaxis]
counterfactual_log_probs.append(log_prob_k)
counterfactual = log_prob - tf.concat(counterfactual_log_probs, axis=1)
@@ -608,7 +608,7 @@ class IODINE(snt.AbstractModule):
x_basis = tf.cos(valx * freqs[None, None, None, None, :, None])
y_basis = tf.cos(valy * freqs[None, None, None, None, None, :])
xy_basis = tf.reshape(x_basis * y_basis, self._sg["1, 1, H, W, F*F"])
coords = tf.tile(xy_basis, self._sg["B, 1, 1, 1, 1"])[Ellipsis, 1:]
coords = tf.tile(xy_basis, self._sg["B, 1, 1, 1, 1"])[..., 1:]
return coords
else:
raise KeyError('Unknown coord_type: "{}"'.format(self.coord_type))
@@ -632,7 +632,7 @@ class IODINE(snt.AbstractModule):
# ########## Mask Monitoring #######
if "mask" in data:
true_mask = self._sg.guard(data["mask"], "B, T, L, H, W, 1")
true_mask = tf.transpose(true_mask[:, -1, Ellipsis, 0], [0, 2, 3, 1])
true_mask = tf.transpose(true_mask[:, -1, ..., 0], [0, 2, 3, 1])
true_mask = self._sg.reshape(true_mask, "B, H*W, L")
else:
true_mask = None
@@ -648,6 +648,6 @@ class IODINE(snt.AbstractModule):
adjusted_rand_index(true_mask, pred_mask))
scalars["loss/ari_nobg"] = tf.reduce_mean(
adjusted_rand_index(true_mask[Ellipsis, 1:], pred_mask))
adjusted_rand_index(true_mask[..., 1:], pred_mask))
return scalars
+1 -1
View File
@@ -258,7 +258,7 @@ class BroadcastConv(snt.AbstractModule):
x_basis = tf.cos(valx * freqs[None, None, None, :, None])
y_basis = tf.cos(valy * freqs[None, None, None, None, :])
xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"])
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[Ellipsis, 1:]
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[..., 1:]
return tf.concat([output, coords], axis=-1)
else:
raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
+4 -4
View File
@@ -83,7 +83,7 @@ def show_mask(m, ax):
@optional_clean_ax
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
return ax.matshow(
m[Ellipsis, 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
m[..., 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
@optional_clean_ax
@@ -115,7 +115,7 @@ def example_plot(rinfo,
show_img(image, ax=axes[0], color="#000000")
show_img(recons, ax=axes[1], color="#000000")
show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000")
show_mask(pred_mask[..., 0], ax=axes[2], color="#000000")
for k in range(K):
mask = pred_mask[k] if mask_components else None
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
@@ -145,7 +145,7 @@ def iterations_plot(rinfo, b=0, mask_components=False, size=2):
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
for t in range(T):
show_img(recons[t, 0], ax=axes[t, 0])
show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1])
show_mask(pred_mask[t, ..., 0], ax=axes[t, 1])
axes[t, 0].set_ylabel("iter {}".format(t))
for k in range(K):
mask = pred_mask[t, k] if mask_components else None
@@ -154,7 +154,7 @@ def iterations_plot(rinfo, b=0, mask_components=False, size=2):
axes[0, 0].set_title("Reconstruction")
axes[0, 1].set_title("Mask")
show_img(image[0], ax=axes[T, 0])
show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1])
show_mask(true_mask[0, ..., 0], ax=axes[T, 1])
vmin = np.min(pred_mask_logits[T - 1])
vmax = np.max(pred_mask_logits[T - 1])
+3 -3
View File
@@ -123,7 +123,7 @@ def construct_diagnostic_image(
components = tf.tile(components[:nr_images], [1, 1, 1, 1, 3])
if mask_components:
components *= masks[:nr_images, Ellipsis, tf.newaxis]
components *= masks[:nr_images, ..., tf.newaxis]
# Pad everything
no_pad, pad = (0, 0), (border_width, border_width)
@@ -415,7 +415,7 @@ def images_to_grid(
if max_grid_width is not None:
grid_width = min(max_grid_width, grid_width)
images = images[: grid_height * grid_width, Ellipsis]
images = images[: grid_height * grid_width, ...]
# Pad with extra blank frames if grid_height x grid_width is less than the
# number of frames provided.
@@ -460,7 +460,7 @@ def flatten_all_but_last(tensor, n_dims=1):
def ensure_3d(tensor):
if tensor.shape.ndims == 2:
return tensor[Ellipsis, None]
return tensor[..., None]
assert tensor.shape.ndims == 3
return tensor
+3 -3
View File
@@ -82,7 +82,7 @@ class Agent():
def option_values(values, policy):
return tf.tensordot(
values[:, policy, Ellipsis], self._policy_weights[policy], axes=[1, 0])
values[:, policy, ...], self._policy_weights[policy], axes=[1, 0])
# Placeholders for policy.
o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
@@ -103,8 +103,8 @@ class Agent():
qo_t = option_values(q_t, p)
a_t = tf.cast(tf.argmax(qo_t, axis=-1), tf.int32)
qa_tm1 = _batched_index(q_tm1[:, p, Ellipsis], a_tm1)
qa_t = _batched_index(q_t[:, p, Ellipsis], a_t)
qa_tm1 = _batched_index(q_tm1[:, p, ...], a_tm1)
qa_t = _batched_index(q_t[:, p, ...], a_t)
# TD error
g = additional_discount * tf.expand_dims(d_t, axis=-1)
+1 -1
View File
@@ -91,7 +91,7 @@ def make_face_model_dataset(
# Vertices are quantized. So convert to floats for input to face model
example['vertices'] = modules.dequantize_verts(vertices, quantization_bits)
example['vertices_mask'] = tf.ones_like(
example['vertices'][Ellipsis, 0], dtype=tf.float32)
example['vertices'][..., 0], dtype=tf.float32)
example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32)
return example
return ds.map(_face_model_map_fn)
+8 -8
View File
@@ -799,7 +799,7 @@ class VertexModel(snt.AbstractModule):
# Continuous vertex value embeddings
else:
vert_embeddings = tf.layers.dense(
dequantize_verts(vertices[Ellipsis, None], self.quantization_bits),
dequantize_verts(vertices[..., None], self.quantization_bits),
self.embedding_dim,
use_bias=True,
name='value_embeddings')
@@ -984,7 +984,7 @@ class VertexModel(snt.AbstractModule):
verts_dequantized = dequantize_verts(v, self.quantization_bits)
vertices = tf.reshape(verts_dequantized, [num_samples, -1, 3])
vertices = tf.stack(
[vertices[Ellipsis, 2], vertices[Ellipsis, 1], vertices[Ellipsis, 0]], axis=-1)
[vertices[..., 2], vertices[..., 1], vertices[..., 0]], axis=-1)
# Pad samples to max sample length. This is required in order to concatenate
# Samples across different replicator instances. Pad with stopping tokens
@@ -998,14 +998,14 @@ class VertexModel(snt.AbstractModule):
if recenter_verts:
vert_max = tf.reduce_max(
vertices - 1e10 * (1. - vertices_mask)[Ellipsis, None], axis=1,
vertices - 1e10 * (1. - vertices_mask)[..., None], axis=1,
keepdims=True)
vert_min = tf.reduce_min(
vertices + 1e10 * (1. - vertices_mask)[Ellipsis, None], axis=1,
vertices + 1e10 * (1. - vertices_mask)[..., None], axis=1,
keepdims=True)
vert_centers = 0.5 * (vert_max + vert_min)
vertices -= vert_centers
vertices *= vertices_mask[Ellipsis, None]
vertices *= vertices_mask[..., None]
if only_return_complete:
vertices = tf.boolean_mask(vertices, completed)
@@ -1247,7 +1247,7 @@ class FaceModel(snt.AbstractModule):
sequential_context_embeddings = (
vertex_embeddings *
tf.pad(context['vertices_mask'], [[0, 0], [2, 0]],
constant_values=1)[Ellipsis, None])
constant_values=1)[..., None])
else:
sequential_context_embeddings = None
return (vertex_embeddings, global_context_embedding,
@@ -1266,11 +1266,11 @@ class FaceModel(snt.AbstractModule):
embed_dim=self.embedding_dim,
initializers={'embeddings': tf.glorot_uniform_initializer},
densify_gradients=True,
name='coord_{}'.format(c))(verts_quantized[Ellipsis, c])
name='coord_{}'.format(c))(verts_quantized[..., c])
else:
vertex_embeddings = tf.layers.dense(
vertices, self.embedding_dim, use_bias=True, name='vertex_embeddings')
vertex_embeddings *= vertices_mask[Ellipsis, None]
vertex_embeddings *= vertices_mask[..., None]
# Pad vertex embeddings with learned embeddings for stopping and new face
# tokens
+16 -16
View File
@@ -106,7 +106,7 @@ TESTING_SUITE = [
ALL = TUNING_SUITE + TESTING_SUITE
def _decode_frames(pngs):
def _decode_frames(pngs: tf.Tensor):
"""Decode PNGs.
Args:
@@ -122,13 +122,13 @@ def _decode_frames(pngs):
return frames
def _make_reverb_sample(o_t,
a_t,
r_t,
d_t,
o_tp1,
a_tp1,
extras):
def _make_reverb_sample(o_t: tf.Tensor,
a_t: tf.Tensor,
r_t: tf.Tensor,
d_t: tf.Tensor,
o_tp1: tf.Tensor,
a_tp1: tf.Tensor,
extras: Dict[str, tf.Tensor]) -> reverb.ReplaySample:
"""Create Reverb sample with offline data.
Args:
@@ -151,8 +151,8 @@ def _make_reverb_sample(o_t,
return reverb.ReplaySample(info=info, data=data)
def _tf_example_to_reverb_sample(tf_example
):
def _tf_example_to_reverb_sample(tf_example: tf.train.Example
) -> reverb.ReplaySample:
"""Create a Reverb replay sample from a TF example."""
# Parse tf.Example.
@@ -184,11 +184,11 @@ def _tf_example_to_reverb_sample(tf_example
extras)
def dataset(path,
game,
run,
num_shards = 100,
shuffle_buffer_size = 100000):
def dataset(path: str,
game: str,
run: int,
num_shards: int = 100,
shuffle_buffer_size: int = 100000) -> tf.data.Dataset:
"""TF dataset of Atari SARSA tuples."""
path = os.path.join(path, f'{game}/run_{run}')
filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(num_shards)]
@@ -243,7 +243,7 @@ class AtariDopamineWrapper(dm_env.Environment):
return specs.DiscreteArray(self._env.action_space.n)
def environment(game):
def environment(game: str) -> dm_env.Environment:
"""Atari environment."""
env = atari_lib.create_atari_environment(game_name=game,
sticky_actions=True)
+9 -9
View File
@@ -773,15 +773,15 @@ def _padded_batch(example_ds, batch_size, shapes, drop_remainder=False):
drop_remainder=drop_remainder)
def dataset(root_path,
data_path,
shapes,
num_threads,
batch_size,
uint8_features = None,
num_shards = 100,
shuffle_buffer_size = 100000,
sarsa = True):
def dataset(root_path: str,
data_path: str,
shapes: Dict[str, Tuple[int]],
num_threads: int,
batch_size: int,
uint8_features: Set[str] = None,
num_shards: int = 100,
shuffle_buffer_size: int = 100000,
sarsa: bool = True) -> tf.data.Dataset:
"""Create tf dataset for training."""
uint8_features = uint8_features if uint8_features else {}
+17 -17
View File
@@ -55,7 +55,7 @@ DELIMITER = ':'
DEFAULT_NUM_TIMESTEPS = 1001
def _decombine_key(k, delimiter = DELIMITER):
def _decombine_key(k: str, delimiter: str = DELIMITER) -> Sequence[str]:
return k.split(delimiter)
@@ -79,7 +79,7 @@ def tf_example_to_feature_description(example,
def tree_deflatten_with_delimiter(
flat_dict, delimiter = DELIMITER):
flat_dict: Dict[str, Any], delimiter: str = DELIMITER) -> Dict[str, Any]:
"""De-flattens a dict to its originally nested structure.
Does the opposite of {combine_nested_keys(k) :v
@@ -102,12 +102,12 @@ def tree_deflatten_with_delimiter(
return dict(root)
def get_slice_of_nested(nested, start,
end):
def get_slice_of_nested(nested: Dict[str, Any], start: int,
end: int) -> Dict[str, Any]:
return tree.map_structure(lambda item: item[start:end], nested)
def repeat_last_and_append_to_nested(nested):
def repeat_last_and_append_to_nested(nested: Dict[str, Any]) -> Dict[str, Any]:
return tree.map_structure(
lambda item: tf.concat((item, item[-1:]), axis=0), nested)
@@ -133,13 +133,13 @@ def tf_example_to_reverb_sample(example,
return ret
def dataset(path,
combined_challenge,
domain,
task,
difficulty,
num_shards = 100,
shuffle_buffer_size = 100000):
def dataset(path: str,
combined_challenge: str,
domain: str,
task: str,
difficulty: str,
num_shards: int = 100,
shuffle_buffer_size: int = 100000) -> tf.data.Dataset:
"""TF dataset of RWRL SARSA tuples."""
path = os.path.join(
path,
@@ -178,11 +178,11 @@ def dataset(path,
def environment(
combined_challenge,
domain,
task,
log_output = None,
environment_kwargs = None):
combined_challenge: str,
domain: str,
task: str,
log_output: Optional[str] = None,
environment_kwargs: Optional[Dict[str, Any]] = None) -> dm_env.Environment:
"""RWRL environment."""
env = rwrl_envs.load(
domain_name=domain,
+2 -2
View File
@@ -106,8 +106,8 @@ class Transporter(snt.AbstractModule):
num_keypoints = image_a_keypoints["heatmaps"].shape[-1]
transported_features = image_a_features
for k in range(num_keypoints):
mask_a = image_a_keypoints["heatmaps"][Ellipsis, k, None]
mask_b = image_b_keypoints["heatmaps"][Ellipsis, k, None]
mask_a = image_a_keypoints["heatmaps"][..., k, None]
mask_b = image_b_keypoints["heatmaps"][..., k, None]
# suppress features from image a, around both keypoint locations.
transported_features = (