mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-06-04 13:52:13 +08:00
Add checkpoints from the ablation study.
PiperOrigin-RevId: 328023346
This commit is contained in:
committed by
Diego de Las Casas
parent
22c3daff19
commit
8457046b2c
@@ -176,3 +176,37 @@ python -m byol.main_loop \
|
|||||||
With these settings, BYOL should achieve ~92.3% top-1 accuracy (for the
|
With these settings, BYOL should achieve ~92.3% top-1 accuracy (for the
|
||||||
*online* classifier) in roughly 4 hours. Note that the above parameters were not
|
*online* classifier) in roughly 4 hours. Note that the above parameters were not
|
||||||
finely tuned and may not be optimal.
|
finely tuned and may not be optimal.
|
||||||
|
|
||||||
|
|
||||||
|
## Additional checkpoints
|
||||||
|
|
||||||
|
Alongside with the pretrained ResNet-50 and ResNet-200 2x, we provide the
|
||||||
|
following checkpoints from our ablation study. They all correspond to a
|
||||||
|
ResNet-50 1x pre-trained over 300 epochs and were randomly selected within the
|
||||||
|
three seeds; file size is roughly 640MB each.
|
||||||
|
|
||||||
|
- [Baseline](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_baseline.pkl)
|
||||||
|
|
||||||
|
- Smaller batch sizes (figure 3a):
|
||||||
|
- [Batch size 2048](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_2048.pkl)
|
||||||
|
- [Batch size 1024](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_1024.pkl)
|
||||||
|
- [Batch size 512](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_512.pkl)
|
||||||
|
- [Batch size 256](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_256.pkl)
|
||||||
|
- [Batch size 128](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_128.pkl)
|
||||||
|
- [Batch size 64](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_64.pkl)
|
||||||
|
|
||||||
|
- Ablation on transformations (figure 3b):
|
||||||
|
- [Remove grayscale](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_no_grayscale.pkl)
|
||||||
|
- [Remove color](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_no_color.pkl)
|
||||||
|
- [Crop and blur only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_and_blur_only.pkl)
|
||||||
|
- [Crop only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_only.pkl)
|
||||||
|
- (from Table 18) [Crop and color only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_and_color_only.pkl)
|
||||||
|
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
While the code is licensed under the Apache 2.0 License, the checkpoints weights
|
||||||
|
are made available for non-commercial use only under the terms of the
|
||||||
|
Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
|
||||||
|
license. You can find details at:
|
||||||
|
https://creativecommons.org/licenses/by-nc/4.0/legalcode.
|
||||||
|
|||||||
+47
-47
@@ -56,17 +56,17 @@ class ByolExperiment:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
random_seed,
|
random_seed: int,
|
||||||
num_classes,
|
num_classes: int,
|
||||||
batch_size,
|
batch_size: int,
|
||||||
max_steps,
|
max_steps: int,
|
||||||
enable_double_transpose,
|
enable_double_transpose: bool,
|
||||||
base_target_ema,
|
base_target_ema: float,
|
||||||
network_config,
|
network_config: Mapping[Text, Any],
|
||||||
optimizer_config,
|
optimizer_config: Mapping[Text, Any],
|
||||||
lr_schedule_config,
|
lr_schedule_config: Mapping[Text, Any],
|
||||||
evaluation_config,
|
evaluation_config: Mapping[Text, Any],
|
||||||
checkpointing_config):
|
checkpointing_config: Mapping[Text, Any]):
|
||||||
"""Constructs the experiment.
|
"""Constructs the experiment.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -115,15 +115,15 @@ class ByolExperiment:
|
|||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
inputs,
|
inputs: dataset.Batch,
|
||||||
projector_hidden_size,
|
projector_hidden_size: int,
|
||||||
projector_output_size,
|
projector_output_size: int,
|
||||||
predictor_hidden_size,
|
predictor_hidden_size: int,
|
||||||
encoder_class,
|
encoder_class: Text,
|
||||||
encoder_config,
|
encoder_config: Mapping[Text, Any],
|
||||||
bn_config,
|
bn_config: Mapping[Text, Any],
|
||||||
is_training,
|
is_training: bool,
|
||||||
):
|
) -> Mapping[Text, jnp.ndarray]:
|
||||||
"""Forward application of byol's architecture.
|
"""Forward application of byol's architecture.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -163,7 +163,7 @@ class ByolExperiment:
|
|||||||
classifier = hk.Linear(
|
classifier = hk.Linear(
|
||||||
output_size=self._num_classes, name='classifier')
|
output_size=self._num_classes, name='classifier')
|
||||||
|
|
||||||
def apply_once_fn(images, suffix = ''):
|
def apply_once_fn(images: jnp.ndarray, suffix: Text = ''):
|
||||||
images = dataset.normalize_images(images)
|
images = dataset.normalize_images(images)
|
||||||
|
|
||||||
embedding = net(images, is_training=is_training)
|
embedding = net(images, is_training=is_training)
|
||||||
@@ -186,7 +186,7 @@ class ByolExperiment:
|
|||||||
else:
|
else:
|
||||||
return apply_once_fn(inputs['images'], '')
|
return apply_once_fn(inputs['images'], '')
|
||||||
|
|
||||||
def _optimizer(self, learning_rate):
|
def _optimizer(self, learning_rate: float) -> optax.GradientTransformation:
|
||||||
"""Build optimizer from config."""
|
"""Build optimizer from config."""
|
||||||
return optimizers.lars(
|
return optimizers.lars(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
@@ -196,13 +196,13 @@ class ByolExperiment:
|
|||||||
|
|
||||||
def loss_fn(
|
def loss_fn(
|
||||||
self,
|
self,
|
||||||
online_params,
|
online_params: hk.Params,
|
||||||
target_params,
|
target_params: hk.Params,
|
||||||
online_state,
|
online_state: hk.State,
|
||||||
target_state,
|
target_state: hk.Params,
|
||||||
rng,
|
rng: jnp.ndarray,
|
||||||
inputs,
|
inputs: dataset.Batch,
|
||||||
):
|
) -> Tuple[jnp.ndarray, Tuple[Mapping[Text, hk.State], LogsDict]]:
|
||||||
"""Compute BYOL's loss function.
|
"""Compute BYOL's loss function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -292,11 +292,11 @@ class ByolExperiment:
|
|||||||
|
|
||||||
def _update_fn(
|
def _update_fn(
|
||||||
self,
|
self,
|
||||||
byol_state,
|
byol_state: _ByolExperimentState,
|
||||||
global_step,
|
global_step: jnp.ndarray,
|
||||||
rng,
|
rng: jnp.ndarray,
|
||||||
inputs,
|
inputs: dataset.Batch,
|
||||||
):
|
) -> Tuple[_ByolExperimentState, LogsDict]:
|
||||||
"""Update online and target parameters.
|
"""Update online and target parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -352,9 +352,9 @@ class ByolExperiment:
|
|||||||
|
|
||||||
def _make_initial_state(
|
def _make_initial_state(
|
||||||
self,
|
self,
|
||||||
rng,
|
rng: jnp.ndarray,
|
||||||
dummy_input,
|
dummy_input: dataset.Batch,
|
||||||
):
|
) -> _ByolExperimentState:
|
||||||
"""BYOL's _ByolExperimentState initialization.
|
"""BYOL's _ByolExperimentState initialization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -393,8 +393,8 @@ class ByolExperiment:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def step(self, *,
|
def step(self, *,
|
||||||
global_step,
|
global_step: jnp.ndarray,
|
||||||
rng):
|
rng: jnp.ndarray) -> Mapping[Text, np.ndarray]:
|
||||||
"""Performs a single training step."""
|
"""Performs a single training step."""
|
||||||
if self._train_input is None:
|
if self._train_input is None:
|
||||||
self._initialize_train()
|
self._initialize_train()
|
||||||
@@ -410,11 +410,11 @@ class ByolExperiment:
|
|||||||
|
|
||||||
return helpers.get_first(scalars)
|
return helpers.get_first(scalars)
|
||||||
|
|
||||||
def save_checkpoint(self, step, rng):
|
def save_checkpoint(self, step: int, rng: jnp.ndarray):
|
||||||
self._checkpointer.maybe_save_checkpoint(
|
self._checkpointer.maybe_save_checkpoint(
|
||||||
self._byol_state, step=step, rng=rng, is_final=step >= self._max_steps)
|
self._byol_state, step=step, rng=rng, is_final=step >= self._max_steps)
|
||||||
|
|
||||||
def load_checkpoint(self):
|
def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]:
|
||||||
checkpoint_data = self._checkpointer.maybe_load_checkpoint()
|
checkpoint_data = self._checkpointer.maybe_load_checkpoint()
|
||||||
if checkpoint_data is None:
|
if checkpoint_data is None:
|
||||||
return None
|
return None
|
||||||
@@ -444,7 +444,7 @@ class ByolExperiment:
|
|||||||
|
|
||||||
self._byol_state = init_byol(rng=init_rng, dummy_input=inputs)
|
self._byol_state = init_byol(rng=init_rng, dummy_input=inputs)
|
||||||
|
|
||||||
def _build_train_input(self):
|
def _build_train_input(self) -> Generator[dataset.Batch, None, None]:
|
||||||
"""Loads the (infinitely looping) dataset iterator."""
|
"""Loads the (infinitely looping) dataset iterator."""
|
||||||
num_devices = jax.device_count()
|
num_devices = jax.device_count()
|
||||||
global_batch_size = self._batch_size
|
global_batch_size = self._batch_size
|
||||||
@@ -463,10 +463,10 @@ class ByolExperiment:
|
|||||||
|
|
||||||
def _eval_batch(
|
def _eval_batch(
|
||||||
self,
|
self,
|
||||||
params,
|
params: hk.Params,
|
||||||
state,
|
state: hk.State,
|
||||||
batch,
|
batch: dataset.Batch,
|
||||||
):
|
) -> Mapping[Text, jnp.ndarray]:
|
||||||
"""Evaluates a batch.
|
"""Evaluates a batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -494,7 +494,7 @@ class ByolExperiment:
|
|||||||
'top5_accuracy': top5_correct,
|
'top5_accuracy': top5_correct,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _eval_epoch(self, subset, batch_size):
|
def _eval_epoch(self, subset: Text, batch_size: int):
|
||||||
"""Evaluates an epoch."""
|
"""Evaluates an epoch."""
|
||||||
num_samples = 0.
|
num_samples = 0.
|
||||||
summed_scalars = None
|
summed_scalars = None
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ _WD_PRESETS = {40: 1e-6, 100: 1e-6, 300: 1e-6, 1000: 1.5e-6}
|
|||||||
_EMA_PRESETS = {40: 0.97, 100: 0.99, 300: 0.99, 1000: 0.996}
|
_EMA_PRESETS = {40: 0.97, 100: 0.99, 300: 0.99, 1000: 0.996}
|
||||||
|
|
||||||
|
|
||||||
def get_config(num_epochs, batch_size):
|
def get_config(num_epochs: int, batch_size: int):
|
||||||
"""Return config object, containing all hyperparameters for training."""
|
"""Return config object, containing all hyperparameters for training."""
|
||||||
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
|
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from typing import Text
|
|||||||
from byol.utils import dataset
|
from byol.utils import dataset
|
||||||
|
|
||||||
|
|
||||||
def get_config(checkpoint_to_evaluate, batch_size):
|
def get_config(checkpoint_to_evaluate: Text, batch_size: int):
|
||||||
"""Return config object for training."""
|
"""Return config object for training."""
|
||||||
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
|
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
|
||||||
|
|
||||||
|
|||||||
+47
-47
@@ -53,19 +53,19 @@ class EvalExperiment:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
random_seed,
|
random_seed: int,
|
||||||
num_classes,
|
num_classes: int,
|
||||||
batch_size,
|
batch_size: int,
|
||||||
max_steps,
|
max_steps: int,
|
||||||
enable_double_transpose,
|
enable_double_transpose: bool,
|
||||||
checkpoint_to_evaluate,
|
checkpoint_to_evaluate: Optional[Text],
|
||||||
allow_train_from_scratch,
|
allow_train_from_scratch: bool,
|
||||||
freeze_backbone,
|
freeze_backbone: bool,
|
||||||
network_config,
|
network_config: Mapping[Text, Any],
|
||||||
optimizer_config,
|
optimizer_config: Mapping[Text, Any],
|
||||||
lr_schedule_config,
|
lr_schedule_config: Mapping[Text, Any],
|
||||||
evaluation_config,
|
evaluation_config: Mapping[Text, Any],
|
||||||
checkpointing_config):
|
checkpointing_config: Mapping[Text, Any]):
|
||||||
"""Constructs the experiment.
|
"""Constructs the experiment.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -125,12 +125,12 @@ class EvalExperiment:
|
|||||||
|
|
||||||
def _backbone_fn(
|
def _backbone_fn(
|
||||||
self,
|
self,
|
||||||
inputs,
|
inputs: dataset.Batch,
|
||||||
encoder_class,
|
encoder_class: Text,
|
||||||
encoder_config,
|
encoder_config: Mapping[Text, Any],
|
||||||
bn_decay_rate,
|
bn_decay_rate: float,
|
||||||
is_training,
|
is_training: bool,
|
||||||
):
|
) -> jnp.ndarray:
|
||||||
"""Forward of the encoder (backbone)."""
|
"""Forward of the encoder (backbone)."""
|
||||||
bn_config = {'decay_rate': bn_decay_rate}
|
bn_config = {'decay_rate': bn_decay_rate}
|
||||||
encoder = getattr(networks, encoder_class)
|
encoder = getattr(networks, encoder_class)
|
||||||
@@ -146,8 +146,8 @@ class EvalExperiment:
|
|||||||
|
|
||||||
def _classif_fn(
|
def _classif_fn(
|
||||||
self,
|
self,
|
||||||
embeddings,
|
embeddings: jnp.ndarray,
|
||||||
):
|
) -> jnp.ndarray:
|
||||||
classifier = hk.Linear(output_size=self._num_classes)
|
classifier = hk.Linear(output_size=self._num_classes)
|
||||||
return classifier(embeddings)
|
return classifier(embeddings)
|
||||||
|
|
||||||
@@ -159,8 +159,8 @@ class EvalExperiment:
|
|||||||
#
|
#
|
||||||
|
|
||||||
def step(self, *,
|
def step(self, *,
|
||||||
global_step,
|
global_step: jnp.ndarray,
|
||||||
rng):
|
rng: jnp.ndarray) -> Mapping[Text, np.ndarray]:
|
||||||
"""Performs a single training step."""
|
"""Performs a single training step."""
|
||||||
|
|
||||||
if self._train_input is None:
|
if self._train_input is None:
|
||||||
@@ -173,12 +173,12 @@ class EvalExperiment:
|
|||||||
scalars = helpers.get_first(scalars)
|
scalars = helpers.get_first(scalars)
|
||||||
return scalars
|
return scalars
|
||||||
|
|
||||||
def save_checkpoint(self, step, rng):
|
def save_checkpoint(self, step: int, rng: jnp.ndarray):
|
||||||
self._checkpointer.maybe_save_checkpoint(
|
self._checkpointer.maybe_save_checkpoint(
|
||||||
self._experiment_state, step=step, rng=rng,
|
self._experiment_state, step=step, rng=rng,
|
||||||
is_final=step >= self._max_steps)
|
is_final=step >= self._max_steps)
|
||||||
|
|
||||||
def load_checkpoint(self):
|
def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]:
|
||||||
checkpoint_data = self._checkpointer.maybe_load_checkpoint()
|
checkpoint_data = self._checkpointer.maybe_load_checkpoint()
|
||||||
if checkpoint_data is None:
|
if checkpoint_data is None:
|
||||||
return None
|
return None
|
||||||
@@ -253,11 +253,11 @@ class EvalExperiment:
|
|||||||
|
|
||||||
def _make_initial_state(
|
def _make_initial_state(
|
||||||
self,
|
self,
|
||||||
rng,
|
rng: jnp.ndarray,
|
||||||
dummy_input,
|
dummy_input: dataset.Batch,
|
||||||
backbone_params,
|
backbone_params: hk.Params,
|
||||||
backbone_state,
|
backbone_state: hk.Params,
|
||||||
):
|
) -> _EvalExperimentState:
|
||||||
"""_EvalExperimentState initialization."""
|
"""_EvalExperimentState initialization."""
|
||||||
|
|
||||||
# Initialize the backbone params
|
# Initialize the backbone params
|
||||||
@@ -279,7 +279,7 @@ class EvalExperiment:
|
|||||||
classif_opt_state=classif_opt_state,
|
classif_opt_state=classif_opt_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_train_input(self):
|
def _build_train_input(self) -> Generator[dataset.Batch, None, None]:
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
num_devices = jax.device_count()
|
num_devices = jax.device_count()
|
||||||
global_batch_size = self._batch_size
|
global_batch_size = self._batch_size
|
||||||
@@ -296,17 +296,17 @@ class EvalExperiment:
|
|||||||
transpose=self._should_transpose_images(),
|
transpose=self._should_transpose_images(),
|
||||||
batch_dims=[jax.local_device_count(), per_device_batch_size])
|
batch_dims=[jax.local_device_count(), per_device_batch_size])
|
||||||
|
|
||||||
def _optimizer(self, learning_rate):
|
def _optimizer(self, learning_rate: float):
|
||||||
"""Build optimizer from config."""
|
"""Build optimizer from config."""
|
||||||
return optax.sgd(learning_rate, **self._optimizer_config)
|
return optax.sgd(learning_rate, **self._optimizer_config)
|
||||||
|
|
||||||
def _loss_fn(
|
def _loss_fn(
|
||||||
self,
|
self,
|
||||||
backbone_params,
|
backbone_params: hk.Params,
|
||||||
classif_params,
|
classif_params: hk.Params,
|
||||||
backbone_state,
|
backbone_state: hk.State,
|
||||||
inputs,
|
inputs: dataset.Batch,
|
||||||
):
|
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, hk.State]]:
|
||||||
"""Compute the classification loss function.
|
"""Compute the classification loss function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -333,10 +333,10 @@ class EvalExperiment:
|
|||||||
|
|
||||||
def _update_func(
|
def _update_func(
|
||||||
self,
|
self,
|
||||||
experiment_state,
|
experiment_state: _EvalExperimentState,
|
||||||
global_step,
|
global_step: jnp.ndarray,
|
||||||
inputs,
|
inputs: dataset.Batch,
|
||||||
):
|
) -> Tuple[_EvalExperimentState, LogsDict]:
|
||||||
"""Applies an update to parameters and returns new state."""
|
"""Applies an update to parameters and returns new state."""
|
||||||
# This function computes the gradient of the first output of loss_fn and
|
# This function computes the gradient of the first output of loss_fn and
|
||||||
# passes through the other arguments unchanged.
|
# passes through the other arguments unchanged.
|
||||||
@@ -421,11 +421,11 @@ class EvalExperiment:
|
|||||||
|
|
||||||
def _eval_batch(
|
def _eval_batch(
|
||||||
self,
|
self,
|
||||||
backbone_params,
|
backbone_params: hk.Params,
|
||||||
classif_params,
|
classif_params: hk.Params,
|
||||||
backbone_state,
|
backbone_state: hk.State,
|
||||||
inputs,
|
inputs: dataset.Batch,
|
||||||
):
|
) -> LogsDict:
|
||||||
"""Evaluates a batch."""
|
"""Evaluates a batch."""
|
||||||
embeddings, backbone_state = self.forward_backbone.apply(
|
embeddings, backbone_state = self.forward_backbone.apply(
|
||||||
backbone_params, backbone_state, inputs, is_training=False)
|
backbone_params, backbone_state, inputs, is_training=False)
|
||||||
@@ -441,7 +441,7 @@ class EvalExperiment:
|
|||||||
'top5_accuracy': top5_correct
|
'top5_accuracy': top5_correct
|
||||||
}
|
}
|
||||||
|
|
||||||
def _eval_epoch(self, subset, batch_size):
|
def _eval_epoch(self, subset: Text, batch_size: int):
|
||||||
"""Evaluates an epoch."""
|
"""Evaluates an epoch."""
|
||||||
num_samples = 0.
|
num_samples = 0.
|
||||||
summed_scalars = None
|
summed_scalars = None
|
||||||
|
|||||||
+2
-2
@@ -47,7 +47,7 @@ Experiment = Union[
|
|||||||
Type[eval_experiment.EvalExperiment]]
|
Type[eval_experiment.EvalExperiment]]
|
||||||
|
|
||||||
|
|
||||||
def train_loop(experiment_class, config):
|
def train_loop(experiment_class: Experiment, config: Mapping[Text, Any]):
|
||||||
"""The main training loop.
|
"""The main training loop.
|
||||||
|
|
||||||
This loop periodically saves a checkpoint to be evaluated in the eval_loop.
|
This loop periodically saves a checkpoint to be evaluated in the eval_loop.
|
||||||
@@ -95,7 +95,7 @@ def train_loop(experiment_class, config):
|
|||||||
experiment.save_checkpoint(step, rng)
|
experiment.save_checkpoint(step, rng)
|
||||||
|
|
||||||
|
|
||||||
def eval_loop(experiment_class, config):
|
def eval_loop(experiment_class: Experiment, config: Mapping[Text, Any]):
|
||||||
"""The main evaluation loop.
|
"""The main evaluation loop.
|
||||||
|
|
||||||
This loop periodically loads a checkpoint and evaluates its performance on the
|
This loop periodically loads a checkpoint and evaluates its performance on the
|
||||||
|
|||||||
@@ -66,14 +66,14 @@ augment_config = dict(
|
|||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
def postprocess(inputs, rng):
|
def postprocess(inputs: JaxBatch, rng: jnp.ndarray):
|
||||||
"""Apply the image augmentations to crops in inputs (view1 and view2)."""
|
"""Apply the image augmentations to crops in inputs (view1 and view2)."""
|
||||||
|
|
||||||
def _postprocess_image(
|
def _postprocess_image(
|
||||||
images,
|
images: jnp.ndarray,
|
||||||
rng,
|
rng: jnp.ndarray,
|
||||||
presets,
|
presets: ConfigDict,
|
||||||
):
|
) -> JaxBatch:
|
||||||
"""Applies augmentations in post-processing.
|
"""Applies augmentations in post-processing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -144,7 +144,7 @@ def _gaussian_blur_single_image(image, kernel_size, padding, sigma):
|
|||||||
blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels])
|
blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels])
|
||||||
expand_batch_dim = len(image.shape) == 3
|
expand_batch_dim = len(image.shape) == 3
|
||||||
if expand_batch_dim:
|
if expand_batch_dim:
|
||||||
image = image[jnp.newaxis, Ellipsis]
|
image = image[jnp.newaxis, ...]
|
||||||
blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding)
|
blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding)
|
||||||
blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding)
|
blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding)
|
||||||
blurred = jnp.squeeze(blurred, axis=0)
|
blurred = jnp.squeeze(blurred, axis=0)
|
||||||
@@ -284,7 +284,7 @@ def _random_hue(rgb_tuple, rng, max_delta):
|
|||||||
|
|
||||||
def _to_grayscale(image):
|
def _to_grayscale(image):
|
||||||
rgb_weights = jnp.array([0.2989, 0.5870, 0.1140])
|
rgb_weights = jnp.array([0.2989, 0.5870, 0.1140])
|
||||||
grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[Ellipsis, jnp.newaxis]
|
grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[..., jnp.newaxis]
|
||||||
return jnp.tile(grayscale, (1, 1, 3)) # Back to 3 channels.
|
return jnp.tile(grayscale, (1, 1, 3)) # Back to 3 channels.
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,10 +31,10 @@ class Checkpointer:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
use_checkpointing,
|
use_checkpointing: bool,
|
||||||
checkpoint_dir,
|
checkpoint_dir: Text,
|
||||||
save_checkpoint_interval,
|
save_checkpoint_interval: int,
|
||||||
filename):
|
filename: Text):
|
||||||
if (not use_checkpointing or
|
if (not use_checkpointing or
|
||||||
checkpoint_dir is None or
|
checkpoint_dir is None or
|
||||||
save_checkpoint_interval <= 0):
|
save_checkpoint_interval <= 0):
|
||||||
@@ -51,10 +51,10 @@ class Checkpointer:
|
|||||||
|
|
||||||
def maybe_save_checkpoint(
|
def maybe_save_checkpoint(
|
||||||
self,
|
self,
|
||||||
experiment_state,
|
experiment_state: Mapping[Text, jnp.ndarray],
|
||||||
step,
|
step: int,
|
||||||
rng,
|
rng: jnp.ndarray,
|
||||||
is_final):
|
is_final: bool):
|
||||||
"""Saves a checkpoint if enough time has passed since the previous one."""
|
"""Saves a checkpoint if enough time has passed since the previous one."""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
if (not self._checkpoint_enabled or
|
if (not self._checkpoint_enabled or
|
||||||
@@ -80,7 +80,7 @@ class Checkpointer:
|
|||||||
self._last_checkpoint_time = current_time
|
self._last_checkpoint_time = current_time
|
||||||
|
|
||||||
def maybe_load_checkpoint(
|
def maybe_load_checkpoint(
|
||||||
self):
|
self) -> Union[Tuple[Mapping[Text, jnp.ndarray], int, jnp.ndarray], None]:
|
||||||
"""Loads a checkpoint if any is found."""
|
"""Loads a checkpoint if any is found."""
|
||||||
checkpoint_data = load_checkpoint(self._checkpoint_path)
|
checkpoint_data = load_checkpoint(self._checkpoint_path)
|
||||||
if checkpoint_data is None:
|
if checkpoint_data is None:
|
||||||
|
|||||||
+17
-17
@@ -34,7 +34,7 @@ class Split(enum.Enum):
|
|||||||
TEST = 4
|
TEST = 4
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(cls, name):
|
def from_string(cls, name: Text) -> 'Split':
|
||||||
return {
|
return {
|
||||||
'TRAIN': Split.TRAIN,
|
'TRAIN': Split.TRAIN,
|
||||||
'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
|
'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
|
||||||
@@ -60,7 +60,7 @@ class PreprocessMode(enum.Enum):
|
|||||||
EVAL = 3 # Generates a single center crop.
|
EVAL = 3 # Generates a single center crop.
|
||||||
|
|
||||||
|
|
||||||
def normalize_images(images):
|
def normalize_images(images: jnp.ndarray) -> jnp.ndarray:
|
||||||
"""Normalize the image using ImageNet statistics."""
|
"""Normalize the image using ImageNet statistics."""
|
||||||
mean_rgb = (0.485, 0.456, 0.406)
|
mean_rgb = (0.485, 0.456, 0.406)
|
||||||
stddev_rgb = (0.229, 0.224, 0.225)
|
stddev_rgb = (0.229, 0.224, 0.225)
|
||||||
@@ -69,12 +69,12 @@ def normalize_images(images):
|
|||||||
return normed_images
|
return normed_images
|
||||||
|
|
||||||
|
|
||||||
def load(split,
|
def load(split: Split,
|
||||||
*,
|
*,
|
||||||
preprocess_mode,
|
preprocess_mode: PreprocessMode,
|
||||||
batch_dims,
|
batch_dims: Sequence[int],
|
||||||
transpose = False,
|
transpose: bool = False,
|
||||||
allow_caching = False):
|
allow_caching: bool = False) -> Generator[Batch, None, None]:
|
||||||
"""Loads the given split of the dataset."""
|
"""Loads the given split of the dataset."""
|
||||||
start, end = _shard(split, jax.host_id(), jax.host_count())
|
start, end = _shard(split, jax.host_id(), jax.host_count())
|
||||||
|
|
||||||
@@ -153,7 +153,7 @@ def load(split,
|
|||||||
yield from tfds.as_numpy(ds)
|
yield from tfds.as_numpy(ds)
|
||||||
|
|
||||||
|
|
||||||
def _to_tfds_split(split):
|
def _to_tfds_split(split: Split) -> tfds.Split:
|
||||||
"""Returns the TFDS split appropriately sharded."""
|
"""Returns the TFDS split appropriately sharded."""
|
||||||
# NOTE: Imagenet did not release labels for the test split used in the
|
# NOTE: Imagenet did not release labels for the test split used in the
|
||||||
# competition, we consider the VALID split the TEST split and reserve
|
# competition, we consider the VALID split the TEST split and reserve
|
||||||
@@ -165,7 +165,7 @@ def _to_tfds_split(split):
|
|||||||
return tfds.Split.VALIDATION
|
return tfds.Split.VALIDATION
|
||||||
|
|
||||||
|
|
||||||
def _shard(split, shard_index, num_shards):
|
def _shard(split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]:
|
||||||
"""Returns [start, end) for the given shard index."""
|
"""Returns [start, end) for the given shard index."""
|
||||||
assert shard_index < num_shards
|
assert shard_index < num_shards
|
||||||
arange = np.arange(split.num_examples)
|
arange = np.arange(split.num_examples)
|
||||||
@@ -180,9 +180,9 @@ def _shard(split, shard_index, num_shards):
|
|||||||
|
|
||||||
|
|
||||||
def _preprocess_image(
|
def _preprocess_image(
|
||||||
image_bytes,
|
image_bytes: tf.Tensor,
|
||||||
mode,
|
mode: PreprocessMode,
|
||||||
):
|
) -> tf.Tensor:
|
||||||
"""Returns processed and resized images."""
|
"""Returns processed and resized images."""
|
||||||
if mode is PreprocessMode.PRETRAIN:
|
if mode is PreprocessMode.PRETRAIN:
|
||||||
image = _decode_and_random_crop(image_bytes)
|
image = _decode_and_random_crop(image_bytes)
|
||||||
@@ -201,7 +201,7 @@ def _preprocess_image(
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def _decode_and_random_crop(image_bytes):
|
def _decode_and_random_crop(image_bytes: tf.Tensor) -> tf.Tensor:
|
||||||
"""Make a random crop of 224."""
|
"""Make a random crop of 224."""
|
||||||
img_size = tf.image.extract_jpeg_shape(image_bytes)
|
img_size = tf.image.extract_jpeg_shape(image_bytes)
|
||||||
area = tf.cast(img_size[1] * img_size[0], tf.float32)
|
area = tf.cast(img_size[1] * img_size[0], tf.float32)
|
||||||
@@ -231,7 +231,7 @@ def _decode_and_random_crop(image_bytes):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def transpose_images(batch):
|
def transpose_images(batch: Batch):
|
||||||
"""Transpose images for TPU training.."""
|
"""Transpose images for TPU training.."""
|
||||||
new_batch = dict(batch) # Avoid mutating in place.
|
new_batch = dict(batch) # Avoid mutating in place.
|
||||||
if 'images' in batch:
|
if 'images' in batch:
|
||||||
@@ -243,9 +243,9 @@ def transpose_images(batch):
|
|||||||
|
|
||||||
|
|
||||||
def _decode_and_center_crop(
|
def _decode_and_center_crop(
|
||||||
image_bytes,
|
image_bytes: tf.Tensor,
|
||||||
jpeg_shape = None,
|
jpeg_shape: Optional[tf.Tensor] = None,
|
||||||
):
|
) -> tf.Tensor:
|
||||||
"""Crops to center of image with padding then scales."""
|
"""Crops to center of image with padding then scales."""
|
||||||
if jpeg_shape is None:
|
if jpeg_shape is None:
|
||||||
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
|
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
|
||||||
|
|||||||
+14
-14
@@ -21,11 +21,11 @@ import jax.numpy as jnp
|
|||||||
|
|
||||||
|
|
||||||
def topk_accuracy(
|
def topk_accuracy(
|
||||||
logits,
|
logits: jnp.ndarray,
|
||||||
labels,
|
labels: jnp.ndarray,
|
||||||
topk,
|
topk: int,
|
||||||
ignore_label_above = None,
|
ignore_label_above: Optional[int] = None,
|
||||||
):
|
) -> jnp.ndarray:
|
||||||
"""Top-num_codes accuracy."""
|
"""Top-num_codes accuracy."""
|
||||||
assert len(labels.shape) == 1, 'topk expects 1d int labels.'
|
assert len(labels.shape) == 1, 'topk expects 1d int labels.'
|
||||||
assert len(logits.shape) == 2, 'topk expects 2d logits.'
|
assert len(logits.shape) == 2, 'topk expects 2d logits.'
|
||||||
@@ -42,10 +42,10 @@ def topk_accuracy(
|
|||||||
|
|
||||||
|
|
||||||
def softmax_cross_entropy(
|
def softmax_cross_entropy(
|
||||||
logits,
|
logits: jnp.ndarray,
|
||||||
labels,
|
labels: jnp.ndarray,
|
||||||
reduction = 'mean',
|
reduction: Optional[Text] = 'mean',
|
||||||
):
|
) -> jnp.ndarray:
|
||||||
"""Computes softmax cross entropy given logits and one-hot class labels.
|
"""Computes softmax cross entropy given logits and one-hot class labels.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -72,10 +72,10 @@ def softmax_cross_entropy(
|
|||||||
|
|
||||||
|
|
||||||
def l2_normalize(
|
def l2_normalize(
|
||||||
x,
|
x: jnp.ndarray,
|
||||||
axis = None,
|
axis: Optional[int] = None,
|
||||||
epsilon = 1e-12,
|
epsilon: float = 1e-12,
|
||||||
):
|
) -> jnp.ndarray:
|
||||||
"""l2 normalize a tensor on an axis with numerical stability."""
|
"""l2 normalize a tensor on an axis with numerical stability."""
|
||||||
square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)
|
square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)
|
||||||
x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))
|
x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))
|
||||||
@@ -106,7 +106,7 @@ def l2_weight_regularizer(params):
|
|||||||
return 0.5 * l2_norm
|
return 0.5 * l2_norm
|
||||||
|
|
||||||
|
|
||||||
def regression_loss(x, y):
|
def regression_loss(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
|
||||||
"""Byol's regression loss. This is a simple cosine similarity."""
|
"""Byol's regression loss. This is a simple cosine similarity."""
|
||||||
normed_x, normed_y = l2_normalize(x, axis=-1), l2_normalize(y, axis=-1)
|
normed_x, normed_y = l2_normalize(x, axis=-1), l2_normalize(y, axis=-1)
|
||||||
return jnp.sum((normed_x - normed_y)**2, axis=-1)
|
return jnp.sum((normed_x - normed_y)**2, axis=-1)
|
||||||
|
|||||||
+49
-49
@@ -27,17 +27,17 @@ class MLP(hk.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name,
|
name: Text,
|
||||||
hidden_size,
|
hidden_size: int,
|
||||||
output_size,
|
output_size: int,
|
||||||
bn_config,
|
bn_config: Mapping[Text, Any],
|
||||||
):
|
):
|
||||||
super().__init__(name=name)
|
super().__init__(name=name)
|
||||||
self._hidden_size = hidden_size
|
self._hidden_size = hidden_size
|
||||||
self._output_size = output_size
|
self._output_size = output_size
|
||||||
self._bn_config = bn_config
|
self._bn_config = bn_config
|
||||||
|
|
||||||
def __call__(self, inputs, is_training):
|
def __call__(self, inputs: jnp.ndarray, is_training: bool) -> jnp.ndarray:
|
||||||
out = hk.Linear(output_size=self._hidden_size, with_bias=True)(inputs)
|
out = hk.Linear(output_size=self._hidden_size, with_bias=True)(inputs)
|
||||||
out = hk.BatchNorm(**self._bn_config)(out, is_training=is_training)
|
out = hk.BatchNorm(**self._bn_config)(out, is_training=is_training)
|
||||||
out = jax.nn.relu(out)
|
out = jax.nn.relu(out)
|
||||||
@@ -55,15 +55,15 @@ class ResNetTorso(hk.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
blocks_per_group,
|
blocks_per_group: Sequence[int],
|
||||||
num_classes = None,
|
num_classes: int = None,
|
||||||
bn_config = None,
|
bn_config: Optional[Mapping[str, float]] = None,
|
||||||
resnet_v2 = False,
|
resnet_v2: bool = False,
|
||||||
bottleneck = True,
|
bottleneck: bool = True,
|
||||||
channels_per_group = (256, 512, 1024, 2048),
|
channels_per_group: Sequence[int] = (256, 512, 1024, 2048),
|
||||||
use_projection = (True, True, True, True),
|
use_projection: Sequence[bool] = (True, True, True, True),
|
||||||
width_multiplier = 1,
|
width_multiplier: int = 1,
|
||||||
name = None,
|
name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Constructs a ResNet model.
|
"""Constructs a ResNet model.
|
||||||
|
|
||||||
@@ -155,11 +155,11 @@ class TinyResNet(ResNetTorso):
|
|||||||
"""Tiny resnet for local runs and tests."""
|
"""Tiny resnet for local runs and tests."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_classes = None,
|
num_classes: Optional[int] = None,
|
||||||
bn_config = None,
|
bn_config: Optional[Mapping[str, float]] = None,
|
||||||
resnet_v2 = False,
|
resnet_v2: bool = False,
|
||||||
width_multiplier = 1,
|
width_multiplier: int = 1,
|
||||||
name = None):
|
name: Optional[str] = None):
|
||||||
"""Constructs a ResNet model.
|
"""Constructs a ResNet model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -185,11 +185,11 @@ class ResNet18(ResNetTorso):
|
|||||||
"""ResNet18."""
|
"""ResNet18."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_classes = None,
|
num_classes: Optional[int] = None,
|
||||||
bn_config = None,
|
bn_config: Optional[Mapping[str, float]] = None,
|
||||||
resnet_v2 = False,
|
resnet_v2: bool = False,
|
||||||
width_multiplier = 1,
|
width_multiplier: int = 1,
|
||||||
name = None):
|
name: Optional[str] = None):
|
||||||
"""Constructs a ResNet model.
|
"""Constructs a ResNet model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -215,11 +215,11 @@ class ResNet34(ResNetTorso):
|
|||||||
"""ResNet34."""
|
"""ResNet34."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_classes,
|
num_classes: Optional[int],
|
||||||
bn_config = None,
|
bn_config: Optional[Mapping[str, float]] = None,
|
||||||
resnet_v2 = False,
|
resnet_v2: bool = False,
|
||||||
width_multiplier = 1,
|
width_multiplier: int = 1,
|
||||||
name = None):
|
name: Optional[str] = None):
|
||||||
"""Constructs a ResNet model.
|
"""Constructs a ResNet model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -245,11 +245,11 @@ class ResNet50(ResNetTorso):
|
|||||||
"""ResNet50."""
|
"""ResNet50."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_classes = None,
|
num_classes: Optional[int] = None,
|
||||||
bn_config = None,
|
bn_config: Optional[Mapping[str, float]] = None,
|
||||||
resnet_v2 = False,
|
resnet_v2: bool = False,
|
||||||
width_multiplier = 1,
|
width_multiplier: int = 1,
|
||||||
name = None):
|
name: Optional[str] = None):
|
||||||
"""Constructs a ResNet model.
|
"""Constructs a ResNet model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -274,11 +274,11 @@ class ResNet101(ResNetTorso):
|
|||||||
"""ResNet101."""
|
"""ResNet101."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_classes,
|
num_classes: Optional[int],
|
||||||
bn_config = None,
|
bn_config: Optional[Mapping[str, float]] = None,
|
||||||
resnet_v2 = False,
|
resnet_v2: bool = False,
|
||||||
width_multiplier = 1,
|
width_multiplier: int = 1,
|
||||||
name = None):
|
name: Optional[str] = None):
|
||||||
"""Constructs a ResNet model.
|
"""Constructs a ResNet model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -303,11 +303,11 @@ class ResNet152(ResNetTorso):
|
|||||||
"""ResNet152."""
|
"""ResNet152."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_classes,
|
num_classes: Optional[int],
|
||||||
bn_config = None,
|
bn_config: Optional[Mapping[str, float]] = None,
|
||||||
resnet_v2 = False,
|
resnet_v2: bool = False,
|
||||||
width_multiplier = 1,
|
width_multiplier: int = 1,
|
||||||
name = None):
|
name: Optional[str] = None):
|
||||||
"""Constructs a ResNet model.
|
"""Constructs a ResNet model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -332,11 +332,11 @@ class ResNet200(ResNetTorso):
|
|||||||
"""ResNet200."""
|
"""ResNet200."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_classes,
|
num_classes: Optional[int],
|
||||||
bn_config = None,
|
bn_config: Optional[Mapping[str, float]] = None,
|
||||||
resnet_v2 = False,
|
resnet_v2: bool = False,
|
||||||
width_multiplier = 1,
|
width_multiplier: int = 1,
|
||||||
name = None):
|
name: Optional[str] = None):
|
||||||
"""Constructs a ResNet model.
|
"""Constructs a ResNet model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
+29
-29
@@ -27,7 +27,7 @@ import tree as nest
|
|||||||
FilterFn = Callable[[Tuple[Any], jnp.ndarray], jnp.ndarray]
|
FilterFn = Callable[[Tuple[Any], jnp.ndarray], jnp.ndarray]
|
||||||
|
|
||||||
|
|
||||||
def exclude_bias_and_norm(path, val):
|
def exclude_bias_and_norm(path: Tuple[Any], val: jnp.ndarray) -> jnp.ndarray:
|
||||||
"""Filter to exclude biaises and normalizations weights."""
|
"""Filter to exclude biaises and normalizations weights."""
|
||||||
del val
|
del val
|
||||||
if path[-1] == "b" or "norm" in path[-2]:
|
if path[-1] == "b" or "norm" in path[-2]:
|
||||||
@@ -35,10 +35,10 @@ def exclude_bias_and_norm(path, val):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _partial_update(updates,
|
def _partial_update(updates: optax.Updates,
|
||||||
new_updates,
|
new_updates: optax.Updates,
|
||||||
params,
|
params: optax.Params,
|
||||||
filter_fn = None):
|
filter_fn: Optional[FilterFn] = None) -> optax.Updates:
|
||||||
"""Returns new_update for params which filter_fn is True else updates."""
|
"""Returns new_update for params which filter_fn is True else updates."""
|
||||||
|
|
||||||
if filter_fn is None:
|
if filter_fn is None:
|
||||||
@@ -47,7 +47,7 @@ def _partial_update(updates,
|
|||||||
wrapped_filter_fn = lambda x, y: jnp.array(filter_fn(x, y))
|
wrapped_filter_fn = lambda x, y: jnp.array(filter_fn(x, y))
|
||||||
params_to_filter = nest.map_structure_with_path(wrapped_filter_fn, params)
|
params_to_filter = nest.map_structure_with_path(wrapped_filter_fn, params)
|
||||||
|
|
||||||
def _update_fn(g, t, m):
|
def _update_fn(g: jnp.ndarray, t: jnp.ndarray, m: jnp.ndarray) -> jnp.ndarray:
|
||||||
m = m.astype(g.dtype)
|
m = m.astype(g.dtype)
|
||||||
return g * (1. - m) + t * m
|
return g * (1. - m) + t * m
|
||||||
|
|
||||||
@@ -59,9 +59,9 @@ class ScaleByLarsState(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
def scale_by_lars(
|
def scale_by_lars(
|
||||||
momentum = 0.9,
|
momentum: float = 0.9,
|
||||||
eta = 0.001,
|
eta: float = 0.001,
|
||||||
filter_fn = None):
|
filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation:
|
||||||
"""Rescales updates according to the LARS algorithm.
|
"""Rescales updates according to the LARS algorithm.
|
||||||
|
|
||||||
Does not include weight decay.
|
Does not include weight decay.
|
||||||
@@ -77,17 +77,17 @@ def scale_by_lars(
|
|||||||
An (init_fn, update_fn) tuple.
|
An (init_fn, update_fn) tuple.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def init_fn(params):
|
def init_fn(params: optax.Params) -> ScaleByLarsState:
|
||||||
mu = jax.tree_multimap(jnp.zeros_like, params) # momentum
|
mu = jax.tree_multimap(jnp.zeros_like, params) # momentum
|
||||||
return ScaleByLarsState(mu=mu)
|
return ScaleByLarsState(mu=mu)
|
||||||
|
|
||||||
def update_fn(updates, state,
|
def update_fn(updates: optax.Updates, state: ScaleByLarsState,
|
||||||
params):
|
params: optax.Params) -> Tuple[optax.Updates, ScaleByLarsState]:
|
||||||
|
|
||||||
def lars_adaptation(
|
def lars_adaptation(
|
||||||
update,
|
update: jnp.ndarray,
|
||||||
param,
|
param: jnp.ndarray,
|
||||||
):
|
) -> jnp.ndarray:
|
||||||
param_norm = jnp.linalg.norm(param)
|
param_norm = jnp.linalg.norm(param)
|
||||||
update_norm = jnp.linalg.norm(update)
|
update_norm = jnp.linalg.norm(update)
|
||||||
return update * jnp.where(
|
return update * jnp.where(
|
||||||
@@ -110,8 +110,8 @@ class AddWeightDecayState(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
def add_weight_decay(
|
def add_weight_decay(
|
||||||
weight_decay,
|
weight_decay: float,
|
||||||
filter_fn = None):
|
filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation:
|
||||||
"""Adds a weight decay to the update.
|
"""Adds a weight decay to the update.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -122,14 +122,14 @@ def add_weight_decay(
|
|||||||
An (init_fn, update_fn) tuple.
|
An (init_fn, update_fn) tuple.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def init_fn(_):
|
def init_fn(_) -> AddWeightDecayState:
|
||||||
return AddWeightDecayState()
|
return AddWeightDecayState()
|
||||||
|
|
||||||
def update_fn(
|
def update_fn(
|
||||||
updates,
|
updates: optax.Updates,
|
||||||
state,
|
state: AddWeightDecayState,
|
||||||
params,
|
params: optax.Params,
|
||||||
):
|
) -> Tuple[optax.Updates, AddWeightDecayState]:
|
||||||
new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
|
new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
|
||||||
params)
|
params)
|
||||||
new_updates = _partial_update(updates, new_updates, params, filter_fn)
|
new_updates = _partial_update(updates, new_updates, params, filter_fn)
|
||||||
@@ -142,13 +142,13 @@ LarsState = List # Type for the lars optimizer
|
|||||||
|
|
||||||
|
|
||||||
def lars(
|
def lars(
|
||||||
learning_rate,
|
learning_rate: float,
|
||||||
weight_decay = 0.,
|
weight_decay: float = 0.,
|
||||||
momentum = 0.9,
|
momentum: float = 0.9,
|
||||||
eta = 0.001,
|
eta: float = 0.001,
|
||||||
weight_decay_filter = None,
|
weight_decay_filter: Optional[FilterFn] = None,
|
||||||
lars_adaptation_filter = None,
|
lars_adaptation_filter: Optional[FilterFn] = None,
|
||||||
):
|
) -> optax.GradientTransformation:
|
||||||
"""Creates lars optimizer with weight decay.
|
"""Creates lars optimizer with weight decay.
|
||||||
|
|
||||||
References:
|
References:
|
||||||
|
|||||||
+11
-11
@@ -17,18 +17,18 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
def target_ema(global_step,
|
def target_ema(global_step: jnp.ndarray,
|
||||||
base_ema,
|
base_ema: float,
|
||||||
max_steps):
|
max_steps: int) -> jnp.ndarray:
|
||||||
decay = _cosine_decay(global_step, max_steps, 1.)
|
decay = _cosine_decay(global_step, max_steps, 1.)
|
||||||
return 1. - (1. - base_ema) * decay
|
return 1. - (1. - base_ema) * decay
|
||||||
|
|
||||||
|
|
||||||
def learning_schedule(global_step,
|
def learning_schedule(global_step: jnp.ndarray,
|
||||||
batch_size,
|
batch_size: int,
|
||||||
base_learning_rate,
|
base_learning_rate: float,
|
||||||
total_steps,
|
total_steps: int,
|
||||||
warmup_steps):
|
warmup_steps: int) -> float:
|
||||||
"""Cosine learning rate scheduler."""
|
"""Cosine learning rate scheduler."""
|
||||||
# Compute LR & Scaled LR
|
# Compute LR & Scaled LR
|
||||||
scaled_lr = base_learning_rate * batch_size / 256.
|
scaled_lr = base_learning_rate * batch_size / 256.
|
||||||
@@ -43,9 +43,9 @@ def learning_schedule(global_step,
|
|||||||
scaled_lr))
|
scaled_lr))
|
||||||
|
|
||||||
|
|
||||||
def _cosine_decay(global_step,
|
def _cosine_decay(global_step: jnp.ndarray,
|
||||||
max_steps,
|
max_steps: int,
|
||||||
initial_value):
|
initial_value: float) -> jnp.ndarray:
|
||||||
"""Simple implementation of cosine decay from TF1."""
|
"""Simple implementation of cosine decay from TF1."""
|
||||||
global_step = jnp.minimum(global_step, max_steps)
|
global_step = jnp.minimum(global_step, max_steps)
|
||||||
cosine_decay_value = 0.5 * (1 + jnp.cos(jnp.pi * global_step / max_steps))
|
cosine_decay_value = 0.5 * (1 + jnp.cos(jnp.pi * global_step / max_steps))
|
||||||
|
|||||||
+2
-2
@@ -46,7 +46,7 @@ def make_so_tangent(q):
|
|||||||
for j in range(i+1, n):
|
for j in range(i+1, n):
|
||||||
a[i, j] = 1
|
a[i, j] = 1
|
||||||
a[j, i] = -1
|
a[j, i] = -1
|
||||||
dq[Ellipsis, ii] = a @ q # tangent vectors are skew-symmetric matrix times Q
|
dq[..., ii] = a @ q # tangent vectors are skew-symmetric matrix times Q
|
||||||
a[i, j] = 0
|
a[i, j] = 0
|
||||||
a[j, i] = 0
|
a[j, i] = 0
|
||||||
ii += 1
|
ii += 1
|
||||||
@@ -106,7 +106,7 @@ def make_product_manifold(specification, npts):
|
|||||||
spec_array[1, i] = dim
|
spec_array[1, i] = dim
|
||||||
latent_dim += dim
|
latent_dim += dim
|
||||||
dat = np.random.randn(npts, dim+1)
|
dat = np.random.randn(npts, dim+1)
|
||||||
dat /= np.tile(np.sqrt(np.sum(dat**2, axis=1)[Ellipsis, None]),
|
dat /= np.tile(np.sqrt(np.sum(dat**2, axis=1)[..., None]),
|
||||||
[1, dim+1])
|
[1, dim+1])
|
||||||
elif so_spec is not None:
|
elif so_spec is not None:
|
||||||
dim = int(so_spec.group(1))
|
dim = int(so_spec.group(1))
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
from __future__ import google_type_annotations
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ The architecture and performance of this model is described in our publication:
|
|||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
from __future__ import google_type_annotations
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
@@ -35,10 +35,10 @@ from typing import Any, Dict, Text, Tuple, Optional
|
|||||||
|
|
||||||
|
|
||||||
def make_graph_from_static_structure(
|
def make_graph_from_static_structure(
|
||||||
positions,
|
positions: tf.Tensor,
|
||||||
types,
|
types: tf.Tensor,
|
||||||
box,
|
box: tf.Tensor,
|
||||||
edge_threshold):
|
edge_threshold: float) -> graphs.GraphsTuple:
|
||||||
"""Returns graph representing the static structure of the glass.
|
"""Returns graph representing the static structure of the glass.
|
||||||
|
|
||||||
Each particle is represented by a node in the graph. The particle type is
|
Each particle is represented by a node in the graph. The particle type is
|
||||||
@@ -81,7 +81,7 @@ def make_graph_from_static_structure(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_random_rotation(graph):
|
def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple:
|
||||||
"""Returns randomly rotated graph representation.
|
"""Returns randomly rotated graph representation.
|
||||||
|
|
||||||
The rotation is an element of O(3) with rotation angles multiple of pi/2.
|
The rotation is an element of O(3) with rotation angles multiple of pi/2.
|
||||||
@@ -118,9 +118,9 @@ class GraphBasedModel(snt.AbstractModule):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
n_recurrences,
|
n_recurrences: int,
|
||||||
mlp_sizes,
|
mlp_sizes: Tuple[int],
|
||||||
mlp_kwargs = None,
|
mlp_kwargs: Optional[Dict[Text, Any]] = None,
|
||||||
name='Graph'):
|
name='Graph'):
|
||||||
"""Creates a new GraphBasedModel object.
|
"""Creates a new GraphBasedModel object.
|
||||||
|
|
||||||
@@ -168,7 +168,7 @@ class GraphBasedModel(snt.AbstractModule):
|
|||||||
node_model_fn=final_model_fn,
|
node_model_fn=final_model_fn,
|
||||||
edge_model_fn=model_fn)
|
edge_model_fn=model_fn)
|
||||||
|
|
||||||
def _build(self, graphs_tuple):
|
def _build(self, graphs_tuple: graphs.GraphsTuple) -> tf.Tensor:
|
||||||
"""Connects the model into the tensorflow graph.
|
"""Connects the model into the tensorflow graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
+32
-32
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
from __future__ import google_type_annotations
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
@@ -53,8 +53,8 @@ class ParticleType(enum.IntEnum):
|
|||||||
|
|
||||||
|
|
||||||
def get_targets(
|
def get_targets(
|
||||||
initial_positions,
|
initial_positions: np.ndarray,
|
||||||
trajectory_target_positions):
|
trajectory_target_positions: Sequence[np.ndarray]) -> np.ndarray:
|
||||||
"""Returns the averaged particle mobilities from the sampled trajectories.
|
"""Returns the averaged particle mobilities from the sampled trajectories.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -70,9 +70,9 @@ def get_targets(
|
|||||||
|
|
||||||
|
|
||||||
def load_data(
|
def load_data(
|
||||||
file_pattern,
|
file_pattern: Text,
|
||||||
time_index,
|
time_index: int,
|
||||||
max_files_to_load = None):
|
max_files_to_load: Optional[int] = None) -> List[GlassSimulationData]:
|
||||||
"""Returns a dictionary containing the training or test dataset.
|
"""Returns a dictionary containing the training or test dataset.
|
||||||
|
|
||||||
The dictionary contains:
|
The dictionary contains:
|
||||||
@@ -108,9 +108,9 @@ def load_data(
|
|||||||
|
|
||||||
|
|
||||||
def get_loss_ops(
|
def get_loss_ops(
|
||||||
prediction,
|
prediction: tf.Tensor,
|
||||||
target,
|
target: tf.Tensor,
|
||||||
types):
|
types: tf.Tensor) -> LossCollection:
|
||||||
"""Returns L1/L2 loss and correlation for type A particles.
|
"""Returns L1/L2 loss and correlation for type A particles.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -132,9 +132,9 @@ def get_loss_ops(
|
|||||||
|
|
||||||
|
|
||||||
def get_minimize_op(
|
def get_minimize_op(
|
||||||
loss,
|
loss: tf.Tensor,
|
||||||
learning_rate,
|
learning_rate: float,
|
||||||
grad_clip = None):
|
grad_clip: Optional[float] = None) -> tf.Tensor:
|
||||||
"""Returns minimization operation.
|
"""Returns minimization operation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -152,8 +152,8 @@ def get_minimize_op(
|
|||||||
|
|
||||||
|
|
||||||
def _log_stats_and_return_mean_correlation(
|
def _log_stats_and_return_mean_correlation(
|
||||||
label,
|
label: Text,
|
||||||
stats):
|
stats: Sequence[LossCollection]) -> float:
|
||||||
"""Logs performance statistics and returns mean correlation.
|
"""Logs performance statistics and returns mean correlation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -171,20 +171,20 @@ def _log_stats_and_return_mean_correlation(
|
|||||||
return np.mean([s.correlation for s in stats])
|
return np.mean([s.correlation for s in stats])
|
||||||
|
|
||||||
|
|
||||||
def train_model(train_file_pattern,
|
def train_model(train_file_pattern: Text,
|
||||||
test_file_pattern,
|
test_file_pattern: Text,
|
||||||
max_files_to_load = None,
|
max_files_to_load: Optional[int] = None,
|
||||||
n_epochs = 1000,
|
n_epochs: int = 1000,
|
||||||
time_index = 9,
|
time_index: int = 9,
|
||||||
augment_data_using_rotations = True,
|
augment_data_using_rotations: bool = True,
|
||||||
learning_rate = 1e-4,
|
learning_rate: float = 1e-4,
|
||||||
grad_clip = 1.0,
|
grad_clip: Optional[float] = 1.0,
|
||||||
n_recurrences = 7,
|
n_recurrences: int = 7,
|
||||||
mlp_sizes = (64, 64),
|
mlp_sizes: Tuple[int] = (64, 64),
|
||||||
mlp_kwargs = None,
|
mlp_kwargs: Optional[Dict[Text, Any]] = None,
|
||||||
edge_threshold = 2.0,
|
edge_threshold: float = 2.0,
|
||||||
measurement_store_interval = 1000,
|
measurement_store_interval: int = 1000,
|
||||||
checkpoint_path = None):
|
checkpoint_path: Optional[Text] = None) -> float:
|
||||||
"""Trains GraphModel using tensorflow.
|
"""Trains GraphModel using tensorflow.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -325,10 +325,10 @@ def train_model(train_file_pattern,
|
|||||||
return best_so_far
|
return best_so_far
|
||||||
|
|
||||||
|
|
||||||
def apply_model(checkpoint_path,
|
def apply_model(checkpoint_path: Text,
|
||||||
file_pattern,
|
file_pattern: Text,
|
||||||
max_files_to_load = None,
|
max_files_to_load: Optional[int] = None,
|
||||||
time_index = 9):
|
time_index: int = 9) -> List[np.ndarray]:
|
||||||
"""Applies trained GraphModel using tensorflow.
|
"""Applies trained GraphModel using tensorflow.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
from __future__ import google_type_annotations
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -145,8 +145,8 @@ class _HierarchicalCore(snt.AbstractModule):
|
|||||||
regularizers=self._regularizers,
|
regularizers=self._regularizers,
|
||||||
)(decoder_features)
|
)(decoder_features)
|
||||||
|
|
||||||
mu = mu_logsigma[Ellipsis, :latent_dim]
|
mu = mu_logsigma[..., :latent_dim]
|
||||||
logsigma = mu_logsigma[Ellipsis, latent_dim:]
|
logsigma = mu_logsigma[..., latent_dim:]
|
||||||
dist = tfd.MultivariateNormalDiag(loc=mu, scale_diag=tf.exp(logsigma))
|
dist = tfd.MultivariateNormalDiag(loc=mu, scale_diag=tf.exp(logsigma))
|
||||||
distributions.append(dist)
|
distributions.append(dist)
|
||||||
|
|
||||||
|
|||||||
@@ -37,8 +37,8 @@ class ComponentDecoder(snt.AbstractModule):
|
|||||||
pixel_params = self._pixel_decoder(z_flat).params
|
pixel_params = self._pixel_decoder(z_flat).params
|
||||||
|
|
||||||
self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp")
|
self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp")
|
||||||
mask_params = pixel_params[Ellipsis, 0:1]
|
mask_params = pixel_params[..., 0:1]
|
||||||
pixel_params = pixel_params[Ellipsis, 1:]
|
pixel_params = pixel_params[..., 1:]
|
||||||
|
|
||||||
output = MixtureParameters(
|
output = MixtureParameters(
|
||||||
pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"),
|
pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"),
|
||||||
|
|||||||
@@ -134,8 +134,8 @@ class LocScaleDistribution(DistributionModule):
|
|||||||
n_channels = params.get_shape().as_list()[-1]
|
n_channels = params.get_shape().as_list()[-1]
|
||||||
assert n_channels % 2 == 0
|
assert n_channels % 2 == 0
|
||||||
assert n_channels // 2 == self.output_shape[-1]
|
assert n_channels // 2 == self.output_shape[-1]
|
||||||
loc = params[Ellipsis, :n_channels // 2]
|
loc = params[..., :n_channels // 2]
|
||||||
scale = params[Ellipsis, n_channels // 2:]
|
scale = params[..., n_channels // 2:]
|
||||||
|
|
||||||
# apply activation functions
|
# apply activation functions
|
||||||
if self._scale != "fixed":
|
if self._scale != "fixed":
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class FactorRegressor(snt.AbstractModule):
|
|||||||
for m in self._mapping:
|
for m in self._mapping:
|
||||||
with tf.name_scope(m.name):
|
with tf.name_scope(m.name):
|
||||||
assert m.name in latent, "{} not in {}".format(m.name, latent.keys())
|
assert m.name in latent, "{} not in {}".format(m.name, latent.keys())
|
||||||
pred = all_preds[Ellipsis, idx:idx + m.size]
|
pred = all_preds[..., idx:idx + m.size]
|
||||||
predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size))
|
predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size))
|
||||||
idx += m.size
|
idx += m.size
|
||||||
|
|
||||||
@@ -165,7 +165,7 @@ class FactorRegressor(snt.AbstractModule):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def one_hot(f, nr_categories):
|
def one_hot(f, nr_categories):
|
||||||
return tf.one_hot(tf.cast(f[Ellipsis, 0], tf.int32), depth=nr_categories)
|
return tf.one_hot(tf.cast(f[..., 0], tf.int32), depth=nr_categories)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def angle_to_vector(theta):
|
def angle_to_vector(theta):
|
||||||
@@ -194,7 +194,7 @@ def accuracy(labels, logits, assignment, mean_var_tot, num_vis):
|
|||||||
pred = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
pred = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||||||
labels = tf.argmax(labels, axis=-1, output_type=tf.int32)
|
labels = tf.argmax(labels, axis=-1, output_type=tf.int32)
|
||||||
correct = tf.cast(tf.equal(labels, pred), tf.float32)
|
correct = tf.cast(tf.equal(labels, pred), tf.float32)
|
||||||
return tf.reduce_sum(correct * assignment[Ellipsis, 0]) / num_vis
|
return tf.reduce_sum(correct * assignment[..., 0]) / num_vis
|
||||||
|
|
||||||
|
|
||||||
def r2(labels, pred, assignment, mean_var_tot, num_vis):
|
def r2(labels, pred, assignment, mean_var_tot, num_vis):
|
||||||
|
|||||||
+10
-10
@@ -325,12 +325,12 @@ class IODINE(snt.AbstractModule):
|
|||||||
[get_components(xd) for xd in iterations["x_dist"]])
|
[get_components(xd) for xd in iterations["x_dist"]])
|
||||||
|
|
||||||
# metrics
|
# metrics
|
||||||
tm = tf.transpose(true_mask[Ellipsis, 0], [0, 1, 3, 4, 2])
|
tm = tf.transpose(true_mask[..., 0], [0, 1, 3, 4, 2])
|
||||||
tm = tf.reshape(tf.tile(tm, sg["1, T, 1, 1, 1"]), sg["B * T, H * W, L"])
|
tm = tf.reshape(tf.tile(tm, sg["1, T, 1, 1, 1"]), sg["B * T, H * W, L"])
|
||||||
pm = tf.transpose(pred_mask[Ellipsis, 0], [0, 1, 3, 4, 2])
|
pm = tf.transpose(pred_mask[..., 0], [0, 1, 3, 4, 2])
|
||||||
pm = tf.reshape(pm, sg["B * T, H * W, K"])
|
pm = tf.reshape(pm, sg["B * T, H * W, K"])
|
||||||
ari = tf.reshape(adjusted_rand_index(tm, pm), sg["B, T"])
|
ari = tf.reshape(adjusted_rand_index(tm, pm), sg["B, T"])
|
||||||
ari_nobg = tf.reshape(adjusted_rand_index(tm[Ellipsis, 1:], pm), sg["B, T"])
|
ari_nobg = tf.reshape(adjusted_rand_index(tm[..., 1:], pm), sg["B, T"])
|
||||||
|
|
||||||
mse = tf.reduce_mean(tf.square(recons - image[:, None]), axis=[2, 3, 4, 5])
|
mse = tf.reduce_mean(tf.square(recons - image[:, None]), axis=[2, 3, 4, 5])
|
||||||
|
|
||||||
@@ -387,7 +387,7 @@ class IODINE(snt.AbstractModule):
|
|||||||
factor_info["assignment"].append(fass)
|
factor_info["assignment"].append(fass)
|
||||||
for k in fpred:
|
for k in fpred:
|
||||||
factor_info["predictions"][k].append(
|
factor_info["predictions"][k].append(
|
||||||
tf.reduce_sum(fpred[k] * fass[Ellipsis, None], axis=2))
|
tf.reduce_sum(fpred[k] * fass[..., None], axis=2))
|
||||||
factor_info["metrics"][k].append(fscalars[k])
|
factor_info["metrics"][k].append(fscalars[k])
|
||||||
|
|
||||||
info["losses"]["factor"] = sg.guard(tf.stack(factor_info["loss"]), "T")
|
info["losses"]["factor"] = sg.guard(tf.stack(factor_info["loss"]), "T")
|
||||||
@@ -496,7 +496,7 @@ class IODINE(snt.AbstractModule):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_mask_posterior(out_dist, img):
|
def _get_mask_posterior(out_dist, img):
|
||||||
p_comp = out_dist.components_distribution.prob(img[Ellipsis, tf.newaxis, :])
|
p_comp = out_dist.components_distribution.prob(img[..., tf.newaxis, :])
|
||||||
posterior = p_comp / (tf.reduce_sum(p_comp, axis=-1, keepdims=True) + 1e-6)
|
posterior = p_comp / (tf.reduce_sum(p_comp, axis=-1, keepdims=True) + 1e-6)
|
||||||
return tf.transpose(posterior, [0, 4, 2, 3, 1])
|
return tf.transpose(posterior, [0, 4, 2, 3, 1])
|
||||||
|
|
||||||
@@ -506,7 +506,7 @@ class IODINE(snt.AbstractModule):
|
|||||||
dzp, dxp, dmp = tf.gradients(loss, [zp, out_params.pixel, out_params.mask])
|
dzp, dxp, dmp = tf.gradients(loss, [zp, out_params.pixel, out_params.mask])
|
||||||
|
|
||||||
log_prob = sg.guard(
|
log_prob = sg.guard(
|
||||||
out_dist.log_prob(img)[Ellipsis, tf.newaxis], "B, 1, H, W, 1")
|
out_dist.log_prob(img)[..., tf.newaxis], "B, 1, H, W, 1")
|
||||||
|
|
||||||
counterfactual_log_probs = []
|
counterfactual_log_probs = []
|
||||||
for k in range(0, self.num_components):
|
for k in range(0, self.num_components):
|
||||||
@@ -515,7 +515,7 @@ class IODINE(snt.AbstractModule):
|
|||||||
pixel = tf.concat([out_params.pixel[:, :k], out_params.pixel[:, k + 1:]],
|
pixel = tf.concat([out_params.pixel[:, :k], out_params.pixel[:, k + 1:]],
|
||||||
axis=1)
|
axis=1)
|
||||||
out_dist_k = self.output_dist(pixel, mask)
|
out_dist_k = self.output_dist(pixel, mask)
|
||||||
log_prob_k = out_dist_k.log_prob(img)[Ellipsis, tf.newaxis]
|
log_prob_k = out_dist_k.log_prob(img)[..., tf.newaxis]
|
||||||
counterfactual_log_probs.append(log_prob_k)
|
counterfactual_log_probs.append(log_prob_k)
|
||||||
counterfactual = log_prob - tf.concat(counterfactual_log_probs, axis=1)
|
counterfactual = log_prob - tf.concat(counterfactual_log_probs, axis=1)
|
||||||
|
|
||||||
@@ -608,7 +608,7 @@ class IODINE(snt.AbstractModule):
|
|||||||
x_basis = tf.cos(valx * freqs[None, None, None, None, :, None])
|
x_basis = tf.cos(valx * freqs[None, None, None, None, :, None])
|
||||||
y_basis = tf.cos(valy * freqs[None, None, None, None, None, :])
|
y_basis = tf.cos(valy * freqs[None, None, None, None, None, :])
|
||||||
xy_basis = tf.reshape(x_basis * y_basis, self._sg["1, 1, H, W, F*F"])
|
xy_basis = tf.reshape(x_basis * y_basis, self._sg["1, 1, H, W, F*F"])
|
||||||
coords = tf.tile(xy_basis, self._sg["B, 1, 1, 1, 1"])[Ellipsis, 1:]
|
coords = tf.tile(xy_basis, self._sg["B, 1, 1, 1, 1"])[..., 1:]
|
||||||
return coords
|
return coords
|
||||||
else:
|
else:
|
||||||
raise KeyError('Unknown coord_type: "{}"'.format(self.coord_type))
|
raise KeyError('Unknown coord_type: "{}"'.format(self.coord_type))
|
||||||
@@ -632,7 +632,7 @@ class IODINE(snt.AbstractModule):
|
|||||||
# ########## Mask Monitoring #######
|
# ########## Mask Monitoring #######
|
||||||
if "mask" in data:
|
if "mask" in data:
|
||||||
true_mask = self._sg.guard(data["mask"], "B, T, L, H, W, 1")
|
true_mask = self._sg.guard(data["mask"], "B, T, L, H, W, 1")
|
||||||
true_mask = tf.transpose(true_mask[:, -1, Ellipsis, 0], [0, 2, 3, 1])
|
true_mask = tf.transpose(true_mask[:, -1, ..., 0], [0, 2, 3, 1])
|
||||||
true_mask = self._sg.reshape(true_mask, "B, H*W, L")
|
true_mask = self._sg.reshape(true_mask, "B, H*W, L")
|
||||||
else:
|
else:
|
||||||
true_mask = None
|
true_mask = None
|
||||||
@@ -648,6 +648,6 @@ class IODINE(snt.AbstractModule):
|
|||||||
adjusted_rand_index(true_mask, pred_mask))
|
adjusted_rand_index(true_mask, pred_mask))
|
||||||
|
|
||||||
scalars["loss/ari_nobg"] = tf.reduce_mean(
|
scalars["loss/ari_nobg"] = tf.reduce_mean(
|
||||||
adjusted_rand_index(true_mask[Ellipsis, 1:], pred_mask))
|
adjusted_rand_index(true_mask[..., 1:], pred_mask))
|
||||||
|
|
||||||
return scalars
|
return scalars
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ class BroadcastConv(snt.AbstractModule):
|
|||||||
x_basis = tf.cos(valx * freqs[None, None, None, :, None])
|
x_basis = tf.cos(valx * freqs[None, None, None, :, None])
|
||||||
y_basis = tf.cos(valy * freqs[None, None, None, None, :])
|
y_basis = tf.cos(valy * freqs[None, None, None, None, :])
|
||||||
xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"])
|
xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"])
|
||||||
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[Ellipsis, 1:]
|
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[..., 1:]
|
||||||
return tf.concat([output, coords], axis=-1)
|
return tf.concat([output, coords], axis=-1)
|
||||||
else:
|
else:
|
||||||
raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
|
raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ def show_mask(m, ax):
|
|||||||
@optional_clean_ax
|
@optional_clean_ax
|
||||||
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
|
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
|
||||||
return ax.matshow(
|
return ax.matshow(
|
||||||
m[Ellipsis, 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
|
m[..., 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
|
||||||
|
|
||||||
|
|
||||||
@optional_clean_ax
|
@optional_clean_ax
|
||||||
@@ -115,7 +115,7 @@ def example_plot(rinfo,
|
|||||||
|
|
||||||
show_img(image, ax=axes[0], color="#000000")
|
show_img(image, ax=axes[0], color="#000000")
|
||||||
show_img(recons, ax=axes[1], color="#000000")
|
show_img(recons, ax=axes[1], color="#000000")
|
||||||
show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000")
|
show_mask(pred_mask[..., 0], ax=axes[2], color="#000000")
|
||||||
for k in range(K):
|
for k in range(K):
|
||||||
mask = pred_mask[k] if mask_components else None
|
mask = pred_mask[k] if mask_components else None
|
||||||
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
|
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
|
||||||
@@ -145,7 +145,7 @@ def iterations_plot(rinfo, b=0, mask_components=False, size=2):
|
|||||||
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
|
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
show_img(recons[t, 0], ax=axes[t, 0])
|
show_img(recons[t, 0], ax=axes[t, 0])
|
||||||
show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1])
|
show_mask(pred_mask[t, ..., 0], ax=axes[t, 1])
|
||||||
axes[t, 0].set_ylabel("iter {}".format(t))
|
axes[t, 0].set_ylabel("iter {}".format(t))
|
||||||
for k in range(K):
|
for k in range(K):
|
||||||
mask = pred_mask[t, k] if mask_components else None
|
mask = pred_mask[t, k] if mask_components else None
|
||||||
@@ -154,7 +154,7 @@ def iterations_plot(rinfo, b=0, mask_components=False, size=2):
|
|||||||
axes[0, 0].set_title("Reconstruction")
|
axes[0, 0].set_title("Reconstruction")
|
||||||
axes[0, 1].set_title("Mask")
|
axes[0, 1].set_title("Mask")
|
||||||
show_img(image[0], ax=axes[T, 0])
|
show_img(image[0], ax=axes[T, 0])
|
||||||
show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1])
|
show_mask(true_mask[0, ..., 0], ax=axes[T, 1])
|
||||||
vmin = np.min(pred_mask_logits[T - 1])
|
vmin = np.min(pred_mask_logits[T - 1])
|
||||||
vmax = np.max(pred_mask_logits[T - 1])
|
vmax = np.max(pred_mask_logits[T - 1])
|
||||||
|
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ def construct_diagnostic_image(
|
|||||||
components = tf.tile(components[:nr_images], [1, 1, 1, 1, 3])
|
components = tf.tile(components[:nr_images], [1, 1, 1, 1, 3])
|
||||||
|
|
||||||
if mask_components:
|
if mask_components:
|
||||||
components *= masks[:nr_images, Ellipsis, tf.newaxis]
|
components *= masks[:nr_images, ..., tf.newaxis]
|
||||||
|
|
||||||
# Pad everything
|
# Pad everything
|
||||||
no_pad, pad = (0, 0), (border_width, border_width)
|
no_pad, pad = (0, 0), (border_width, border_width)
|
||||||
@@ -415,7 +415,7 @@ def images_to_grid(
|
|||||||
if max_grid_width is not None:
|
if max_grid_width is not None:
|
||||||
grid_width = min(max_grid_width, grid_width)
|
grid_width = min(max_grid_width, grid_width)
|
||||||
|
|
||||||
images = images[: grid_height * grid_width, Ellipsis]
|
images = images[: grid_height * grid_width, ...]
|
||||||
|
|
||||||
# Pad with extra blank frames if grid_height x grid_width is less than the
|
# Pad with extra blank frames if grid_height x grid_width is less than the
|
||||||
# number of frames provided.
|
# number of frames provided.
|
||||||
@@ -460,7 +460,7 @@ def flatten_all_but_last(tensor, n_dims=1):
|
|||||||
|
|
||||||
def ensure_3d(tensor):
|
def ensure_3d(tensor):
|
||||||
if tensor.shape.ndims == 2:
|
if tensor.shape.ndims == 2:
|
||||||
return tensor[Ellipsis, None]
|
return tensor[..., None]
|
||||||
|
|
||||||
assert tensor.shape.ndims == 3
|
assert tensor.shape.ndims == 3
|
||||||
return tensor
|
return tensor
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ class Agent():
|
|||||||
|
|
||||||
def option_values(values, policy):
|
def option_values(values, policy):
|
||||||
return tf.tensordot(
|
return tf.tensordot(
|
||||||
values[:, policy, Ellipsis], self._policy_weights[policy], axes=[1, 0])
|
values[:, policy, ...], self._policy_weights[policy], axes=[1, 0])
|
||||||
|
|
||||||
# Placeholders for policy.
|
# Placeholders for policy.
|
||||||
o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
|
o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
|
||||||
@@ -103,8 +103,8 @@ class Agent():
|
|||||||
qo_t = option_values(q_t, p)
|
qo_t = option_values(q_t, p)
|
||||||
|
|
||||||
a_t = tf.cast(tf.argmax(qo_t, axis=-1), tf.int32)
|
a_t = tf.cast(tf.argmax(qo_t, axis=-1), tf.int32)
|
||||||
qa_tm1 = _batched_index(q_tm1[:, p, Ellipsis], a_tm1)
|
qa_tm1 = _batched_index(q_tm1[:, p, ...], a_tm1)
|
||||||
qa_t = _batched_index(q_t[:, p, Ellipsis], a_t)
|
qa_t = _batched_index(q_t[:, p, ...], a_t)
|
||||||
|
|
||||||
# TD error
|
# TD error
|
||||||
g = additional_discount * tf.expand_dims(d_t, axis=-1)
|
g = additional_discount * tf.expand_dims(d_t, axis=-1)
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ def make_face_model_dataset(
|
|||||||
# Vertices are quantized. So convert to floats for input to face model
|
# Vertices are quantized. So convert to floats for input to face model
|
||||||
example['vertices'] = modules.dequantize_verts(vertices, quantization_bits)
|
example['vertices'] = modules.dequantize_verts(vertices, quantization_bits)
|
||||||
example['vertices_mask'] = tf.ones_like(
|
example['vertices_mask'] = tf.ones_like(
|
||||||
example['vertices'][Ellipsis, 0], dtype=tf.float32)
|
example['vertices'][..., 0], dtype=tf.float32)
|
||||||
example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32)
|
example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32)
|
||||||
return example
|
return example
|
||||||
return ds.map(_face_model_map_fn)
|
return ds.map(_face_model_map_fn)
|
||||||
|
|||||||
+8
-8
@@ -799,7 +799,7 @@ class VertexModel(snt.AbstractModule):
|
|||||||
# Continuous vertex value embeddings
|
# Continuous vertex value embeddings
|
||||||
else:
|
else:
|
||||||
vert_embeddings = tf.layers.dense(
|
vert_embeddings = tf.layers.dense(
|
||||||
dequantize_verts(vertices[Ellipsis, None], self.quantization_bits),
|
dequantize_verts(vertices[..., None], self.quantization_bits),
|
||||||
self.embedding_dim,
|
self.embedding_dim,
|
||||||
use_bias=True,
|
use_bias=True,
|
||||||
name='value_embeddings')
|
name='value_embeddings')
|
||||||
@@ -984,7 +984,7 @@ class VertexModel(snt.AbstractModule):
|
|||||||
verts_dequantized = dequantize_verts(v, self.quantization_bits)
|
verts_dequantized = dequantize_verts(v, self.quantization_bits)
|
||||||
vertices = tf.reshape(verts_dequantized, [num_samples, -1, 3])
|
vertices = tf.reshape(verts_dequantized, [num_samples, -1, 3])
|
||||||
vertices = tf.stack(
|
vertices = tf.stack(
|
||||||
[vertices[Ellipsis, 2], vertices[Ellipsis, 1], vertices[Ellipsis, 0]], axis=-1)
|
[vertices[..., 2], vertices[..., 1], vertices[..., 0]], axis=-1)
|
||||||
|
|
||||||
# Pad samples to max sample length. This is required in order to concatenate
|
# Pad samples to max sample length. This is required in order to concatenate
|
||||||
# Samples across different replicator instances. Pad with stopping tokens
|
# Samples across different replicator instances. Pad with stopping tokens
|
||||||
@@ -998,14 +998,14 @@ class VertexModel(snt.AbstractModule):
|
|||||||
|
|
||||||
if recenter_verts:
|
if recenter_verts:
|
||||||
vert_max = tf.reduce_max(
|
vert_max = tf.reduce_max(
|
||||||
vertices - 1e10 * (1. - vertices_mask)[Ellipsis, None], axis=1,
|
vertices - 1e10 * (1. - vertices_mask)[..., None], axis=1,
|
||||||
keepdims=True)
|
keepdims=True)
|
||||||
vert_min = tf.reduce_min(
|
vert_min = tf.reduce_min(
|
||||||
vertices + 1e10 * (1. - vertices_mask)[Ellipsis, None], axis=1,
|
vertices + 1e10 * (1. - vertices_mask)[..., None], axis=1,
|
||||||
keepdims=True)
|
keepdims=True)
|
||||||
vert_centers = 0.5 * (vert_max + vert_min)
|
vert_centers = 0.5 * (vert_max + vert_min)
|
||||||
vertices -= vert_centers
|
vertices -= vert_centers
|
||||||
vertices *= vertices_mask[Ellipsis, None]
|
vertices *= vertices_mask[..., None]
|
||||||
|
|
||||||
if only_return_complete:
|
if only_return_complete:
|
||||||
vertices = tf.boolean_mask(vertices, completed)
|
vertices = tf.boolean_mask(vertices, completed)
|
||||||
@@ -1247,7 +1247,7 @@ class FaceModel(snt.AbstractModule):
|
|||||||
sequential_context_embeddings = (
|
sequential_context_embeddings = (
|
||||||
vertex_embeddings *
|
vertex_embeddings *
|
||||||
tf.pad(context['vertices_mask'], [[0, 0], [2, 0]],
|
tf.pad(context['vertices_mask'], [[0, 0], [2, 0]],
|
||||||
constant_values=1)[Ellipsis, None])
|
constant_values=1)[..., None])
|
||||||
else:
|
else:
|
||||||
sequential_context_embeddings = None
|
sequential_context_embeddings = None
|
||||||
return (vertex_embeddings, global_context_embedding,
|
return (vertex_embeddings, global_context_embedding,
|
||||||
@@ -1266,11 +1266,11 @@ class FaceModel(snt.AbstractModule):
|
|||||||
embed_dim=self.embedding_dim,
|
embed_dim=self.embedding_dim,
|
||||||
initializers={'embeddings': tf.glorot_uniform_initializer},
|
initializers={'embeddings': tf.glorot_uniform_initializer},
|
||||||
densify_gradients=True,
|
densify_gradients=True,
|
||||||
name='coord_{}'.format(c))(verts_quantized[Ellipsis, c])
|
name='coord_{}'.format(c))(verts_quantized[..., c])
|
||||||
else:
|
else:
|
||||||
vertex_embeddings = tf.layers.dense(
|
vertex_embeddings = tf.layers.dense(
|
||||||
vertices, self.embedding_dim, use_bias=True, name='vertex_embeddings')
|
vertices, self.embedding_dim, use_bias=True, name='vertex_embeddings')
|
||||||
vertex_embeddings *= vertices_mask[Ellipsis, None]
|
vertex_embeddings *= vertices_mask[..., None]
|
||||||
|
|
||||||
# Pad vertex embeddings with learned embeddings for stopping and new face
|
# Pad vertex embeddings with learned embeddings for stopping and new face
|
||||||
# tokens
|
# tokens
|
||||||
|
|||||||
+16
-16
@@ -106,7 +106,7 @@ TESTING_SUITE = [
|
|||||||
ALL = TUNING_SUITE + TESTING_SUITE
|
ALL = TUNING_SUITE + TESTING_SUITE
|
||||||
|
|
||||||
|
|
||||||
def _decode_frames(pngs):
|
def _decode_frames(pngs: tf.Tensor):
|
||||||
"""Decode PNGs.
|
"""Decode PNGs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -122,13 +122,13 @@ def _decode_frames(pngs):
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
def _make_reverb_sample(o_t,
|
def _make_reverb_sample(o_t: tf.Tensor,
|
||||||
a_t,
|
a_t: tf.Tensor,
|
||||||
r_t,
|
r_t: tf.Tensor,
|
||||||
d_t,
|
d_t: tf.Tensor,
|
||||||
o_tp1,
|
o_tp1: tf.Tensor,
|
||||||
a_tp1,
|
a_tp1: tf.Tensor,
|
||||||
extras):
|
extras: Dict[str, tf.Tensor]) -> reverb.ReplaySample:
|
||||||
"""Create Reverb sample with offline data.
|
"""Create Reverb sample with offline data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -151,8 +151,8 @@ def _make_reverb_sample(o_t,
|
|||||||
return reverb.ReplaySample(info=info, data=data)
|
return reverb.ReplaySample(info=info, data=data)
|
||||||
|
|
||||||
|
|
||||||
def _tf_example_to_reverb_sample(tf_example
|
def _tf_example_to_reverb_sample(tf_example: tf.train.Example
|
||||||
):
|
) -> reverb.ReplaySample:
|
||||||
"""Create a Reverb replay sample from a TF example."""
|
"""Create a Reverb replay sample from a TF example."""
|
||||||
|
|
||||||
# Parse tf.Example.
|
# Parse tf.Example.
|
||||||
@@ -184,11 +184,11 @@ def _tf_example_to_reverb_sample(tf_example
|
|||||||
extras)
|
extras)
|
||||||
|
|
||||||
|
|
||||||
def dataset(path,
|
def dataset(path: str,
|
||||||
game,
|
game: str,
|
||||||
run,
|
run: int,
|
||||||
num_shards = 100,
|
num_shards: int = 100,
|
||||||
shuffle_buffer_size = 100000):
|
shuffle_buffer_size: int = 100000) -> tf.data.Dataset:
|
||||||
"""TF dataset of Atari SARSA tuples."""
|
"""TF dataset of Atari SARSA tuples."""
|
||||||
path = os.path.join(path, f'{game}/run_{run}')
|
path = os.path.join(path, f'{game}/run_{run}')
|
||||||
filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(num_shards)]
|
filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(num_shards)]
|
||||||
@@ -243,7 +243,7 @@ class AtariDopamineWrapper(dm_env.Environment):
|
|||||||
return specs.DiscreteArray(self._env.action_space.n)
|
return specs.DiscreteArray(self._env.action_space.n)
|
||||||
|
|
||||||
|
|
||||||
def environment(game):
|
def environment(game: str) -> dm_env.Environment:
|
||||||
"""Atari environment."""
|
"""Atari environment."""
|
||||||
env = atari_lib.create_atari_environment(game_name=game,
|
env = atari_lib.create_atari_environment(game_name=game,
|
||||||
sticky_actions=True)
|
sticky_actions=True)
|
||||||
|
|||||||
@@ -773,15 +773,15 @@ def _padded_batch(example_ds, batch_size, shapes, drop_remainder=False):
|
|||||||
drop_remainder=drop_remainder)
|
drop_remainder=drop_remainder)
|
||||||
|
|
||||||
|
|
||||||
def dataset(root_path,
|
def dataset(root_path: str,
|
||||||
data_path,
|
data_path: str,
|
||||||
shapes,
|
shapes: Dict[str, Tuple[int]],
|
||||||
num_threads,
|
num_threads: int,
|
||||||
batch_size,
|
batch_size: int,
|
||||||
uint8_features = None,
|
uint8_features: Set[str] = None,
|
||||||
num_shards = 100,
|
num_shards: int = 100,
|
||||||
shuffle_buffer_size = 100000,
|
shuffle_buffer_size: int = 100000,
|
||||||
sarsa = True):
|
sarsa: bool = True) -> tf.data.Dataset:
|
||||||
"""Create tf dataset for training."""
|
"""Create tf dataset for training."""
|
||||||
|
|
||||||
uint8_features = uint8_features if uint8_features else {}
|
uint8_features = uint8_features if uint8_features else {}
|
||||||
|
|||||||
+17
-17
@@ -55,7 +55,7 @@ DELIMITER = ':'
|
|||||||
DEFAULT_NUM_TIMESTEPS = 1001
|
DEFAULT_NUM_TIMESTEPS = 1001
|
||||||
|
|
||||||
|
|
||||||
def _decombine_key(k, delimiter = DELIMITER):
|
def _decombine_key(k: str, delimiter: str = DELIMITER) -> Sequence[str]:
|
||||||
return k.split(delimiter)
|
return k.split(delimiter)
|
||||||
|
|
||||||
|
|
||||||
@@ -79,7 +79,7 @@ def tf_example_to_feature_description(example,
|
|||||||
|
|
||||||
|
|
||||||
def tree_deflatten_with_delimiter(
|
def tree_deflatten_with_delimiter(
|
||||||
flat_dict, delimiter = DELIMITER):
|
flat_dict: Dict[str, Any], delimiter: str = DELIMITER) -> Dict[str, Any]:
|
||||||
"""De-flattens a dict to its originally nested structure.
|
"""De-flattens a dict to its originally nested structure.
|
||||||
|
|
||||||
Does the opposite of {combine_nested_keys(k) :v
|
Does the opposite of {combine_nested_keys(k) :v
|
||||||
@@ -102,12 +102,12 @@ def tree_deflatten_with_delimiter(
|
|||||||
return dict(root)
|
return dict(root)
|
||||||
|
|
||||||
|
|
||||||
def get_slice_of_nested(nested, start,
|
def get_slice_of_nested(nested: Dict[str, Any], start: int,
|
||||||
end):
|
end: int) -> Dict[str, Any]:
|
||||||
return tree.map_structure(lambda item: item[start:end], nested)
|
return tree.map_structure(lambda item: item[start:end], nested)
|
||||||
|
|
||||||
|
|
||||||
def repeat_last_and_append_to_nested(nested):
|
def repeat_last_and_append_to_nested(nested: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
return tree.map_structure(
|
return tree.map_structure(
|
||||||
lambda item: tf.concat((item, item[-1:]), axis=0), nested)
|
lambda item: tf.concat((item, item[-1:]), axis=0), nested)
|
||||||
|
|
||||||
@@ -133,13 +133,13 @@ def tf_example_to_reverb_sample(example,
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def dataset(path,
|
def dataset(path: str,
|
||||||
combined_challenge,
|
combined_challenge: str,
|
||||||
domain,
|
domain: str,
|
||||||
task,
|
task: str,
|
||||||
difficulty,
|
difficulty: str,
|
||||||
num_shards = 100,
|
num_shards: int = 100,
|
||||||
shuffle_buffer_size = 100000):
|
shuffle_buffer_size: int = 100000) -> tf.data.Dataset:
|
||||||
"""TF dataset of RWRL SARSA tuples."""
|
"""TF dataset of RWRL SARSA tuples."""
|
||||||
path = os.path.join(
|
path = os.path.join(
|
||||||
path,
|
path,
|
||||||
@@ -178,11 +178,11 @@ def dataset(path,
|
|||||||
|
|
||||||
|
|
||||||
def environment(
|
def environment(
|
||||||
combined_challenge,
|
combined_challenge: str,
|
||||||
domain,
|
domain: str,
|
||||||
task,
|
task: str,
|
||||||
log_output = None,
|
log_output: Optional[str] = None,
|
||||||
environment_kwargs = None):
|
environment_kwargs: Optional[Dict[str, Any]] = None) -> dm_env.Environment:
|
||||||
"""RWRL environment."""
|
"""RWRL environment."""
|
||||||
env = rwrl_envs.load(
|
env = rwrl_envs.load(
|
||||||
domain_name=domain,
|
domain_name=domain,
|
||||||
|
|||||||
@@ -106,8 +106,8 @@ class Transporter(snt.AbstractModule):
|
|||||||
num_keypoints = image_a_keypoints["heatmaps"].shape[-1]
|
num_keypoints = image_a_keypoints["heatmaps"].shape[-1]
|
||||||
transported_features = image_a_features
|
transported_features = image_a_features
|
||||||
for k in range(num_keypoints):
|
for k in range(num_keypoints):
|
||||||
mask_a = image_a_keypoints["heatmaps"][Ellipsis, k, None]
|
mask_a = image_a_keypoints["heatmaps"][..., k, None]
|
||||||
mask_b = image_b_keypoints["heatmaps"][Ellipsis, k, None]
|
mask_b = image_b_keypoints["heatmaps"][..., k, None]
|
||||||
|
|
||||||
# suppress features from image a, around both keypoint locations.
|
# suppress features from image a, around both keypoint locations.
|
||||||
transported_features = (
|
transported_features = (
|
||||||
|
|||||||
Reference in New Issue
Block a user