diff --git a/byol/README.md b/byol/README.md index a233b23..e404bb0 100644 --- a/byol/README.md +++ b/byol/README.md @@ -176,3 +176,37 @@ python -m byol.main_loop \ With these settings, BYOL should achieve ~92.3% top-1 accuracy (for the *online* classifier) in roughly 4 hours. Note that the above parameters were not finely tuned and may not be optimal. + + +## Additional checkpoints + +Alongside with the pretrained ResNet-50 and ResNet-200 2x, we provide the +following checkpoints from our ablation study. They all correspond to a +ResNet-50 1x pre-trained over 300 epochs and were randomly selected within the +three seeds; file size is roughly 640MB each. + +- [Baseline](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_baseline.pkl) + +- Smaller batch sizes (figure 3a): + - [Batch size 2048](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_2048.pkl) + - [Batch size 1024](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_1024.pkl) + - [Batch size 512](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_512.pkl) + - [Batch size 256](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_256.pkl) + - [Batch size 128](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_128.pkl) + - [Batch size 64](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_64.pkl) + +- Ablation on transformations (figure 3b): + - [Remove grayscale](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_no_grayscale.pkl) + - [Remove color](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_no_color.pkl) + - [Crop and blur only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_and_blur_only.pkl) + - [Crop only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_only.pkl) + - (from Table 18) [Crop and color only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_and_color_only.pkl) + + +## License + +While the code is licensed under the Apache 2.0 License, the checkpoints weights +are made available for non-commercial use only under the terms of the +Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) +license. You can find details at: +https://creativecommons.org/licenses/by-nc/4.0/legalcode. diff --git a/byol/byol_experiment.py b/byol/byol_experiment.py index ef6143f..d6b76ae 100644 --- a/byol/byol_experiment.py +++ b/byol/byol_experiment.py @@ -56,17 +56,17 @@ class ByolExperiment: def __init__( self, - random_seed, - num_classes, - batch_size, - max_steps, - enable_double_transpose, - base_target_ema, - network_config, - optimizer_config, - lr_schedule_config, - evaluation_config, - checkpointing_config): + random_seed: int, + num_classes: int, + batch_size: int, + max_steps: int, + enable_double_transpose: bool, + base_target_ema: float, + network_config: Mapping[Text, Any], + optimizer_config: Mapping[Text, Any], + lr_schedule_config: Mapping[Text, Any], + evaluation_config: Mapping[Text, Any], + checkpointing_config: Mapping[Text, Any]): """Constructs the experiment. Args: @@ -115,15 +115,15 @@ class ByolExperiment: def _forward( self, - inputs, - projector_hidden_size, - projector_output_size, - predictor_hidden_size, - encoder_class, - encoder_config, - bn_config, - is_training, - ): + inputs: dataset.Batch, + projector_hidden_size: int, + projector_output_size: int, + predictor_hidden_size: int, + encoder_class: Text, + encoder_config: Mapping[Text, Any], + bn_config: Mapping[Text, Any], + is_training: bool, + ) -> Mapping[Text, jnp.ndarray]: """Forward application of byol's architecture. Args: @@ -163,7 +163,7 @@ class ByolExperiment: classifier = hk.Linear( output_size=self._num_classes, name='classifier') - def apply_once_fn(images, suffix = ''): + def apply_once_fn(images: jnp.ndarray, suffix: Text = ''): images = dataset.normalize_images(images) embedding = net(images, is_training=is_training) @@ -186,7 +186,7 @@ class ByolExperiment: else: return apply_once_fn(inputs['images'], '') - def _optimizer(self, learning_rate): + def _optimizer(self, learning_rate: float) -> optax.GradientTransformation: """Build optimizer from config.""" return optimizers.lars( learning_rate, @@ -196,13 +196,13 @@ class ByolExperiment: def loss_fn( self, - online_params, - target_params, - online_state, - target_state, - rng, - inputs, - ): + online_params: hk.Params, + target_params: hk.Params, + online_state: hk.State, + target_state: hk.Params, + rng: jnp.ndarray, + inputs: dataset.Batch, + ) -> Tuple[jnp.ndarray, Tuple[Mapping[Text, hk.State], LogsDict]]: """Compute BYOL's loss function. Args: @@ -292,11 +292,11 @@ class ByolExperiment: def _update_fn( self, - byol_state, - global_step, - rng, - inputs, - ): + byol_state: _ByolExperimentState, + global_step: jnp.ndarray, + rng: jnp.ndarray, + inputs: dataset.Batch, + ) -> Tuple[_ByolExperimentState, LogsDict]: """Update online and target parameters. Args: @@ -352,9 +352,9 @@ class ByolExperiment: def _make_initial_state( self, - rng, - dummy_input, - ): + rng: jnp.ndarray, + dummy_input: dataset.Batch, + ) -> _ByolExperimentState: """BYOL's _ByolExperimentState initialization. Args: @@ -393,8 +393,8 @@ class ByolExperiment: ) def step(self, *, - global_step, - rng): + global_step: jnp.ndarray, + rng: jnp.ndarray) -> Mapping[Text, np.ndarray]: """Performs a single training step.""" if self._train_input is None: self._initialize_train() @@ -410,11 +410,11 @@ class ByolExperiment: return helpers.get_first(scalars) - def save_checkpoint(self, step, rng): + def save_checkpoint(self, step: int, rng: jnp.ndarray): self._checkpointer.maybe_save_checkpoint( self._byol_state, step=step, rng=rng, is_final=step >= self._max_steps) - def load_checkpoint(self): + def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]: checkpoint_data = self._checkpointer.maybe_load_checkpoint() if checkpoint_data is None: return None @@ -444,7 +444,7 @@ class ByolExperiment: self._byol_state = init_byol(rng=init_rng, dummy_input=inputs) - def _build_train_input(self): + def _build_train_input(self) -> Generator[dataset.Batch, None, None]: """Loads the (infinitely looping) dataset iterator.""" num_devices = jax.device_count() global_batch_size = self._batch_size @@ -463,10 +463,10 @@ class ByolExperiment: def _eval_batch( self, - params, - state, - batch, - ): + params: hk.Params, + state: hk.State, + batch: dataset.Batch, + ) -> Mapping[Text, jnp.ndarray]: """Evaluates a batch. Args: @@ -494,7 +494,7 @@ class ByolExperiment: 'top5_accuracy': top5_correct, } - def _eval_epoch(self, subset, batch_size): + def _eval_epoch(self, subset: Text, batch_size: int): """Evaluates an epoch.""" num_samples = 0. summed_scalars = None diff --git a/byol/configs/byol.py b/byol/configs/byol.py index 7088324..b983b86 100644 --- a/byol/configs/byol.py +++ b/byol/configs/byol.py @@ -23,7 +23,7 @@ _WD_PRESETS = {40: 1e-6, 100: 1e-6, 300: 1e-6, 1000: 1.5e-6} _EMA_PRESETS = {40: 0.97, 100: 0.99, 300: 0.99, 1000: 0.996} -def get_config(num_epochs, batch_size): +def get_config(num_epochs: int, batch_size: int): """Return config object, containing all hyperparameters for training.""" train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples diff --git a/byol/configs/eval.py b/byol/configs/eval.py index d01b6dc..8943e05 100644 --- a/byol/configs/eval.py +++ b/byol/configs/eval.py @@ -19,7 +19,7 @@ from typing import Text from byol.utils import dataset -def get_config(checkpoint_to_evaluate, batch_size): +def get_config(checkpoint_to_evaluate: Text, batch_size: int): """Return config object for training.""" train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples diff --git a/byol/eval_experiment.py b/byol/eval_experiment.py index 2755848..83b4fbf 100644 --- a/byol/eval_experiment.py +++ b/byol/eval_experiment.py @@ -53,19 +53,19 @@ class EvalExperiment: def __init__( self, - random_seed, - num_classes, - batch_size, - max_steps, - enable_double_transpose, - checkpoint_to_evaluate, - allow_train_from_scratch, - freeze_backbone, - network_config, - optimizer_config, - lr_schedule_config, - evaluation_config, - checkpointing_config): + random_seed: int, + num_classes: int, + batch_size: int, + max_steps: int, + enable_double_transpose: bool, + checkpoint_to_evaluate: Optional[Text], + allow_train_from_scratch: bool, + freeze_backbone: bool, + network_config: Mapping[Text, Any], + optimizer_config: Mapping[Text, Any], + lr_schedule_config: Mapping[Text, Any], + evaluation_config: Mapping[Text, Any], + checkpointing_config: Mapping[Text, Any]): """Constructs the experiment. Args: @@ -125,12 +125,12 @@ class EvalExperiment: def _backbone_fn( self, - inputs, - encoder_class, - encoder_config, - bn_decay_rate, - is_training, - ): + inputs: dataset.Batch, + encoder_class: Text, + encoder_config: Mapping[Text, Any], + bn_decay_rate: float, + is_training: bool, + ) -> jnp.ndarray: """Forward of the encoder (backbone).""" bn_config = {'decay_rate': bn_decay_rate} encoder = getattr(networks, encoder_class) @@ -146,8 +146,8 @@ class EvalExperiment: def _classif_fn( self, - embeddings, - ): + embeddings: jnp.ndarray, + ) -> jnp.ndarray: classifier = hk.Linear(output_size=self._num_classes) return classifier(embeddings) @@ -159,8 +159,8 @@ class EvalExperiment: # def step(self, *, - global_step, - rng): + global_step: jnp.ndarray, + rng: jnp.ndarray) -> Mapping[Text, np.ndarray]: """Performs a single training step.""" if self._train_input is None: @@ -173,12 +173,12 @@ class EvalExperiment: scalars = helpers.get_first(scalars) return scalars - def save_checkpoint(self, step, rng): + def save_checkpoint(self, step: int, rng: jnp.ndarray): self._checkpointer.maybe_save_checkpoint( self._experiment_state, step=step, rng=rng, is_final=step >= self._max_steps) - def load_checkpoint(self): + def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]: checkpoint_data = self._checkpointer.maybe_load_checkpoint() if checkpoint_data is None: return None @@ -253,11 +253,11 @@ class EvalExperiment: def _make_initial_state( self, - rng, - dummy_input, - backbone_params, - backbone_state, - ): + rng: jnp.ndarray, + dummy_input: dataset.Batch, + backbone_params: hk.Params, + backbone_state: hk.Params, + ) -> _EvalExperimentState: """_EvalExperimentState initialization.""" # Initialize the backbone params @@ -279,7 +279,7 @@ class EvalExperiment: classif_opt_state=classif_opt_state, ) - def _build_train_input(self): + def _build_train_input(self) -> Generator[dataset.Batch, None, None]: """See base class.""" num_devices = jax.device_count() global_batch_size = self._batch_size @@ -296,17 +296,17 @@ class EvalExperiment: transpose=self._should_transpose_images(), batch_dims=[jax.local_device_count(), per_device_batch_size]) - def _optimizer(self, learning_rate): + def _optimizer(self, learning_rate: float): """Build optimizer from config.""" return optax.sgd(learning_rate, **self._optimizer_config) def _loss_fn( self, - backbone_params, - classif_params, - backbone_state, - inputs, - ): + backbone_params: hk.Params, + classif_params: hk.Params, + backbone_state: hk.State, + inputs: dataset.Batch, + ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, hk.State]]: """Compute the classification loss function. Args: @@ -333,10 +333,10 @@ class EvalExperiment: def _update_func( self, - experiment_state, - global_step, - inputs, - ): + experiment_state: _EvalExperimentState, + global_step: jnp.ndarray, + inputs: dataset.Batch, + ) -> Tuple[_EvalExperimentState, LogsDict]: """Applies an update to parameters and returns new state.""" # This function computes the gradient of the first output of loss_fn and # passes through the other arguments unchanged. @@ -421,11 +421,11 @@ class EvalExperiment: def _eval_batch( self, - backbone_params, - classif_params, - backbone_state, - inputs, - ): + backbone_params: hk.Params, + classif_params: hk.Params, + backbone_state: hk.State, + inputs: dataset.Batch, + ) -> LogsDict: """Evaluates a batch.""" embeddings, backbone_state = self.forward_backbone.apply( backbone_params, backbone_state, inputs, is_training=False) @@ -441,7 +441,7 @@ class EvalExperiment: 'top5_accuracy': top5_correct } - def _eval_epoch(self, subset, batch_size): + def _eval_epoch(self, subset: Text, batch_size: int): """Evaluates an epoch.""" num_samples = 0. summed_scalars = None diff --git a/byol/main_loop.py b/byol/main_loop.py index 9dc5b4c..544bbf2 100644 --- a/byol/main_loop.py +++ b/byol/main_loop.py @@ -47,7 +47,7 @@ Experiment = Union[ Type[eval_experiment.EvalExperiment]] -def train_loop(experiment_class, config): +def train_loop(experiment_class: Experiment, config: Mapping[Text, Any]): """The main training loop. This loop periodically saves a checkpoint to be evaluated in the eval_loop. @@ -95,7 +95,7 @@ def train_loop(experiment_class, config): experiment.save_checkpoint(step, rng) -def eval_loop(experiment_class, config): +def eval_loop(experiment_class: Experiment, config: Mapping[Text, Any]): """The main evaluation loop. This loop periodically loads a checkpoint and evaluates its performance on the diff --git a/byol/utils/augmentations.py b/byol/utils/augmentations.py index 9be3b49..0049309 100644 --- a/byol/utils/augmentations.py +++ b/byol/utils/augmentations.py @@ -66,14 +66,14 @@ augment_config = dict( )) -def postprocess(inputs, rng): +def postprocess(inputs: JaxBatch, rng: jnp.ndarray): """Apply the image augmentations to crops in inputs (view1 and view2).""" def _postprocess_image( - images, - rng, - presets, - ): + images: jnp.ndarray, + rng: jnp.ndarray, + presets: ConfigDict, + ) -> JaxBatch: """Applies augmentations in post-processing. Args: @@ -144,7 +144,7 @@ def _gaussian_blur_single_image(image, kernel_size, padding, sigma): blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels]) expand_batch_dim = len(image.shape) == 3 if expand_batch_dim: - image = image[jnp.newaxis, Ellipsis] + image = image[jnp.newaxis, ...] blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding) blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding) blurred = jnp.squeeze(blurred, axis=0) @@ -284,7 +284,7 @@ def _random_hue(rgb_tuple, rng, max_delta): def _to_grayscale(image): rgb_weights = jnp.array([0.2989, 0.5870, 0.1140]) - grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[Ellipsis, jnp.newaxis] + grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[..., jnp.newaxis] return jnp.tile(grayscale, (1, 1, 3)) # Back to 3 channels. diff --git a/byol/utils/checkpointing.py b/byol/utils/checkpointing.py index 94b1825..0b02395 100644 --- a/byol/utils/checkpointing.py +++ b/byol/utils/checkpointing.py @@ -31,10 +31,10 @@ class Checkpointer: def __init__( self, - use_checkpointing, - checkpoint_dir, - save_checkpoint_interval, - filename): + use_checkpointing: bool, + checkpoint_dir: Text, + save_checkpoint_interval: int, + filename: Text): if (not use_checkpointing or checkpoint_dir is None or save_checkpoint_interval <= 0): @@ -51,10 +51,10 @@ class Checkpointer: def maybe_save_checkpoint( self, - experiment_state, - step, - rng, - is_final): + experiment_state: Mapping[Text, jnp.ndarray], + step: int, + rng: jnp.ndarray, + is_final: bool): """Saves a checkpoint if enough time has passed since the previous one.""" current_time = time.time() if (not self._checkpoint_enabled or @@ -80,7 +80,7 @@ class Checkpointer: self._last_checkpoint_time = current_time def maybe_load_checkpoint( - self): + self) -> Union[Tuple[Mapping[Text, jnp.ndarray], int, jnp.ndarray], None]: """Loads a checkpoint if any is found.""" checkpoint_data = load_checkpoint(self._checkpoint_path) if checkpoint_data is None: diff --git a/byol/utils/dataset.py b/byol/utils/dataset.py index cf58622..7386f88 100644 --- a/byol/utils/dataset.py +++ b/byol/utils/dataset.py @@ -34,7 +34,7 @@ class Split(enum.Enum): TEST = 4 @classmethod - def from_string(cls, name): + def from_string(cls, name: Text) -> 'Split': return { 'TRAIN': Split.TRAIN, 'TRAIN_AND_VALID': Split.TRAIN_AND_VALID, @@ -60,7 +60,7 @@ class PreprocessMode(enum.Enum): EVAL = 3 # Generates a single center crop. -def normalize_images(images): +def normalize_images(images: jnp.ndarray) -> jnp.ndarray: """Normalize the image using ImageNet statistics.""" mean_rgb = (0.485, 0.456, 0.406) stddev_rgb = (0.229, 0.224, 0.225) @@ -69,12 +69,12 @@ def normalize_images(images): return normed_images -def load(split, +def load(split: Split, *, - preprocess_mode, - batch_dims, - transpose = False, - allow_caching = False): + preprocess_mode: PreprocessMode, + batch_dims: Sequence[int], + transpose: bool = False, + allow_caching: bool = False) -> Generator[Batch, None, None]: """Loads the given split of the dataset.""" start, end = _shard(split, jax.host_id(), jax.host_count()) @@ -153,7 +153,7 @@ def load(split, yield from tfds.as_numpy(ds) -def _to_tfds_split(split): +def _to_tfds_split(split: Split) -> tfds.Split: """Returns the TFDS split appropriately sharded.""" # NOTE: Imagenet did not release labels for the test split used in the # competition, we consider the VALID split the TEST split and reserve @@ -165,7 +165,7 @@ def _to_tfds_split(split): return tfds.Split.VALIDATION -def _shard(split, shard_index, num_shards): +def _shard(split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]: """Returns [start, end) for the given shard index.""" assert shard_index < num_shards arange = np.arange(split.num_examples) @@ -180,9 +180,9 @@ def _shard(split, shard_index, num_shards): def _preprocess_image( - image_bytes, - mode, -): + image_bytes: tf.Tensor, + mode: PreprocessMode, +) -> tf.Tensor: """Returns processed and resized images.""" if mode is PreprocessMode.PRETRAIN: image = _decode_and_random_crop(image_bytes) @@ -201,7 +201,7 @@ def _preprocess_image( return image -def _decode_and_random_crop(image_bytes): +def _decode_and_random_crop(image_bytes: tf.Tensor) -> tf.Tensor: """Make a random crop of 224.""" img_size = tf.image.extract_jpeg_shape(image_bytes) area = tf.cast(img_size[1] * img_size[0], tf.float32) @@ -231,7 +231,7 @@ def _decode_and_random_crop(image_bytes): return image -def transpose_images(batch): +def transpose_images(batch: Batch): """Transpose images for TPU training..""" new_batch = dict(batch) # Avoid mutating in place. if 'images' in batch: @@ -243,9 +243,9 @@ def transpose_images(batch): def _decode_and_center_crop( - image_bytes, - jpeg_shape = None, -): + image_bytes: tf.Tensor, + jpeg_shape: Optional[tf.Tensor] = None, +) -> tf.Tensor: """Crops to center of image with padding then scales.""" if jpeg_shape is None: jpeg_shape = tf.image.extract_jpeg_shape(image_bytes) diff --git a/byol/utils/helpers.py b/byol/utils/helpers.py index 18d2cc2..81a9224 100644 --- a/byol/utils/helpers.py +++ b/byol/utils/helpers.py @@ -21,11 +21,11 @@ import jax.numpy as jnp def topk_accuracy( - logits, - labels, - topk, - ignore_label_above = None, -): + logits: jnp.ndarray, + labels: jnp.ndarray, + topk: int, + ignore_label_above: Optional[int] = None, +) -> jnp.ndarray: """Top-num_codes accuracy.""" assert len(labels.shape) == 1, 'topk expects 1d int labels.' assert len(logits.shape) == 2, 'topk expects 2d logits.' @@ -42,10 +42,10 @@ def topk_accuracy( def softmax_cross_entropy( - logits, - labels, - reduction = 'mean', -): + logits: jnp.ndarray, + labels: jnp.ndarray, + reduction: Optional[Text] = 'mean', +) -> jnp.ndarray: """Computes softmax cross entropy given logits and one-hot class labels. Args: @@ -72,10 +72,10 @@ def softmax_cross_entropy( def l2_normalize( - x, - axis = None, - epsilon = 1e-12, -): + x: jnp.ndarray, + axis: Optional[int] = None, + epsilon: float = 1e-12, +) -> jnp.ndarray: """l2 normalize a tensor on an axis with numerical stability.""" square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True) x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon)) @@ -106,7 +106,7 @@ def l2_weight_regularizer(params): return 0.5 * l2_norm -def regression_loss(x, y): +def regression_loss(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Byol's regression loss. This is a simple cosine similarity.""" normed_x, normed_y = l2_normalize(x, axis=-1), l2_normalize(y, axis=-1) return jnp.sum((normed_x - normed_y)**2, axis=-1) diff --git a/byol/utils/networks.py b/byol/utils/networks.py index 3bd10e6..e2e4bf4 100644 --- a/byol/utils/networks.py +++ b/byol/utils/networks.py @@ -27,17 +27,17 @@ class MLP(hk.Module): def __init__( self, - name, - hidden_size, - output_size, - bn_config, + name: Text, + hidden_size: int, + output_size: int, + bn_config: Mapping[Text, Any], ): super().__init__(name=name) self._hidden_size = hidden_size self._output_size = output_size self._bn_config = bn_config - def __call__(self, inputs, is_training): + def __call__(self, inputs: jnp.ndarray, is_training: bool) -> jnp.ndarray: out = hk.Linear(output_size=self._hidden_size, with_bias=True)(inputs) out = hk.BatchNorm(**self._bn_config)(out, is_training=is_training) out = jax.nn.relu(out) @@ -55,15 +55,15 @@ class ResNetTorso(hk.Module): def __init__( self, - blocks_per_group, - num_classes = None, - bn_config = None, - resnet_v2 = False, - bottleneck = True, - channels_per_group = (256, 512, 1024, 2048), - use_projection = (True, True, True, True), - width_multiplier = 1, - name = None, + blocks_per_group: Sequence[int], + num_classes: int = None, + bn_config: Optional[Mapping[str, float]] = None, + resnet_v2: bool = False, + bottleneck: bool = True, + channels_per_group: Sequence[int] = (256, 512, 1024, 2048), + use_projection: Sequence[bool] = (True, True, True, True), + width_multiplier: int = 1, + name: Optional[str] = None, ): """Constructs a ResNet model. @@ -155,11 +155,11 @@ class TinyResNet(ResNetTorso): """Tiny resnet for local runs and tests.""" def __init__(self, - num_classes = None, - bn_config = None, - resnet_v2 = False, - width_multiplier = 1, - name = None): + num_classes: Optional[int] = None, + bn_config: Optional[Mapping[str, float]] = None, + resnet_v2: bool = False, + width_multiplier: int = 1, + name: Optional[str] = None): """Constructs a ResNet model. Args: @@ -185,11 +185,11 @@ class ResNet18(ResNetTorso): """ResNet18.""" def __init__(self, - num_classes = None, - bn_config = None, - resnet_v2 = False, - width_multiplier = 1, - name = None): + num_classes: Optional[int] = None, + bn_config: Optional[Mapping[str, float]] = None, + resnet_v2: bool = False, + width_multiplier: int = 1, + name: Optional[str] = None): """Constructs a ResNet model. Args: @@ -215,11 +215,11 @@ class ResNet34(ResNetTorso): """ResNet34.""" def __init__(self, - num_classes, - bn_config = None, - resnet_v2 = False, - width_multiplier = 1, - name = None): + num_classes: Optional[int], + bn_config: Optional[Mapping[str, float]] = None, + resnet_v2: bool = False, + width_multiplier: int = 1, + name: Optional[str] = None): """Constructs a ResNet model. Args: @@ -245,11 +245,11 @@ class ResNet50(ResNetTorso): """ResNet50.""" def __init__(self, - num_classes = None, - bn_config = None, - resnet_v2 = False, - width_multiplier = 1, - name = None): + num_classes: Optional[int] = None, + bn_config: Optional[Mapping[str, float]] = None, + resnet_v2: bool = False, + width_multiplier: int = 1, + name: Optional[str] = None): """Constructs a ResNet model. Args: @@ -274,11 +274,11 @@ class ResNet101(ResNetTorso): """ResNet101.""" def __init__(self, - num_classes, - bn_config = None, - resnet_v2 = False, - width_multiplier = 1, - name = None): + num_classes: Optional[int], + bn_config: Optional[Mapping[str, float]] = None, + resnet_v2: bool = False, + width_multiplier: int = 1, + name: Optional[str] = None): """Constructs a ResNet model. Args: @@ -303,11 +303,11 @@ class ResNet152(ResNetTorso): """ResNet152.""" def __init__(self, - num_classes, - bn_config = None, - resnet_v2 = False, - width_multiplier = 1, - name = None): + num_classes: Optional[int], + bn_config: Optional[Mapping[str, float]] = None, + resnet_v2: bool = False, + width_multiplier: int = 1, + name: Optional[str] = None): """Constructs a ResNet model. Args: @@ -332,11 +332,11 @@ class ResNet200(ResNetTorso): """ResNet200.""" def __init__(self, - num_classes, - bn_config = None, - resnet_v2 = False, - width_multiplier = 1, - name = None): + num_classes: Optional[int], + bn_config: Optional[Mapping[str, float]] = None, + resnet_v2: bool = False, + width_multiplier: int = 1, + name: Optional[str] = None): """Constructs a ResNet model. Args: diff --git a/byol/utils/optimizers.py b/byol/utils/optimizers.py index bc8f039..f80423c 100644 --- a/byol/utils/optimizers.py +++ b/byol/utils/optimizers.py @@ -27,7 +27,7 @@ import tree as nest FilterFn = Callable[[Tuple[Any], jnp.ndarray], jnp.ndarray] -def exclude_bias_and_norm(path, val): +def exclude_bias_and_norm(path: Tuple[Any], val: jnp.ndarray) -> jnp.ndarray: """Filter to exclude biaises and normalizations weights.""" del val if path[-1] == "b" or "norm" in path[-2]: @@ -35,10 +35,10 @@ def exclude_bias_and_norm(path, val): return True -def _partial_update(updates, - new_updates, - params, - filter_fn = None): +def _partial_update(updates: optax.Updates, + new_updates: optax.Updates, + params: optax.Params, + filter_fn: Optional[FilterFn] = None) -> optax.Updates: """Returns new_update for params which filter_fn is True else updates.""" if filter_fn is None: @@ -47,7 +47,7 @@ def _partial_update(updates, wrapped_filter_fn = lambda x, y: jnp.array(filter_fn(x, y)) params_to_filter = nest.map_structure_with_path(wrapped_filter_fn, params) - def _update_fn(g, t, m): + def _update_fn(g: jnp.ndarray, t: jnp.ndarray, m: jnp.ndarray) -> jnp.ndarray: m = m.astype(g.dtype) return g * (1. - m) + t * m @@ -59,9 +59,9 @@ class ScaleByLarsState(NamedTuple): def scale_by_lars( - momentum = 0.9, - eta = 0.001, - filter_fn = None): + momentum: float = 0.9, + eta: float = 0.001, + filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation: """Rescales updates according to the LARS algorithm. Does not include weight decay. @@ -77,17 +77,17 @@ def scale_by_lars( An (init_fn, update_fn) tuple. """ - def init_fn(params): + def init_fn(params: optax.Params) -> ScaleByLarsState: mu = jax.tree_multimap(jnp.zeros_like, params) # momentum return ScaleByLarsState(mu=mu) - def update_fn(updates, state, - params): + def update_fn(updates: optax.Updates, state: ScaleByLarsState, + params: optax.Params) -> Tuple[optax.Updates, ScaleByLarsState]: def lars_adaptation( - update, - param, - ): + update: jnp.ndarray, + param: jnp.ndarray, + ) -> jnp.ndarray: param_norm = jnp.linalg.norm(param) update_norm = jnp.linalg.norm(update) return update * jnp.where( @@ -110,8 +110,8 @@ class AddWeightDecayState(NamedTuple): def add_weight_decay( - weight_decay, - filter_fn = None): + weight_decay: float, + filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation: """Adds a weight decay to the update. Args: @@ -122,14 +122,14 @@ def add_weight_decay( An (init_fn, update_fn) tuple. """ - def init_fn(_): + def init_fn(_) -> AddWeightDecayState: return AddWeightDecayState() def update_fn( - updates, - state, - params, - ): + updates: optax.Updates, + state: AddWeightDecayState, + params: optax.Params, + ) -> Tuple[optax.Updates, AddWeightDecayState]: new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates, params) new_updates = _partial_update(updates, new_updates, params, filter_fn) @@ -142,13 +142,13 @@ LarsState = List # Type for the lars optimizer def lars( - learning_rate, - weight_decay = 0., - momentum = 0.9, - eta = 0.001, - weight_decay_filter = None, - lars_adaptation_filter = None, -): + learning_rate: float, + weight_decay: float = 0., + momentum: float = 0.9, + eta: float = 0.001, + weight_decay_filter: Optional[FilterFn] = None, + lars_adaptation_filter: Optional[FilterFn] = None, +) -> optax.GradientTransformation: """Creates lars optimizer with weight decay. References: diff --git a/byol/utils/schedules.py b/byol/utils/schedules.py index f8817a0..773d40c 100644 --- a/byol/utils/schedules.py +++ b/byol/utils/schedules.py @@ -17,18 +17,18 @@ import jax.numpy as jnp -def target_ema(global_step, - base_ema, - max_steps): +def target_ema(global_step: jnp.ndarray, + base_ema: float, + max_steps: int) -> jnp.ndarray: decay = _cosine_decay(global_step, max_steps, 1.) return 1. - (1. - base_ema) * decay -def learning_schedule(global_step, - batch_size, - base_learning_rate, - total_steps, - warmup_steps): +def learning_schedule(global_step: jnp.ndarray, + batch_size: int, + base_learning_rate: float, + total_steps: int, + warmup_steps: int) -> float: """Cosine learning rate scheduler.""" # Compute LR & Scaled LR scaled_lr = base_learning_rate * batch_size / 256. @@ -43,9 +43,9 @@ def learning_schedule(global_step, scaled_lr)) -def _cosine_decay(global_step, - max_steps, - initial_value): +def _cosine_decay(global_step: jnp.ndarray, + max_steps: int, + initial_value: float) -> jnp.ndarray: """Simple implementation of cosine decay from TF1.""" global_step = jnp.minimum(global_step, max_steps) cosine_decay_value = 0.5 * (1 + jnp.cos(jnp.pi * global_step / max_steps)) diff --git a/geomancer/train.py b/geomancer/train.py index 21cc0c7..2ebd541 100644 --- a/geomancer/train.py +++ b/geomancer/train.py @@ -46,7 +46,7 @@ def make_so_tangent(q): for j in range(i+1, n): a[i, j] = 1 a[j, i] = -1 - dq[Ellipsis, ii] = a @ q # tangent vectors are skew-symmetric matrix times Q + dq[..., ii] = a @ q # tangent vectors are skew-symmetric matrix times Q a[i, j] = 0 a[j, i] = 0 ii += 1 @@ -106,7 +106,7 @@ def make_product_manifold(specification, npts): spec_array[1, i] = dim latent_dim += dim dat = np.random.randn(npts, dim+1) - dat /= np.tile(np.sqrt(np.sum(dat**2, axis=1)[Ellipsis, None]), + dat /= np.tile(np.sqrt(np.sum(dat**2, axis=1)[..., None]), [1, dim+1]) elif so_spec is not None: dim = int(so_spec.group(1)) diff --git a/glassy_dynamics/apply_binary.py b/glassy_dynamics/apply_binary.py index a85fc08..137cc7f 100644 --- a/glassy_dynamics/apply_binary.py +++ b/glassy_dynamics/apply_binary.py @@ -16,7 +16,7 @@ from __future__ import absolute_import from __future__ import division - +from __future__ import google_type_annotations from __future__ import print_function import os diff --git a/glassy_dynamics/graph_model.py b/glassy_dynamics/graph_model.py index ddf9dea..83f8900 100644 --- a/glassy_dynamics/graph_model.py +++ b/glassy_dynamics/graph_model.py @@ -20,7 +20,7 @@ The architecture and performance of this model is described in our publication: from __future__ import absolute_import from __future__ import division - +from __future__ import google_type_annotations from __future__ import print_function import functools @@ -35,10 +35,10 @@ from typing import Any, Dict, Text, Tuple, Optional def make_graph_from_static_structure( - positions, - types, - box, - edge_threshold): + positions: tf.Tensor, + types: tf.Tensor, + box: tf.Tensor, + edge_threshold: float) -> graphs.GraphsTuple: """Returns graph representing the static structure of the glass. Each particle is represented by a node in the graph. The particle type is @@ -81,7 +81,7 @@ def make_graph_from_static_structure( ) -def apply_random_rotation(graph): +def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple: """Returns randomly rotated graph representation. The rotation is an element of O(3) with rotation angles multiple of pi/2. @@ -118,9 +118,9 @@ class GraphBasedModel(snt.AbstractModule): """ def __init__(self, - n_recurrences, - mlp_sizes, - mlp_kwargs = None, + n_recurrences: int, + mlp_sizes: Tuple[int], + mlp_kwargs: Optional[Dict[Text, Any]] = None, name='Graph'): """Creates a new GraphBasedModel object. @@ -168,7 +168,7 @@ class GraphBasedModel(snt.AbstractModule): node_model_fn=final_model_fn, edge_model_fn=model_fn) - def _build(self, graphs_tuple): + def _build(self, graphs_tuple: graphs.GraphsTuple) -> tf.Tensor: """Connects the model into the tensorflow graph. Args: diff --git a/glassy_dynamics/train.py b/glassy_dynamics/train.py index 8977042..d623ac9 100644 --- a/glassy_dynamics/train.py +++ b/glassy_dynamics/train.py @@ -16,7 +16,7 @@ from __future__ import absolute_import from __future__ import division - +from __future__ import google_type_annotations from __future__ import print_function import collections @@ -53,8 +53,8 @@ class ParticleType(enum.IntEnum): def get_targets( - initial_positions, - trajectory_target_positions): + initial_positions: np.ndarray, + trajectory_target_positions: Sequence[np.ndarray]) -> np.ndarray: """Returns the averaged particle mobilities from the sampled trajectories. Args: @@ -70,9 +70,9 @@ def get_targets( def load_data( - file_pattern, - time_index, - max_files_to_load = None): + file_pattern: Text, + time_index: int, + max_files_to_load: Optional[int] = None) -> List[GlassSimulationData]: """Returns a dictionary containing the training or test dataset. The dictionary contains: @@ -108,9 +108,9 @@ def load_data( def get_loss_ops( - prediction, - target, - types): + prediction: tf.Tensor, + target: tf.Tensor, + types: tf.Tensor) -> LossCollection: """Returns L1/L2 loss and correlation for type A particles. Args: @@ -132,9 +132,9 @@ def get_loss_ops( def get_minimize_op( - loss, - learning_rate, - grad_clip = None): + loss: tf.Tensor, + learning_rate: float, + grad_clip: Optional[float] = None) -> tf.Tensor: """Returns minimization operation. Args: @@ -152,8 +152,8 @@ def get_minimize_op( def _log_stats_and_return_mean_correlation( - label, - stats): + label: Text, + stats: Sequence[LossCollection]) -> float: """Logs performance statistics and returns mean correlation. Args: @@ -171,20 +171,20 @@ def _log_stats_and_return_mean_correlation( return np.mean([s.correlation for s in stats]) -def train_model(train_file_pattern, - test_file_pattern, - max_files_to_load = None, - n_epochs = 1000, - time_index = 9, - augment_data_using_rotations = True, - learning_rate = 1e-4, - grad_clip = 1.0, - n_recurrences = 7, - mlp_sizes = (64, 64), - mlp_kwargs = None, - edge_threshold = 2.0, - measurement_store_interval = 1000, - checkpoint_path = None): +def train_model(train_file_pattern: Text, + test_file_pattern: Text, + max_files_to_load: Optional[int] = None, + n_epochs: int = 1000, + time_index: int = 9, + augment_data_using_rotations: bool = True, + learning_rate: float = 1e-4, + grad_clip: Optional[float] = 1.0, + n_recurrences: int = 7, + mlp_sizes: Tuple[int] = (64, 64), + mlp_kwargs: Optional[Dict[Text, Any]] = None, + edge_threshold: float = 2.0, + measurement_store_interval: int = 1000, + checkpoint_path: Optional[Text] = None) -> float: """Trains GraphModel using tensorflow. Args: @@ -325,10 +325,10 @@ def train_model(train_file_pattern, return best_so_far -def apply_model(checkpoint_path, - file_pattern, - max_files_to_load = None, - time_index = 9): +def apply_model(checkpoint_path: Text, + file_pattern: Text, + max_files_to_load: Optional[int] = None, + time_index: int = 9) -> List[np.ndarray]: """Applies trained GraphModel using tensorflow. Args: diff --git a/glassy_dynamics/train_binary.py b/glassy_dynamics/train_binary.py index fcdf0d5..1b9d42d 100644 --- a/glassy_dynamics/train_binary.py +++ b/glassy_dynamics/train_binary.py @@ -16,7 +16,7 @@ from __future__ import absolute_import from __future__ import division - +from __future__ import google_type_annotations from __future__ import print_function import os diff --git a/hierarchical_probabilistic_unet/model.py b/hierarchical_probabilistic_unet/model.py index c0fccf3..4e25ac5 100644 --- a/hierarchical_probabilistic_unet/model.py +++ b/hierarchical_probabilistic_unet/model.py @@ -145,8 +145,8 @@ class _HierarchicalCore(snt.AbstractModule): regularizers=self._regularizers, )(decoder_features) - mu = mu_logsigma[Ellipsis, :latent_dim] - logsigma = mu_logsigma[Ellipsis, latent_dim:] + mu = mu_logsigma[..., :latent_dim] + logsigma = mu_logsigma[..., latent_dim:] dist = tfd.MultivariateNormalDiag(loc=mu, scale_diag=tf.exp(logsigma)) distributions.append(dist) diff --git a/iodine/modules/decoder.py b/iodine/modules/decoder.py index 90c8827..c9f0560 100644 --- a/iodine/modules/decoder.py +++ b/iodine/modules/decoder.py @@ -37,8 +37,8 @@ class ComponentDecoder(snt.AbstractModule): pixel_params = self._pixel_decoder(z_flat).params self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp") - mask_params = pixel_params[Ellipsis, 0:1] - pixel_params = pixel_params[Ellipsis, 1:] + mask_params = pixel_params[..., 0:1] + pixel_params = pixel_params[..., 1:] output = MixtureParameters( pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"), diff --git a/iodine/modules/distributions.py b/iodine/modules/distributions.py index 9e12e97..4c84576 100644 --- a/iodine/modules/distributions.py +++ b/iodine/modules/distributions.py @@ -134,8 +134,8 @@ class LocScaleDistribution(DistributionModule): n_channels = params.get_shape().as_list()[-1] assert n_channels % 2 == 0 assert n_channels // 2 == self.output_shape[-1] - loc = params[Ellipsis, :n_channels // 2] - scale = params[Ellipsis, n_channels // 2:] + loc = params[..., :n_channels // 2] + scale = params[..., n_channels // 2:] # apply activation functions if self._scale != "fixed": diff --git a/iodine/modules/factor_eval.py b/iodine/modules/factor_eval.py index de28574..3d52464 100644 --- a/iodine/modules/factor_eval.py +++ b/iodine/modules/factor_eval.py @@ -84,7 +84,7 @@ class FactorRegressor(snt.AbstractModule): for m in self._mapping: with tf.name_scope(m.name): assert m.name in latent, "{} not in {}".format(m.name, latent.keys()) - pred = all_preds[Ellipsis, idx:idx + m.size] + pred = all_preds[..., idx:idx + m.size] predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size)) idx += m.size @@ -165,7 +165,7 @@ class FactorRegressor(snt.AbstractModule): @staticmethod def one_hot(f, nr_categories): - return tf.one_hot(tf.cast(f[Ellipsis, 0], tf.int32), depth=nr_categories) + return tf.one_hot(tf.cast(f[..., 0], tf.int32), depth=nr_categories) @staticmethod def angle_to_vector(theta): @@ -194,7 +194,7 @@ def accuracy(labels, logits, assignment, mean_var_tot, num_vis): pred = tf.argmax(logits, axis=-1, output_type=tf.int32) labels = tf.argmax(labels, axis=-1, output_type=tf.int32) correct = tf.cast(tf.equal(labels, pred), tf.float32) - return tf.reduce_sum(correct * assignment[Ellipsis, 0]) / num_vis + return tf.reduce_sum(correct * assignment[..., 0]) / num_vis def r2(labels, pred, assignment, mean_var_tot, num_vis): diff --git a/iodine/modules/iodine.py b/iodine/modules/iodine.py index 30b368a..af7c4c1 100644 --- a/iodine/modules/iodine.py +++ b/iodine/modules/iodine.py @@ -325,12 +325,12 @@ class IODINE(snt.AbstractModule): [get_components(xd) for xd in iterations["x_dist"]]) # metrics - tm = tf.transpose(true_mask[Ellipsis, 0], [0, 1, 3, 4, 2]) + tm = tf.transpose(true_mask[..., 0], [0, 1, 3, 4, 2]) tm = tf.reshape(tf.tile(tm, sg["1, T, 1, 1, 1"]), sg["B * T, H * W, L"]) - pm = tf.transpose(pred_mask[Ellipsis, 0], [0, 1, 3, 4, 2]) + pm = tf.transpose(pred_mask[..., 0], [0, 1, 3, 4, 2]) pm = tf.reshape(pm, sg["B * T, H * W, K"]) ari = tf.reshape(adjusted_rand_index(tm, pm), sg["B, T"]) - ari_nobg = tf.reshape(adjusted_rand_index(tm[Ellipsis, 1:], pm), sg["B, T"]) + ari_nobg = tf.reshape(adjusted_rand_index(tm[..., 1:], pm), sg["B, T"]) mse = tf.reduce_mean(tf.square(recons - image[:, None]), axis=[2, 3, 4, 5]) @@ -387,7 +387,7 @@ class IODINE(snt.AbstractModule): factor_info["assignment"].append(fass) for k in fpred: factor_info["predictions"][k].append( - tf.reduce_sum(fpred[k] * fass[Ellipsis, None], axis=2)) + tf.reduce_sum(fpred[k] * fass[..., None], axis=2)) factor_info["metrics"][k].append(fscalars[k]) info["losses"]["factor"] = sg.guard(tf.stack(factor_info["loss"]), "T") @@ -496,7 +496,7 @@ class IODINE(snt.AbstractModule): @staticmethod def _get_mask_posterior(out_dist, img): - p_comp = out_dist.components_distribution.prob(img[Ellipsis, tf.newaxis, :]) + p_comp = out_dist.components_distribution.prob(img[..., tf.newaxis, :]) posterior = p_comp / (tf.reduce_sum(p_comp, axis=-1, keepdims=True) + 1e-6) return tf.transpose(posterior, [0, 4, 2, 3, 1]) @@ -506,7 +506,7 @@ class IODINE(snt.AbstractModule): dzp, dxp, dmp = tf.gradients(loss, [zp, out_params.pixel, out_params.mask]) log_prob = sg.guard( - out_dist.log_prob(img)[Ellipsis, tf.newaxis], "B, 1, H, W, 1") + out_dist.log_prob(img)[..., tf.newaxis], "B, 1, H, W, 1") counterfactual_log_probs = [] for k in range(0, self.num_components): @@ -515,7 +515,7 @@ class IODINE(snt.AbstractModule): pixel = tf.concat([out_params.pixel[:, :k], out_params.pixel[:, k + 1:]], axis=1) out_dist_k = self.output_dist(pixel, mask) - log_prob_k = out_dist_k.log_prob(img)[Ellipsis, tf.newaxis] + log_prob_k = out_dist_k.log_prob(img)[..., tf.newaxis] counterfactual_log_probs.append(log_prob_k) counterfactual = log_prob - tf.concat(counterfactual_log_probs, axis=1) @@ -608,7 +608,7 @@ class IODINE(snt.AbstractModule): x_basis = tf.cos(valx * freqs[None, None, None, None, :, None]) y_basis = tf.cos(valy * freqs[None, None, None, None, None, :]) xy_basis = tf.reshape(x_basis * y_basis, self._sg["1, 1, H, W, F*F"]) - coords = tf.tile(xy_basis, self._sg["B, 1, 1, 1, 1"])[Ellipsis, 1:] + coords = tf.tile(xy_basis, self._sg["B, 1, 1, 1, 1"])[..., 1:] return coords else: raise KeyError('Unknown coord_type: "{}"'.format(self.coord_type)) @@ -632,7 +632,7 @@ class IODINE(snt.AbstractModule): # ########## Mask Monitoring ####### if "mask" in data: true_mask = self._sg.guard(data["mask"], "B, T, L, H, W, 1") - true_mask = tf.transpose(true_mask[:, -1, Ellipsis, 0], [0, 2, 3, 1]) + true_mask = tf.transpose(true_mask[:, -1, ..., 0], [0, 2, 3, 1]) true_mask = self._sg.reshape(true_mask, "B, H*W, L") else: true_mask = None @@ -648,6 +648,6 @@ class IODINE(snt.AbstractModule): adjusted_rand_index(true_mask, pred_mask)) scalars["loss/ari_nobg"] = tf.reduce_mean( - adjusted_rand_index(true_mask[Ellipsis, 1:], pred_mask)) + adjusted_rand_index(true_mask[..., 1:], pred_mask)) return scalars diff --git a/iodine/modules/networks.py b/iodine/modules/networks.py index e6e7d8e..6364cd2 100644 --- a/iodine/modules/networks.py +++ b/iodine/modules/networks.py @@ -258,7 +258,7 @@ class BroadcastConv(snt.AbstractModule): x_basis = tf.cos(valx * freqs[None, None, None, :, None]) y_basis = tf.cos(valy * freqs[None, None, None, None, :]) xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"]) - coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[Ellipsis, 1:] + coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[..., 1:] return tf.concat([output, coords], axis=-1) else: raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type)) diff --git a/iodine/modules/plotting.py b/iodine/modules/plotting.py index 0bdb9bc..990b083 100644 --- a/iodine/modules/plotting.py +++ b/iodine/modules/plotting.py @@ -83,7 +83,7 @@ def show_mask(m, ax): @optional_clean_ax def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"): return ax.matshow( - m[Ellipsis, 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest") + m[..., 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest") @optional_clean_ax @@ -115,7 +115,7 @@ def example_plot(rinfo, show_img(image, ax=axes[0], color="#000000") show_img(recons, ax=axes[1], color="#000000") - show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000") + show_mask(pred_mask[..., 0], ax=axes[2], color="#000000") for k in range(K): mask = pred_mask[k] if mask_components else None show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask) @@ -145,7 +145,7 @@ def iterations_plot(rinfo, b=0, mask_components=False, size=2): nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size)) for t in range(T): show_img(recons[t, 0], ax=axes[t, 0]) - show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1]) + show_mask(pred_mask[t, ..., 0], ax=axes[t, 1]) axes[t, 0].set_ylabel("iter {}".format(t)) for k in range(K): mask = pred_mask[t, k] if mask_components else None @@ -154,7 +154,7 @@ def iterations_plot(rinfo, b=0, mask_components=False, size=2): axes[0, 0].set_title("Reconstruction") axes[0, 1].set_title("Mask") show_img(image[0], ax=axes[T, 0]) - show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1]) + show_mask(true_mask[0, ..., 0], ax=axes[T, 1]) vmin = np.min(pred_mask_logits[T - 1]) vmax = np.max(pred_mask_logits[T - 1]) diff --git a/iodine/modules/utils.py b/iodine/modules/utils.py index 04d29a7..dfedb46 100644 --- a/iodine/modules/utils.py +++ b/iodine/modules/utils.py @@ -123,7 +123,7 @@ def construct_diagnostic_image( components = tf.tile(components[:nr_images], [1, 1, 1, 1, 3]) if mask_components: - components *= masks[:nr_images, Ellipsis, tf.newaxis] + components *= masks[:nr_images, ..., tf.newaxis] # Pad everything no_pad, pad = (0, 0), (border_width, border_width) @@ -415,7 +415,7 @@ def images_to_grid( if max_grid_width is not None: grid_width = min(max_grid_width, grid_width) - images = images[: grid_height * grid_width, Ellipsis] + images = images[: grid_height * grid_width, ...] # Pad with extra blank frames if grid_height x grid_width is less than the # number of frames provided. @@ -460,7 +460,7 @@ def flatten_all_but_last(tensor, n_dims=1): def ensure_3d(tensor): if tensor.shape.ndims == 2: - return tensor[Ellipsis, None] + return tensor[..., None] assert tensor.shape.ndims == 3 return tensor diff --git a/option_keyboard/keyboard_agent.py b/option_keyboard/keyboard_agent.py index 852f2b0..b0d1110 100644 --- a/option_keyboard/keyboard_agent.py +++ b/option_keyboard/keyboard_agent.py @@ -82,7 +82,7 @@ class Agent(): def option_values(values, policy): return tf.tensordot( - values[:, policy, Ellipsis], self._policy_weights[policy], axes=[1, 0]) + values[:, policy, ...], self._policy_weights[policy], axes=[1, 0]) # Placeholders for policy. o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype) @@ -103,8 +103,8 @@ class Agent(): qo_t = option_values(q_t, p) a_t = tf.cast(tf.argmax(qo_t, axis=-1), tf.int32) - qa_tm1 = _batched_index(q_tm1[:, p, Ellipsis], a_tm1) - qa_t = _batched_index(q_t[:, p, Ellipsis], a_t) + qa_tm1 = _batched_index(q_tm1[:, p, ...], a_tm1) + qa_t = _batched_index(q_t[:, p, ...], a_t) # TD error g = additional_discount * tf.expand_dims(d_t, axis=-1) diff --git a/polygen/data_utils.py b/polygen/data_utils.py index 1450d40..425badb 100644 --- a/polygen/data_utils.py +++ b/polygen/data_utils.py @@ -91,7 +91,7 @@ def make_face_model_dataset( # Vertices are quantized. So convert to floats for input to face model example['vertices'] = modules.dequantize_verts(vertices, quantization_bits) example['vertices_mask'] = tf.ones_like( - example['vertices'][Ellipsis, 0], dtype=tf.float32) + example['vertices'][..., 0], dtype=tf.float32) example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32) return example return ds.map(_face_model_map_fn) diff --git a/polygen/modules.py b/polygen/modules.py index 8a848a3..9e4cfb7 100644 --- a/polygen/modules.py +++ b/polygen/modules.py @@ -799,7 +799,7 @@ class VertexModel(snt.AbstractModule): # Continuous vertex value embeddings else: vert_embeddings = tf.layers.dense( - dequantize_verts(vertices[Ellipsis, None], self.quantization_bits), + dequantize_verts(vertices[..., None], self.quantization_bits), self.embedding_dim, use_bias=True, name='value_embeddings') @@ -984,7 +984,7 @@ class VertexModel(snt.AbstractModule): verts_dequantized = dequantize_verts(v, self.quantization_bits) vertices = tf.reshape(verts_dequantized, [num_samples, -1, 3]) vertices = tf.stack( - [vertices[Ellipsis, 2], vertices[Ellipsis, 1], vertices[Ellipsis, 0]], axis=-1) + [vertices[..., 2], vertices[..., 1], vertices[..., 0]], axis=-1) # Pad samples to max sample length. This is required in order to concatenate # Samples across different replicator instances. Pad with stopping tokens @@ -998,14 +998,14 @@ class VertexModel(snt.AbstractModule): if recenter_verts: vert_max = tf.reduce_max( - vertices - 1e10 * (1. - vertices_mask)[Ellipsis, None], axis=1, + vertices - 1e10 * (1. - vertices_mask)[..., None], axis=1, keepdims=True) vert_min = tf.reduce_min( - vertices + 1e10 * (1. - vertices_mask)[Ellipsis, None], axis=1, + vertices + 1e10 * (1. - vertices_mask)[..., None], axis=1, keepdims=True) vert_centers = 0.5 * (vert_max + vert_min) vertices -= vert_centers - vertices *= vertices_mask[Ellipsis, None] + vertices *= vertices_mask[..., None] if only_return_complete: vertices = tf.boolean_mask(vertices, completed) @@ -1247,7 +1247,7 @@ class FaceModel(snt.AbstractModule): sequential_context_embeddings = ( vertex_embeddings * tf.pad(context['vertices_mask'], [[0, 0], [2, 0]], - constant_values=1)[Ellipsis, None]) + constant_values=1)[..., None]) else: sequential_context_embeddings = None return (vertex_embeddings, global_context_embedding, @@ -1266,11 +1266,11 @@ class FaceModel(snt.AbstractModule): embed_dim=self.embedding_dim, initializers={'embeddings': tf.glorot_uniform_initializer}, densify_gradients=True, - name='coord_{}'.format(c))(verts_quantized[Ellipsis, c]) + name='coord_{}'.format(c))(verts_quantized[..., c]) else: vertex_embeddings = tf.layers.dense( vertices, self.embedding_dim, use_bias=True, name='vertex_embeddings') - vertex_embeddings *= vertices_mask[Ellipsis, None] + vertex_embeddings *= vertices_mask[..., None] # Pad vertex embeddings with learned embeddings for stopping and new face # tokens diff --git a/rl_unplugged/atari.py b/rl_unplugged/atari.py index f661bdd..bab73e4 100644 --- a/rl_unplugged/atari.py +++ b/rl_unplugged/atari.py @@ -106,7 +106,7 @@ TESTING_SUITE = [ ALL = TUNING_SUITE + TESTING_SUITE -def _decode_frames(pngs): +def _decode_frames(pngs: tf.Tensor): """Decode PNGs. Args: @@ -122,13 +122,13 @@ def _decode_frames(pngs): return frames -def _make_reverb_sample(o_t, - a_t, - r_t, - d_t, - o_tp1, - a_tp1, - extras): +def _make_reverb_sample(o_t: tf.Tensor, + a_t: tf.Tensor, + r_t: tf.Tensor, + d_t: tf.Tensor, + o_tp1: tf.Tensor, + a_tp1: tf.Tensor, + extras: Dict[str, tf.Tensor]) -> reverb.ReplaySample: """Create Reverb sample with offline data. Args: @@ -151,8 +151,8 @@ def _make_reverb_sample(o_t, return reverb.ReplaySample(info=info, data=data) -def _tf_example_to_reverb_sample(tf_example - ): +def _tf_example_to_reverb_sample(tf_example: tf.train.Example + ) -> reverb.ReplaySample: """Create a Reverb replay sample from a TF example.""" # Parse tf.Example. @@ -184,11 +184,11 @@ def _tf_example_to_reverb_sample(tf_example extras) -def dataset(path, - game, - run, - num_shards = 100, - shuffle_buffer_size = 100000): +def dataset(path: str, + game: str, + run: int, + num_shards: int = 100, + shuffle_buffer_size: int = 100000) -> tf.data.Dataset: """TF dataset of Atari SARSA tuples.""" path = os.path.join(path, f'{game}/run_{run}') filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(num_shards)] @@ -243,7 +243,7 @@ class AtariDopamineWrapper(dm_env.Environment): return specs.DiscreteArray(self._env.action_space.n) -def environment(game): +def environment(game: str) -> dm_env.Environment: """Atari environment.""" env = atari_lib.create_atari_environment(game_name=game, sticky_actions=True) diff --git a/rl_unplugged/dm_control_suite.py b/rl_unplugged/dm_control_suite.py index 2ba63f9..b59939e 100644 --- a/rl_unplugged/dm_control_suite.py +++ b/rl_unplugged/dm_control_suite.py @@ -773,15 +773,15 @@ def _padded_batch(example_ds, batch_size, shapes, drop_remainder=False): drop_remainder=drop_remainder) -def dataset(root_path, - data_path, - shapes, - num_threads, - batch_size, - uint8_features = None, - num_shards = 100, - shuffle_buffer_size = 100000, - sarsa = True): +def dataset(root_path: str, + data_path: str, + shapes: Dict[str, Tuple[int]], + num_threads: int, + batch_size: int, + uint8_features: Set[str] = None, + num_shards: int = 100, + shuffle_buffer_size: int = 100000, + sarsa: bool = True) -> tf.data.Dataset: """Create tf dataset for training.""" uint8_features = uint8_features if uint8_features else {} diff --git a/rl_unplugged/rwrl.py b/rl_unplugged/rwrl.py index c720c94..c75546b 100644 --- a/rl_unplugged/rwrl.py +++ b/rl_unplugged/rwrl.py @@ -55,7 +55,7 @@ DELIMITER = ':' DEFAULT_NUM_TIMESTEPS = 1001 -def _decombine_key(k, delimiter = DELIMITER): +def _decombine_key(k: str, delimiter: str = DELIMITER) -> Sequence[str]: return k.split(delimiter) @@ -79,7 +79,7 @@ def tf_example_to_feature_description(example, def tree_deflatten_with_delimiter( - flat_dict, delimiter = DELIMITER): + flat_dict: Dict[str, Any], delimiter: str = DELIMITER) -> Dict[str, Any]: """De-flattens a dict to its originally nested structure. Does the opposite of {combine_nested_keys(k) :v @@ -102,12 +102,12 @@ def tree_deflatten_with_delimiter( return dict(root) -def get_slice_of_nested(nested, start, - end): +def get_slice_of_nested(nested: Dict[str, Any], start: int, + end: int) -> Dict[str, Any]: return tree.map_structure(lambda item: item[start:end], nested) -def repeat_last_and_append_to_nested(nested): +def repeat_last_and_append_to_nested(nested: Dict[str, Any]) -> Dict[str, Any]: return tree.map_structure( lambda item: tf.concat((item, item[-1:]), axis=0), nested) @@ -133,13 +133,13 @@ def tf_example_to_reverb_sample(example, return ret -def dataset(path, - combined_challenge, - domain, - task, - difficulty, - num_shards = 100, - shuffle_buffer_size = 100000): +def dataset(path: str, + combined_challenge: str, + domain: str, + task: str, + difficulty: str, + num_shards: int = 100, + shuffle_buffer_size: int = 100000) -> tf.data.Dataset: """TF dataset of RWRL SARSA tuples.""" path = os.path.join( path, @@ -178,11 +178,11 @@ def dataset(path, def environment( - combined_challenge, - domain, - task, - log_output = None, - environment_kwargs = None): + combined_challenge: str, + domain: str, + task: str, + log_output: Optional[str] = None, + environment_kwargs: Optional[Dict[str, Any]] = None) -> dm_env.Environment: """RWRL environment.""" env = rwrl_envs.load( domain_name=domain, diff --git a/transporter/transporter.py b/transporter/transporter.py index cb14e90..d49614b 100644 --- a/transporter/transporter.py +++ b/transporter/transporter.py @@ -106,8 +106,8 @@ class Transporter(snt.AbstractModule): num_keypoints = image_a_keypoints["heatmaps"].shape[-1] transported_features = image_a_features for k in range(num_keypoints): - mask_a = image_a_keypoints["heatmaps"][Ellipsis, k, None] - mask_b = image_b_keypoints["heatmaps"][Ellipsis, k, None] + mask_a = image_a_keypoints["heatmaps"][..., k, None] + mask_b = image_b_keypoints["heatmaps"][..., k, None] # suppress features from image a, around both keypoint locations. transported_features = (