mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Add checkpoints from the ablation study.
PiperOrigin-RevId: 328023346
This commit is contained in:
committed by
Diego de Las Casas
parent
22c3daff19
commit
8457046b2c
@@ -176,3 +176,37 @@ python -m byol.main_loop \
|
||||
With these settings, BYOL should achieve ~92.3% top-1 accuracy (for the
|
||||
*online* classifier) in roughly 4 hours. Note that the above parameters were not
|
||||
finely tuned and may not be optimal.
|
||||
|
||||
|
||||
## Additional checkpoints
|
||||
|
||||
Alongside with the pretrained ResNet-50 and ResNet-200 2x, we provide the
|
||||
following checkpoints from our ablation study. They all correspond to a
|
||||
ResNet-50 1x pre-trained over 300 epochs and were randomly selected within the
|
||||
three seeds; file size is roughly 640MB each.
|
||||
|
||||
- [Baseline](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_baseline.pkl)
|
||||
|
||||
- Smaller batch sizes (figure 3a):
|
||||
- [Batch size 2048](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_2048.pkl)
|
||||
- [Batch size 1024](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_1024.pkl)
|
||||
- [Batch size 512](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_512.pkl)
|
||||
- [Batch size 256](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_256.pkl)
|
||||
- [Batch size 128](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_128.pkl)
|
||||
- [Batch size 64](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_64.pkl)
|
||||
|
||||
- Ablation on transformations (figure 3b):
|
||||
- [Remove grayscale](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_no_grayscale.pkl)
|
||||
- [Remove color](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_no_color.pkl)
|
||||
- [Crop and blur only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_and_blur_only.pkl)
|
||||
- [Crop only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_only.pkl)
|
||||
- (from Table 18) [Crop and color only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_and_color_only.pkl)
|
||||
|
||||
|
||||
## License
|
||||
|
||||
While the code is licensed under the Apache 2.0 License, the checkpoints weights
|
||||
are made available for non-commercial use only under the terms of the
|
||||
Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
|
||||
license. You can find details at:
|
||||
https://creativecommons.org/licenses/by-nc/4.0/legalcode.
|
||||
|
||||
+47
-47
@@ -56,17 +56,17 @@ class ByolExperiment:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
random_seed,
|
||||
num_classes,
|
||||
batch_size,
|
||||
max_steps,
|
||||
enable_double_transpose,
|
||||
base_target_ema,
|
||||
network_config,
|
||||
optimizer_config,
|
||||
lr_schedule_config,
|
||||
evaluation_config,
|
||||
checkpointing_config):
|
||||
random_seed: int,
|
||||
num_classes: int,
|
||||
batch_size: int,
|
||||
max_steps: int,
|
||||
enable_double_transpose: bool,
|
||||
base_target_ema: float,
|
||||
network_config: Mapping[Text, Any],
|
||||
optimizer_config: Mapping[Text, Any],
|
||||
lr_schedule_config: Mapping[Text, Any],
|
||||
evaluation_config: Mapping[Text, Any],
|
||||
checkpointing_config: Mapping[Text, Any]):
|
||||
"""Constructs the experiment.
|
||||
|
||||
Args:
|
||||
@@ -115,15 +115,15 @@ class ByolExperiment:
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
inputs,
|
||||
projector_hidden_size,
|
||||
projector_output_size,
|
||||
predictor_hidden_size,
|
||||
encoder_class,
|
||||
encoder_config,
|
||||
bn_config,
|
||||
is_training,
|
||||
):
|
||||
inputs: dataset.Batch,
|
||||
projector_hidden_size: int,
|
||||
projector_output_size: int,
|
||||
predictor_hidden_size: int,
|
||||
encoder_class: Text,
|
||||
encoder_config: Mapping[Text, Any],
|
||||
bn_config: Mapping[Text, Any],
|
||||
is_training: bool,
|
||||
) -> Mapping[Text, jnp.ndarray]:
|
||||
"""Forward application of byol's architecture.
|
||||
|
||||
Args:
|
||||
@@ -163,7 +163,7 @@ class ByolExperiment:
|
||||
classifier = hk.Linear(
|
||||
output_size=self._num_classes, name='classifier')
|
||||
|
||||
def apply_once_fn(images, suffix = ''):
|
||||
def apply_once_fn(images: jnp.ndarray, suffix: Text = ''):
|
||||
images = dataset.normalize_images(images)
|
||||
|
||||
embedding = net(images, is_training=is_training)
|
||||
@@ -186,7 +186,7 @@ class ByolExperiment:
|
||||
else:
|
||||
return apply_once_fn(inputs['images'], '')
|
||||
|
||||
def _optimizer(self, learning_rate):
|
||||
def _optimizer(self, learning_rate: float) -> optax.GradientTransformation:
|
||||
"""Build optimizer from config."""
|
||||
return optimizers.lars(
|
||||
learning_rate,
|
||||
@@ -196,13 +196,13 @@ class ByolExperiment:
|
||||
|
||||
def loss_fn(
|
||||
self,
|
||||
online_params,
|
||||
target_params,
|
||||
online_state,
|
||||
target_state,
|
||||
rng,
|
||||
inputs,
|
||||
):
|
||||
online_params: hk.Params,
|
||||
target_params: hk.Params,
|
||||
online_state: hk.State,
|
||||
target_state: hk.Params,
|
||||
rng: jnp.ndarray,
|
||||
inputs: dataset.Batch,
|
||||
) -> Tuple[jnp.ndarray, Tuple[Mapping[Text, hk.State], LogsDict]]:
|
||||
"""Compute BYOL's loss function.
|
||||
|
||||
Args:
|
||||
@@ -292,11 +292,11 @@ class ByolExperiment:
|
||||
|
||||
def _update_fn(
|
||||
self,
|
||||
byol_state,
|
||||
global_step,
|
||||
rng,
|
||||
inputs,
|
||||
):
|
||||
byol_state: _ByolExperimentState,
|
||||
global_step: jnp.ndarray,
|
||||
rng: jnp.ndarray,
|
||||
inputs: dataset.Batch,
|
||||
) -> Tuple[_ByolExperimentState, LogsDict]:
|
||||
"""Update online and target parameters.
|
||||
|
||||
Args:
|
||||
@@ -352,9 +352,9 @@ class ByolExperiment:
|
||||
|
||||
def _make_initial_state(
|
||||
self,
|
||||
rng,
|
||||
dummy_input,
|
||||
):
|
||||
rng: jnp.ndarray,
|
||||
dummy_input: dataset.Batch,
|
||||
) -> _ByolExperimentState:
|
||||
"""BYOL's _ByolExperimentState initialization.
|
||||
|
||||
Args:
|
||||
@@ -393,8 +393,8 @@ class ByolExperiment:
|
||||
)
|
||||
|
||||
def step(self, *,
|
||||
global_step,
|
||||
rng):
|
||||
global_step: jnp.ndarray,
|
||||
rng: jnp.ndarray) -> Mapping[Text, np.ndarray]:
|
||||
"""Performs a single training step."""
|
||||
if self._train_input is None:
|
||||
self._initialize_train()
|
||||
@@ -410,11 +410,11 @@ class ByolExperiment:
|
||||
|
||||
return helpers.get_first(scalars)
|
||||
|
||||
def save_checkpoint(self, step, rng):
|
||||
def save_checkpoint(self, step: int, rng: jnp.ndarray):
|
||||
self._checkpointer.maybe_save_checkpoint(
|
||||
self._byol_state, step=step, rng=rng, is_final=step >= self._max_steps)
|
||||
|
||||
def load_checkpoint(self):
|
||||
def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]:
|
||||
checkpoint_data = self._checkpointer.maybe_load_checkpoint()
|
||||
if checkpoint_data is None:
|
||||
return None
|
||||
@@ -444,7 +444,7 @@ class ByolExperiment:
|
||||
|
||||
self._byol_state = init_byol(rng=init_rng, dummy_input=inputs)
|
||||
|
||||
def _build_train_input(self):
|
||||
def _build_train_input(self) -> Generator[dataset.Batch, None, None]:
|
||||
"""Loads the (infinitely looping) dataset iterator."""
|
||||
num_devices = jax.device_count()
|
||||
global_batch_size = self._batch_size
|
||||
@@ -463,10 +463,10 @@ class ByolExperiment:
|
||||
|
||||
def _eval_batch(
|
||||
self,
|
||||
params,
|
||||
state,
|
||||
batch,
|
||||
):
|
||||
params: hk.Params,
|
||||
state: hk.State,
|
||||
batch: dataset.Batch,
|
||||
) -> Mapping[Text, jnp.ndarray]:
|
||||
"""Evaluates a batch.
|
||||
|
||||
Args:
|
||||
@@ -494,7 +494,7 @@ class ByolExperiment:
|
||||
'top5_accuracy': top5_correct,
|
||||
}
|
||||
|
||||
def _eval_epoch(self, subset, batch_size):
|
||||
def _eval_epoch(self, subset: Text, batch_size: int):
|
||||
"""Evaluates an epoch."""
|
||||
num_samples = 0.
|
||||
summed_scalars = None
|
||||
|
||||
@@ -23,7 +23,7 @@ _WD_PRESETS = {40: 1e-6, 100: 1e-6, 300: 1e-6, 1000: 1.5e-6}
|
||||
_EMA_PRESETS = {40: 0.97, 100: 0.99, 300: 0.99, 1000: 0.996}
|
||||
|
||||
|
||||
def get_config(num_epochs, batch_size):
|
||||
def get_config(num_epochs: int, batch_size: int):
|
||||
"""Return config object, containing all hyperparameters for training."""
|
||||
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Text
|
||||
from byol.utils import dataset
|
||||
|
||||
|
||||
def get_config(checkpoint_to_evaluate, batch_size):
|
||||
def get_config(checkpoint_to_evaluate: Text, batch_size: int):
|
||||
"""Return config object for training."""
|
||||
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
|
||||
|
||||
|
||||
+47
-47
@@ -53,19 +53,19 @@ class EvalExperiment:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
random_seed,
|
||||
num_classes,
|
||||
batch_size,
|
||||
max_steps,
|
||||
enable_double_transpose,
|
||||
checkpoint_to_evaluate,
|
||||
allow_train_from_scratch,
|
||||
freeze_backbone,
|
||||
network_config,
|
||||
optimizer_config,
|
||||
lr_schedule_config,
|
||||
evaluation_config,
|
||||
checkpointing_config):
|
||||
random_seed: int,
|
||||
num_classes: int,
|
||||
batch_size: int,
|
||||
max_steps: int,
|
||||
enable_double_transpose: bool,
|
||||
checkpoint_to_evaluate: Optional[Text],
|
||||
allow_train_from_scratch: bool,
|
||||
freeze_backbone: bool,
|
||||
network_config: Mapping[Text, Any],
|
||||
optimizer_config: Mapping[Text, Any],
|
||||
lr_schedule_config: Mapping[Text, Any],
|
||||
evaluation_config: Mapping[Text, Any],
|
||||
checkpointing_config: Mapping[Text, Any]):
|
||||
"""Constructs the experiment.
|
||||
|
||||
Args:
|
||||
@@ -125,12 +125,12 @@ class EvalExperiment:
|
||||
|
||||
def _backbone_fn(
|
||||
self,
|
||||
inputs,
|
||||
encoder_class,
|
||||
encoder_config,
|
||||
bn_decay_rate,
|
||||
is_training,
|
||||
):
|
||||
inputs: dataset.Batch,
|
||||
encoder_class: Text,
|
||||
encoder_config: Mapping[Text, Any],
|
||||
bn_decay_rate: float,
|
||||
is_training: bool,
|
||||
) -> jnp.ndarray:
|
||||
"""Forward of the encoder (backbone)."""
|
||||
bn_config = {'decay_rate': bn_decay_rate}
|
||||
encoder = getattr(networks, encoder_class)
|
||||
@@ -146,8 +146,8 @@ class EvalExperiment:
|
||||
|
||||
def _classif_fn(
|
||||
self,
|
||||
embeddings,
|
||||
):
|
||||
embeddings: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
classifier = hk.Linear(output_size=self._num_classes)
|
||||
return classifier(embeddings)
|
||||
|
||||
@@ -159,8 +159,8 @@ class EvalExperiment:
|
||||
#
|
||||
|
||||
def step(self, *,
|
||||
global_step,
|
||||
rng):
|
||||
global_step: jnp.ndarray,
|
||||
rng: jnp.ndarray) -> Mapping[Text, np.ndarray]:
|
||||
"""Performs a single training step."""
|
||||
|
||||
if self._train_input is None:
|
||||
@@ -173,12 +173,12 @@ class EvalExperiment:
|
||||
scalars = helpers.get_first(scalars)
|
||||
return scalars
|
||||
|
||||
def save_checkpoint(self, step, rng):
|
||||
def save_checkpoint(self, step: int, rng: jnp.ndarray):
|
||||
self._checkpointer.maybe_save_checkpoint(
|
||||
self._experiment_state, step=step, rng=rng,
|
||||
is_final=step >= self._max_steps)
|
||||
|
||||
def load_checkpoint(self):
|
||||
def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]:
|
||||
checkpoint_data = self._checkpointer.maybe_load_checkpoint()
|
||||
if checkpoint_data is None:
|
||||
return None
|
||||
@@ -253,11 +253,11 @@ class EvalExperiment:
|
||||
|
||||
def _make_initial_state(
|
||||
self,
|
||||
rng,
|
||||
dummy_input,
|
||||
backbone_params,
|
||||
backbone_state,
|
||||
):
|
||||
rng: jnp.ndarray,
|
||||
dummy_input: dataset.Batch,
|
||||
backbone_params: hk.Params,
|
||||
backbone_state: hk.Params,
|
||||
) -> _EvalExperimentState:
|
||||
"""_EvalExperimentState initialization."""
|
||||
|
||||
# Initialize the backbone params
|
||||
@@ -279,7 +279,7 @@ class EvalExperiment:
|
||||
classif_opt_state=classif_opt_state,
|
||||
)
|
||||
|
||||
def _build_train_input(self):
|
||||
def _build_train_input(self) -> Generator[dataset.Batch, None, None]:
|
||||
"""See base class."""
|
||||
num_devices = jax.device_count()
|
||||
global_batch_size = self._batch_size
|
||||
@@ -296,17 +296,17 @@ class EvalExperiment:
|
||||
transpose=self._should_transpose_images(),
|
||||
batch_dims=[jax.local_device_count(), per_device_batch_size])
|
||||
|
||||
def _optimizer(self, learning_rate):
|
||||
def _optimizer(self, learning_rate: float):
|
||||
"""Build optimizer from config."""
|
||||
return optax.sgd(learning_rate, **self._optimizer_config)
|
||||
|
||||
def _loss_fn(
|
||||
self,
|
||||
backbone_params,
|
||||
classif_params,
|
||||
backbone_state,
|
||||
inputs,
|
||||
):
|
||||
backbone_params: hk.Params,
|
||||
classif_params: hk.Params,
|
||||
backbone_state: hk.State,
|
||||
inputs: dataset.Batch,
|
||||
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, hk.State]]:
|
||||
"""Compute the classification loss function.
|
||||
|
||||
Args:
|
||||
@@ -333,10 +333,10 @@ class EvalExperiment:
|
||||
|
||||
def _update_func(
|
||||
self,
|
||||
experiment_state,
|
||||
global_step,
|
||||
inputs,
|
||||
):
|
||||
experiment_state: _EvalExperimentState,
|
||||
global_step: jnp.ndarray,
|
||||
inputs: dataset.Batch,
|
||||
) -> Tuple[_EvalExperimentState, LogsDict]:
|
||||
"""Applies an update to parameters and returns new state."""
|
||||
# This function computes the gradient of the first output of loss_fn and
|
||||
# passes through the other arguments unchanged.
|
||||
@@ -421,11 +421,11 @@ class EvalExperiment:
|
||||
|
||||
def _eval_batch(
|
||||
self,
|
||||
backbone_params,
|
||||
classif_params,
|
||||
backbone_state,
|
||||
inputs,
|
||||
):
|
||||
backbone_params: hk.Params,
|
||||
classif_params: hk.Params,
|
||||
backbone_state: hk.State,
|
||||
inputs: dataset.Batch,
|
||||
) -> LogsDict:
|
||||
"""Evaluates a batch."""
|
||||
embeddings, backbone_state = self.forward_backbone.apply(
|
||||
backbone_params, backbone_state, inputs, is_training=False)
|
||||
@@ -441,7 +441,7 @@ class EvalExperiment:
|
||||
'top5_accuracy': top5_correct
|
||||
}
|
||||
|
||||
def _eval_epoch(self, subset, batch_size):
|
||||
def _eval_epoch(self, subset: Text, batch_size: int):
|
||||
"""Evaluates an epoch."""
|
||||
num_samples = 0.
|
||||
summed_scalars = None
|
||||
|
||||
+2
-2
@@ -47,7 +47,7 @@ Experiment = Union[
|
||||
Type[eval_experiment.EvalExperiment]]
|
||||
|
||||
|
||||
def train_loop(experiment_class, config):
|
||||
def train_loop(experiment_class: Experiment, config: Mapping[Text, Any]):
|
||||
"""The main training loop.
|
||||
|
||||
This loop periodically saves a checkpoint to be evaluated in the eval_loop.
|
||||
@@ -95,7 +95,7 @@ def train_loop(experiment_class, config):
|
||||
experiment.save_checkpoint(step, rng)
|
||||
|
||||
|
||||
def eval_loop(experiment_class, config):
|
||||
def eval_loop(experiment_class: Experiment, config: Mapping[Text, Any]):
|
||||
"""The main evaluation loop.
|
||||
|
||||
This loop periodically loads a checkpoint and evaluates its performance on the
|
||||
|
||||
@@ -66,14 +66,14 @@ augment_config = dict(
|
||||
))
|
||||
|
||||
|
||||
def postprocess(inputs, rng):
|
||||
def postprocess(inputs: JaxBatch, rng: jnp.ndarray):
|
||||
"""Apply the image augmentations to crops in inputs (view1 and view2)."""
|
||||
|
||||
def _postprocess_image(
|
||||
images,
|
||||
rng,
|
||||
presets,
|
||||
):
|
||||
images: jnp.ndarray,
|
||||
rng: jnp.ndarray,
|
||||
presets: ConfigDict,
|
||||
) -> JaxBatch:
|
||||
"""Applies augmentations in post-processing.
|
||||
|
||||
Args:
|
||||
@@ -144,7 +144,7 @@ def _gaussian_blur_single_image(image, kernel_size, padding, sigma):
|
||||
blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels])
|
||||
expand_batch_dim = len(image.shape) == 3
|
||||
if expand_batch_dim:
|
||||
image = image[jnp.newaxis, Ellipsis]
|
||||
image = image[jnp.newaxis, ...]
|
||||
blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding)
|
||||
blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding)
|
||||
blurred = jnp.squeeze(blurred, axis=0)
|
||||
@@ -284,7 +284,7 @@ def _random_hue(rgb_tuple, rng, max_delta):
|
||||
|
||||
def _to_grayscale(image):
|
||||
rgb_weights = jnp.array([0.2989, 0.5870, 0.1140])
|
||||
grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[Ellipsis, jnp.newaxis]
|
||||
grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[..., jnp.newaxis]
|
||||
return jnp.tile(grayscale, (1, 1, 3)) # Back to 3 channels.
|
||||
|
||||
|
||||
|
||||
@@ -31,10 +31,10 @@ class Checkpointer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_checkpointing,
|
||||
checkpoint_dir,
|
||||
save_checkpoint_interval,
|
||||
filename):
|
||||
use_checkpointing: bool,
|
||||
checkpoint_dir: Text,
|
||||
save_checkpoint_interval: int,
|
||||
filename: Text):
|
||||
if (not use_checkpointing or
|
||||
checkpoint_dir is None or
|
||||
save_checkpoint_interval <= 0):
|
||||
@@ -51,10 +51,10 @@ class Checkpointer:
|
||||
|
||||
def maybe_save_checkpoint(
|
||||
self,
|
||||
experiment_state,
|
||||
step,
|
||||
rng,
|
||||
is_final):
|
||||
experiment_state: Mapping[Text, jnp.ndarray],
|
||||
step: int,
|
||||
rng: jnp.ndarray,
|
||||
is_final: bool):
|
||||
"""Saves a checkpoint if enough time has passed since the previous one."""
|
||||
current_time = time.time()
|
||||
if (not self._checkpoint_enabled or
|
||||
@@ -80,7 +80,7 @@ class Checkpointer:
|
||||
self._last_checkpoint_time = current_time
|
||||
|
||||
def maybe_load_checkpoint(
|
||||
self):
|
||||
self) -> Union[Tuple[Mapping[Text, jnp.ndarray], int, jnp.ndarray], None]:
|
||||
"""Loads a checkpoint if any is found."""
|
||||
checkpoint_data = load_checkpoint(self._checkpoint_path)
|
||||
if checkpoint_data is None:
|
||||
|
||||
+17
-17
@@ -34,7 +34,7 @@ class Split(enum.Enum):
|
||||
TEST = 4
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, name):
|
||||
def from_string(cls, name: Text) -> 'Split':
|
||||
return {
|
||||
'TRAIN': Split.TRAIN,
|
||||
'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
|
||||
@@ -60,7 +60,7 @@ class PreprocessMode(enum.Enum):
|
||||
EVAL = 3 # Generates a single center crop.
|
||||
|
||||
|
||||
def normalize_images(images):
|
||||
def normalize_images(images: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Normalize the image using ImageNet statistics."""
|
||||
mean_rgb = (0.485, 0.456, 0.406)
|
||||
stddev_rgb = (0.229, 0.224, 0.225)
|
||||
@@ -69,12 +69,12 @@ def normalize_images(images):
|
||||
return normed_images
|
||||
|
||||
|
||||
def load(split,
|
||||
def load(split: Split,
|
||||
*,
|
||||
preprocess_mode,
|
||||
batch_dims,
|
||||
transpose = False,
|
||||
allow_caching = False):
|
||||
preprocess_mode: PreprocessMode,
|
||||
batch_dims: Sequence[int],
|
||||
transpose: bool = False,
|
||||
allow_caching: bool = False) -> Generator[Batch, None, None]:
|
||||
"""Loads the given split of the dataset."""
|
||||
start, end = _shard(split, jax.host_id(), jax.host_count())
|
||||
|
||||
@@ -153,7 +153,7 @@ def load(split,
|
||||
yield from tfds.as_numpy(ds)
|
||||
|
||||
|
||||
def _to_tfds_split(split):
|
||||
def _to_tfds_split(split: Split) -> tfds.Split:
|
||||
"""Returns the TFDS split appropriately sharded."""
|
||||
# NOTE: Imagenet did not release labels for the test split used in the
|
||||
# competition, we consider the VALID split the TEST split and reserve
|
||||
@@ -165,7 +165,7 @@ def _to_tfds_split(split):
|
||||
return tfds.Split.VALIDATION
|
||||
|
||||
|
||||
def _shard(split, shard_index, num_shards):
|
||||
def _shard(split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]:
|
||||
"""Returns [start, end) for the given shard index."""
|
||||
assert shard_index < num_shards
|
||||
arange = np.arange(split.num_examples)
|
||||
@@ -180,9 +180,9 @@ def _shard(split, shard_index, num_shards):
|
||||
|
||||
|
||||
def _preprocess_image(
|
||||
image_bytes,
|
||||
mode,
|
||||
):
|
||||
image_bytes: tf.Tensor,
|
||||
mode: PreprocessMode,
|
||||
) -> tf.Tensor:
|
||||
"""Returns processed and resized images."""
|
||||
if mode is PreprocessMode.PRETRAIN:
|
||||
image = _decode_and_random_crop(image_bytes)
|
||||
@@ -201,7 +201,7 @@ def _preprocess_image(
|
||||
return image
|
||||
|
||||
|
||||
def _decode_and_random_crop(image_bytes):
|
||||
def _decode_and_random_crop(image_bytes: tf.Tensor) -> tf.Tensor:
|
||||
"""Make a random crop of 224."""
|
||||
img_size = tf.image.extract_jpeg_shape(image_bytes)
|
||||
area = tf.cast(img_size[1] * img_size[0], tf.float32)
|
||||
@@ -231,7 +231,7 @@ def _decode_and_random_crop(image_bytes):
|
||||
return image
|
||||
|
||||
|
||||
def transpose_images(batch):
|
||||
def transpose_images(batch: Batch):
|
||||
"""Transpose images for TPU training.."""
|
||||
new_batch = dict(batch) # Avoid mutating in place.
|
||||
if 'images' in batch:
|
||||
@@ -243,9 +243,9 @@ def transpose_images(batch):
|
||||
|
||||
|
||||
def _decode_and_center_crop(
|
||||
image_bytes,
|
||||
jpeg_shape = None,
|
||||
):
|
||||
image_bytes: tf.Tensor,
|
||||
jpeg_shape: Optional[tf.Tensor] = None,
|
||||
) -> tf.Tensor:
|
||||
"""Crops to center of image with padding then scales."""
|
||||
if jpeg_shape is None:
|
||||
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
|
||||
|
||||
+14
-14
@@ -21,11 +21,11 @@ import jax.numpy as jnp
|
||||
|
||||
|
||||
def topk_accuracy(
|
||||
logits,
|
||||
labels,
|
||||
topk,
|
||||
ignore_label_above = None,
|
||||
):
|
||||
logits: jnp.ndarray,
|
||||
labels: jnp.ndarray,
|
||||
topk: int,
|
||||
ignore_label_above: Optional[int] = None,
|
||||
) -> jnp.ndarray:
|
||||
"""Top-num_codes accuracy."""
|
||||
assert len(labels.shape) == 1, 'topk expects 1d int labels.'
|
||||
assert len(logits.shape) == 2, 'topk expects 2d logits.'
|
||||
@@ -42,10 +42,10 @@ def topk_accuracy(
|
||||
|
||||
|
||||
def softmax_cross_entropy(
|
||||
logits,
|
||||
labels,
|
||||
reduction = 'mean',
|
||||
):
|
||||
logits: jnp.ndarray,
|
||||
labels: jnp.ndarray,
|
||||
reduction: Optional[Text] = 'mean',
|
||||
) -> jnp.ndarray:
|
||||
"""Computes softmax cross entropy given logits and one-hot class labels.
|
||||
|
||||
Args:
|
||||
@@ -72,10 +72,10 @@ def softmax_cross_entropy(
|
||||
|
||||
|
||||
def l2_normalize(
|
||||
x,
|
||||
axis = None,
|
||||
epsilon = 1e-12,
|
||||
):
|
||||
x: jnp.ndarray,
|
||||
axis: Optional[int] = None,
|
||||
epsilon: float = 1e-12,
|
||||
) -> jnp.ndarray:
|
||||
"""l2 normalize a tensor on an axis with numerical stability."""
|
||||
square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)
|
||||
x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))
|
||||
@@ -106,7 +106,7 @@ def l2_weight_regularizer(params):
|
||||
return 0.5 * l2_norm
|
||||
|
||||
|
||||
def regression_loss(x, y):
|
||||
def regression_loss(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Byol's regression loss. This is a simple cosine similarity."""
|
||||
normed_x, normed_y = l2_normalize(x, axis=-1), l2_normalize(y, axis=-1)
|
||||
return jnp.sum((normed_x - normed_y)**2, axis=-1)
|
||||
|
||||
+49
-49
@@ -27,17 +27,17 @@ class MLP(hk.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
hidden_size,
|
||||
output_size,
|
||||
bn_config,
|
||||
name: Text,
|
||||
hidden_size: int,
|
||||
output_size: int,
|
||||
bn_config: Mapping[Text, Any],
|
||||
):
|
||||
super().__init__(name=name)
|
||||
self._hidden_size = hidden_size
|
||||
self._output_size = output_size
|
||||
self._bn_config = bn_config
|
||||
|
||||
def __call__(self, inputs, is_training):
|
||||
def __call__(self, inputs: jnp.ndarray, is_training: bool) -> jnp.ndarray:
|
||||
out = hk.Linear(output_size=self._hidden_size, with_bias=True)(inputs)
|
||||
out = hk.BatchNorm(**self._bn_config)(out, is_training=is_training)
|
||||
out = jax.nn.relu(out)
|
||||
@@ -55,15 +55,15 @@ class ResNetTorso(hk.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
blocks_per_group,
|
||||
num_classes = None,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
bottleneck = True,
|
||||
channels_per_group = (256, 512, 1024, 2048),
|
||||
use_projection = (True, True, True, True),
|
||||
width_multiplier = 1,
|
||||
name = None,
|
||||
blocks_per_group: Sequence[int],
|
||||
num_classes: int = None,
|
||||
bn_config: Optional[Mapping[str, float]] = None,
|
||||
resnet_v2: bool = False,
|
||||
bottleneck: bool = True,
|
||||
channels_per_group: Sequence[int] = (256, 512, 1024, 2048),
|
||||
use_projection: Sequence[bool] = (True, True, True, True),
|
||||
width_multiplier: int = 1,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
@@ -155,11 +155,11 @@ class TinyResNet(ResNetTorso):
|
||||
"""Tiny resnet for local runs and tests."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes = None,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
num_classes: Optional[int] = None,
|
||||
bn_config: Optional[Mapping[str, float]] = None,
|
||||
resnet_v2: bool = False,
|
||||
width_multiplier: int = 1,
|
||||
name: Optional[str] = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
@@ -185,11 +185,11 @@ class ResNet18(ResNetTorso):
|
||||
"""ResNet18."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes = None,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
num_classes: Optional[int] = None,
|
||||
bn_config: Optional[Mapping[str, float]] = None,
|
||||
resnet_v2: bool = False,
|
||||
width_multiplier: int = 1,
|
||||
name: Optional[str] = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
@@ -215,11 +215,11 @@ class ResNet34(ResNetTorso):
|
||||
"""ResNet34."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
num_classes: Optional[int],
|
||||
bn_config: Optional[Mapping[str, float]] = None,
|
||||
resnet_v2: bool = False,
|
||||
width_multiplier: int = 1,
|
||||
name: Optional[str] = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
@@ -245,11 +245,11 @@ class ResNet50(ResNetTorso):
|
||||
"""ResNet50."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes = None,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
num_classes: Optional[int] = None,
|
||||
bn_config: Optional[Mapping[str, float]] = None,
|
||||
resnet_v2: bool = False,
|
||||
width_multiplier: int = 1,
|
||||
name: Optional[str] = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
@@ -274,11 +274,11 @@ class ResNet101(ResNetTorso):
|
||||
"""ResNet101."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
num_classes: Optional[int],
|
||||
bn_config: Optional[Mapping[str, float]] = None,
|
||||
resnet_v2: bool = False,
|
||||
width_multiplier: int = 1,
|
||||
name: Optional[str] = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
@@ -303,11 +303,11 @@ class ResNet152(ResNetTorso):
|
||||
"""ResNet152."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
num_classes: Optional[int],
|
||||
bn_config: Optional[Mapping[str, float]] = None,
|
||||
resnet_v2: bool = False,
|
||||
width_multiplier: int = 1,
|
||||
name: Optional[str] = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
@@ -332,11 +332,11 @@ class ResNet200(ResNetTorso):
|
||||
"""ResNet200."""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
bn_config = None,
|
||||
resnet_v2 = False,
|
||||
width_multiplier = 1,
|
||||
name = None):
|
||||
num_classes: Optional[int],
|
||||
bn_config: Optional[Mapping[str, float]] = None,
|
||||
resnet_v2: bool = False,
|
||||
width_multiplier: int = 1,
|
||||
name: Optional[str] = None):
|
||||
"""Constructs a ResNet model.
|
||||
|
||||
Args:
|
||||
|
||||
+29
-29
@@ -27,7 +27,7 @@ import tree as nest
|
||||
FilterFn = Callable[[Tuple[Any], jnp.ndarray], jnp.ndarray]
|
||||
|
||||
|
||||
def exclude_bias_and_norm(path, val):
|
||||
def exclude_bias_and_norm(path: Tuple[Any], val: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Filter to exclude biaises and normalizations weights."""
|
||||
del val
|
||||
if path[-1] == "b" or "norm" in path[-2]:
|
||||
@@ -35,10 +35,10 @@ def exclude_bias_and_norm(path, val):
|
||||
return True
|
||||
|
||||
|
||||
def _partial_update(updates,
|
||||
new_updates,
|
||||
params,
|
||||
filter_fn = None):
|
||||
def _partial_update(updates: optax.Updates,
|
||||
new_updates: optax.Updates,
|
||||
params: optax.Params,
|
||||
filter_fn: Optional[FilterFn] = None) -> optax.Updates:
|
||||
"""Returns new_update for params which filter_fn is True else updates."""
|
||||
|
||||
if filter_fn is None:
|
||||
@@ -47,7 +47,7 @@ def _partial_update(updates,
|
||||
wrapped_filter_fn = lambda x, y: jnp.array(filter_fn(x, y))
|
||||
params_to_filter = nest.map_structure_with_path(wrapped_filter_fn, params)
|
||||
|
||||
def _update_fn(g, t, m):
|
||||
def _update_fn(g: jnp.ndarray, t: jnp.ndarray, m: jnp.ndarray) -> jnp.ndarray:
|
||||
m = m.astype(g.dtype)
|
||||
return g * (1. - m) + t * m
|
||||
|
||||
@@ -59,9 +59,9 @@ class ScaleByLarsState(NamedTuple):
|
||||
|
||||
|
||||
def scale_by_lars(
|
||||
momentum = 0.9,
|
||||
eta = 0.001,
|
||||
filter_fn = None):
|
||||
momentum: float = 0.9,
|
||||
eta: float = 0.001,
|
||||
filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation:
|
||||
"""Rescales updates according to the LARS algorithm.
|
||||
|
||||
Does not include weight decay.
|
||||
@@ -77,17 +77,17 @@ def scale_by_lars(
|
||||
An (init_fn, update_fn) tuple.
|
||||
"""
|
||||
|
||||
def init_fn(params):
|
||||
def init_fn(params: optax.Params) -> ScaleByLarsState:
|
||||
mu = jax.tree_multimap(jnp.zeros_like, params) # momentum
|
||||
return ScaleByLarsState(mu=mu)
|
||||
|
||||
def update_fn(updates, state,
|
||||
params):
|
||||
def update_fn(updates: optax.Updates, state: ScaleByLarsState,
|
||||
params: optax.Params) -> Tuple[optax.Updates, ScaleByLarsState]:
|
||||
|
||||
def lars_adaptation(
|
||||
update,
|
||||
param,
|
||||
):
|
||||
update: jnp.ndarray,
|
||||
param: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
param_norm = jnp.linalg.norm(param)
|
||||
update_norm = jnp.linalg.norm(update)
|
||||
return update * jnp.where(
|
||||
@@ -110,8 +110,8 @@ class AddWeightDecayState(NamedTuple):
|
||||
|
||||
|
||||
def add_weight_decay(
|
||||
weight_decay,
|
||||
filter_fn = None):
|
||||
weight_decay: float,
|
||||
filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation:
|
||||
"""Adds a weight decay to the update.
|
||||
|
||||
Args:
|
||||
@@ -122,14 +122,14 @@ def add_weight_decay(
|
||||
An (init_fn, update_fn) tuple.
|
||||
"""
|
||||
|
||||
def init_fn(_):
|
||||
def init_fn(_) -> AddWeightDecayState:
|
||||
return AddWeightDecayState()
|
||||
|
||||
def update_fn(
|
||||
updates,
|
||||
state,
|
||||
params,
|
||||
):
|
||||
updates: optax.Updates,
|
||||
state: AddWeightDecayState,
|
||||
params: optax.Params,
|
||||
) -> Tuple[optax.Updates, AddWeightDecayState]:
|
||||
new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
|
||||
params)
|
||||
new_updates = _partial_update(updates, new_updates, params, filter_fn)
|
||||
@@ -142,13 +142,13 @@ LarsState = List # Type for the lars optimizer
|
||||
|
||||
|
||||
def lars(
|
||||
learning_rate,
|
||||
weight_decay = 0.,
|
||||
momentum = 0.9,
|
||||
eta = 0.001,
|
||||
weight_decay_filter = None,
|
||||
lars_adaptation_filter = None,
|
||||
):
|
||||
learning_rate: float,
|
||||
weight_decay: float = 0.,
|
||||
momentum: float = 0.9,
|
||||
eta: float = 0.001,
|
||||
weight_decay_filter: Optional[FilterFn] = None,
|
||||
lars_adaptation_filter: Optional[FilterFn] = None,
|
||||
) -> optax.GradientTransformation:
|
||||
"""Creates lars optimizer with weight decay.
|
||||
|
||||
References:
|
||||
|
||||
+11
-11
@@ -17,18 +17,18 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def target_ema(global_step,
|
||||
base_ema,
|
||||
max_steps):
|
||||
def target_ema(global_step: jnp.ndarray,
|
||||
base_ema: float,
|
||||
max_steps: int) -> jnp.ndarray:
|
||||
decay = _cosine_decay(global_step, max_steps, 1.)
|
||||
return 1. - (1. - base_ema) * decay
|
||||
|
||||
|
||||
def learning_schedule(global_step,
|
||||
batch_size,
|
||||
base_learning_rate,
|
||||
total_steps,
|
||||
warmup_steps):
|
||||
def learning_schedule(global_step: jnp.ndarray,
|
||||
batch_size: int,
|
||||
base_learning_rate: float,
|
||||
total_steps: int,
|
||||
warmup_steps: int) -> float:
|
||||
"""Cosine learning rate scheduler."""
|
||||
# Compute LR & Scaled LR
|
||||
scaled_lr = base_learning_rate * batch_size / 256.
|
||||
@@ -43,9 +43,9 @@ def learning_schedule(global_step,
|
||||
scaled_lr))
|
||||
|
||||
|
||||
def _cosine_decay(global_step,
|
||||
max_steps,
|
||||
initial_value):
|
||||
def _cosine_decay(global_step: jnp.ndarray,
|
||||
max_steps: int,
|
||||
initial_value: float) -> jnp.ndarray:
|
||||
"""Simple implementation of cosine decay from TF1."""
|
||||
global_step = jnp.minimum(global_step, max_steps)
|
||||
cosine_decay_value = 0.5 * (1 + jnp.cos(jnp.pi * global_step / max_steps))
|
||||
|
||||
+2
-2
@@ -46,7 +46,7 @@ def make_so_tangent(q):
|
||||
for j in range(i+1, n):
|
||||
a[i, j] = 1
|
||||
a[j, i] = -1
|
||||
dq[Ellipsis, ii] = a @ q # tangent vectors are skew-symmetric matrix times Q
|
||||
dq[..., ii] = a @ q # tangent vectors are skew-symmetric matrix times Q
|
||||
a[i, j] = 0
|
||||
a[j, i] = 0
|
||||
ii += 1
|
||||
@@ -106,7 +106,7 @@ def make_product_manifold(specification, npts):
|
||||
spec_array[1, i] = dim
|
||||
latent_dim += dim
|
||||
dat = np.random.randn(npts, dim+1)
|
||||
dat /= np.tile(np.sqrt(np.sum(dat**2, axis=1)[Ellipsis, None]),
|
||||
dat /= np.tile(np.sqrt(np.sum(dat**2, axis=1)[..., None]),
|
||||
[1, dim+1])
|
||||
elif so_spec is not None:
|
||||
dim = int(so_spec.group(1))
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
@@ -20,7 +20,7 @@ The architecture and performance of this model is described in our publication:
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
@@ -35,10 +35,10 @@ from typing import Any, Dict, Text, Tuple, Optional
|
||||
|
||||
|
||||
def make_graph_from_static_structure(
|
||||
positions,
|
||||
types,
|
||||
box,
|
||||
edge_threshold):
|
||||
positions: tf.Tensor,
|
||||
types: tf.Tensor,
|
||||
box: tf.Tensor,
|
||||
edge_threshold: float) -> graphs.GraphsTuple:
|
||||
"""Returns graph representing the static structure of the glass.
|
||||
|
||||
Each particle is represented by a node in the graph. The particle type is
|
||||
@@ -81,7 +81,7 @@ def make_graph_from_static_structure(
|
||||
)
|
||||
|
||||
|
||||
def apply_random_rotation(graph):
|
||||
def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple:
|
||||
"""Returns randomly rotated graph representation.
|
||||
|
||||
The rotation is an element of O(3) with rotation angles multiple of pi/2.
|
||||
@@ -118,9 +118,9 @@ class GraphBasedModel(snt.AbstractModule):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_recurrences,
|
||||
mlp_sizes,
|
||||
mlp_kwargs = None,
|
||||
n_recurrences: int,
|
||||
mlp_sizes: Tuple[int],
|
||||
mlp_kwargs: Optional[Dict[Text, Any]] = None,
|
||||
name='Graph'):
|
||||
"""Creates a new GraphBasedModel object.
|
||||
|
||||
@@ -168,7 +168,7 @@ class GraphBasedModel(snt.AbstractModule):
|
||||
node_model_fn=final_model_fn,
|
||||
edge_model_fn=model_fn)
|
||||
|
||||
def _build(self, graphs_tuple):
|
||||
def _build(self, graphs_tuple: graphs.GraphsTuple) -> tf.Tensor:
|
||||
"""Connects the model into the tensorflow graph.
|
||||
|
||||
Args:
|
||||
|
||||
+32
-32
@@ -16,7 +16,7 @@
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
@@ -53,8 +53,8 @@ class ParticleType(enum.IntEnum):
|
||||
|
||||
|
||||
def get_targets(
|
||||
initial_positions,
|
||||
trajectory_target_positions):
|
||||
initial_positions: np.ndarray,
|
||||
trajectory_target_positions: Sequence[np.ndarray]) -> np.ndarray:
|
||||
"""Returns the averaged particle mobilities from the sampled trajectories.
|
||||
|
||||
Args:
|
||||
@@ -70,9 +70,9 @@ def get_targets(
|
||||
|
||||
|
||||
def load_data(
|
||||
file_pattern,
|
||||
time_index,
|
||||
max_files_to_load = None):
|
||||
file_pattern: Text,
|
||||
time_index: int,
|
||||
max_files_to_load: Optional[int] = None) -> List[GlassSimulationData]:
|
||||
"""Returns a dictionary containing the training or test dataset.
|
||||
|
||||
The dictionary contains:
|
||||
@@ -108,9 +108,9 @@ def load_data(
|
||||
|
||||
|
||||
def get_loss_ops(
|
||||
prediction,
|
||||
target,
|
||||
types):
|
||||
prediction: tf.Tensor,
|
||||
target: tf.Tensor,
|
||||
types: tf.Tensor) -> LossCollection:
|
||||
"""Returns L1/L2 loss and correlation for type A particles.
|
||||
|
||||
Args:
|
||||
@@ -132,9 +132,9 @@ def get_loss_ops(
|
||||
|
||||
|
||||
def get_minimize_op(
|
||||
loss,
|
||||
learning_rate,
|
||||
grad_clip = None):
|
||||
loss: tf.Tensor,
|
||||
learning_rate: float,
|
||||
grad_clip: Optional[float] = None) -> tf.Tensor:
|
||||
"""Returns minimization operation.
|
||||
|
||||
Args:
|
||||
@@ -152,8 +152,8 @@ def get_minimize_op(
|
||||
|
||||
|
||||
def _log_stats_and_return_mean_correlation(
|
||||
label,
|
||||
stats):
|
||||
label: Text,
|
||||
stats: Sequence[LossCollection]) -> float:
|
||||
"""Logs performance statistics and returns mean correlation.
|
||||
|
||||
Args:
|
||||
@@ -171,20 +171,20 @@ def _log_stats_and_return_mean_correlation(
|
||||
return np.mean([s.correlation for s in stats])
|
||||
|
||||
|
||||
def train_model(train_file_pattern,
|
||||
test_file_pattern,
|
||||
max_files_to_load = None,
|
||||
n_epochs = 1000,
|
||||
time_index = 9,
|
||||
augment_data_using_rotations = True,
|
||||
learning_rate = 1e-4,
|
||||
grad_clip = 1.0,
|
||||
n_recurrences = 7,
|
||||
mlp_sizes = (64, 64),
|
||||
mlp_kwargs = None,
|
||||
edge_threshold = 2.0,
|
||||
measurement_store_interval = 1000,
|
||||
checkpoint_path = None):
|
||||
def train_model(train_file_pattern: Text,
|
||||
test_file_pattern: Text,
|
||||
max_files_to_load: Optional[int] = None,
|
||||
n_epochs: int = 1000,
|
||||
time_index: int = 9,
|
||||
augment_data_using_rotations: bool = True,
|
||||
learning_rate: float = 1e-4,
|
||||
grad_clip: Optional[float] = 1.0,
|
||||
n_recurrences: int = 7,
|
||||
mlp_sizes: Tuple[int] = (64, 64),
|
||||
mlp_kwargs: Optional[Dict[Text, Any]] = None,
|
||||
edge_threshold: float = 2.0,
|
||||
measurement_store_interval: int = 1000,
|
||||
checkpoint_path: Optional[Text] = None) -> float:
|
||||
"""Trains GraphModel using tensorflow.
|
||||
|
||||
Args:
|
||||
@@ -325,10 +325,10 @@ def train_model(train_file_pattern,
|
||||
return best_so_far
|
||||
|
||||
|
||||
def apply_model(checkpoint_path,
|
||||
file_pattern,
|
||||
max_files_to_load = None,
|
||||
time_index = 9):
|
||||
def apply_model(checkpoint_path: Text,
|
||||
file_pattern: Text,
|
||||
max_files_to_load: Optional[int] = None,
|
||||
time_index: int = 9) -> List[np.ndarray]:
|
||||
"""Applies trained GraphModel using tensorflow.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
@@ -145,8 +145,8 @@ class _HierarchicalCore(snt.AbstractModule):
|
||||
regularizers=self._regularizers,
|
||||
)(decoder_features)
|
||||
|
||||
mu = mu_logsigma[Ellipsis, :latent_dim]
|
||||
logsigma = mu_logsigma[Ellipsis, latent_dim:]
|
||||
mu = mu_logsigma[..., :latent_dim]
|
||||
logsigma = mu_logsigma[..., latent_dim:]
|
||||
dist = tfd.MultivariateNormalDiag(loc=mu, scale_diag=tf.exp(logsigma))
|
||||
distributions.append(dist)
|
||||
|
||||
|
||||
@@ -37,8 +37,8 @@ class ComponentDecoder(snt.AbstractModule):
|
||||
pixel_params = self._pixel_decoder(z_flat).params
|
||||
|
||||
self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp")
|
||||
mask_params = pixel_params[Ellipsis, 0:1]
|
||||
pixel_params = pixel_params[Ellipsis, 1:]
|
||||
mask_params = pixel_params[..., 0:1]
|
||||
pixel_params = pixel_params[..., 1:]
|
||||
|
||||
output = MixtureParameters(
|
||||
pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"),
|
||||
|
||||
@@ -134,8 +134,8 @@ class LocScaleDistribution(DistributionModule):
|
||||
n_channels = params.get_shape().as_list()[-1]
|
||||
assert n_channels % 2 == 0
|
||||
assert n_channels // 2 == self.output_shape[-1]
|
||||
loc = params[Ellipsis, :n_channels // 2]
|
||||
scale = params[Ellipsis, n_channels // 2:]
|
||||
loc = params[..., :n_channels // 2]
|
||||
scale = params[..., n_channels // 2:]
|
||||
|
||||
# apply activation functions
|
||||
if self._scale != "fixed":
|
||||
|
||||
@@ -84,7 +84,7 @@ class FactorRegressor(snt.AbstractModule):
|
||||
for m in self._mapping:
|
||||
with tf.name_scope(m.name):
|
||||
assert m.name in latent, "{} not in {}".format(m.name, latent.keys())
|
||||
pred = all_preds[Ellipsis, idx:idx + m.size]
|
||||
pred = all_preds[..., idx:idx + m.size]
|
||||
predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size))
|
||||
idx += m.size
|
||||
|
||||
@@ -165,7 +165,7 @@ class FactorRegressor(snt.AbstractModule):
|
||||
|
||||
@staticmethod
|
||||
def one_hot(f, nr_categories):
|
||||
return tf.one_hot(tf.cast(f[Ellipsis, 0], tf.int32), depth=nr_categories)
|
||||
return tf.one_hot(tf.cast(f[..., 0], tf.int32), depth=nr_categories)
|
||||
|
||||
@staticmethod
|
||||
def angle_to_vector(theta):
|
||||
@@ -194,7 +194,7 @@ def accuracy(labels, logits, assignment, mean_var_tot, num_vis):
|
||||
pred = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||||
labels = tf.argmax(labels, axis=-1, output_type=tf.int32)
|
||||
correct = tf.cast(tf.equal(labels, pred), tf.float32)
|
||||
return tf.reduce_sum(correct * assignment[Ellipsis, 0]) / num_vis
|
||||
return tf.reduce_sum(correct * assignment[..., 0]) / num_vis
|
||||
|
||||
|
||||
def r2(labels, pred, assignment, mean_var_tot, num_vis):
|
||||
|
||||
+10
-10
@@ -325,12 +325,12 @@ class IODINE(snt.AbstractModule):
|
||||
[get_components(xd) for xd in iterations["x_dist"]])
|
||||
|
||||
# metrics
|
||||
tm = tf.transpose(true_mask[Ellipsis, 0], [0, 1, 3, 4, 2])
|
||||
tm = tf.transpose(true_mask[..., 0], [0, 1, 3, 4, 2])
|
||||
tm = tf.reshape(tf.tile(tm, sg["1, T, 1, 1, 1"]), sg["B * T, H * W, L"])
|
||||
pm = tf.transpose(pred_mask[Ellipsis, 0], [0, 1, 3, 4, 2])
|
||||
pm = tf.transpose(pred_mask[..., 0], [0, 1, 3, 4, 2])
|
||||
pm = tf.reshape(pm, sg["B * T, H * W, K"])
|
||||
ari = tf.reshape(adjusted_rand_index(tm, pm), sg["B, T"])
|
||||
ari_nobg = tf.reshape(adjusted_rand_index(tm[Ellipsis, 1:], pm), sg["B, T"])
|
||||
ari_nobg = tf.reshape(adjusted_rand_index(tm[..., 1:], pm), sg["B, T"])
|
||||
|
||||
mse = tf.reduce_mean(tf.square(recons - image[:, None]), axis=[2, 3, 4, 5])
|
||||
|
||||
@@ -387,7 +387,7 @@ class IODINE(snt.AbstractModule):
|
||||
factor_info["assignment"].append(fass)
|
||||
for k in fpred:
|
||||
factor_info["predictions"][k].append(
|
||||
tf.reduce_sum(fpred[k] * fass[Ellipsis, None], axis=2))
|
||||
tf.reduce_sum(fpred[k] * fass[..., None], axis=2))
|
||||
factor_info["metrics"][k].append(fscalars[k])
|
||||
|
||||
info["losses"]["factor"] = sg.guard(tf.stack(factor_info["loss"]), "T")
|
||||
@@ -496,7 +496,7 @@ class IODINE(snt.AbstractModule):
|
||||
|
||||
@staticmethod
|
||||
def _get_mask_posterior(out_dist, img):
|
||||
p_comp = out_dist.components_distribution.prob(img[Ellipsis, tf.newaxis, :])
|
||||
p_comp = out_dist.components_distribution.prob(img[..., tf.newaxis, :])
|
||||
posterior = p_comp / (tf.reduce_sum(p_comp, axis=-1, keepdims=True) + 1e-6)
|
||||
return tf.transpose(posterior, [0, 4, 2, 3, 1])
|
||||
|
||||
@@ -506,7 +506,7 @@ class IODINE(snt.AbstractModule):
|
||||
dzp, dxp, dmp = tf.gradients(loss, [zp, out_params.pixel, out_params.mask])
|
||||
|
||||
log_prob = sg.guard(
|
||||
out_dist.log_prob(img)[Ellipsis, tf.newaxis], "B, 1, H, W, 1")
|
||||
out_dist.log_prob(img)[..., tf.newaxis], "B, 1, H, W, 1")
|
||||
|
||||
counterfactual_log_probs = []
|
||||
for k in range(0, self.num_components):
|
||||
@@ -515,7 +515,7 @@ class IODINE(snt.AbstractModule):
|
||||
pixel = tf.concat([out_params.pixel[:, :k], out_params.pixel[:, k + 1:]],
|
||||
axis=1)
|
||||
out_dist_k = self.output_dist(pixel, mask)
|
||||
log_prob_k = out_dist_k.log_prob(img)[Ellipsis, tf.newaxis]
|
||||
log_prob_k = out_dist_k.log_prob(img)[..., tf.newaxis]
|
||||
counterfactual_log_probs.append(log_prob_k)
|
||||
counterfactual = log_prob - tf.concat(counterfactual_log_probs, axis=1)
|
||||
|
||||
@@ -608,7 +608,7 @@ class IODINE(snt.AbstractModule):
|
||||
x_basis = tf.cos(valx * freqs[None, None, None, None, :, None])
|
||||
y_basis = tf.cos(valy * freqs[None, None, None, None, None, :])
|
||||
xy_basis = tf.reshape(x_basis * y_basis, self._sg["1, 1, H, W, F*F"])
|
||||
coords = tf.tile(xy_basis, self._sg["B, 1, 1, 1, 1"])[Ellipsis, 1:]
|
||||
coords = tf.tile(xy_basis, self._sg["B, 1, 1, 1, 1"])[..., 1:]
|
||||
return coords
|
||||
else:
|
||||
raise KeyError('Unknown coord_type: "{}"'.format(self.coord_type))
|
||||
@@ -632,7 +632,7 @@ class IODINE(snt.AbstractModule):
|
||||
# ########## Mask Monitoring #######
|
||||
if "mask" in data:
|
||||
true_mask = self._sg.guard(data["mask"], "B, T, L, H, W, 1")
|
||||
true_mask = tf.transpose(true_mask[:, -1, Ellipsis, 0], [0, 2, 3, 1])
|
||||
true_mask = tf.transpose(true_mask[:, -1, ..., 0], [0, 2, 3, 1])
|
||||
true_mask = self._sg.reshape(true_mask, "B, H*W, L")
|
||||
else:
|
||||
true_mask = None
|
||||
@@ -648,6 +648,6 @@ class IODINE(snt.AbstractModule):
|
||||
adjusted_rand_index(true_mask, pred_mask))
|
||||
|
||||
scalars["loss/ari_nobg"] = tf.reduce_mean(
|
||||
adjusted_rand_index(true_mask[Ellipsis, 1:], pred_mask))
|
||||
adjusted_rand_index(true_mask[..., 1:], pred_mask))
|
||||
|
||||
return scalars
|
||||
|
||||
@@ -258,7 +258,7 @@ class BroadcastConv(snt.AbstractModule):
|
||||
x_basis = tf.cos(valx * freqs[None, None, None, :, None])
|
||||
y_basis = tf.cos(valy * freqs[None, None, None, None, :])
|
||||
xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"])
|
||||
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[Ellipsis, 1:]
|
||||
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[..., 1:]
|
||||
return tf.concat([output, coords], axis=-1)
|
||||
else:
|
||||
raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
|
||||
|
||||
@@ -83,7 +83,7 @@ def show_mask(m, ax):
|
||||
@optional_clean_ax
|
||||
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
|
||||
return ax.matshow(
|
||||
m[Ellipsis, 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
|
||||
m[..., 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
|
||||
|
||||
|
||||
@optional_clean_ax
|
||||
@@ -115,7 +115,7 @@ def example_plot(rinfo,
|
||||
|
||||
show_img(image, ax=axes[0], color="#000000")
|
||||
show_img(recons, ax=axes[1], color="#000000")
|
||||
show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000")
|
||||
show_mask(pred_mask[..., 0], ax=axes[2], color="#000000")
|
||||
for k in range(K):
|
||||
mask = pred_mask[k] if mask_components else None
|
||||
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
|
||||
@@ -145,7 +145,7 @@ def iterations_plot(rinfo, b=0, mask_components=False, size=2):
|
||||
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
|
||||
for t in range(T):
|
||||
show_img(recons[t, 0], ax=axes[t, 0])
|
||||
show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1])
|
||||
show_mask(pred_mask[t, ..., 0], ax=axes[t, 1])
|
||||
axes[t, 0].set_ylabel("iter {}".format(t))
|
||||
for k in range(K):
|
||||
mask = pred_mask[t, k] if mask_components else None
|
||||
@@ -154,7 +154,7 @@ def iterations_plot(rinfo, b=0, mask_components=False, size=2):
|
||||
axes[0, 0].set_title("Reconstruction")
|
||||
axes[0, 1].set_title("Mask")
|
||||
show_img(image[0], ax=axes[T, 0])
|
||||
show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1])
|
||||
show_mask(true_mask[0, ..., 0], ax=axes[T, 1])
|
||||
vmin = np.min(pred_mask_logits[T - 1])
|
||||
vmax = np.max(pred_mask_logits[T - 1])
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ def construct_diagnostic_image(
|
||||
components = tf.tile(components[:nr_images], [1, 1, 1, 1, 3])
|
||||
|
||||
if mask_components:
|
||||
components *= masks[:nr_images, Ellipsis, tf.newaxis]
|
||||
components *= masks[:nr_images, ..., tf.newaxis]
|
||||
|
||||
# Pad everything
|
||||
no_pad, pad = (0, 0), (border_width, border_width)
|
||||
@@ -415,7 +415,7 @@ def images_to_grid(
|
||||
if max_grid_width is not None:
|
||||
grid_width = min(max_grid_width, grid_width)
|
||||
|
||||
images = images[: grid_height * grid_width, Ellipsis]
|
||||
images = images[: grid_height * grid_width, ...]
|
||||
|
||||
# Pad with extra blank frames if grid_height x grid_width is less than the
|
||||
# number of frames provided.
|
||||
@@ -460,7 +460,7 @@ def flatten_all_but_last(tensor, n_dims=1):
|
||||
|
||||
def ensure_3d(tensor):
|
||||
if tensor.shape.ndims == 2:
|
||||
return tensor[Ellipsis, None]
|
||||
return tensor[..., None]
|
||||
|
||||
assert tensor.shape.ndims == 3
|
||||
return tensor
|
||||
|
||||
@@ -82,7 +82,7 @@ class Agent():
|
||||
|
||||
def option_values(values, policy):
|
||||
return tf.tensordot(
|
||||
values[:, policy, Ellipsis], self._policy_weights[policy], axes=[1, 0])
|
||||
values[:, policy, ...], self._policy_weights[policy], axes=[1, 0])
|
||||
|
||||
# Placeholders for policy.
|
||||
o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
|
||||
@@ -103,8 +103,8 @@ class Agent():
|
||||
qo_t = option_values(q_t, p)
|
||||
|
||||
a_t = tf.cast(tf.argmax(qo_t, axis=-1), tf.int32)
|
||||
qa_tm1 = _batched_index(q_tm1[:, p, Ellipsis], a_tm1)
|
||||
qa_t = _batched_index(q_t[:, p, Ellipsis], a_t)
|
||||
qa_tm1 = _batched_index(q_tm1[:, p, ...], a_tm1)
|
||||
qa_t = _batched_index(q_t[:, p, ...], a_t)
|
||||
|
||||
# TD error
|
||||
g = additional_discount * tf.expand_dims(d_t, axis=-1)
|
||||
|
||||
@@ -91,7 +91,7 @@ def make_face_model_dataset(
|
||||
# Vertices are quantized. So convert to floats for input to face model
|
||||
example['vertices'] = modules.dequantize_verts(vertices, quantization_bits)
|
||||
example['vertices_mask'] = tf.ones_like(
|
||||
example['vertices'][Ellipsis, 0], dtype=tf.float32)
|
||||
example['vertices'][..., 0], dtype=tf.float32)
|
||||
example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32)
|
||||
return example
|
||||
return ds.map(_face_model_map_fn)
|
||||
|
||||
+8
-8
@@ -799,7 +799,7 @@ class VertexModel(snt.AbstractModule):
|
||||
# Continuous vertex value embeddings
|
||||
else:
|
||||
vert_embeddings = tf.layers.dense(
|
||||
dequantize_verts(vertices[Ellipsis, None], self.quantization_bits),
|
||||
dequantize_verts(vertices[..., None], self.quantization_bits),
|
||||
self.embedding_dim,
|
||||
use_bias=True,
|
||||
name='value_embeddings')
|
||||
@@ -984,7 +984,7 @@ class VertexModel(snt.AbstractModule):
|
||||
verts_dequantized = dequantize_verts(v, self.quantization_bits)
|
||||
vertices = tf.reshape(verts_dequantized, [num_samples, -1, 3])
|
||||
vertices = tf.stack(
|
||||
[vertices[Ellipsis, 2], vertices[Ellipsis, 1], vertices[Ellipsis, 0]], axis=-1)
|
||||
[vertices[..., 2], vertices[..., 1], vertices[..., 0]], axis=-1)
|
||||
|
||||
# Pad samples to max sample length. This is required in order to concatenate
|
||||
# Samples across different replicator instances. Pad with stopping tokens
|
||||
@@ -998,14 +998,14 @@ class VertexModel(snt.AbstractModule):
|
||||
|
||||
if recenter_verts:
|
||||
vert_max = tf.reduce_max(
|
||||
vertices - 1e10 * (1. - vertices_mask)[Ellipsis, None], axis=1,
|
||||
vertices - 1e10 * (1. - vertices_mask)[..., None], axis=1,
|
||||
keepdims=True)
|
||||
vert_min = tf.reduce_min(
|
||||
vertices + 1e10 * (1. - vertices_mask)[Ellipsis, None], axis=1,
|
||||
vertices + 1e10 * (1. - vertices_mask)[..., None], axis=1,
|
||||
keepdims=True)
|
||||
vert_centers = 0.5 * (vert_max + vert_min)
|
||||
vertices -= vert_centers
|
||||
vertices *= vertices_mask[Ellipsis, None]
|
||||
vertices *= vertices_mask[..., None]
|
||||
|
||||
if only_return_complete:
|
||||
vertices = tf.boolean_mask(vertices, completed)
|
||||
@@ -1247,7 +1247,7 @@ class FaceModel(snt.AbstractModule):
|
||||
sequential_context_embeddings = (
|
||||
vertex_embeddings *
|
||||
tf.pad(context['vertices_mask'], [[0, 0], [2, 0]],
|
||||
constant_values=1)[Ellipsis, None])
|
||||
constant_values=1)[..., None])
|
||||
else:
|
||||
sequential_context_embeddings = None
|
||||
return (vertex_embeddings, global_context_embedding,
|
||||
@@ -1266,11 +1266,11 @@ class FaceModel(snt.AbstractModule):
|
||||
embed_dim=self.embedding_dim,
|
||||
initializers={'embeddings': tf.glorot_uniform_initializer},
|
||||
densify_gradients=True,
|
||||
name='coord_{}'.format(c))(verts_quantized[Ellipsis, c])
|
||||
name='coord_{}'.format(c))(verts_quantized[..., c])
|
||||
else:
|
||||
vertex_embeddings = tf.layers.dense(
|
||||
vertices, self.embedding_dim, use_bias=True, name='vertex_embeddings')
|
||||
vertex_embeddings *= vertices_mask[Ellipsis, None]
|
||||
vertex_embeddings *= vertices_mask[..., None]
|
||||
|
||||
# Pad vertex embeddings with learned embeddings for stopping and new face
|
||||
# tokens
|
||||
|
||||
+16
-16
@@ -106,7 +106,7 @@ TESTING_SUITE = [
|
||||
ALL = TUNING_SUITE + TESTING_SUITE
|
||||
|
||||
|
||||
def _decode_frames(pngs):
|
||||
def _decode_frames(pngs: tf.Tensor):
|
||||
"""Decode PNGs.
|
||||
|
||||
Args:
|
||||
@@ -122,13 +122,13 @@ def _decode_frames(pngs):
|
||||
return frames
|
||||
|
||||
|
||||
def _make_reverb_sample(o_t,
|
||||
a_t,
|
||||
r_t,
|
||||
d_t,
|
||||
o_tp1,
|
||||
a_tp1,
|
||||
extras):
|
||||
def _make_reverb_sample(o_t: tf.Tensor,
|
||||
a_t: tf.Tensor,
|
||||
r_t: tf.Tensor,
|
||||
d_t: tf.Tensor,
|
||||
o_tp1: tf.Tensor,
|
||||
a_tp1: tf.Tensor,
|
||||
extras: Dict[str, tf.Tensor]) -> reverb.ReplaySample:
|
||||
"""Create Reverb sample with offline data.
|
||||
|
||||
Args:
|
||||
@@ -151,8 +151,8 @@ def _make_reverb_sample(o_t,
|
||||
return reverb.ReplaySample(info=info, data=data)
|
||||
|
||||
|
||||
def _tf_example_to_reverb_sample(tf_example
|
||||
):
|
||||
def _tf_example_to_reverb_sample(tf_example: tf.train.Example
|
||||
) -> reverb.ReplaySample:
|
||||
"""Create a Reverb replay sample from a TF example."""
|
||||
|
||||
# Parse tf.Example.
|
||||
@@ -184,11 +184,11 @@ def _tf_example_to_reverb_sample(tf_example
|
||||
extras)
|
||||
|
||||
|
||||
def dataset(path,
|
||||
game,
|
||||
run,
|
||||
num_shards = 100,
|
||||
shuffle_buffer_size = 100000):
|
||||
def dataset(path: str,
|
||||
game: str,
|
||||
run: int,
|
||||
num_shards: int = 100,
|
||||
shuffle_buffer_size: int = 100000) -> tf.data.Dataset:
|
||||
"""TF dataset of Atari SARSA tuples."""
|
||||
path = os.path.join(path, f'{game}/run_{run}')
|
||||
filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(num_shards)]
|
||||
@@ -243,7 +243,7 @@ class AtariDopamineWrapper(dm_env.Environment):
|
||||
return specs.DiscreteArray(self._env.action_space.n)
|
||||
|
||||
|
||||
def environment(game):
|
||||
def environment(game: str) -> dm_env.Environment:
|
||||
"""Atari environment."""
|
||||
env = atari_lib.create_atari_environment(game_name=game,
|
||||
sticky_actions=True)
|
||||
|
||||
@@ -773,15 +773,15 @@ def _padded_batch(example_ds, batch_size, shapes, drop_remainder=False):
|
||||
drop_remainder=drop_remainder)
|
||||
|
||||
|
||||
def dataset(root_path,
|
||||
data_path,
|
||||
shapes,
|
||||
num_threads,
|
||||
batch_size,
|
||||
uint8_features = None,
|
||||
num_shards = 100,
|
||||
shuffle_buffer_size = 100000,
|
||||
sarsa = True):
|
||||
def dataset(root_path: str,
|
||||
data_path: str,
|
||||
shapes: Dict[str, Tuple[int]],
|
||||
num_threads: int,
|
||||
batch_size: int,
|
||||
uint8_features: Set[str] = None,
|
||||
num_shards: int = 100,
|
||||
shuffle_buffer_size: int = 100000,
|
||||
sarsa: bool = True) -> tf.data.Dataset:
|
||||
"""Create tf dataset for training."""
|
||||
|
||||
uint8_features = uint8_features if uint8_features else {}
|
||||
|
||||
+17
-17
@@ -55,7 +55,7 @@ DELIMITER = ':'
|
||||
DEFAULT_NUM_TIMESTEPS = 1001
|
||||
|
||||
|
||||
def _decombine_key(k, delimiter = DELIMITER):
|
||||
def _decombine_key(k: str, delimiter: str = DELIMITER) -> Sequence[str]:
|
||||
return k.split(delimiter)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ def tf_example_to_feature_description(example,
|
||||
|
||||
|
||||
def tree_deflatten_with_delimiter(
|
||||
flat_dict, delimiter = DELIMITER):
|
||||
flat_dict: Dict[str, Any], delimiter: str = DELIMITER) -> Dict[str, Any]:
|
||||
"""De-flattens a dict to its originally nested structure.
|
||||
|
||||
Does the opposite of {combine_nested_keys(k) :v
|
||||
@@ -102,12 +102,12 @@ def tree_deflatten_with_delimiter(
|
||||
return dict(root)
|
||||
|
||||
|
||||
def get_slice_of_nested(nested, start,
|
||||
end):
|
||||
def get_slice_of_nested(nested: Dict[str, Any], start: int,
|
||||
end: int) -> Dict[str, Any]:
|
||||
return tree.map_structure(lambda item: item[start:end], nested)
|
||||
|
||||
|
||||
def repeat_last_and_append_to_nested(nested):
|
||||
def repeat_last_and_append_to_nested(nested: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return tree.map_structure(
|
||||
lambda item: tf.concat((item, item[-1:]), axis=0), nested)
|
||||
|
||||
@@ -133,13 +133,13 @@ def tf_example_to_reverb_sample(example,
|
||||
return ret
|
||||
|
||||
|
||||
def dataset(path,
|
||||
combined_challenge,
|
||||
domain,
|
||||
task,
|
||||
difficulty,
|
||||
num_shards = 100,
|
||||
shuffle_buffer_size = 100000):
|
||||
def dataset(path: str,
|
||||
combined_challenge: str,
|
||||
domain: str,
|
||||
task: str,
|
||||
difficulty: str,
|
||||
num_shards: int = 100,
|
||||
shuffle_buffer_size: int = 100000) -> tf.data.Dataset:
|
||||
"""TF dataset of RWRL SARSA tuples."""
|
||||
path = os.path.join(
|
||||
path,
|
||||
@@ -178,11 +178,11 @@ def dataset(path,
|
||||
|
||||
|
||||
def environment(
|
||||
combined_challenge,
|
||||
domain,
|
||||
task,
|
||||
log_output = None,
|
||||
environment_kwargs = None):
|
||||
combined_challenge: str,
|
||||
domain: str,
|
||||
task: str,
|
||||
log_output: Optional[str] = None,
|
||||
environment_kwargs: Optional[Dict[str, Any]] = None) -> dm_env.Environment:
|
||||
"""RWRL environment."""
|
||||
env = rwrl_envs.load(
|
||||
domain_name=domain,
|
||||
|
||||
@@ -106,8 +106,8 @@ class Transporter(snt.AbstractModule):
|
||||
num_keypoints = image_a_keypoints["heatmaps"].shape[-1]
|
||||
transported_features = image_a_features
|
||||
for k in range(num_keypoints):
|
||||
mask_a = image_a_keypoints["heatmaps"][Ellipsis, k, None]
|
||||
mask_b = image_b_keypoints["heatmaps"][Ellipsis, k, None]
|
||||
mask_a = image_a_keypoints["heatmaps"][..., k, None]
|
||||
mask_b = image_b_keypoints["heatmaps"][..., k, None]
|
||||
|
||||
# suppress features from image a, around both keypoint locations.
|
||||
transported_features = (
|
||||
|
||||
Reference in New Issue
Block a user