Release of IODINE

PiperOrigin-RevId: 299101887
This commit is contained in:
Diego de Las Casas
2020-03-05 15:52:20 +00:00
parent a5efafff3a
commit afcdc77239
23 changed files with 7600 additions and 0 deletions
+1
View File
@@ -10,6 +10,7 @@ env:
matrix:
- PROJECT="tvt"
- PROJECT="cs_gan"
- PROJECT="iodine"
- PROJECT="transporter"
before_script:
- sudo apt-get update -qq
+1
View File
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
## Projects
* [Multi-Object Representation Learning with Iterative Variational Inference (IODINE)](iodine)
* [AlphaFold CASP13](alphafold_casp13), Nature 2020
* [Unrestricted Adversarial Challenge](unrestricted_advx)
* [Hierarchical Probabilistic U-Net (HPU-Net)](hierarchical_probabilistic_unet)
+666
View File
File diff suppressed because one or more lines are too long
+142
View File
@@ -0,0 +1,142 @@
# IODINE
Reference implementation for the paper ["Multi-Object Representation Learning with Iterative Variational Inference"](https://arxiv.org/abs/1903.00450).
This repository contains:
* An IODINE implementation in Tensorflow v1.
* Configurations used in the paper (checkpoints available in Cloud Storage) for:
* CLEVR
* Multi-dSprites
* Tetrominoes
* A notebook for running and inspecting the model and plotting the results
## Installation
1. Clone the DeepMind research repository:
``` bash
git clone https://github.com/deepmind/deepmind-research.git
cd deepmind-research
```
2. Download the checkpoints from GCP. A shell script is provided:
```bash
./iodine/download_checkpoints.sh
```
On platforms without wget, the files can be downloaded from [this webpage](https://console.cloud.google.com/storage/browser/deepmind-research-iodine?pli=1)
and the unzipped `checkpoints/` folder should be placed in
`deepmind-research/iodine/checkpoints`.
3. Prepare a Python 3 environment - virtualenv is recommended.
```bash
python3 -m venv iodine_venv
source iodine_venv/bin/activate
```
4. Install dependencies:
```bash
pip3 install -r iodine/requirements.txt
```
5. The `multi_object_datasets` package installed via requirements.txt provides python code to open the data files, but not the data files themselves.
Download the desired datasets either manually from the [Google Cloud Storage](https://console.cloud.google.com/storage/browser/multi-object-datasets) or using the commands below:
```bash
pushd iodine/multi_object_datasets
# CLEVR
wget https://storage.googleapis.com/multi-object-datasets/clevr_with_masks/clevr_with_masks_train.tfrecords
# Multi-dSprites
wget https://storage.googleapis.com/multi-object-datasets/multi_dsprites/multi_dsprites_colored_on_grayscale.tfrecords
# Tetrominoes
wget https://storage.googleapis.com/multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords
# Get back to location containing 'iodine' directory
popd
```
See [multi_object_datasets repository](https://github.com/deepmind/multi_object_datasets)
for further details.
6. Make sure that you have CUDA 10 and CuDNN 7 installed
## Interact with a Model
Use the jupyter notebook `Eval.ipynb` to load and run one of the checkpoints.
It also contains code to plot the outputs and latent traversals.
## Train a Model
To train your own model use the [Sacred](https://github.com/IDSIA/sacred) experiment defined in `main.py`.
The configurations used in the paper for the different datasets are available as [named configs](https://sacred.readthedocs.io/en/latest/configuration.html#named-configurations) inside of `configuration.py`.
### Train a new model
* CLEVR6
```bash
python3 -m iodine.main -f with clevr6
```
* Multi-dSprites
```bash
python3 -m iodine.main -f with multi_dsprites
```
* Tetrominoes
```bash
python3 -m iodine.main -f with tetrominoes
```
It is recommended to add an observer to your run to let Sacred record the details of run.
To add a [FileStorageObserver](https://sacred.readthedocs.io/en/latest/command_line.html#filestorage-observer) add `-F my_storage_dir`, and add `-m my_db_name` for a [MongoObserver](https://sacred.readthedocs.io/en/latest/command_line.html#mongodb-observer).
### Adjusting Config Values
The experiment has a configuration that can be printed and adjusted from the commandline. E.g.:
``` bash
# print configuration
python3 -m iodine.main -f print_config with clevr6
# run experiment after adjusting batch_size and the size of the shuffle buffer
python3 -m iodine.main -f with clevr6 batch_size=2 data.shuffle_buffer=100
```
### Tensorboard
Each run stores checkpoints and summaries in the directory specified by `checkpoint_dir`, to which a suffix based on the run_id is appended.
If an observer is added the `run_id` is set automatically. Otherwise it should be set manually using e.g. `run_id=5`.
Summaries can be viewed using tensorboard. E.g. like this for clevr6 (assuming `run_id=1`):
```bash
tensorboard --log-dir iodine/checkpoints/clevr6_1
```
### Continue Previous Run
To continue a previous run pass `continue_run=True` and the path of the checkpoints:
```bash
python3 -m iodine.main -f with clevr6 checkpoint_dir=iodine/checkpoints/clevr6_1
```
## Code Structure
The main experiment defined in `main.py` uses `sacred` and the configurations for the different datasets are added as named configs and can be found in `configuration.py`.
The model implementation can be found in the `modules` directory and is based on `tensorflow` and `sonnet`:
* `iodine.py` The main IODINE module that assembles the decoder, refinement network, distributions and factor regressor.
* `decoder.py` The ComponentDecoder which is a wrapper around networks that takes care of splitting the output channels into means and masks.
* `refinement.py` The refinement components assembles the encoder network, LSTM and refinement head.
* `networks.py` Different standard networks such as CNN, BroadcastCNN, and LSTM.
* `distribution.py` Definition of the latent and pixel distributions.
* `factor_eval.py` Contains the factor regressor which predicts the true factors from the inferred object latents.
* `data.py` Dataset wrappers around `multi_object_datasets` that take care of shuffling, batching and preprocessing.
* `plotting.py` Helper functions for plotting results.
* `utils.py` General helper functions.
---
**DISCLAIMER**
This is not an officially supported Google product.
---
+370
View File
@@ -0,0 +1,370 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurations for IODINE."""
# pylint: disable=missing-docstring, unused-variable
import math
def clevr6():
n_z = 64 # number of latent dimensions
num_components = 7 # number of components (K)
num_iters = 5
checkpoint_dir = "iodine/checkpoints/clevr6"
# For the paper we used 8 GPUs with a batch size of 4 each.
# This means a total batch size of 32, which is too large for a single GPU.
# When reducing the batch size, the learning rate should also be lowered.
batch_size = 4
learn_rate = 0.001 * math.sqrt(batch_size / 32)
data = {
"constructor": "iodine.modules.data.CLEVR",
"batch_size": batch_size,
"path": "multi_object_datasets/clevr_with_masks_train.tfrecords",
"max_num_objects": 6,
}
model = {
"constructor": "iodine.modules.iodine.IODINE",
"n_z": n_z,
"num_components": num_components,
"num_iters": num_iters,
"iter_loss_weight": "linspace",
"coord_type": "linear",
"decoder": {
"constructor": "iodine.modules.decoder.ComponentDecoder",
"pixel_decoder": {
"constructor": "iodine.modules.networks.BroadcastConv",
"cnn_opt": {
# Final channels is irrelevant with target_output_shape
"output_channels": [64, 64, 64, 64, None],
"kernel_shapes": [3],
"strides": [1],
"activation": "elu",
},
"coord_type": "linear",
},
},
"refinement_core": {
"constructor": "iodine.modules.refinement.RefinementCore",
"encoder_net": {
"constructor": "iodine.modules.networks.CNN",
"mode": "avg_pool",
"cnn_opt": {
"output_channels": [64, 64, 64, 64],
"strides": [2],
"kernel_shapes": [3],
"activation": "elu",
},
"mlp_opt": {
"output_sizes": [256, 256],
"activation": "elu"
},
},
"recurrent_net": {
"constructor": "iodine.modules.networks.LSTM",
"hidden_sizes": [256],
},
"refinement_head": {
"constructor": "iodine.modules.refinement.ResHead"
},
},
"latent_dist": {
"constructor": "iodine.modules.distributions.LocScaleDistribution",
"dist": "normal",
"scale_act": "softplus",
"scale": "var",
"name": "latent_dist",
},
"output_dist": {
"constructor": "iodine.modules.distributions.MaskedMixture",
"num_components": num_components,
"component_dist": {
"constructor":
"iodine.modules.distributions.LocScaleDistribution",
"dist":
"logistic",
"scale":
"fixed",
"scale_val":
0.03,
"name":
"pixel_distribution",
},
},
"factor_evaluator": {
"constructor":
"iodine.modules.factor_eval.FactorRegressor",
"mapping": [
("color", 9, "categorical"),
("shape", 4, "categorical"),
("size", 3, "categorical"),
("position", 3, "scalar"),
],
},
}
optimizer = {
"constructor": "tensorflow.train.AdamOptimizer",
"learning_rate": {
"constructor": "tensorflow.train.exponential_decay",
"learning_rate": learn_rate,
"global_step": {
"constructor": "tensorflow.train.get_or_create_global_step"
},
"decay_steps": 1000000,
"decay_rate": 0.1,
},
"beta1": 0.95,
}
def multi_dsprites():
n_z = 16 # number of latent dimensions
num_components = 6 # number of components (K)
num_iters = 5
checkpoint_dir = "iodine/checkpoints/multi_dsprites"
# For the paper we used 8 GPUs with a batch size of 16 each.
# This means a total batch size of 128, which is too large for a single GPU.
# When reducing the batch size, the learning rate should also be lowered.
batch_size = 16
learn_rate = 0.0003 * math.sqrt(batch_size / 128)
data = {
"constructor":
"iodine.modules.data.MultiDSprites",
"batch_size":
batch_size,
"path":
"multi_object_datasets/multi_dsprites_colored_on_grayscale.tfrecords",
"dataset_variant":
"colored_on_grayscale",
"min_num_objs":
3,
"max_num_objs":
3,
}
model = {
"constructor": "iodine.modules.iodine.IODINE",
"n_z": n_z,
"num_components": num_components,
"num_iters": num_iters,
"iter_loss_weight": "linspace",
"coord_type": "cos",
"coord_freqs": 3,
"decoder": {
"constructor": "iodine.modules.decoder.ComponentDecoder",
"pixel_decoder": {
"constructor": "iodine.modules.networks.BroadcastConv",
"cnn_opt": {
# Final channels is irrelevant with target_output_shape
"output_channels": [32, 32, 32, 32, None],
"kernel_shapes": [5],
"strides": [1],
"activation": "elu",
},
"coord_type": "linear",
},
},
"refinement_core": {
"constructor": "iodine.modules.refinement.RefinementCore",
"encoder_net": {
"constructor": "iodine.modules.networks.CNN",
"mode": "avg_pool",
"cnn_opt": {
"output_channels": [32, 32, 32],
"strides": [2],
"kernel_shapes": [5],
"activation": "elu",
},
"mlp_opt": {
"output_sizes": [128],
"activation": "elu"
},
},
"recurrent_net": {
"constructor": "iodine.modules.networks.LSTM",
"hidden_sizes": [128],
},
"refinement_head": {
"constructor": "iodine.modules.refinement.ResHead"
},
},
"latent_dist": {
"constructor": "iodine.modules.distributions.LocScaleDistribution",
"dist": "normal",
"scale_act": "softplus",
"scale": "var",
"name": "latent_dist",
},
"output_dist": {
"constructor": "iodine.modules.distributions.MaskedMixture",
"num_components": num_components,
"component_dist": {
"constructor":
"iodine.modules.distributions.LocScaleDistribution",
"dist":
"logistic",
"scale":
"fixed",
"scale_val":
0.03,
"name":
"pixel_distribution",
},
},
"factor_evaluator": {
"constructor":
"iodine.modules.factor_eval.FactorRegressor",
"mapping": [
("color", 3, "scalar"),
("shape", 4, "categorical"),
("scale", 1, "scalar"),
("x", 1, "scalar"),
("y", 1, "scalar"),
("orientation", 2, "angle"),
],
},
}
optimizer = {
"constructor": "tensorflow.train.AdamOptimizer",
"learning_rate": {
"constructor": "tensorflow.train.exponential_decay",
"learning_rate": learn_rate,
"global_step": {
"constructor": "tensorflow.train.get_or_create_global_step"
},
"decay_steps": 1000000,
"decay_rate": 0.1,
},
"beta1": 0.95,
}
def tetrominoes():
n_z = 32 # number of latent dimensions
num_components = 4 # number of components (K)
num_iters = 5
checkpoint_dir = "iodine/checkpoints/tetrominoes"
# For the paper we used 8 GPUs with a batch size of 32 each.
# This means a total batch size of 256, which is too large for a single GPU.
# When reducing the batch size, the learning rate should also be lowered.
batch_size = 128
learn_rate = 0.0003 * math.sqrt(batch_size / 256)
data = {
"constructor": "iodine.modules.data.Tetrominoes",
"batch_size": batch_size,
"path": "iodine/multi_object_datasets/tetrominoes_train.tfrecords",
}
model = {
"constructor": "iodine.modules.iodine.IODINE",
"n_z": n_z,
"num_components": num_components,
"num_iters": num_iters,
"iter_loss_weight": "linspace",
"coord_type": "cos",
"coord_freqs": 3,
"decoder": {
"constructor": "iodine.modules.decoder.ComponentDecoder",
"pixel_decoder": {
"constructor": "iodine.modules.networks.BroadcastConv",
"cnn_opt": {
# Final channels is irrelevant with target_output_shape
"output_channels": [32, 32, 32, 32, None],
"kernel_shapes": [5],
"strides": [1],
"activation": "elu",
},
"coord_type": "linear",
"coord_freqs": 3,
},
},
"refinement_core": {
"constructor": "iodine.modules.refinement.RefinementCore",
"encoder_net": {
"constructor": "iodine.modules.networks.CNN",
"mode": "avg_pool",
"cnn_opt": {
"output_channels": [32, 32, 32],
"strides": [2],
"kernel_shapes": [5],
"activation": "elu",
},
"mlp_opt": {
"output_sizes": [128],
"activation": "elu"
},
},
"recurrent_net": {
"constructor": "iodine.modules.networks.LSTM",
"hidden_sizes": [], # No recurrent layer used for this dataset
},
"refinement_head": {
"constructor": "iodine.modules.refinement.ResHead"
},
},
"latent_dist": {
"constructor": "iodine.modules.distributions.LocScaleDistribution",
"dist": "normal",
"scale_act": "softplus",
"scale": "var",
"name": "latent_dist",
},
"output_dist": {
"constructor": "iodine.modules.distributions.MaskedMixture",
"num_components": num_components,
"component_dist": {
"constructor":
"iodine.modules.distributions.LocScaleDistribution",
"dist":
"logistic",
"scale":
"fixed",
"scale_val":
0.03,
"name":
"pixel_distribution",
},
},
"factor_evaluator": {
"constructor":
"iodine.modules.factor_eval.FactorRegressor",
"mapping": [
("position", 2, "scalar"),
("color", 3, "scalar"),
("shape", 20, "categorical"),
],
},
}
optimizer = {
"constructor": "tensorflow.train.AdamOptimizer",
"learning_rate": {
"constructor": "tensorflow.train.exponential_decay",
"learning_rate": learn_rate,
"global_step": {
"constructor": "tensorflow.train.get_or_create_global_step"
},
"decay_steps": 1000000,
"decay_rate": 0.1,
},
"beta1": 0.95,
}
+20
View File
@@ -0,0 +1,20 @@
#!/bin/bash
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pushd iodine
wget http://storage.googleapis.com/deepmind-research-iodine/iodine_checkpoints.zip
unzip iodine_checkpoints.zip
popd
File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 148 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

+202
View File
@@ -0,0 +1,202 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=g-importing-member, g-multiple-import, g-import-not-at-top
# pylint: disable=protected-access, g-bad-import-order, missing-docstring
# pylint: disable=unused-variable, invalid-name, no-value-for-parameter
from copy import deepcopy
import os.path
import warnings
from absl import logging
import numpy as np
from sacred import Experiment, SETTINGS
# Ignore all tensorflow deprecation warnings
logging._warn_preinit_stderr = 0
warnings.filterwarnings("ignore", module=".*tensorflow.*")
import tensorflow.compat.v1 as tf
tf.logging.set_verbosity(tf.logging.ERROR)
import sonnet as snt
from sacred.stflow import LogFileWriter
from iodine.modules import utils
from iodine import configurations
SETTINGS.CONFIG.READ_ONLY_CONFIG = False
ex = Experiment("iodine")
@ex.config
def default_config():
continue_run = False # set to continue experiment from an existing checkpoint
checkpoint_dir = ("checkpoints/iodine"
) # if continue_run is False, "_{run_id}" will be appended
save_summaries_steps = 10
save_checkpoint_steps = 1000
n_z = 64 # number of latent dimensions
num_components = 7 # number of components (K)
num_iters = 5
learn_rate = 0.001
batch_size = 4
stop_after_steps = int(1e6)
# Details for the dataset, model and optimizer are left empty here.
# They can be found in the configurations for individual datasets,
# which are provided in configurations.py and added as named configs.
data = {} # Dataset details will go here
model = {} # Model details will go here
optimizer = {} # Optimizer details will go here
ex.named_config(configurations.clevr6)
ex.named_config(configurations.multi_dsprites)
ex.named_config(configurations.tetrominoes)
@ex.capture
def build(identifier, _config):
config_copy = deepcopy(_config[identifier])
return utils.build(config_copy, identifier=identifier)
def get_train_step(model, dataset, optimizer):
loss, scalars, _ = model(dataset("train"))
global_step = tf.train.get_or_create_global_step()
grads = optimizer.compute_gradients(loss)
gradients, variables = zip(*grads)
global_norm = tf.global_norm(gradients)
gradients, global_norm = tf.clip_by_global_norm(
gradients, 5.0, use_norm=global_norm)
grads = zip(gradients, variables)
train_op = optimizer.apply_gradients(grads, global_step=global_step)
with tf.control_dependencies([train_op]):
overview = model.get_overview_images(dataset("summary"))
scalars["debug/global_grad_norm"] = global_norm
summaries = {
k: tf.summary.scalar(k, v) for k, v in scalars.items()
}
summaries.update(
{k: tf.summary.image(k, v) for k, v in overview.items()})
return tf.identity(global_step), scalars, train_op
@ex.capture
def get_checkpoint_dir(continue_run, checkpoint_dir, _run, _log):
if continue_run:
assert os.path.exists(checkpoint_dir)
_log.info("Continuing run from checkpoint at {}".format(checkpoint_dir))
return checkpoint_dir
run_id = _run._id
if run_id is None: # then no observer was added that provided an _id
if not _run.unobserved:
_log.warning(
"No run_id given or provided by an Observer. (Re-)using run_id=1.")
run_id = 1
checkpoint_dir = checkpoint_dir + "_{run_id}".format(run_id=run_id)
_log.info(
"Starting a new run using checkpoint dir: '{}'".format(checkpoint_dir))
return checkpoint_dir
@ex.capture
def get_session(chkp_dir, loss, stop_after_steps, save_summaries_steps,
save_checkpoint_steps):
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
hooks = [
tf.train.StopAtStepHook(last_step=stop_after_steps),
tf.train.NanTensorHook(loss),
]
return tf.train.MonitoredTrainingSession(
hooks=hooks,
config=config,
checkpoint_dir=chkp_dir,
save_summaries_steps=save_summaries_steps,
save_checkpoint_steps=save_checkpoint_steps,
)
@ex.command(unobserved=True)
def load_checkpoint(use_placeholder=False, session=None):
dataset = build("data")
model = build("model")
if use_placeholder:
inputs = dataset.get_placeholders()
else:
inputs = dataset()
info = model.eval(inputs)
if session is None:
session = tf.Session()
saver = tf.train.Saver()
checkpoint_dir = get_checkpoint_dir()
checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
saver.restore(session, checkpoint_file)
print('Successfully restored Checkpoint "{}"'.format(checkpoint_file))
# print variables
variables = tf.global_variables() + tf.local_variables()
for row in snt.format_variables(variables, join_lines=False):
print(row)
return {
"session": session,
"model": model,
"info": info,
"inputs": inputs,
"dataset": dataset,
}
@ex.automain
@LogFileWriter(ex)
def main(save_summaries_steps):
checkpoint_dir = get_checkpoint_dir()
dataset = build("data")
model = build("model")
optimizer = build("optimizer")
gstep, train_step_exports, train_op = get_train_step(model, dataset,
optimizer)
loss, ari = [], []
with get_session(checkpoint_dir, train_step_exports["loss/total"]) as sess:
while not sess.should_stop():
out = sess.run({
"step": gstep,
"loss": train_step_exports["loss/total"],
"ari": train_step_exports["loss/ari_nobg"],
"train": train_op,
})
loss.append(out["loss"])
ari.append(out["ari"])
step = out["step"]
if step % save_summaries_steps == 0:
mean_loss = np.mean(loss)
mean_ari = np.mean(ari)
ex.log_scalar("loss", mean_loss, step)
ex.log_scalar("ari", mean_ari, step)
print("{step:>6d} Loss: {loss: >12.2f}\t\tARI-nobg:{ari: >6.2f}".format(
step=step, loss=mean_loss, ari=mean_ari))
loss, ari = [], []
+13
View File
@@ -0,0 +1,13 @@
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+264
View File
@@ -0,0 +1,264 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data loading functionality for IODINE."""
# pylint: disable=g-multiple-import, missing-docstring, unused-import
import os.path
from iodine.modules.utils import flatten_all_but_last, ensure_3d
from multi_object_datasets import (
clevr_with_masks,
multi_dsprites,
tetrominoes,
objects_room,
)
from shapeguard import ShapeGuard
import sonnet as snt
import tensorflow.compat.v1 as tf
class IODINEDataset(snt.AbstractModule):
num_true_objects = 1
num_channels = 3
factors = {}
def __init__(
self,
path,
batch_size,
image_dim,
crop_region=None,
shuffle_buffer=1000,
max_num_objects=None,
min_num_objects=None,
grayscale=False,
name="dataset",
**kwargs,
):
super().__init__(name=name)
self.path = os.path.abspath(os.path.expanduser(path))
self.batch_size = batch_size
self.crop_region = crop_region
self.image_dim = image_dim
self.shuffle_buffer = shuffle_buffer
self.max_num_objects = max_num_objects
self.min_num_objects = min_num_objects
self.grayscale = grayscale
self.dataset = None
def _build(self, subset="train"):
dataset = self.dataset
# filter by number of objects
if self.max_num_objects is not None or self.min_num_objects is not None:
dataset = self.dataset.filter(self.filter_by_num_objects)
if subset == "train":
# normal mode returns a shuffled dataset iterator
if self.shuffle_buffer is not None:
dataset = dataset.shuffle(self.shuffle_buffer)
elif subset == "summary":
# for generating summaries and overview images
# returns a single fixed batch
dataset = dataset.take(self.batch_size)
# repeat and batch
dataset = dataset.repeat().batch(self.batch_size, drop_remainder=True)
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()
# preprocess the data to ensure correct format, scale images etc.
data = self.preprocess(data)
return data
def filter_by_num_objects(self, d):
if "visibility" not in d:
return tf.constant(True)
min_num_objects = self.max_num_objects or 0
max_num_objects = self.max_num_objects or 6
min_predicate = tf.greater_equal(
tf.reduce_sum(d["visibility"]),
tf.constant(min_num_objects - 1e-5, dtype=tf.float32),
)
max_predicate = tf.less_equal(
tf.reduce_sum(d["visibility"]),
tf.constant(max_num_objects + 1e-5, dtype=tf.float32),
)
return tf.logical_and(min_predicate, max_predicate)
def preprocess(self, data):
sg = ShapeGuard(dims={
"B": self.batch_size,
"H": self.image_dim[0],
"W": self.image_dim[1]
})
image = sg.guard(data["image"], "B, h, w, C")
mask = sg.guard(data["mask"], "B, L, h, w, 1")
# to float
image = tf.cast(image, tf.float32) / 255.0
mask = tf.cast(mask, tf.float32) / 255.0
# crop
if self.crop_region is not None:
height_slice = slice(self.crop_region[0][0], self.crop_region[0][1])
width_slice = slice(self.crop_region[1][0], self.crop_region[1][1])
image = image[:, height_slice, width_slice, :]
mask = mask[:, :, height_slice, width_slice, :]
flat_mask, unflatten = flatten_all_but_last(mask, n_dims=3)
# rescale
size = tf.constant(
self.image_dim, dtype=tf.int32, shape=[2], verify_shape=True)
image = tf.image.resize_images(
image, size, method=tf.image.ResizeMethod.BILINEAR)
mask = tf.image.resize_images(
flat_mask, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
if self.grayscale:
image = tf.reduce_mean(image, axis=-1, keepdims=True)
output = {
"image": sg.guard(image[:, None], "B, T, H, W, C"),
"mask": sg.guard(unflatten(mask)[:, None], "B, T, L, H, W, 1"),
"factors": self.preprocess_factors(data, sg),
}
if "visibility" in data:
output["visibility"] = sg.guard(data["visibility"], "B, L")
else:
output["visibility"] = tf.ones(sg["B, L"], dtype=tf.float32)
return output
def preprocess_factors(self, data, sg):
return {
name: sg.guard(ensure_3d(data[name]), "B, L, *")
for name in self.factors
}
def get_placeholders(self, batch_size=None):
batch_size = batch_size or self.batch_size
sg = ShapeGuard(
dims={
"B": batch_size,
"H": self.image_dim[0],
"W": self.image_dim[1],
"L": self.num_true_objects,
"C": 3,
"T": 1,
})
return {
"image": tf.placeholder(dtype=tf.float32, shape=sg["B, T, H, W, C"]),
"mask": tf.placeholder(dtype=tf.float32, shape=sg["B, T, L, H, W, 1"]),
"visibility": tf.placeholder(dtype=tf.float32, shape=sg["B, L"]),
"factors": {
name:
tf.placeholder(dtype=dtype, shape=sg["B, L, {}".format(size)])
for name, (dtype, size) in self.factors
},
}
class CLEVR(IODINEDataset):
num_true_objects = 11
num_channels = 3
factors = {
"color": (tf.uint8, 1),
"shape": (tf.uint8, 1),
"size": (tf.uint8, 1),
"position": (tf.float32, 3),
"rotation": (tf.float32, 1),
}
def __init__(
self,
path,
crop_region=((29, 221), (64, 256)),
image_dim=(128, 128),
name="clevr",
**kwargs,
):
super().__init__(
path=path,
crop_region=crop_region,
image_dim=image_dim,
name=name,
**kwargs)
self.dataset = clevr_with_masks.dataset(self.path)
def preprocess_factors(self, data, sg):
return {
"color": sg.guard(ensure_3d(data["color"]), "B, L, 1"),
"shape": sg.guard(ensure_3d(data["shape"]), "B, L, 1"),
"size": sg.guard(ensure_3d(data["color"]), "B, L, 1"),
"position": sg.guard(ensure_3d(data["pixel_coords"]), "B, L, 3"),
"rotation": sg.guard(ensure_3d(data["rotation"]), "B, L, 1"),
}
class MultiDSprites(IODINEDataset):
num_true_objects = 6
num_channels = 3
factors = {
"color": (tf.float32, 3),
"shape": (tf.uint8, 1),
"scale": (tf.float32, 1),
"x": (tf.float32, 1),
"y": (tf.float32, 1),
"orientation": (tf.float32, 1),
}
def __init__(
self,
path,
# variant from ['binarized', 'colored_on_grayscale', 'colored_on_colored']
dataset_variant="colored_on_grayscale",
image_dim=(64, 64),
name="multi_dsprites",
**kwargs,
):
super().__init__(path=path, name=name, image_dim=image_dim, **kwargs)
self.dataset_variant = dataset_variant
self.dataset = multi_dsprites.dataset(self.path, self.dataset_variant)
class Tetrominoes(IODINEDataset):
num_true_objects = 6
num_channels = 3
factors = {
"color": (tf.uint8, 3),
"shape": (tf.uint8, 1),
"position": (tf.float32, 2),
}
def __init__(self, path, image_dim=(35, 35), name="tetrominoes", **kwargs):
super().__init__(path=path, name=name, image_dim=image_dim, **kwargs)
self.dataset = tetrominoes.dataset(self.path)
def preprocess_factors(self, data, sg):
pos_x = ensure_3d(data["x"])
pos_y = ensure_3d(data["y"])
position = tf.concat([pos_x, pos_y], axis=2)
return {
"color": sg.guard(ensure_3d(data["color"]), "B, L, 3"),
"shape": sg.guard(ensure_3d(data["shape"]), "B, L, 1"),
"position": sg.guard(ensure_3d(position), "B, L, 2"),
}
+49
View File
@@ -0,0 +1,49 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoders for rendering images."""
# pylint: disable=missing-docstring
from iodine.modules.distributions import MixtureParameters
import shapeguard
import sonnet as snt
class ComponentDecoder(snt.AbstractModule):
def __init__(self, pixel_decoder, name="component_decoder"):
super().__init__(name=name)
self._pixel_decoder = pixel_decoder
self._sg = shapeguard.ShapeGuard()
def set_output_shapes(self, pixel, mask):
self._sg.guard(pixel, "K, H, W, Cp")
self._sg.guard(mask, "K, H, W, 1")
self._pixel_decoder.set_output_shapes(self._sg["H, W, 1 + Cp"])
def _build(self, z):
self._sg.guard(z, "B, K, Z")
z_flat = self._sg.reshape(z, "B*K, Z")
pixel_params = self._pixel_decoder(z_flat).params
self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp")
mask_params = pixel_params[Ellipsis, 0:1]
pixel_params = pixel_params[Ellipsis, 1:]
output = MixtureParameters(
pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"),
mask=self._sg.reshape(mask_params, "B, K, H, W, 1"),
)
del self._sg.B
return output
+223
View File
@@ -0,0 +1,223 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Collection of sonnet modules that wrap useful distributions."""
# pylint: disable=missing-docstring, g-doc-args, g-short-docstring-punctuation
# pylint: disable=g-space-before-docstring-summary
# pylint: disable=g-no-space-after-docstring-summary
import collections
from iodine.modules.utils import get_act_func
from iodine.modules.utils import get_distribution
import shapeguard
import sonnet as snt
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
FlatParameters = collections.namedtuple("ParameterOut", ["params"])
MixtureParameters = collections.namedtuple("MixtureOut", ["pixel", "mask"])
class DistributionModule(snt.AbstractModule):
"""Distribution Base class supporting shape inference & default priors."""
def __init__(self, name="distribution"):
super().__init__(name=name)
self._output_shape = None
def set_output_shape(self, shape):
self._output_shape = shape
@property
def output_shape(self):
return self._output_shape
@property
def input_shapes(self):
raise NotImplementedError()
def get_default_prior(self, batch_dim=(1,)):
return self(
tf.zeros(list(batch_dim) + self.input_shapes.params, dtype=tf.float32))
class BernoulliOutput(DistributionModule):
def __init__(self, name="bernoulli_output"):
super().__init__(name=name)
@property
def input_shapes(self):
return FlatParameters(self.output_shape)
def _build(self, params):
return tfd.Independent(
tfd.Bernoulli(logits=params, dtype=tf.float32),
reinterpreted_batch_ndims=1)
class LocScaleDistribution(DistributionModule):
"""Generic IID location / scale distribution.
Input parameters are concatenation of location and scale (2*Z,)
Args:
dist: Distribution or str Kind of distribution used. Supports Normal,
Logistic, Laplace, and StudentT distributions.
dist_kwargs: dict custom keyword arguments for the distribution
scale_act: function or str or None activation function to be applied to
the scale input
scale: str
different modes for computing the scale:
* stddev: scale is computed as scale_act(s)
* var: scale is computed as sqrt(scale_act(s))
* prec: scale is computed as 1./scale_act(s)
* fixed: scale is a global variable (same for all pixels) if
scale_val==-1. then it is a trainable variable initialized to 0.1
else it is fixed to scale_val (input shape is only (Z,) in this
case)
scale_val: float determines the scale value (only used if scale=='fixed').
loc_act: function or str or None activation function to be applied to the
location input. Supports optional activation functions for scale and
location.
Supports different "modes" for scaling:
* stddev:
"""
def __init__(
self,
dist=tfd.Normal,
dist_kwargs=None,
scale_act=tf.exp,
scale="stddev",
scale_val=1.0,
loc_act=None,
name="loc_scale_dist",
):
super().__init__(name=name)
self._scale_act = get_act_func(scale_act)
self._loc_act = get_act_func(loc_act)
# supports Normal, Logstic, Laplace, StudentT
self._dist = get_distribution(dist)
self._dist_kwargs = dist_kwargs or {}
assert scale in ["stddev", "var", "prec", "fixed"], scale
self._scale = scale
self._scale_val = scale_val
@property
def input_shapes(self):
if self._scale == "fixed":
param_shape = self.output_shape
else:
param_shape = self.output_shape[:-1] + [self.output_shape[-1] * 2]
return FlatParameters(param_shape)
def _build(self, params):
if self._scale == "fixed":
loc = params
scale = None # set later
else:
n_channels = params.get_shape().as_list()[-1]
assert n_channels % 2 == 0
assert n_channels // 2 == self.output_shape[-1]
loc = params[Ellipsis, :n_channels // 2]
scale = params[Ellipsis, n_channels // 2:]
# apply activation functions
if self._scale != "fixed":
scale = self._scale_act(scale)
loc = self._loc_act(loc)
# apply the correct parametrization
if self._scale == "var":
scale = tf.sqrt(scale)
elif self._scale == "prec":
scale = tf.reciprocal(scale)
elif self._scale == "fixed":
if self._scale_val == -1.0:
scale_val = tf.get_variable(
"scale", initializer=tf.constant(0.1, dtype=tf.float32))
else:
scale_val = self._scale_val
scale = tf.ones_like(loc) * scale_val
# else 'stddev'
dist = self._dist(loc=loc, scale=scale, **self._dist_kwargs)
return tfd.Independent(dist, reinterpreted_batch_ndims=1)
class MaskedMixture(DistributionModule):
def __init__(
self,
num_components,
component_dist,
mask_activation=None,
name="masked_mixture",
):
"""
Spatial Mixture Model composed of a categorical masking distribution and
a custom pixel-wise component distribution (usually logistic or
gaussian).
Args:
num_components: int Number of mixture components >= 2
component_dist: the distribution to use for the individual components
mask_activation: str or function or None activation function that
should be applied to the mask before the softmax.
name: str
"""
super().__init__(name=name)
self._num_components = num_components
self._dist = component_dist
self._mask_activation = get_act_func(mask_activation)
def set_output_shape(self, shape):
super().set_output_shape(shape)
self._dist.set_output_shape(shape)
def _build(self, pixel, mask):
sg = shapeguard.ShapeGuard()
# MASKING
sg.guard(mask, "B, K, H, W, 1")
mask = tf.transpose(mask, perm=[0, 2, 3, 4, 1])
mask = sg.reshape(mask, "B, H, W, K")
mask = self._mask_activation(mask)
mask = mask[:, tf.newaxis] # add K=1 axis since K is removed by mixture
mix_dist = tfd.Categorical(logits=mask)
# COMPONENTS
sg.guard(pixel, "B, K, H, W, Cp")
params = tf.transpose(pixel, perm=[0, 2, 3, 1, 4])
params = params[:, tf.newaxis] # add K=1 axis since K is removed by mixture
dist = self._dist(params)
return tfd.MixtureSameFamily(
mixture_distribution=mix_dist, components_distribution=dist)
@property
def input_shapes(self):
pixel = [self._num_components] + self._dist.input_shapes.params
mask = pixel[:-1] + [1]
return MixtureParameters(pixel, mask)
def get_default_prior(self, batch_dim=(1,)):
pixel = tf.zeros(
list(batch_dim) + self.input_shapes.pixel, dtype=tf.float32)
mask = tf.zeros(list(batch_dim) + self.input_shapes.mask, dtype=tf.float32)
return self(pixel, mask)
+206
View File
@@ -0,0 +1,206 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factor Evaluation Module."""
# pylint: disable=unused-variable
import collections
import functools
from iodine.modules import utils
import shapeguard
import sonnet as snt
import tensorflow.compat.v1 as tf
Factor = collections.namedtuple("Factor", ["name", "size", "type"])
class FactorRegressor(snt.AbstractModule):
"""Assess representations by learning a linear mapping to latents."""
def __init__(self, mapping=None, name="repres_content"):
super().__init__(name=name)
if mapping is None:
self._mapping = [
Factor("color", 3, "scalar"),
Factor("shape", 4, "categorical"),
Factor("scale", 1, "scalar"),
Factor("x", 1, "scalar"),
Factor("y", 1, "scalar"),
Factor("orientation", 2, "angle"),
]
else:
self._mapping = [Factor(*m) for m in mapping]
def _build(self, z, latent, visibility, pred_mask, true_mask):
sg = shapeguard.ShapeGuard()
z = sg.guard(z, "B, K, Z")
pred_mask = sg.guard(pred_mask, "B, K, H, W, 1")
true_mask = sg.guard(true_mask, "B, L, H, W, 1")
visibility = sg.guard(visibility, "B, L")
num_visible_obj = tf.reduce_sum(visibility)
# Map z to predictions for all latents
sg.M = sum([m.size for m in self._mapping])
self.predictor = snt.Linear(sg.M, name="predict_latents")
z_flat = sg.reshape(z, "B*K, Z")
all_preds = sg.guard(self.predictor(z_flat), "B*K, M")
all_preds = sg.reshape(all_preds, "B, 1, K, M")
all_preds = tf.tile(all_preds, sg["1, L, 1, 1"])
# prepare latents
latents = {}
mean_var_tot = {}
for m in self._mapping:
with tf.name_scope(m.name):
# preprocess, reshape, and tile
lat_preprocess = self.get_preprocessing(m)
lat = sg.guard(
lat_preprocess(latent[m.name]), "B, L, {}".format(m.size))
# compute mean over latent by training a variable using mse
if m.type in {"scalar", "angle"}:
mvt = utils.OnlineMeanVarEstimator(
axis=[0, 1], ddof=1, name="{}_mean_var".format(m.name))
mean_var_tot[m.name] = mvt(lat, visibility[:, :, tf.newaxis])
lat = tf.reshape(lat, sg["B, L, 1"] + [-1])
lat = tf.tile(lat, sg["1, 1, K, 1"])
latents[m.name] = lat
# prepare predictions
idx = 0
predictions = {}
for m in self._mapping:
with tf.name_scope(m.name):
assert m.name in latent, "{} not in {}".format(m.name, latent.keys())
pred = all_preds[Ellipsis, idx:idx + m.size]
predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size))
idx += m.size
# compute error
total_pairwise_errors = None
for m in self._mapping:
with tf.name_scope(m.name):
error_fn = self.get_error_func(m)
sg.guard(latents[m.name], "B, L, K, {}".format(m.size))
sg.guard(predictions[m.name], "B, L, K, {}".format(m.size))
err = error_fn(latents[m.name], predictions[m.name])
sg.guard(err, "B, L, K")
if total_pairwise_errors is None:
total_pairwise_errors = err
else:
total_pairwise_errors += err
# determine best assignment by comparing masks
obj_mask = true_mask[:, :, tf.newaxis]
pred_mask = pred_mask[:, tf.newaxis]
pairwise_overlap = tf.reduce_sum(obj_mask * pred_mask, axis=[3, 4, 5])
best_match = sg.guard(tf.argmax(pairwise_overlap, axis=2), "B, L")
assignment = tf.one_hot(best_match, sg.K)
assignment *= visibility[:, :, tf.newaxis] # Mask non-visible objects
# total error
total_error = (
tf.reduce_sum(assignment * total_pairwise_errors) / num_visible_obj)
# compute scalars
monitored_scalars = {}
for m in self._mapping:
with tf.name_scope(m.name):
metric = self.get_metric(m)
scalar = metric(
latents[m.name],
predictions[m.name],
assignment[:, :, :, tf.newaxis],
mean_var_tot.get(m.name),
num_visible_obj,
)
monitored_scalars[m.name] = scalar
return total_error, monitored_scalars, mean_var_tot, predictions, assignment
@snt.reuse_variables
def predict(self, z):
sg = shapeguard.ShapeGuard()
z = sg.guard(z, "B, Z")
all_preds = sg.guard(self.predictor(z), "B, M")
idx = 0
predictions = {}
for m in self._mapping:
with tf.name_scope(m.name):
pred = all_preds[:, idx:idx + m.size]
predictions[m.name] = sg.guard(pred, "B, {}".format(m.size))
idx += m.size
return predictions
@staticmethod
def get_error_func(factor):
if factor.type in {"scalar", "angle"}:
return sse
elif factor.type == "categorical":
return functools.partial(
tf.losses.softmax_cross_entropy, reduction="none")
else:
raise KeyError(factor.type)
@staticmethod
def get_metric(factor):
if factor.type in {"scalar", "angle"}:
return r2
elif factor.type == "categorical":
return accuracy
else:
raise KeyError(factor.type)
@staticmethod
def one_hot(f, nr_categories):
return tf.one_hot(tf.cast(f[Ellipsis, 0], tf.int32), depth=nr_categories)
@staticmethod
def angle_to_vector(theta):
return tf.concat([tf.math.cos(theta), tf.math.sin(theta)], axis=-1)
@staticmethod
def get_preprocessing(factor):
if factor.type == "scalar":
return tf.identity
elif factor.type == "categorical":
return functools.partial(
FactorRegressor.one_hot, nr_categories=factor.size)
elif factor.type == "angle":
return FactorRegressor.angle_to_vector
else:
raise KeyError(factor.type)
def sse(true, pred):
# run our own sum squared error because we want to reduce sum over last dim
return tf.reduce_sum(tf.square(true - pred), axis=-1)
def accuracy(labels, logits, assignment, mean_var_tot, num_vis):
del mean_var_tot # unused
pred = tf.argmax(logits, axis=-1, output_type=tf.int32)
labels = tf.argmax(labels, axis=-1, output_type=tf.int32)
correct = tf.cast(tf.equal(labels, pred), tf.float32)
return tf.reduce_sum(correct * assignment[Ellipsis, 0]) / num_vis
def r2(labels, pred, assignment, mean_var_tot, num_vis):
del num_vis # unused
mean, var, _ = mean_var_tot
# labels, pred: (B, L, K, n)
ss_res = tf.reduce_sum(tf.square(labels - pred) * assignment, axis=2)
ss_tot = var[tf.newaxis, tf.newaxis, :] # (1, 1, n)
return tf.reduce_mean(1.0 - ss_res / ss_tot)
File diff suppressed because it is too large Load Diff
+300
View File
@@ -0,0 +1,300 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Network modules."""
# pylint: disable=g-multiple-import, g-doc-args, g-short-docstring-punctuation
# pylint: disable=g-no-space-after-docstring-summary
from iodine.modules.distributions import FlatParameters
from iodine.modules.utils import flatten_all_but_last, get_act_func
import numpy as np
import shapeguard
import sonnet as snt
import tensorflow.compat.v1 as tf
class CNN(snt.AbstractModule):
"""ConvNet2D followed by an MLP.
This is a typical encoder architecture for VAEs, and has been found to work
well. One small improvement is to append coordinate channels on the input,
though for most datasets the improvement obtained is negligible.
"""
def __init__(self, cnn_opt, mlp_opt, mode="flatten", name="cnn"):
"""Constructor.
Args:
cnn_opt: Dictionary. Kwargs for the cnn. See vae_lib.ConvNet2D for
details.
mlp_opt: Dictionary. Kwargs for the mlp. See vae_lib.MLP for details.
name: String. Optional name.
"""
super().__init__(name=name)
if "activation" in cnn_opt:
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
self._cnn_opt = cnn_opt
if "activation" in mlp_opt:
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
self._mlp_opt = mlp_opt
self._mode = mode
def set_output_shapes(self, shape):
# assert self._mlp_opt['output_sizes'][-1] is None, self._mlp_opt
sg = shapeguard.ShapeGuard()
sg.guard(shape, "1, Y")
self._mlp_opt["output_sizes"][-1] = sg.Y
def _build(self, image):
"""Connect model to TensorFlow graph."""
assert self._mlp_opt["output_sizes"][-1] is not None, "set output_shapes"
sg = shapeguard.ShapeGuard()
flat_image, unflatten = flatten_all_but_last(image, n_dims=3)
sg.guard(flat_image, "B, H, W, C")
cnn = snt.nets.ConvNet2D(
activate_final=True,
paddings=("SAME",),
normalize_final=False,
**self._cnn_opt)
mlp = snt.nets.MLP(**self._mlp_opt)
# run CNN
net = cnn(flat_image)
if self._mode == "flatten":
# flatten
net_shape = net.get_shape().as_list()
flat_shape = net_shape[:-3] + [np.prod(net_shape[-3:])]
net = tf.reshape(net, flat_shape)
elif self._mode == "avg_pool":
net = tf.reduce_mean(net, axis=[1, 2])
else:
raise KeyError('Unknown mode "{}"'.format(self._mode))
# run MLP
output = sg.guard(mlp(net), "B, Y")
return FlatParameters(unflatten(output))
class MLP(snt.AbstractModule):
"""MLP."""
def __init__(self, name="mlp", **mlp_opt):
super().__init__(name=name)
if "activation" in mlp_opt:
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
self._mlp_opt = mlp_opt
assert mlp_opt["output_sizes"][-1] is None, mlp_opt
def set_output_shapes(self, shape):
sg = shapeguard.ShapeGuard()
sg.guard(shape, "1, Y")
self._mlp_opt["output_sizes"][-1] = sg.Y
def _build(self, data):
"""Connect model to TensorFlow graph."""
assert self._mlp_opt["output_sizes"][-1] is not None, "set output_shapes"
sg = shapeguard.ShapeGuard()
flat_data, unflatten = flatten_all_but_last(data)
sg.guard(flat_data, "B, N")
mlp = snt.nets.MLP(**self._mlp_opt)
# run MLP
output = sg.guard(mlp(flat_data), "B, Y")
return FlatParameters(unflatten(output))
class DeConv(snt.AbstractModule):
"""MLP followed by Deconv net.
This decoder is commonly used by vanilla VAE models. However, in practice
BroadcastConv (see below) seems to disentangle slightly better.
"""
def __init__(self, mlp_opt, cnn_opt, name="deconv"):
"""Constructor.
Args:
mlp_opt: Dictionary. Kwargs for vae_lib.MLP.
cnn_opt: Dictionary. Kwargs for vae_lib.ConvNet2D for the CNN.
name: Optional name.
"""
super().__init__(name=name)
assert cnn_opt["output_channels"][-1] is None, cnn_opt
if "activation" in cnn_opt:
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
self._cnn_opt = cnn_opt
if mlp_opt and "activation" in mlp_opt:
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
self._mlp_opt = mlp_opt
self._target_out_shape = None
def set_output_shapes(self, shape):
self._target_out_shape = shape
self._cnn_opt["output_channels"][-1] = self._target_out_shape[-1]
def _build(self, z):
"""Connect model to TensorFlow graph."""
sg = shapeguard.ShapeGuard()
flat_z, unflatten = flatten_all_but_last(z)
sg.guard(flat_z, "B, Z")
sg.guard(self._target_out_shape, "H, W, C")
mlp = snt.nets.MLP(**self._mlp_opt)
cnn = snt.nets.ConvNet2DTranspose(
paddings=("SAME",), normalize_final=False, **self._cnn_opt)
net = mlp(flat_z)
output = sg.guard(cnn(net), "B, H, W, C")
return FlatParameters(unflatten(output))
class BroadcastConv(snt.AbstractModule):
"""MLP followed by a broadcast convolution.
This decoder takes a latent vector z, (optionally) applies an MLP to it,
then tiles the resulting vector across space to have dimension [B, H, W, C]
i.e. tiles across H and W. Then coordinate channels are appended and a
convolutional layer is applied.
"""
def __init__(
self,
cnn_opt,
mlp_opt=None,
coord_type="linear",
coord_freqs=3,
name="broadcast_conv",
):
"""Args:
cnn_opt: dict Kwargs for vae_lib.ConvNet2D for the CNN.
mlp_opt: None or dict If dictionary, then kwargs for snt.nets.MLP. If
None, then the model will not process the latent vector by an mlp.
coord_type: ["linear", "cos", None] type of coordinate channels to
add.
None: add no coordinate channels.
linear: two channels with values linearly spaced from -1. to 1. in
the H and W dimension respectively.
cos: coord_freqs^2 many channels containing cosine basis functions.
coord_freqs: int number of frequencies used to construct the cosine
basis functions (only for coord_type=="cos")
name: Optional name.
"""
super().__init__(name=name)
assert cnn_opt["output_channels"][-1] is None, cnn_opt
if "activation" in cnn_opt:
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
self._cnn_opt = cnn_opt
if mlp_opt and "activation" in mlp_opt:
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
self._mlp_opt = mlp_opt
self._target_out_shape = None
self._coord_type = coord_type
self._coord_freqs = coord_freqs
def set_output_shapes(self, shape):
self._target_out_shape = shape
self._cnn_opt["output_channels"][-1] = self._target_out_shape[-1]
def _build(self, z):
"""Connect model to TensorFlow graph."""
assert self._target_out_shape is not None, "Call set_output_shape"
# reshape components into batch dimension before processing them
sg = shapeguard.ShapeGuard()
flat_z, unflatten = flatten_all_but_last(z)
sg.guard(flat_z, "B, Z")
sg.guard(self._target_out_shape, "H, W, C")
if self._mlp_opt is None:
mlp = tf.identity
else:
mlp = snt.nets.MLP(activate_final=True, **self._mlp_opt)
mlp_output = sg.guard(mlp(flat_z), "B, hidden")
# tile MLP output spatially and append coordinate channels
broadcast_mlp_output = tf.tile(
mlp_output[:, tf.newaxis, tf.newaxis],
multiples=tf.constant(sg["1, H, W, 1"]),
) # B, H, W, Z
dec_cnn_inputs = self.append_coordinate_channels(broadcast_mlp_output)
cnn = snt.nets.ConvNet2D(
paddings=("SAME",), normalize_final=False, **self._cnn_opt)
cnn_outputs = cnn(dec_cnn_inputs)
sg.guard(cnn_outputs, "B, H, W, C")
return FlatParameters(unflatten(cnn_outputs))
def append_coordinate_channels(self, output):
sg = shapeguard.ShapeGuard()
sg.guard(output, "B, H, W, C")
if self._coord_type is None:
return output
if self._coord_type == "linear":
w_coords = tf.linspace(-1.0, 1.0, sg.W)[None, None, :, None]
h_coords = tf.linspace(-1.0, 1.0, sg.H)[None, :, None, None]
w_coords = tf.tile(w_coords, sg["B, H, 1, 1"])
h_coords = tf.tile(h_coords, sg["B, 1, W, 1"])
return tf.concat([output, h_coords, w_coords], axis=-1)
elif self._coord_type == "cos":
freqs = sg.guard(tf.range(0.0, self._coord_freqs), "F")
valx = tf.linspace(0.0, np.pi, sg.W)[None, None, :, None, None]
valy = tf.linspace(0.0, np.pi, sg.H)[None, :, None, None, None]
x_basis = tf.cos(valx * freqs[None, None, None, :, None])
y_basis = tf.cos(valy * freqs[None, None, None, None, :])
xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"])
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[Ellipsis, 1:]
return tf.concat([output, coords], axis=-1)
else:
raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
class LSTM(snt.RNNCore):
"""Wrapper around snt.LSTM that supports multi-layers and runs K components in
parallel.
Expects input data of shape (B, K, H) and outputs data of shape (B, K, Y)
"""
def __init__(self, hidden_sizes, name="lstm"):
super().__init__(name=name)
self._hidden_sizes = hidden_sizes
with self._enter_variable_scope():
self._lstm_layers = [snt.LSTM(hidden_size=h) for h in self._hidden_sizes]
def initial_state(self, batch_size, **kwargs):
return [
lstm.initial_state(batch_size, **kwargs) for lstm in self._lstm_layers
]
def _build(self, data, prev_states):
assert not self._hidden_sizes or self._hidden_sizes[-1] is not None
assert len(prev_states) == len(self._hidden_sizes)
sg = shapeguard.ShapeGuard()
sg.guard(data, "B, K, H")
data = sg.reshape(data, "B*K, H")
out = data
new_states = []
for lstm, pstate in zip(self._lstm_layers, prev_states):
out, nstate = lstm(out, pstate)
new_states.append(nstate)
sg.guard(out, "B*K, Y")
out = sg.reshape(out, "B, K, Y")
return out, new_states
+226
View File
@@ -0,0 +1,226 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Plotting tools for IODINE."""
# pylint: disable=unused-import, missing-docstring, unused-variable
# pylint: disable=invalid-name, unexpected-keyword-arg
import functools
from iodine.modules.utils import get_mask_plot_colors
from matplotlib.colors import hsv_to_rgb
import matplotlib.pyplot as plt
import numpy as np
__all__ = ("get_mask_plot_colors", "example_plot", "iterations_plot",
"inputs_plot")
def clean_ax(ax, color=None, lw=4.0):
ax.set_xticks([])
ax.set_yticks([])
if color is not None:
for spine in ax.spines.values():
spine.set_linewidth(lw)
spine.set_color(color)
def optional_ax(fn):
def _wrapped(*args, **kwargs):
if kwargs.get("ax", None) is None:
figsize = kwargs.pop("figsize", (4, 4))
fig, ax = plt.subplots(figsize=figsize)
kwargs["ax"] = ax
return fn(*args, **kwargs)
return _wrapped
def optional_clean_ax(fn):
def _wrapped(*args, **kwargs):
if kwargs.get("ax", None) is None:
figsize = kwargs.pop("figsize", (4, 4))
fig, ax = plt.subplots(figsize=figsize)
kwargs["ax"] = ax
color = kwargs.pop("color", None)
lw = kwargs.pop("lw", 4.0)
res = fn(*args, **kwargs)
clean_ax(kwargs["ax"], color, lw)
return res
return _wrapped
@optional_clean_ax
def show_img(img, mask=None, ax=None, norm=False):
if norm:
vmin, vmax = np.min(img), np.max(img)
img = (img - vmin) / (vmax - vmin)
if mask is not None:
img = img * mask + np.ones_like(img) * (1.0 - mask)
return ax.imshow(img.clip(0.0, 1.0), interpolation="nearest")
@optional_clean_ax
def show_mask(m, ax):
color_conv = get_mask_plot_colors(m.shape[0])
color_mask = np.dot(np.transpose(m, [1, 2, 0]), color_conv)
return ax.imshow(color_mask.clip(0.0, 1.0), interpolation="nearest")
@optional_clean_ax
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
return ax.matshow(
m[Ellipsis, 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
@optional_clean_ax
def show_coords(m, ax):
vmin, vmax = np.min(m), np.max(m)
m = (m - vmin) / (vmax - vmin)
color_conv = get_mask_plot_colors(m.shape[-1])
color_mask = np.dot(m, color_conv)
return ax.imshow(color_mask, interpolation="nearest")
def example_plot(rinfo,
b=0,
t=-1,
mask_components=False,
size=2,
column_titles=True):
image = rinfo["data"]["image"][b, 0]
recons = rinfo["outputs"]["recons"][b, t, 0]
pred_mask = rinfo["outputs"]["pred_mask"][b, t]
components = rinfo["outputs"]["components"][b, t]
K, H, W, C = components.shape
colors = get_mask_plot_colors(K)
nrows = 1
ncols = 3 + K
fig, axes = plt.subplots(ncols=ncols, figsize=(ncols * size, nrows * size))
show_img(image, ax=axes[0], color="#000000")
show_img(recons, ax=axes[1], color="#000000")
show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000")
for k in range(K):
mask = pred_mask[k] if mask_components else None
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
if column_titles:
labels = ["Image", "Recons.", "Mask"
] + ["Component {}".format(k + 1) for k in range(K)]
for ax, title in zip(axes, labels):
ax.set_title(title)
plt.subplots_adjust(hspace=0.03, wspace=0.035)
return fig
def iterations_plot(rinfo, b=0, mask_components=False, size=2):
image = rinfo["data"]["image"][b]
true_mask = rinfo["data"]["true_mask"][b]
recons = rinfo["outputs"]["recons"][b]
pred_mask = rinfo["outputs"]["pred_mask"][b]
pred_mask_logits = rinfo["outputs"]["pred_mask_logits"][b]
components = rinfo["outputs"]["components"][b]
T, K, H, W, C = components.shape
colors = get_mask_plot_colors(K)
nrows = T + 1
ncols = 2 + K
fig, axes = plt.subplots(
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
for t in range(T):
show_img(recons[t, 0], ax=axes[t, 0])
show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1])
axes[t, 0].set_ylabel("iter {}".format(t))
for k in range(K):
mask = pred_mask[t, k] if mask_components else None
show_img(components[t, k], ax=axes[t, k + 2], color=colors[k], mask=mask)
axes[0, 0].set_title("Reconstruction")
axes[0, 1].set_title("Mask")
show_img(image[0], ax=axes[T, 0])
show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1])
vmin = np.min(pred_mask_logits[T - 1])
vmax = np.max(pred_mask_logits[T - 1])
for k in range(K):
axes[0, k + 2].set_title("Component {}".format(k + 1)) # , color=colors[k])
show_mat(
pred_mask_logits[T - 1, k], ax=axes[T, k + 2], vmin=vmin, vmax=vmax)
axes[T, k + 2].set_xlabel(
"Mask Logits for\nComponent {}".format(k + 1)) # , color=colors[k])
axes[T, 0].set_xlabel("Input Image")
axes[T, 1].set_xlabel("Ground Truth Mask")
plt.subplots_adjust(wspace=0.05, hspace=0.05)
return fig
def inputs_plot(rinfo, b=0, t=0, size=2):
B, T, K, H, W, C = rinfo["outputs"]["components"].shape
colors = get_mask_plot_colors(K)
inputs = rinfo["inputs"]["spatial"]
rows = [
("image", show_img, False),
("components", show_img, False),
("dcomponents", functools.partial(show_img, norm=True), False),
("mask", show_mat, True),
("pred_mask", show_mat, True),
("dmask", functools.partial(show_mat, cmap="coolwarm"), True),
("posterior", show_mat, True),
("log_prob", show_mat, True),
("counterfactual", show_mat, True),
("coordinates", show_coords, False),
]
rows = [(n, f, mcb) for n, f, mcb in rows if n in inputs]
nrows = len(rows)
ncols = K + 1
fig, axes = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=(ncols * size - size * 0.9, nrows * size),
gridspec_kw={"width_ratios": [1] * K + [0.1]},
)
for r, (name, plot_fn, make_cbar) in enumerate(rows):
axes[r, 0].set_ylabel(name)
if make_cbar:
vmin = np.min(inputs[name][b, t])
vmax = np.max(inputs[name][b, t])
if np.abs(vmin - vmax) < 1e-6:
vmin -= 0.1
vmax += 0.1
plot_fn = functools.partial(plot_fn, vmin=vmin, vmax=vmax)
# print("range of {:<16}: [{:0.2f}, {:0.2f}]".format(name, vmin, vmax))
for k in range(K):
if inputs[name].shape[2] == 1:
m = inputs[name][b, t, 0]
color = (0.0, 0.0, 0.0)
else:
m = inputs[name][b, t, k]
color = colors[k]
mappable = plot_fn(m, ax=axes[r, k], color=color)
if make_cbar:
fig.colorbar(mappable, cax=axes[r, K])
else:
axes[r, K].set_visible(False)
for k in range(K):
axes[0, k].set_title("Component {}".format(k + 1)) # , color=colors[k])
plt.subplots_adjust(hspace=0.05, wspace=0.05)
return fig
+163
View File
@@ -0,0 +1,163 @@
# Lint as: python3
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Iterative refinement modules."""
# pylint: disable=g-doc-bad-indent, unused-variable
from iodine.modules import utils
import shapeguard
import sonnet as snt
import tensorflow.compat.v1 as tf
class RefinementCore(snt.RNNCore):
"""Recurrent Refinement Module.
Refinement modules take as inputs:
* previous state (which could be an arbitrary nested structure)
* current inputs which include
* image-space inputs like pixel-based errors, or mask-posteriors
* latent-space inputs like the previous z_dist, or dz
They use these inputs to produce:
* output (usually a new z_dist)
* new_state
"""
def __init__(self,
encoder_net,
recurrent_net,
refinement_head,
name="refinement"):
super().__init__(name=name)
self._encoder_net = encoder_net
self._recurrent_net = recurrent_net
self._refinement_head = refinement_head
self._sg = shapeguard.ShapeGuard()
def initial_state(self, batch_size, **unused_kwargs):
return self._recurrent_net.initial_state(batch_size)
def _build(self, inputs, prev_state):
sg = self._sg
assert "spatial" in inputs, inputs.keys()
assert "flat" in inputs, inputs.keys()
assert "zp" in inputs["flat"], inputs["flat"].keys()
zp = sg.guard(inputs["flat"]["zp"], "B, K, Zp")
x = sg.guard(self.prepare_spatial_inputs(inputs["spatial"]), "B*K, H, W, C")
h1 = sg.guard(self._encoder_net(x).params, "B*K, H1")
h2 = sg.guard(self.prepare_flat_inputs(h1, inputs["flat"]), "B*K, H2")
h2_unflattened = sg.reshape(h2, "B, K, H2")
h3, next_state = self._recurrent_net(h2_unflattened, prev_state)
sg.guard(h3, "B, K, H3")
outputs = sg.guard(self._refinement_head(zp, h3), "B, K, Y")
del self._sg.B
return outputs, next_state
def prepare_spatial_inputs(self, inputs):
values = []
for name, val in sorted(inputs.items(), key=lambda it: it[0]):
if val.shape.as_list()[1] == 1:
self._sg.guard(val, "B, 1, H, W, _C")
val = tf.tile(val, self._sg["1, K, 1, 1, 1"])
else:
self._sg.guard(val, "B, K, H, W, _C")
values.append(val)
concat_inputs = self._sg.guard(tf.concat(values, axis=-1), "B, K, H, W, C")
return self._sg.reshape(concat_inputs, "B*K, H, W, C")
def prepare_flat_inputs(self, hidden, inputs):
values = [self._sg.guard(hidden, "B*K, H1")]
for name, val in sorted(inputs.items(), key=lambda it: it[0]):
self._sg.guard(val, "B, K, _")
val_flat = tf.reshape(val, self._sg["B*K"] + [-1])
values.append(val_flat)
return tf.concat(values, axis=-1)
class ResHead(snt.AbstractModule):
"""Updates Zp using a residual mechanism."""
def __init__(self, name="residual_head"):
super().__init__(name=name)
def _build(self, zp_old, inputs):
sg = shapeguard.ShapeGuard()
sg.guard(zp_old, "B, K, Zp")
sg.guard(inputs, "B, K, H")
update = snt.Linear(sg.Zp)
flat_zp = sg.reshape(zp_old, "B*K, Zp")
flat_inputs = sg.reshape(inputs, "B*K, H")
zp = flat_zp + update(flat_inputs)
return sg.reshape(zp, "B, K, Zp")
class PredictorCorrectorHead(snt.AbstractModule):
"""This refinement head is used for sequential data.
At every step it computes a prediction from the λ of the previous timestep
and an update from the refinement network of the current timestep.
The next step λ' is computed as a gated combination of both:
λ' = g * λ_corr + (1-g) * λ_pred
"""
def __init__(
self,
hidden_sizes=(64,),
pred_gate_bias=0.0,
corrector_gate_bias=0.0,
activation=tf.nn.elu,
name="predcorr_head",
):
super().__init__(name=name)
self._hidden_sizes = hidden_sizes
self._activation = utils.get_act_func(activation)
self._pred_gate_bias = pred_gate_bias
self._corrector_gate_bias = corrector_gate_bias
def _build(self, zp_old, inputs):
sg = shapeguard.ShapeGuard()
sg.guard(zp_old, "B, K, Zp")
sg.guard(inputs, "B, K, H")
update = snt.Linear(sg.Zp)
update_gate = snt.Linear(sg.Zp)
predict = snt.nets.MLP(
output_sizes=list(self._hidden_sizes) + [sg.Zp * 2],
activation=self._activation,
)
flat_zp = sg.reshape(zp_old, "B*K, Zp")
flat_inputs = sg.reshape(inputs, "B*K, H")
g = tf.nn.sigmoid(update_gate(flat_inputs) + self._corrector_gate_bias)
u = update(flat_inputs)
# a slightly more efficient way of computing the gated update
# (1-g) * flat_zp + g * u
zp_corrected = flat_zp + g * (u - flat_zp)
predicted = predict(flat_zp)
pred_up = predicted[:, :sg.Zp]
pred_gate = tf.nn.sigmoid(predicted[:, sg.Zp:] + self._pred_gate_bias)
zp = zp_corrected + pred_gate * (pred_up - zp_corrected)
return sg.reshape(zp, "B, K, Zp")
File diff suppressed because it is too large Load Diff
+9
View File
@@ -0,0 +1,9 @@
tensorflow-gpu==1.14.0
tensorflow-probability==0.7.0
dm-sonnet==1.35
sacred>=0.7,<0.8
shapeguard
seaborn
pymongo
jupyterlab
git+git://github.com/deepmind/multi_object_datasets.git
Executable
+36
View File
@@ -0,0 +1,36 @@
#!/bin/sh
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
echo "downloading checkpoints from GCP"
iodine/download_checkpoints.sh
python3 -m venv iodine_venv
source iodine_venv/bin/activate
pip3 install --upgrade setuptools wheel
pip3 install -r iodine/requirements.txt
# Get some fake data and put it where the real multi_objects_dataset files live.
mkdir -p iodine/multi_object_datasets
cp iodine/test_data/tetrominoes_mini.tfrecords iodine/multi_object_datasets/tetrominoes_train.tfrecords
# Run training with a cut down size.
python3 -m iodine.main \
-f with tetrominoes \
data.shuffle_buffer=2 \
data.batch_size=2 \
n_z=4 \
num_components=3 \
stop_after_steps=11
Binary file not shown.